0

I'm trying to figure out why the derivative of the solve operation is this in the chainrule package. After reading this i started writing this but i can't seems to get past what i've written.

i know that we have $$ Ax=b\\ x=A\backslash b\\ x = A^{-1}b $$

futhermore $$ \begin{align} \frac{\partial L}{\partial A_{ij}} &= \sum_k\frac{\partial L}{\partial x_k}\frac{\partial x_k}{\partial A_{ij}}\\ \end{align} $$

where $L$ is the loss(a function of x). Let $c_k$ be $\frac{\partial L}{\partial x_k}$

$$ \begin{align} \frac{\partial L}{\partial A_{ij}} &= \sum_k c_k\frac{\partial (A^{-1}b)_k}{\partial A_{ij}}\\ &= c^\top \frac{\partial A^{-1}b}{\partial A_{ij}}\\ &= c^\top \left(\frac{\partial A^{-1}}{\partial A_{ij}} b + A^{-1}\frac{\partial b}{\partial A_{ij}}\right) \\ &= c^\top \left(-A^{-1}\frac{\partial A}{\partial A_{ij}}A^{-1} b + A^{-1}\frac{\partial b}{\partial A_{ij}}\right) \\ &= c^\top \left(-A^{-1}1_{ij}x + A^{-1}\frac{\partial b}{\partial A_{ij}}\right) \\ &= c^\top \left(A^{-1}(\frac{\partial b}{\partial A_{ij}} - 1_{ij}x)\right) \\ \end{align} $$ where $1_ij$ is matrix whose $ij$ th cell is one(the others are zero).

The thing i'm missing here is how can we calculate $\frac{\partial b}{\partial A_{ij}}$ in the backprob since we can't know it.

--edit--

for the sake of completeness Chainrule.jl uses
$$ \partial A = -(A') ^{-1} c x'\\ \partial b = A' \ c $$ where $'$ is the conjugate transpose

Sobhan
  • 147
  • Assuming $b$ is independent of $A$, isn't $\frac{\partial b}{\partial A_{ij}}=0$? – greg Jan 22 '21 at 19:58
  • that's the thing, it doesn't have to be you could have something like x = A \ diag(A), it should still hold. in the code they have a general solution for this – Sobhan Jan 22 '21 at 22:11
  • What's the problem? If $,b={\rm diag}(A),$ then in your notation $,\frac{\partial b}{\partial A_{ij}}={\rm diag}({\tt1}_{ij})$ – greg Jan 22 '21 at 22:40
  • this rule can't be used in a back propagation. – Sobhan Jan 22 '21 at 23:07

0 Answers0