How to extend
Extend +=, -= and ⊻= for irreversible one-out functions
It directly works
julia> using SpecialFunctions, NiLang
julia> x, y = 2.1, 1.0
(2.1, 1.0)
julia> @instr y += besselj0(x)
2.1
julia> x, y
(2.1, 1.7492472503018073)
julia> @instr ~(y += besselj0(x))
2.1
julia> x, y
(2.1, 1.0)Here the statement
@instr y += besselj0(x)is mapped to
@instr y += besselj0(x)However, doing this does not give you correct gradients. For y += scalar_out_function(x), one can bind the backward rules like
julia> using ChainRules, NiLang.AD
julia> besselj0_back(x) = ChainRules.rrule(besselj0, x)[2](1.0)[2]
besselj0_back (generic function with 1 method)
julia> primitive_grad(::typeof(besselj0), x::Real) = besselj0_back(x)
primitive_grad (generic function with 1 method)
julia> xg, yg = GVar(x), GVar(y, 1.0)
(GVar(2.1, 0.0), GVar(1.0, 1.0))
julia> @instr yg -= besselj0(xg)
GVar(2.1, -0.5682921357570385)
julia> xg, yg
(GVar(2.1, -0.5682921357570385), GVar(0.8333930196680097, 1.0))
julia> @instr yg += besselj0(xg)
GVar(2.1, 0.0)
julia> xg, yg
(GVar(2.1, 0.0), GVar(1.0, 1.0))
julia> NiLang.AD.check_grad(PlusEq(besselj0), (1.0, 2.1); iloss=1)
true
julia> using BenchmarkTools
julia> @benchmark PlusEq(besselj0)($yg, $xg)
BenchmarkTools.Trial:
memory estimate: 0 bytes
allocs estimate: 0
--------------
minimum time: 451.523 ns (0.00% GC)
median time: 459.431 ns (0.00% GC)
mean time: 477.419 ns (0.00% GC)
maximum time: 857.036 ns (0.00% GC)
--------------
samples: 10000
evals/sample: 197Good!
Reversible multi-in multi-out functions
It is easy to do, define two normal Julia functions reversible to each other, using the macro @dual to tell the compiler they are reversible to each other.
For example, a pair of dual functions ROT (2D rotation) and IROT (inverse rotation) that already defined in NiLang.
"""
ROT(a!, b!, θ) -> a!', b!', θ
"""
@inline function ROT(i::Real, j::Real, θ::Real)
a, b = rot(i, j, θ)
a, b, θ
end
"""
IROT(a!, b!, θ) -> ROT(a!, b!, -θ)
"""
@inline function IROT(i::Real, j::Real, θ::Real)
i, j, _ = ROT(i, j, -θ)
i, j, θ
end
@dual ROT IROTOne can easily check the reversibility by typing
julia> check_inv(ROT, (1.0, 2.0, 3.0))
trueFor self-reversible functions, one can declare the reversibility for it like this
"""
SWAP(a!, b!) -> b!, a!
"""
@inline function SWAP(a!::Real, b!::Real)
b!, a!
end
@selfdual SWAPTo bind gradients for this multi-in, multi-out function. The general approach is Binding the backward rule on its inverse!
@i @inline function IROT(a!::GVar, b!::GVar, θ::GVar)
IROT(a!.x, b!.x, θ.x)
NEG(θ.x)
θ.x -= π/2
ROT(a!.g, b!.g, θ.x)
θ.g += a!.x * a!.g
θ.g += b!.x * b!.g
θ.x += π/2
NEG(θ.x)
ROT(a!.g, b!.g, π/2)
end
@i @inline function IROT(a!::GVar, b!::GVar, θ::Real)
IROT(a!.x, b!.x, θ)
NEG(θ)
θ -= π/2
ROT(a!.g, b!.g, θ)
θ += π/2
NEG(θ)
ROT(a!.g, b!.g, π/2)
end
@nograd IROT(a!::Real, b!::Real, θ::GVar)When this inverse function is called, the backward rules are automatically applied.
Good! This method can also be extended to linear algebra functions, however, the memory allocation overhead is high because one need to wrap each element with GVar.