How to extend
Extend +=
, -=
and ⊻=
for irreversible one-out functions
It directly works
julia> using SpecialFunctions, NiLang
julia> x, y = 2.1, 1.0
(2.1, 1.0)
julia> @instr y += besselj0(x)
2.1
julia> x, y
(2.1, 1.7492472503018073)
julia> @instr ~(y += besselj0(x))
2.1
julia> x, y
(2.1, 1.0)
Here the statement
@instr y += besselj0(x)
is mapped to
@instr y += besselj0(x)
However, doing this does not give you correct gradients. For y += scalar_out_function(x)
, one can bind the backward rules like
julia> using ChainRules, NiLang.AD
julia> besselj0_back(x) = ChainRules.rrule(besselj0, x)[2](1.0)[2]
besselj0_back (generic function with 1 method)
julia> primitive_grad(::typeof(besselj0), x::Real) = besselj0_back(x)
primitive_grad (generic function with 1 method)
julia> xg, yg = GVar(x), GVar(y, 1.0)
(GVar(2.1, 0.0), GVar(1.0, 1.0))
julia> @instr yg -= besselj0(xg)
GVar(2.1, -0.5682921357570385)
julia> xg, yg
(GVar(2.1, -0.5682921357570385), GVar(0.8333930196680097, 1.0))
julia> @instr yg += besselj0(xg)
GVar(2.1, 0.0)
julia> xg, yg
(GVar(2.1, 0.0), GVar(1.0, 1.0))
julia> NiLang.AD.check_grad(PlusEq(besselj0), (1.0, 2.1); iloss=1)
true
julia> using BenchmarkTools
julia> @benchmark PlusEq(besselj0)($yg, $xg)
BenchmarkTools.Trial:
memory estimate: 0 bytes
allocs estimate: 0
--------------
minimum time: 451.523 ns (0.00% GC)
median time: 459.431 ns (0.00% GC)
mean time: 477.419 ns (0.00% GC)
maximum time: 857.036 ns (0.00% GC)
--------------
samples: 10000
evals/sample: 197
Good!
Reversible multi-in multi-out functions
It is easy to do, define two normal Julia functions reversible to each other, using the macro @dual
to tell the compiler they are reversible to each other.
For example, a pair of dual functions ROT
(2D rotation) and IROT
(inverse rotation) that already defined in NiLang.
"""
ROT(a!, b!, θ) -> a!', b!', θ
"""
@inline function ROT(i::Real, j::Real, θ::Real)
a, b = rot(i, j, θ)
a, b, θ
end
"""
IROT(a!, b!, θ) -> ROT(a!, b!, -θ)
"""
@inline function IROT(i::Real, j::Real, θ::Real)
i, j, _ = ROT(i, j, -θ)
i, j, θ
end
@dual ROT IROT
One can easily check the reversibility by typing
julia> check_inv(ROT, (1.0, 2.0, 3.0))
true
For self-reversible functions, one can declare the reversibility for it like this
"""
SWAP(a!, b!) -> b!, a!
"""
@inline function SWAP(a!::Real, b!::Real)
b!, a!
end
@selfdual SWAP
To bind gradients for this multi-in, multi-out function. The general approach is Binding the backward rule on its inverse!
@i @inline function IROT(a!::GVar, b!::GVar, θ::GVar)
IROT(a!.x, b!.x, θ.x)
NEG(θ.x)
θ.x -= π/2
ROT(a!.g, b!.g, θ.x)
θ.g += a!.x * a!.g
θ.g += b!.x * b!.g
θ.x += π/2
NEG(θ.x)
ROT(a!.g, b!.g, π/2)
end
@i @inline function IROT(a!::GVar, b!::GVar, θ::Real)
IROT(a!.x, b!.x, θ)
NEG(θ)
θ -= π/2
ROT(a!.g, b!.g, θ)
θ += π/2
NEG(θ)
ROT(a!.g, b!.g, π/2)
end
@nograd IROT(a!::Real, b!::Real, θ::GVar)
When this inverse function is called, the backward rules are automatically applied.
Good! This method can also be extended to linear algebra functions, however, the memory allocation overhead is high because one need to wrap each element with GVar
.