How to port NiLang to ChainRules

In How to port NiLang to Zygote we showed the way to insert Nilang-based gradient as Zygote's pullback/adjoint. Given that ChainRules is now the core of many AD packages including Zygote, extending ChainRules.rrule with Nilang does the same job, except that it affects all ChainRules-based AD packages and not just Zygote.

We'll use the same example as How to port NiLang to Zygote, so you might need to restart your Julia to get a fresh environment.

using NiLang, NiLang.AD, Zygote, ChainRules

Let's start from the Julia native implementation of norm2 function.

function norm2(x::AbstractArray{T}) where T
    out = zero(T)
    for i=1:length(x)
        @inbounds out += x[i]^2
    end
    return out
end
norm2 (generic function with 1 method)

Zygote is able to generate correct dual function, i.e., gradients, but much slower than the primal function norm2

using BenchmarkTools
x = randn(1000);
original_grad = norm2'(x)
@benchmark norm2'($x) seconds=1
BenchmarkTools.Trial: 345 samples with 1 evaluation.
 Range (min … max):  1.799 ms … 11.740 ms  ┊ GC (min … max):  0.00% … 44.66%
 Time  (median):     2.193 ms              ┊ GC (median):     0.00%
 Time  (mean ± σ):   2.895 ms ±  1.997 ms  ┊ GC (mean ± σ):  20.30% ± 20.55%

  ▁▅█▆                                                        
  ████▁▄▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▁▅▆▇▇▇▆▅▄▄▁▁▁▁▁▁▁▁▁▁▆▁▁▁▄▅▄▆ ▆
  1.8 ms       Histogram: log(frequency) by time     10.5 ms <

 Memory estimate: 8.36 MiB, allocs estimate: 19059.

The primal function is

@benchmark norm2($x) seconds=1
BenchmarkTools.Trial: 10000 samples with 10 evaluations.
 Range (min … max):  1.090 μs …   4.530 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     1.560 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   1.531 μs ± 203.644 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

               ▃   █                                           
  ▂▁▁▁▁▁▄▁▁▁▁▂▁█▂▁▁█▂▂▂▁▁▂▂▂▁▁▁▁▂▂▂▂▂▁▂▂▂▂▂▂▁▂▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂ ▂
  1.09 μs         Histogram: frequency by time        2.67 μs <

 Memory estimate: 0 bytes, allocs estimate: 0.

Then we have the reversible implementation

@i function r_norm2(out::T, x::AbstractArray{T}) where T
    for i=1:length(x)
        @inbounds out += x[i]^2
    end
end

The gradient generated by NiLang is much faster, which is comparable to the forward program

@benchmark (~r_norm2)(GVar($(norm2(x)), 1.0), $(GVar(x))) seconds=1
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  33.603 μs … 91.407 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     46.403 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   46.368 μs ±  5.655 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

  ▁▄▄       ▂▅▇▆▁      ▃▆██▃▂ ▂▅▇▇▃▂▁          ▁▁             ▃
  ████▃▁▁▁▃▄█████▄▅▁▁▁▅███████████████▇▆██████████████▇▆▆█▇█▇ █
  33.6 μs      Histogram: log(frequency) by time      64.8 μs <

 Memory estimate: 0 bytes, allocs estimate: 0.

By defining our custom rrule using Nilang's gradient implementation, Zygote automaticallly gets boosted because it internally uses the available ChainRules ruleset. Here we need to create a new symbol here because otherwise Zygote will still use the previously generated slow implementation.

norm2_faster(x) = norm2(x)
function ChainRules.rrule(::typeof(norm2_faster), x::AbstractArray{T}) where T
    out = norm2_faster(x)
    function pullback(ȳ)
        ChainRules.NoTangent(), grad((~r_norm2)(GVar(out, ȳ), GVar(x))[2])
    end
    out, pullback
end
@assert norm2_faster'(x) ≈ original_grad

See, much faster

@benchmark norm2_faster'(x) seconds=1
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  36.802 μs …  4.902 ms  ┊ GC (min … max): 0.00% … 97.52%
 Time  (median):     51.204 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   53.093 μs ± 67.636 μs  ┊ GC (mean ± σ):  1.76% ±  1.38%

                      ▁█▄▁   ▄▄                                
  ▂▃▄▃▃▂▂▁▂▃▄██▅▃▂▂▂▂▄████▅▄███▇▄▃▂▂▂▂▂▃▃▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂ ▃
  36.8 μs         Histogram: frequency by time        73.6 μs <

 Memory estimate: 23.69 KiB, allocs estimate: 2.

This page was generated using Literate.jl.