Linear Algebra Autodiff (complex valued)

You can find the Julia implementations in BackwardsLinalg.jl and OMEinsum.jl.


Definition of Einsum

einsum is defined as

Oo=(abc)\oAaBbCc, O_{\vec o} = \sum\limits_{(\vec a \cup \vec b \cup \vec c \ldots) \backslash \vec o }A_{\vec a}B_{\vec b}C_{\vec c} \ldots,

where a=a1,a2\vec a = a_1, a_2\dots are labels that appear in tensor AA, ab\vec a\cup \vec b means the union of two sets of labels, a\b\vec a\backslash \vec b means setdiff between two sets of labels. The above sumation runs over all indices that does not appear in output tensor OO.


Given O\overline O, In order to to obtain BL/B\overline B \equiv \partial \mathcal{L}/\partial B, consider the the diff rule

δL=oOoδOo=oabcOoAaδBbCc \begin{align} \delta \mathcal{L} &= \sum\limits_{\vec o} \overline O_{\vec o} \delta O_{\vec o} \\ &=\sum\limits_{\vec o\cup\vec a \cup \vec b\cup \vec c \ldots} \overline O_{\vec o}A_{\vec a}\delta B_{\vec b}C_{\vec c} \ldots \end{align}

Here, we have used the (partial) differential equation

δOo=(abc)\oAaδBbCc \delta O_{\vec o} = \sum\limits_{(\vec a \cup \vec b \cup \vec c \ldots) \backslash \vec o }A_{\vec a}\delta B_{\vec b}C_{\vec c} \ldots

Then we define

Bb=(abc)\bAaOoCc, \overline B_{\vec b} = \sum\limits_{(\vec a \cup \vec b \cup \vec c \ldots) \backslash \vec b }A_{\vec a}\overline O_{\vec o}C_{\vec c} \ldots,

We can readily verify

δL=bBbδBb \delta \mathcal{L} = \sum\limits_{\vec b} \overline B_{\vec b} \delta B_{\vec b}

This backward rule is exactly an einsum that exchange output tensor OO and input tensor BB.

In conclusion, the index magic of exchanging indices as backward rule holds for einsum.

Thank Andreas Peter for helpful discussion.

Symmetric Eigenvalue Decomposition (ED)



A=UEU A = UEU^\dagger

We have

A=U[E+12(UUF+h.c.)]U \overline{A} = U\left[\overline{E} + \frac{1}{2}\left(\overline{U}^\dagger U \circ F + h.c.\right)\right]U^\dagger

Where Fij=(EjEi)1F_{ij}=(E_j- E_i)^{-1}.

If EE is continuous, we define the density ρ(E)=kδ(EEk)=1πk[Gr(E,k)]\rho(E) = \sum\limits_k \delta(E-E_k)=-\frac{1}{\pi}\int_k \Im[G^r(E, k)] (check sign!). Where Gr(E,k)=1EEk+iδG^r(E, k) = \frac{1}{E-E_k+i\delta}.

We have

A=U[E+12(UU[G(Ei,Ej)]+h.c.)]U \overline{A} = U\left[\overline{E} + \frac{1}{2}\left(\overline{U}^\dagger U \circ \Re [G(E_i, E_j)] + h.c.\right)\right]U^\dagger

Singular Value Decomposition (SVD)





Complex valued SVD is defined as A=USVA = USV^\dagger. For simplicity, we consider a full rank square matrix AA. Differentiation gives

dA=dUSV+UdSV+USdV dA = dUSV^\dagger + U dS V^\dagger + USdV^\dagger UdAV=UdUS+dS+SdVV U^\dagger dA V = U^\dagger dU S + dS + SdV^\dagger V

Defining matrices dC=UdUdC=U^\dagger dU and dD=dVVdD = dV^\dagger V and dP=UdAVdP = U^\dagger dA V, then we have

