Sparse matrices
Source to source automatic differentiation is useful in differentiating sparse matrices. It is a well-known problem that sparse matrix operations can not benefit directly from generic backward rules for dense matrices because general rules do not keep the sparse structure. In the following, we will show that reversible AD can differentiate the Frobenius dot product between two sparse matrices with the state-of-the-art performance. Here, the Frobenius dot product is defined as \texttt{trace(A'B)}. Its native Julia (irreversible) implementation is SparseArrays.dot
.
The following is a reversible counterpart
using NiLang, NiLang.AD
using SparseArrays
@i function idot(r::T, A::SparseMatrixCSC{T},B::SparseMatrixCSC{T}) where {T}
@routine begin
m, n ← size(A)
branch_keeper ← zeros(Bool, 2*m)
end
@safe size(B) == (m,n) || throw(DimensionMismatch("matrices must have the same dimensions"))
@invcheckoff @inbounds for j = 1:n
@routine begin
ia1 ← A.colptr[j]
ib1 ← B.colptr[j]
ia2 ← A.colptr[j+1]
ib2 ← B.colptr[j+1]
ia ← ia1
ib ← ib1
end
@inbounds for i=1:ia2-ia1+ib2-ib1-1
ra ← A.rowval[ia]
rb ← B.rowval[ib]
if (ra == rb, ~)
r += A.nzval[ia]' * B.nzval[ib]
end
# b move -> true, a move -> false
branch_keeper[i] ⊻= @const ia == ia2-1 || ra > rb
ra → A.rowval[ia]
rb → B.rowval[ib]
if (branch_keeper[i], ~)
INC(ib)
else
INC(ia)
end
end
~@inbounds for i=1:ia2-ia1+ib2-ib1-1
# b move -> true, a move -> false
branch_keeper[i] ⊻= @const ia == ia2-1 || A.rowval[ia] > B.rowval[ib]
if (branch_keeper[i], ~)
INC(ib)
else
INC(ia)
end
end
~@routine
end
~@routine
end
Here, the key point is using a \texttt{branch_keeper} vector to cache branch decisions.
The time used for a native implementation is
using BenchmarkTools
a = sprand(1000, 1000, 0.01);
b = sprand(1000, 1000, 0.01);
@benchmark SparseArrays.dot($a, $b)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 80.201 μs … 113.301 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 82.201 μs ┊ GC (median): 0.00%
Time (mean ± σ): 82.370 μs ± 1.436 μs ┊ GC (mean ± σ): 0.00% ± 0.00%
▂▅ ▇ █▅ ▂
▂▂▂▃▄▃▆▅███████▇█▅▆▄▃▃▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▂▂▁▂▁▂▂▂▂▁▂▂▂▂▂▂▂▂▂▂▂▂ ▃
80.2 μs Histogram: frequency by time 89.8 μs <
Memory estimate: 0 bytes, allocs estimate: 0.
To compute the gradients, we wrap each matrix element with GVar
, and send them to the reversible backward pass
out! = SparseArrays.dot(a, b)
@benchmark (~idot)($(GVar(out!, 1.0)),
$(GVar.(a)), $(GVar.(b)))
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 155.101 μs … 1.321 ms ┊ GC (min … max): 0.00% … 0.00%
Time (median): 156.902 μs ┊ GC (median): 0.00%
Time (mean ± σ): 157.735 μs ± 15.073 μs ┊ GC (mean ± σ): 0.00% ± 0.00%
▂▅▇█▇▇▇
▁▂▃▆████████▆▅▄▃▃▂▂▂▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
155 μs Histogram: frequency by time 168 μs <
Memory estimate: 8.67 KiB, allocs estimate: 5.
The time used for computing backward pass is approximately 1.6 times Julia's native forward pass. Here, we have turned off the reversibility check off to achieve better performance. By writing sparse matrix multiplication and other sparse matrix operations reversibly, we will have a differentiable sparse matrix library with proper performance.
See my another blog post for reversible sparse matrix multiplication.
This page was generated using Literate.jl.