How to port NiLang to Zygote
In this demo we'll show how to insert NiLang's gradient implementation to boost Zygote's gradient. A similar demo for ChainRules can be found in How to port NiLang to ChainRules.
using NiLang, NiLang.AD, Zygote
Let's start from the Julia native implementation of norm2
function.
function norm2(x::AbstractArray{T}) where T
out = zero(T)
for i=1:length(x)
@inbounds out += x[i]^2
end
return out
end
norm2 (generic function with 1 method)
Zygote is able to generate correct dual function, i.e., gradients, but much slower than the primal function norm2
using BenchmarkTools
x = randn(1000);
original_grad = norm2'(x)
@benchmark norm2'($x) seconds=1
BenchmarkTools.Trial: 346 samples with 1 evaluation.
Range (min … max): 1.813 ms … 11.029 ms ┊ GC (min … max): 0.00% … 48.18%
Time (median): 2.200 ms ┊ GC (median): 0.00%
Time (mean ± σ): 2.886 ms ± 1.782 ms ┊ GC (mean ± σ): 20.43% ± 21.55%
▂▃▇█▆▁
██████▁▄▁▁▁▄▁▁▁▁▁▄▁▁▁▁▁▁▄▁▆▁▁▄▁▁▁▁▁▁▁▁▁▄▅▄▅▄▇▅▅▇▇▇▆▅▆▄▄▄▁▅ ▆
1.81 ms Histogram: log(frequency) by time 8.3 ms <
Memory estimate: 8.36 MiB, allocs estimate: 19059.
The primal function is
@benchmark norm2($x) seconds=1
BenchmarkTools.Trial: 10000 samples with 10 evaluations.
Range (min … max): 1.080 μs … 4.700 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 1.450 μs ┊ GC (median): 0.00%
Time (mean ± σ): 1.426 μs ± 229.754 ns ┊ GC (mean ± σ): 0.00% ± 0.00%
█
▄▁▁▁▁▁▇▆▁▁▁▂▁▁█▂▁▁█▂▂▂▂▂▂▁▁▁▁▁▁▂▂▂▂▁▂▂▁▂▂▁▂▂▁▁▂▂▂▂▂▂▂▂▂▂▂▂▂ ▂
1.08 μs Histogram: frequency by time 2.62 μs <
Memory estimate: 0 bytes, allocs estimate: 0.
Then we have the reversible implementation
@i function r_norm2(out::T, x::AbstractArray{T}) where T
for i=1:length(x)
@inbounds out += x[i]^2
end
end
The gradient generated by NiLang is much faster, which is comparable to the forward program
@benchmark (~r_norm2)(GVar($(norm2(x)), 1.0), $(GVar(x))) seconds=1
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 34.202 μs … 85.807 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 46.203 μs ┊ GC (median): 0.00%
Time (mean ± σ): 46.356 μs ± 5.330 μs ┊ GC (mean ± σ): 0.00% ± 0.00%
▄▂ ▆▅ ▅█▅ ▇▆▁ ▁ ▂
██▁▁▁▁▁▃▅▃▄██▇▅▄▁▁▃▄▅▄████▇▇▅▇██████▆▇▆▇▇██▇▇▇▇█▇███▇▇▆▇▇▇▆ █
34.2 μs Histogram: log(frequency) by time 64 μs <
Memory estimate: 0 bytes, allocs estimate: 0.
to enjoy the speed of NiLang
in Zygote
, just bind the adjoint rule
Zygote.@adjoint function norm2(x::AbstractArray{T}) where T
out = norm2(x)
out, δy -> (grad((~r_norm2)(GVar(out, δy), GVar(x))[2]),)
end
@assert norm2'(x) ≈ original_grad
See, much faster
@benchmark norm2'(x) seconds=1
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 37.403 μs … 4.801 ms ┊ GC (min … max): 0.00% … 96.85%
Time (median): 51.204 μs ┊ GC (median): 0.00%
Time (mean ± σ): 52.626 μs ± 67.351 μs ┊ GC (mean ± σ): 1.76% ± 1.37%
▂▅▅▃▁ ▃▆▇▅▃ ▁▆██▇▆▃▃▆█▇▆▄▂▁▁ ▁▁▂▂▃▃▂▂▁▂▁▂▁▁▁▁ ▃
█████▆▅▃▁▆███████▇▇▆███████████████████████████████████▇██▇ █
37.4 μs Histogram: log(frequency) by time 71.1 μs <
Memory estimate: 23.69 KiB, allocs estimate: 2.
This page was generated using Literate.jl.