How to port NiLang to ChainRules
In How to port NiLang to Zygote we showed the way to insert Nilang-based gradient as Zygote's pullback/adjoint. Given that ChainRules is now the core of many AD packages including Zygote, extending ChainRules.rrule
with Nilang does the same job, except that it affects all ChainRules-based AD packages and not just Zygote.
We'll use the same example as How to port NiLang to Zygote, so you might need to restart your Julia to get a fresh environment.
using NiLang, NiLang.AD, Zygote, ChainRules
Let's start from the Julia native implementation of norm2
function.
function norm2(x::AbstractArray{T}) where T
out = zero(T)
for i=1:length(x)
@inbounds out += x[i]^2
end
return out
end
norm2 (generic function with 1 method)
Zygote is able to generate correct dual function, i.e., gradients, but much slower than the primal function norm2
using BenchmarkTools
x = randn(1000);
original_grad = norm2'(x)
@benchmark norm2'($x) seconds=1
BenchmarkTools.Trial: 345 samples with 1 evaluation.
Range (min … max): 1.799 ms … 11.740 ms ┊ GC (min … max): 0.00% … 44.66%
Time (median): 2.193 ms ┊ GC (median): 0.00%
Time (mean ± σ): 2.895 ms ± 1.997 ms ┊ GC (mean ± σ): 20.30% ± 20.55%
▁▅█▆
████▁▄▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▁▅▆▇▇▇▆▅▄▄▁▁▁▁▁▁▁▁▁▁▆▁▁▁▄▅▄▆ ▆
1.8 ms Histogram: log(frequency) by time 10.5 ms <
Memory estimate: 8.36 MiB, allocs estimate: 19059.
The primal function is
@benchmark norm2($x) seconds=1
BenchmarkTools.Trial: 10000 samples with 10 evaluations.
Range (min … max): 1.090 μs … 4.530 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 1.560 μs ┊ GC (median): 0.00%
Time (mean ± σ): 1.531 μs ± 203.644 ns ┊ GC (mean ± σ): 0.00% ± 0.00%
▃ █
▂▁▁▁▁▁▄▁▁▁▁▂▁█▂▁▁█▂▂▂▁▁▂▂▂▁▁▁▁▂▂▂▂▂▁▂▂▂▂▂▂▁▂▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂ ▂
1.09 μs Histogram: frequency by time 2.67 μs <
Memory estimate: 0 bytes, allocs estimate: 0.
Then we have the reversible implementation
@i function r_norm2(out::T, x::AbstractArray{T}) where T
for i=1:length(x)
@inbounds out += x[i]^2
end
end
The gradient generated by NiLang is much faster, which is comparable to the forward program
@benchmark (~r_norm2)(GVar($(norm2(x)), 1.0), $(GVar(x))) seconds=1
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 33.603 μs … 91.407 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 46.403 μs ┊ GC (median): 0.00%
Time (mean ± σ): 46.368 μs ± 5.655 μs ┊ GC (mean ± σ): 0.00% ± 0.00%
▁▄▄ ▂▅▇▆▁ ▃▆██▃▂ ▂▅▇▇▃▂▁ ▁▁ ▃
████▃▁▁▁▃▄█████▄▅▁▁▁▅███████████████▇▆██████████████▇▆▆█▇█▇ █
33.6 μs Histogram: log(frequency) by time 64.8 μs <
Memory estimate: 0 bytes, allocs estimate: 0.
By defining our custom rrule
using Nilang's gradient implementation, Zygote
automaticallly gets boosted because it internally uses the available ChainRules ruleset. Here we need to create a new symbol here because otherwise Zygote will still use the previously generated slow implementation.
norm2_faster(x) = norm2(x)
function ChainRules.rrule(::typeof(norm2_faster), x::AbstractArray{T}) where T
out = norm2_faster(x)
function pullback(ȳ)
ChainRules.NoTangent(), grad((~r_norm2)(GVar(out, ȳ), GVar(x))[2])
end
out, pullback
end
@assert norm2_faster'(x) ≈ original_grad
See, much faster
@benchmark norm2_faster'(x) seconds=1
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 36.802 μs … 4.902 ms ┊ GC (min … max): 0.00% … 97.52%
Time (median): 51.204 μs ┊ GC (median): 0.00%
Time (mean ± σ): 53.093 μs ± 67.636 μs ┊ GC (mean ± σ): 1.76% ± 1.38%
▁█▄▁ ▄▄
▂▃▄▃▃▂▂▁▂▃▄██▅▃▂▂▂▂▄████▅▄███▇▄▃▂▂▂▂▂▃▃▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂ ▃
36.8 μs Histogram: frequency by time 73.6 μs <
Memory estimate: 23.69 KiB, allocs estimate: 2.
This page was generated using Literate.jl.