Unitary matrix operations without allocation
A unitary matrix features uniform eigenvalues and reversibility. It is widely used as an approach to ease the gradient exploding and vanishing problem and the memory wall problem. One of the simplest ways to parametrize a unitary matrix is representing a unitary matrix as a product of two-level unitary operations. A real unitary matrix of size $N$ can be parametrized compactly by $N(N-1)/2$ rotation operations
where $\theta$ is the rotation angle, a!
and b!
are target registers.
using NiLang, NiLang.AD
@i function umm!(x!, θ)
@safe @assert length(θ) ==
length(x!)*(length(x!)-1)/2
k ← 0
for j=1:length(x!)
for i=length(x!)-1:-1:j
k += identity(1)
ROT(x![i], x![i+1], θ[k])
end
end
k → length(θ)
end
Here, the ancilla k
is deallocated manually by specifying its value, because we know the loop size is $N(N-1)/2$. We define the test functions in order to check gradients.
@i function isum(out!, x::AbstractArray)
for i=1:length(x)
out! += identity(x[i])
end
end
@i function test!(out!, x!::Vector, θ::Vector)
umm!(x!, θ)
isum(out!, x!)
end
Let's print the program output
out, x, θ = 0.0, randn(4), randn(6);
@instr Grad(test!)(Val(1), out, x, θ)
x
4-element Array{NiLang.AD.GVar{Float64,Float64},1}:
GVar(1.1985625454998947, -0.9367857737423818)
GVar(-0.780748216029033, -1.5539960175638785)
GVar(-0.8355437769889602, 0.41668939872760785)
GVar(0.0586892744993282, 0.7306837458829359)
We can erease the gradient field by uncomputing the gradient function. If you want, you can differentiate it twice to obtain Hessians. However, we suggest using ForwardDifferentiation over our NiLang program, this is more efficient.
@instr (~Grad(test!))(Val(1), out, x, θ)
x
4-element Array{Float64,1}:
1.198562545499895
-0.780748216029033
-0.8355437769889604
0.058689274499328314
In the above testing code, Grad(test)
attaches a gradient field to each element of x
. ~Grad(test)
is the inverse program that erase the gradient fields. Notably, this reversible implementation costs zero memory allocation, although it changes the target variables inplace.
This page was generated using Literate.jl.