How to port NiLang to Zygote

using NiLang, NiLang.AD, Zygote

Let's start from the Julia native implementation of norm2 function.

function norm2(x::AbstractArray{T}) where T
    out = zero(T)
    for i=1:length(x)
        @inbounds out += x[i]^2
    end
    return out
end
norm2 (generic function with 1 method)

Zygote is able to generate correct gradients, but much slower than the original program.

using BenchmarkTools
x = randn(1000);
original_grad = norm2'(x)
@benchmark norm2'($x) seconds=1
BenchmarkTools.Trial: 
  memory estimate:  16.04 MiB
  allocs estimate:  21093
  --------------
  minimum time:     3.392 ms (0.00% GC)
  median time:      3.788 ms (0.00% GC)
  mean time:        4.873 ms (23.05% GC)
  maximum time:     9.339 ms (41.15% GC)
  --------------
  samples:          205
  evals/sample:     1

The orignal program is

@benchmark norm2($x) seconds=1
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     1.152 μs (0.00% GC)
  median time:      1.153 μs (0.00% GC)
  mean time:        1.163 μs (0.00% GC)
  maximum time:     2.759 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     10

Then we have the reversible implementation

@i function r_norm2(out::T, x::AbstractArray{T}) where T
    for i=1:length(x)
        @inbounds out += x[i]^2
    end
end

The gradient generated by NiLang is much faster, which is comparable to the forward program

@benchmark (~r_norm2)(GVar($(norm2(x)), 1.0), $(GVar(x))) seconds=1
BenchmarkTools.Trial: 
  memory estimate:  32 bytes
  allocs estimate:  1
  --------------
  minimum time:     1.182 μs (0.00% GC)
  median time:      1.184 μs (0.00% GC)
  mean time:        1.231 μs (0.00% GC)
  maximum time:     4.909 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     10

to enjoy the speed of NiLang in Zygote, just bind the adjoint rule

Zygote.@adjoint function norm2(x::AbstractArray{T}) where T
    out = norm2(x)
    out, δy -> (grad((~r_norm2)(GVar(out, δy), GVar(x))[2]),)
end
@assert norm2'(x) ≈ original_grad

See, much faster

@benchmark norm2'(x) seconds=1
BenchmarkTools.Trial: 
  memory estimate:  23.69 KiB
  allocs estimate:  2
  --------------
  minimum time:     3.193 μs (0.00% GC)
  median time:      4.045 μs (0.00% GC)
  mean time:        5.292 μs (21.68% GC)
  maximum time:     380.125 μs (95.96% GC)
  --------------
  samples:          10000
  evals/sample:     8

This page was generated using Literate.jl.