chain rule ugly output

43 views
Skip to first unread message

mohab meshref

unread,
Feb 19, 2018, 2:25:34 PM2/19/18
to sympy
so i am implementing a simple chain rule differentiation , given two functions m2 , m
with:

m2 = r0**2 + r1**2 + r2**2

m= m2**(0.5)

i want to get the differentiation of (m  wrt r0), so i calculate: diff(m wrt m2) * diff (m2 wrt r0).

expected output:
r0 *  m2**(-0.5).

sympy output:
1.0*r0*m2(r0**2 + r1**2 + r2**2)**(-0.5)*Subs(Derivative(m(_xi_1), _xi_1), (_xi_1,), (m2(r0**2 + r1**2 + r2**2)**0.5,))*Subs(Derivative(m2(_xi_1), _xi_1), (_xi_1,), (r0**2 + r1**2 + r2**2,))

---------------------------------------------------------------

first thing i want to know is why instead of writing just m2 it has to write this (m2(r0**2 + r1**2 + r2**2)) ?
second, why is this term (Subs(Derivative(m(_xi_1), _xi_1), (_xi_1,), (m2(r0**2 + r1**2 + r2**2)**0.5,))*Subs(Derivative(m2(_xi_1), _xi_1), (_xi_1,), (r0**2 + r1**2 + r2**2,))) here?
and what does it actually means?

Here's my code:

import sympy as sp

r0, r1, r2, t0 = sp.symbols(
    "r0,r1,r2,t0", real=True)

m2, m = sp.symbols('m2, m', cls=sp.Function)

m2 = m2(r0**2 + r1**2 + r2**2)

m = m(m2 ** (0.5))

dm2r0 = sp.Derivative(m2,r0)

dmm2 = sp.Derivative(m,m2)

dmr0 = dmm2 * dm2r0

print("dmr0:\n"+str(dmr0.doit())+"\n\n")

Aaron Meurer

unread,
Feb 19, 2018, 3:31:32 PM2/19/18
to sy...@googlegroups.com
On Mon, Feb 19, 2018 at 6:12 AM, mohab meshref <mohab...@gmail.com> wrote:
> so i am implementing a simple chain rule differentiation , given two
> functions m2 , m
> with:
>
> m2 = r0**2 + r1**2 + r2**2
>
> m= m2**(0.5)
>
> i want to get the differentiation of (m wrt r0), so i calculate: diff(m wrt
> m2) * diff (m2 wrt r0).
>
> expected output:
> r0 * m2**(-0.5).
>
> sympy output:
> 1.0*r0*m2(r0**2 + r1**2 + r2**2)**(-0.5)*Subs(Derivative(m(_xi_1), _xi_1),
> (_xi_1,), (m2(r0**2 + r1**2 + r2**2)**0.5,))*Subs(Derivative(m2(_xi_1),
> _xi_1), (_xi_1,), (r0**2 + r1**2 + r2**2,))
>
> ---------------------------------------------------------------
>
> first thing i want to know is why instead of writing just m2 it has to write
> this (m2(r0**2 + r1**2 + r2**2)) ?

This is the expression in question. Note that SymPy has no idea what
you called your Python variables. See
http://docs.sympy.org/latest/tutorial/gotchas.html

> second, why is this term (Subs(Derivative(m(_xi_1), _xi_1), (_xi_1,),
> (m2(r0**2 + r1**2 + r2**2)**0.5,))*Subs(Derivative(m2(_xi_1), _xi_1),
> (_xi_1,), (r0**2 + r1**2 + r2**2,))) here?
> and what does it actually means?

This represents the derivative of the function evaluated at that
point, which is needed for the chain rule. This is the only way SymPy
has to represent something like f'(a) when 'a' is something more
complicated than a single variable.

It is perhaps clearer to see what it is if you use pprint() on the output:

-0.5⎛ 2 2 2⎞ ⎛ d ⎞│
⎛ d ⎞│
1.0⋅r₀⋅m₂ ⎝r₀ + r₁ + r₂ ⎠⋅⎜───(m(ξ₁))⎟│ 0.5⎛ 2 2
2⎞⋅⎜───(m₂(ξ₁))⎟│ 2 2 2
⎝dξ₁ ⎠│ξ₁=m₂ ⎝r₀ + r₁ + r₂ ⎠
⎝dξ₁ ⎠│ξ₁=r₀ + r₁ + r₂

Aaron Meurer

>
> Here's my code:
>
> import sympy as sp
>
> r0, r1, r2, t0 = sp.symbols(
> "r0,r1,r2,t0", real=True)
>
> m2, m = sp.symbols('m2, m', cls=sp.Function)
>
> m2 = m2(r0**2 + r1**2 + r2**2)
>
> m = m(m2 ** (0.5))
>
> dm2r0 = sp.Derivative(m2,r0)
>
> dmm2 = sp.Derivative(m,m2)
>
> dmr0 = dmm2 * dm2r0
>
> print("dmr0:\n"+str(dmr0.doit())+"\n\n")
>
> --
> You received this message because you are subscribed to the Google Groups
> "sympy" group.
> To unsubscribe from this group and stop receiving emails from it, send an
> email to sympy+un...@googlegroups.com.
> To post to this group, send email to sy...@googlegroups.com.
> Visit this group at https://groups.google.com/group/sympy.
> To view this discussion on the web visit
> https://groups.google.com/d/msgid/sympy/56251b73-da56-4938-a250-1c91d1cdad07%40googlegroups.com.
> For more options, visit https://groups.google.com/d/optout.

Chris Smith

unread,
Feb 20, 2018, 9:02:02 AM2/20/18
to sympy
Did you mean this?

>>> var('r:4 m2')
(r0,r1,r2,r3,m2)
>>> m=sqrt(m2)
>>> _m2 = r0**2 + r1**2 + r2**2
>>> m.diff(m2)*_m2.diff(r0)
r0m2
>>> (m.diff(m2)*_m2.diff(r0)).xreplace({m2:_m2})
r0r02+r12+r22
Without the chain rule this can be computed as:

>>> sqrt(r0**2 + r1**2 + r2**2).diff(r0)
r0r02+r12+r22

Chris Smith

unread,
Feb 20, 2018, 9:04:30 AM2/20/18
to sympy
Hmmm...try again:


>>> var('r:4 m2')
(r0,r1,r2,r3,m2)
>>> m=sqrt(m2)
>>> _m2 = r0**2 + r1**2 + r2**2
>>> m.diff(m2)*_m2.diff(r0)
r0m2
>>> (m.diff(m2)*_m2.diff(r0)).xreplace({m2:_m2})
r0r02+r12+r22
>>> sqrt(r0**2+r1**2+r2**2).diff(r0)
r0r02+r12+r22



Reply all
Reply to author
Forward
0 new messages