Sparse matrices
Source to source automatic differentiation is useful in differentiating sparse matrices. It is a well-known problem that sparse matrix operations can not benefit directly from generic backward rules for dense matrices because general rules do not keep the sparse structure. In the following, we will show that reversible AD can differentiate the Frobenius dot product between two sparse matrices with the state-of-the-art performance. Here, the Frobenius dot product is defined as \texttt{trace(A'B)}. Its native Julia (irreversible) implementation is SparseArrays.dot
.
The following is a reversible counterpart
using NiLang, NiLang.AD
using SparseArrays
@i function idot(r::T, A::SparseMatrixCSC{T},B::SparseMatrixCSC{T}) where {T}
@routine begin
m, n ← size(A)
branch_keeper ← zeros(Bool, 2*m)
end
@safe size(B) == (m,n) || throw(DimensionMismatch("matrices must have the same dimensions"))
@invcheckoff @inbounds for j = 1:n
@routine begin
ia1 ← A.colptr[j]
ib1 ← B.colptr[j]
ia2 ← A.colptr[j+1]
ib2 ← B.colptr[j+1]
ia ← ia1
ib ← ib1
end
@inbounds for i=1:ia2-ia1+ib2-ib1-1
ra ← A.rowval[ia]
rb ← B.rowval[ib]
if (ra == rb, ~)
r += A.nzval[ia]' * B.nzval[ib]
end
# b move -> true, a move -> false
branch_keeper[i] ⊻= @const ia == ia2-1 || ra > rb
ra → A.rowval[ia]
rb → B.rowval[ib]
if (branch_keeper[i], ~)
INC(ib)
else
INC(ia)
end
end
~@inbounds for i=1:ia2-ia1+ib2-ib1-1
# b move -> true, a move -> false
branch_keeper[i] ⊻= @const ia == ia2-1 || A.rowval[ia] > B.rowval[ib]
if (branch_keeper[i], ~)
INC(ib)
else
INC(ia)
end
end
~@routine
end
~@routine
end
Here, the key point is using a \texttt{branch_keeper} vector to cache branch decisions.
The time used for a native implementation is
using BenchmarkTools
a = sprand(1000, 1000, 0.01);
b = sprand(1000, 1000, 0.01);
@benchmark SparseArrays.dot($a, $b)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 67.705 μs … 161.613 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 92.207 μs ┊ GC (median): 0.00%
Time (mean ± σ): 90.284 μs ± 10.666 μs ┊ GC (mean ± σ): 0.00% ± 0.00%
▆▅ ▁▇▇▃ ▃▇█▆▂▁ ▂▂▅▇▆▃▁ ▁▁▂▁▁▁▁▁ ▁ ▃
███▇▅▁▆▇▃▃▁▁▁▁████▇▆▆██▇█▇▇▇███████▇█████████████████████▇██ █
67.7 μs Histogram: log(frequency) by time 115 μs <
Memory estimate: 0 bytes, allocs estimate: 0.
To compute the gradients, we wrap each matrix element with GVar
, and send them to the reversible backward pass
out! = SparseArrays.dot(a, b)
@benchmark (~idot)($(GVar(out!, 1.0)),
$(GVar.(a)), $(GVar.(b)))
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 122.309 μs … 8.374 ms ┊ GC (min … max): 0.00% … 88.00%
Time (median): 164.012 μs ┊ GC (median): 0.00%
Time (mean ± σ): 164.241 μs ± 84.423 μs ┊ GC (mean ± σ): 0.45% ± 0.88%
▄▁ ▇▃▁ ▁ ▁ ▂█▄▃▁▂▂▁▁▂▁▆▆▃▂▁▂▂▁▁▂▁▁ ▂
██▇▃▆▅▁▁▁▁▆▅▅▇▇████▇██▇▇▇█████████████████████████████████▇▆ █
122 μs Histogram: log(frequency) by time 197 μs <
Memory estimate: 8.67 KiB, allocs estimate: 5.
The time used for computing backward pass is approximately 1.6 times Julia's native forward pass. Here, we have turned off the reversibility check off to achieve better performance. By writing sparse matrix multiplication and other sparse matrix operations reversibly, we will have a differentiable sparse matrix library with proper performance.
See my another blog post for reversible sparse matrix multiplication.
This page was generated using Literate.jl.