Sparse matrices

Source to source automatic differentiation is useful in differentiating sparse matrices. It is a well-known problem that sparse matrix operations can not benefit directly from generic backward rules for dense matrices because general rules do not keep the sparse structure. In the following, we will show that reversible AD can differentiate the Frobenius dot product between two sparse matrices with the state-of-the-art performance. Here, the Frobenius dot product is defined as \texttt{trace(A'B)}. Its native Julia (irreversible) implementation is SparseArrays.dot.

The following is a reversible counterpart

using NiLang, NiLang.AD
using SparseArrays

@i function idot(r::T, A::SparseMatrixCSC{T},B::SparseMatrixCSC{T}) where {T}
    m ← size(A, 1)
    n ← size(A, 2)
    @invcheckoff branch_keeper ← zeros(Bool, 2*m)
    @safe size(B) == (m,n) || throw(DimensionMismatch("matrices must have the same dimensions"))
    @invcheckoff @inbounds for j = 1:n
        ia1 ← A.colptr[j]
        ib1 ← B.colptr[j]
        ia2 ← A.colptr[j+1]
        ib2 ← B.colptr[j+1]
        ia ← ia1
        ib ← ib1
        @inbounds for i=1:ia2-ia1+ib2-ib1-1
            ra ← A.rowval[ia]
            rb ← B.rowval[ib]
            if (ra == rb, ~)
                r += A.nzval[ia]' * B.nzval[ib]
            end
            # b move -> true, a move -> false
            branch_keeper[i] ⊻= ia == ia2-1 || ra > rb
            ra → A.rowval[ia]
            rb → B.rowval[ib]
            if (branch_keeper[i], ~)
                INC(ib)
            else
                INC(ia)
            end
        end
        ~@inbounds for i=1:ia2-ia1+ib2-ib1-1
            # b move -> true, a move -> false
            branch_keeper[i] ⊻= ia == ia2-1 || A.rowval[ia] > B.rowval[ib]
            if (branch_keeper[i], ~)
                INC(ib)
            else
                INC(ia)
            end
        end
    end
    @invcheckoff branch_keeper → zeros(Bool, 2*m)
end

Here, the key point is using a \texttt{branch_keeper} vector to cache branch decisions.

The time used for a native implementation is

using BenchmarkTools
a = sprand(1000, 1000, 0.01);
b = sprand(1000, 1000, 0.01);
@benchmark SparseArrays.dot($a, $b)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     69.822 μs (0.00% GC)
  median time:      71.635 μs (0.00% GC)
  mean time:        72.251 μs (0.00% GC)
  maximum time:     132.857 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1

To compute the gradients, we wrap each matrix element with GVar, and send them to the reversible backward pass

out! = SparseArrays.dot(a, b)
@benchmark (~idot)($(GVar(out!, 1.0)),
        $(GVar.(a)), $(GVar.(b)))
BenchmarkTools.Trial: 
  memory estimate:  2.17 KiB
  allocs estimate:  2
  --------------
  minimum time:     99.850 μs (0.00% GC)
  median time:      100.983 μs (0.00% GC)
  mean time:        101.900 μs (0.00% GC)
  maximum time:     725.241 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1

The time used for computing backward pass is approximately 1.6 times Julia's native forward pass. Here, we have turned off the reversibility check off to achieve better performance. By writing sparse matrix multiplication and other sparse matrix operations reversibly, we will have a differentiable sparse matrix library with proper performance.

See my another blog post for reversible sparse matrix multiplication.


This page was generated using Literate.jl.