How to port NiLang to Zygote
using NiLang, NiLang.AD, Zygote
Let's start from the Julia native implementation of norm2
function.
function norm2(x::AbstractArray{T}) where T
out = zero(T)
for i=1:length(x)
@inbounds out += x[i]^2
end
return out
end
norm2 (generic function with 1 method)
Zygote is able to generate correct gradients, but much slower than the original program.
using BenchmarkTools
x = randn(1000);
original_grad = norm2'(x)
@benchmark norm2'($x) seconds=1
BenchmarkTools.Trial:
memory estimate: 16.04 MiB
allocs estimate: 21093
--------------
minimum time: 3.392 ms (0.00% GC)
median time: 3.788 ms (0.00% GC)
mean time: 4.873 ms (23.05% GC)
maximum time: 9.339 ms (41.15% GC)
--------------
samples: 205
evals/sample: 1
The orignal program is
@benchmark norm2($x) seconds=1
BenchmarkTools.Trial:
memory estimate: 0 bytes
allocs estimate: 0
--------------
minimum time: 1.152 μs (0.00% GC)
median time: 1.153 μs (0.00% GC)
mean time: 1.163 μs (0.00% GC)
maximum time: 2.759 μs (0.00% GC)
--------------
samples: 10000
evals/sample: 10
Then we have the reversible implementation
@i function r_norm2(out::T, x::AbstractArray{T}) where T
for i=1:length(x)
@inbounds out += x[i]^2
end
end
The gradient generated by NiLang is much faster, which is comparable to the forward program
@benchmark (~r_norm2)(GVar($(norm2(x)), 1.0), $(GVar(x))) seconds=1
BenchmarkTools.Trial:
memory estimate: 32 bytes
allocs estimate: 1
--------------
minimum time: 1.182 μs (0.00% GC)
median time: 1.184 μs (0.00% GC)
mean time: 1.231 μs (0.00% GC)
maximum time: 4.909 μs (0.00% GC)
--------------
samples: 10000
evals/sample: 10
to enjoy the speed of NiLang
in Zygote
, just bind the adjoint rule
Zygote.@adjoint function norm2(x::AbstractArray{T}) where T
out = norm2(x)
out, δy -> (grad((~r_norm2)(GVar(out, δy), GVar(x))[2]),)
end
@assert norm2'(x) ≈ original_grad
See, much faster
@benchmark norm2'(x) seconds=1
BenchmarkTools.Trial:
memory estimate: 23.69 KiB
allocs estimate: 2
--------------
minimum time: 3.193 μs (0.00% GC)
median time: 4.045 μs (0.00% GC)
mean time: 5.292 μs (21.68% GC)
maximum time: 380.125 μs (95.96% GC)
--------------
samples: 10000
evals/sample: 8
This page was generated using Literate.jl.