{dC+dC=0,dD+dD=0 \begin{cases}dC^\dagger+dC=0,\\dD^\dagger +dD=0\end{cases}

We have

dP=dCS+dS+SdD dP = dC S + dS + SdD

where dCSdCS and SdDSdD has zero real part in diagonal elements. So that dS=[diag(dP)]dS = \Re[{\rm diag}(dP)].

dL=Tr[ATdA+ATdA]=Tr[ATdA+dAA] #rule 3\begin{aligned} d\mathcal{L} &= {\rm Tr}\left[\overline{A}^TdA+\overline{A^*}^TdA^*\right]\\ &= {\rm Tr}\left[\overline{A}^TdA+dA^\dagger\overline{A}^*\right] ~~~~~~~\#rule~3 \end{aligned}

Easy to show As=USVT\overline A_s = U^*\overline SV^T. Notice here, A\overline A is the derivative rather than gradient, they are different by a conjugate, this is why we have transpose rather than conjugate here. see my complex valued autodiff blog for detail.

Using the relations dC+dC=0dC^\dagger+dC=0 and dD+dD=0dD^\dagger+dD=0

{dPS+SdP=dCS2S2dCSdP+dPS=S2dDdDS2 \begin{cases} dPS + SdP^\dagger &= dC S^2-S^2dC\\ SdP + dP^\dagger S &= S^2dD-dD S^2 \end{cases} {dC=F(dPS+SdP)dD=F(SdP+dPS) \begin{cases} dC = F\circ(dPS+SdP^\dagger)\\ dD = -F\circ (SdP+dP^\dagger S) \end{cases}

where Fij=1sj2si2F_{ij} = \frac{1}{s_j^2-s_i^2}, easy to verify FT=FF^T = -F. Notice here, the relation between the imaginary diagonal parts is lost

[IdP]=[I(dC+dD)] \color{red}{\Im[I\circ dP] = \Im[I\circ(dC+dD)]}

This the missing diagonal imaginary part is definitely not trivial, but has been ignored for a long time until @refraction-ray (Shixin Zhang) mentioned and solved it. Let's first focus on the off-diagonal contributions from dUdU

TrUTdU=TrUTUdC+UT(IUU)dAVS1=TrUTU(F(dPS+SdP))=Tr(dPS+SdP)(F(UTU))#rule 1,2=Tr(dPS+SdP)JT \begin{align} {\rm Tr}\overline U^TdU &= {\rm Tr} \overline U ^TU dC + \overline U^T (I-UU^\dagger) dAVS^{-1}\\ &= {\rm Tr}\overline U^T U (F\circ(dPS+SdP^\dagger))\\ &= {\rm Tr}(dPS+SdP^\dagger)(-F\circ (\overline U^T U)) \# rule~1,2\\ &= {\rm Tr}(dPS+SdP^\dagger)J^T \end{align}

Here, we defined J=F(UTU)J=F\circ(U^T\overline U).

dL=Tr(dPS+SdP)(J+J)T=TrdPS(J+J)T+h.c.=TrUdAVS(J+J)T+h.c.=Tr[VS(J+J)TU]dA+h.c. \begin{align*} d\mathcal L &= {\rm Tr} (dPS+SdP^\dagger)(J+J^\dagger)^T\\ &= {\rm Tr} dPS(J+J^\dagger)^T+h.c.\\ &= {\rm Tr} U^\dagger dA V S(J+J^\dagger)^T+h.c.\\ &= {\rm Tr}\left[ VS(J+J^\dagger)^TU^\dagger\right] dA+h.c. \end{align*}

By comparing with dL=Tr[ATdA+h.c.]d\mathcal L = {\rm Tr}\left[\overline{A}^TdA+h.c. \right], we have

AˉU(real)=[VS(J+J)TU]T=U(J+J)SVT \begin{align} \bar A_U^{(\rm real)} &= \left[VS(J+J^\dagger)^TU^\dagger\right]^T\\ &=U^*(J+J^\dagger)SV^T \end{align}

Update: The missing diagonal imaginary part

Now let's inspect the diagonal imaginary parts of dCdC and dDdD in Eq. 16. At a first glance, it is not sufficient to derive dCdC and dDdD from dPdP, but consider there is still an information not used, the loss must be gauge invariant, which means

L(UΛ,S,VΛ) \mathcal{L}(U\Lambda, S, V\Lambda)

Should be independent of the choice of gauge Λ\Lambda, which is defined as diag(eiϕ,...){\rm diag}(e^i\phi, ...)

dL=Tr[UΛTd(UΛ)+SdS+VΛTd(VΛ)]+h.c.=Tr[UΛT(dUΛ+UdΛ)+SdS+VΛT(VdΛ+dVΛ)]+h.c.=Tr[(UΛTU+VΛTV)dΛ]++h.c.\begin{aligned} d\mathcal{L} &={\rm Tr}[ \overline{U\Lambda}^T d(U\Lambda) +\overline SdS+\overline{V\Lambda}^Td(V\Lambda)] + h.c.\\ &={\rm Tr}[ \overline {U\Lambda}^T (dU\Lambda+Ud\Lambda) +\overline{S}dS+ \overline{V\Lambda}^T(Vd\Lambda +dV\Lambda)] + h.c.\\ &= {\rm Tr}[(\overline{U\Lambda}^TU+\overline{V\Lambda}^TV )d\Lambda ] + \ldots + h.c. \end{aligned}

Gauge invariance refers to

Λ=I(UΛTU+VΛTV)=0 \overline{\Lambda} = I\circ(\overline{U\Lambda}^TU+\overline{V\Lambda}^TV) = 0

For any Λ\Lambda, where II refers to the diagonal mask matrix. It is of cause valid when Λ1\Lambda\rightarrow1, I(UTU+VTV)=0I\circ(\overline{U}^TU+\overline V^TV) = 0.

Consider the contribution from the diagonal imaginary part, we have

Tr[UTU(I[dC])+VTV(I[dD])]+h.c.=Tr[I(UTU)[dC]I(VTV)[dD]]+h.c. #rule1=Tr[I(UTU)([dC]+[dD])]=Tr[I(UTU)[dP]S1]=Tr[S1ΛJUdAV]\begin{aligned} &{\rm Tr} [\overline U^T U (I \circ \Im [dC])+\overline V^T V (I \circ \Im [dD^\dagger])] + h.c.\\ &={\rm Tr} [ I \circ (\overline U^T U)\Im [dC]-I\circ (\overline V^T V) \Im [dD]] +h.c. ~~~~~~~~~~~~~~\# rule 1\\ &={\rm Tr} [ I \circ (\overline U^T U)(\Im [dC]+ \Im [dD])] \\ &={\rm Tr}[I\circ (\overline U^T U) \Im[dP]S^{-1}] \\ &={\rm Tr}[S^{-1}\Lambda_J U^{\dagger}dA V]\\ \end{aligned}

where ΛJ=[I(UTU)]=12I(UTU)h.c.\Lambda_J = \Im[I\circ(\overline U^TU)]= \frac 1 2I\circ(\overline U^TU)-h.c., with II the mask for diagonal part. Since only the real part contribute to δL\delta \mathcal{L} (the imaginary part will be canceled by the Hermitian conjugate counterpart), we can safely move \Im from right to left.

AˉU+V(imag)=UΛJS1VT\begin{aligned} \color{red}{\bar A_{U+V}^{(\rm imag)} = U^*\Lambda_J S^{-1}V^T} \end{aligned}

Thanks @refraction-ray (Shixin Zhang) for sharing his idea in the first time. This is the issue for discussion. His arXiv preprint is coming out soon.

When UU is not full rank, this formula should take an extra term (Ref. 2)

AˉU(real)=U(J+J)SVT+(VS1UT(IUU))T\begin{aligned} \bar A_U^{(\rm real)} &=U^*(J+J^\dagger)SV^T + (VS^{-1}\overline U^T(I-UU^\dagger))^T \end{aligned}

Similarly, for VV​ we have

AV(real)=US(K+K)VT+(US1VT(IVV)),\begin{aligned} \overline A_V^{(\rm real)} &=U^*S(K+K^\dagger)V^T + (U S^{-1} \overline V^T (I - VV^\dagger))^*, \end{aligned}

where K=F(VTV)K=F\circ(V^T\overline V)​.

To wrap up

A=AU(real)+AS+AV(real)+AU+V(imag) \overline A = \overline A_U^{\rm (real)} + \overline A_S + \overline A_V^{\rm (real)} + \overline A_{U+V}^{\rm (imag)}

This result can be directly used in autograd.

For the gradient used in training, one should change the convention

A=A,U=U,V=V. \mathcal{\overline A} = \overline A^*,\\ \mathcal{\overline U} = \overline U^*,\\ \mathcal{\overline V}= \overline V^*.

This convention is used in tensorflow, Zygote.jl. Which is

A=U(J+J)SV+(IUU)US1V+USV+US(K+K)V+US1V(IVV)+12U(I(UU)h.c.)S1V\begin{aligned} \mathcal{\overline A} =& U(\mathcal{J}+\mathcal{J}^\dagger)SV^\dagger + (I-UU^\dagger)\mathcal{\overline U}S^{-1}V^\dagger\\ &+ U\overline SV^\dagger\\ &+US(\mathcal{K}+\mathcal{K}^\dagger)V^\dagger + U S^{-1} \mathcal{\overline V}^\dagger (I - VV^\dagger)\\ &\color{red}{+\frac 1 2 U (I\circ(U^\dagger\overline U)-h.c.)S^{-1}V^\dagger} \end{aligned}

where J=F(UU)J=F\circ(U^\dagger\mathcal{\overline U}) and K=F(VV)K=F\circ(V^\dagger \mathcal{\overline V}).


rule 1. Tr[A(CB)]=ATCB=Tr((CAT)TB)=Tr(CTA)B{\rm Tr} \left[A(C\circ B\right)] = \sum A^T\circ C\circ B = {\rm Tr} ((C\circ A^T)^TB)={\rm Tr}(C^T\circ A)B

rule2. (CA)T=CTAT(C\circ A)^T = C^T \circ A^T

rule3. When L\mathcal L is real,

Lx=(Lx)\frac{\partial \mathcal{L}}{\partial x^*} = \left(\frac{\partial \mathcal{L}}{\partial x}\right)^*

How to Test SVD

e.g. To test the adjoint contribution from UU, we can construct a gauge insensitive test function

# H is a random Hermitian Matrix
function loss(A)
    U, S, V = svd(A)
    psi = U[:,1]

function gradient(A)
    U, S, V = svd(A)
    dU = zero(U)
    dS = zero(S)
    dV = zero(V)
    dU[:,1] = U[:,1]'*H
    dA = svd_back(U, S, V, dU, dS, dV)

QR decomposition





with QQ=IQ^\dagger Q = \mathbb{I}, so that dQQ+QdQ=0dQ^\dagger Q+Q^\dagger dQ=0. RR is a complex upper triangular matrix, with diagonal part real.

dA=dQR+QdR dA = dQR+QdR dQ=dAR1QdRR1 dQ = dAR^{-1}-QdRR^{-1} {QdQ=dCdRR1dQQ=dCRdR \begin{cases} Q^\dagger dQ = dC - dRR^{-1}\\ dQ^\dagger Q =dC^\dagger - R^{-\dagger}dR^\dagger \end{cases}

where dC=QdAR1dC=Q^\dagger dAR^{-1}.


dC+dC=dRR1+(dRR1) dC+dC^\dagger = dRR^{-1} +(dRR^{-1})^\dagger

Notice dRdR is upper triangular and its diag is lower triangular, this restriction gives

U(dC+dC)=dRR1 U\circ(dC+dC^\dagger) = dRR^{-1}

where UU is a mask operator that its element value is 11 for upper triangular part, 0.50.5 for diagonal part and 00 for lower triangular part. One should also notice here both RR and dRdR has real diagonal parts, as well as the product dRR1dRR^{-1}.

Now let's wrap up using the Zygote convension of gradient

dL=Tr[QdQ+RdR+h.c.]=Tr[QdAR1QQdRR1+RdR+h.c.]=Tr[R1QdA+R1(QQ+RR)dR+h.c.]=Tr[R1QdA+R1MdR+h.c.] \begin{align} d\mathcal L &= {\rm Tr}\left[\overline{\mathcal{Q}}^\dagger dQ+\overline{\mathcal{R}}^\dagger dR +h.c. \right]\\ &={\rm Tr}\left[\overline{\mathcal{Q}}^\dagger dA R^{-1}-\overline{\mathcal{Q}}^\dagger QdR R^{-1}+\overline{\mathcal{R}}^\dagger dR +h.c. \right]\\ &={\rm Tr}\left[ R^{-1}\overline{\mathcal{Q}}^\dagger dA+ R^{-1}(-\overline{\mathcal{Q}}^\dagger Q +R\overline{\mathcal{R}}^\dagger) dR +h.c. \right]\\ &={\rm Tr}\left[ R^{-1}\overline{\mathcal{Q}}^\dagger dA+ R^{-1}M dR +h.c. \right] \end{align}

here, M=RRQQM=R\overline{\mathcal{R}}^\dagger-\overline{\mathcal{Q}}^\dagger Q. Plug in dRdR we have

dL=Tr[R1QdA+M[U(dC+dC)]+h.c.]=Tr[R1QdA+(ML)(dC+dC)+h.c.] #rule 1=Tr[(R1QdA+h.c.)+(ML)(dC+dC)+(ML)(dC+dC)]=Tr[R1QdA+(ML+h.c.)dC+h.c.]=Tr[R1QdA+(ML+h.c.)QdAR1]+h.c. \begin{align} d\mathcal{L}&={\rm Tr}\left[ R^{-1}\overline{\mathcal{Q}}^\dagger dA + M \left[U\circ(dC+dC^\dagger)\right] +h.c. \right]\\ &={\rm Tr}\left[ R^{-1}\overline{\mathcal{Q}}^\dagger dA + (M\circ L)(dC+dC^\dagger) +h.c. \right] \;\;\# rule\; 1\\ &={\rm Tr}\left[ (R^{-1}\overline{\mathcal{Q}}^\dagger dA+h.c.) + (M\circ L)(dC + dC^\dagger)+ (M\circ L)^\dagger (dC + dC^\dagger)\right]\\ &={\rm Tr}\left[ R^{-1}\overline{\mathcal{Q}}^\dagger dA + (M\circ L+h.c.)dC + h.c.\right]\\ &={\rm Tr}\left[ R^{-1}\overline{\mathcal{Q}}^\dagger dA + (M\circ L+h.c.)Q^\dagger dAR^{-1}\right]+h.c.\\ \end{align}

where L=U=1UL =U^\dagger = 1-U is the mask of lower triangular part of a matrix.

A=R1[Q+(ML+h.c.)Q]A=[Q+Q(ML+h.c.)]R=[Q+Qcopyltu(M)]R \begin{align} \mathcal{\overline A}^\dagger &= R^{-1}\left[\overline{\mathcal{Q}}^\dagger + (M\circ L+h.c.)Q^\dagger\right]\\ \mathcal{\overline A} &= \left[\overline{\mathcal{Q}} + Q(M\circ L+h.c.)\right]R^{-\dagger}\\ &=\left[\overline{\mathcal{Q}} + Q \texttt{copyltu}(M)\right]R^{-\dagger} \end{align}

Here, the copyltu\texttt{copyltu}​ takes conjugate when copying elements to upper triangular part.

