How to port NiLang to Zygote

In this demo we'll show how to insert NiLang's gradient implementation to boost Zygote's gradient. A similar demo for ChainRules can be found in How to port NiLang to ChainRules.

using NiLang, NiLang.AD, Zygote

Let's start from the Julia native implementation of norm2 function.

function norm2(x::AbstractArray{T}) where T
    out = zero(T)
    for i=1:length(x)
        @inbounds out += x[i]^2
    end
    return out
end
norm2 (generic function with 1 method)

Zygote is able to generate correct dual function, i.e., gradients, but much slower than the primal function norm2

using BenchmarkTools
x = randn(1000);
original_grad = norm2'(x)
@benchmark norm2'($x) seconds=1
BenchmarkTools.Trial: 346 samples with 1 evaluation.
 Range (min … max):  1.813 ms … 11.029 ms  ┊ GC (min … max):  0.00% … 48.18%
 Time  (median):     2.200 ms              ┊ GC (median):     0.00%
 Time  (mean ± σ):   2.886 ms ±  1.782 ms  ┊ GC (mean ± σ):  20.43% ± 21.55%

  ▂▃▇█▆▁                                                      
  ██████▁▄▁▁▁▄▁▁▁▁▁▄▁▁▁▁▁▁▄▁▆▁▁▄▁▁▁▁▁▁▁▁▁▄▅▄▅▄▇▅▅▇▇▇▆▅▆▄▄▄▁▅ ▆
  1.81 ms      Histogram: log(frequency) by time      8.3 ms <

 Memory estimate: 8.36 MiB, allocs estimate: 19059.

The primal function is

@benchmark norm2($x) seconds=1
BenchmarkTools.Trial: 10000 samples with 10 evaluations.
 Range (min … max):  1.080 μs …   4.700 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     1.450 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   1.426 μs ± 229.754 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

                █                                              
  ▄▁▁▁▁▁▇▆▁▁▁▂▁▁█▂▁▁█▂▂▂▂▂▂▁▁▁▁▁▁▂▂▂▂▁▂▂▁▂▂▁▂▂▁▁▂▂▂▂▂▂▂▂▂▂▂▂▂ ▂
  1.08 μs         Histogram: frequency by time        2.62 μs <

 Memory estimate: 0 bytes, allocs estimate: 0.

Then we have the reversible implementation

@i function r_norm2(out::T, x::AbstractArray{T}) where T
    for i=1:length(x)
        @inbounds out += x[i]^2
    end
end

The gradient generated by NiLang is much faster, which is comparable to the forward program

@benchmark (~r_norm2)(GVar($(norm2(x)), 1.0), $(GVar(x))) seconds=1
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  34.202 μs … 85.807 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     46.203 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   46.356 μs ±  5.330 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

  ▄▂         ▆▅         ▅█▅     ▇▆▁ ▁                         ▂
  ██▁▁▁▁▁▃▅▃▄██▇▅▄▁▁▃▄▅▄████▇▇▅▇██████▆▇▆▇▇██▇▇▇▇█▇███▇▇▆▇▇▇▆ █
  34.2 μs      Histogram: log(frequency) by time        64 μs <

 Memory estimate: 0 bytes, allocs estimate: 0.

to enjoy the speed of NiLang in Zygote, just bind the adjoint rule

Zygote.@adjoint function norm2(x::AbstractArray{T}) where T
    out = norm2(x)
    out, δy -> (grad((~r_norm2)(GVar(out, δy), GVar(x))[2]),)
end
@assert norm2'(x) ≈ original_grad

See, much faster

@benchmark norm2'(x) seconds=1
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  37.403 μs …  4.801 ms  ┊ GC (min … max): 0.00% … 96.85%
 Time  (median):     51.204 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   52.626 μs ± 67.351 μs  ┊ GC (mean ± σ):  1.76% ±  1.37%

  ▂▅▅▃▁     ▃▆▇▅▃     ▁▆██▇▆▃▃▆█▇▆▄▂▁▁  ▁▁▂▂▃▃▂▂▁▂▁▂▁▁▁▁      ▃
  █████▆▅▃▁▆███████▇▇▆███████████████████████████████████▇██▇ █
  37.4 μs      Histogram: log(frequency) by time      71.1 μs <

 Memory estimate: 23.69 KiB, allocs estimate: 2.

This page was generated using Literate.jl.