So with the hint with the functions from @Davide_sd I made a generic method that allows me to pretty easily control how the sub-steps are split up. Basically, I'm manually deriving the functions I split off, but much like cse, keep the results in a dictionary to share among all occurences.
The base expression that make up the calculation are input and never modified, the derivation list is seeded with what you want to derive (multiple expressions are ok), and it will recursively derive them, using the expression list as required.
At the end, I can still use cse to a) bring it into that format should you require it, and b) factor out even more common occurences.
It works decently well with my small example, may update it as I add more complexity to the function I need derived.
from sympy import *
def find_derivatives(expression):
derivatives = []
if isinstance(expression, Derivative):
#print(expression)
derivatives.append(expression)
elif isinstance(expression, Basic):
for a in expression.args:
derivatives += find_derivatives(a)
elif isinstance(expression, MatrixBase):
for i in range(rows):
for j in range(cols):
derivatives += find_derivatives(self[i, j])
return derivatives
def derive_recursively(expression_list, derive_done, derive_todo):
newly_derived = {}
for s, e in derive_todo.items():
print("Handling derivatives in " + str(e))
derivatives = find_derivatives(e)
for d in derivatives:
if d in newly_derived:
#print("Found derivative " + str(d) + " in done list, already handled!")
continue
if d in derive_todo:
#print("Found derivative " + str(d) + " in todo list, already handling!")
continue
if d in expression_list:
#print("Found derivative " + str(d) + " in past list, already handled!")
continue
if d.expr in expression_list:
expression = expression_list[d.expr]
print(" Deriving " + str(d.expr) + " w.r.t. " + str(d.variables))
print(" Expression: " + str(expression))
derivative = Derivative(expression, *d.variable_count).doit().simplify()
print(" Derivative: " + str(derivative))
if derivative == 0:
e = e.subs(d, 0)
derive_todo[s] = e
print(" Replacing main expression with: " + str(e))
continue
newly_derived[d] = derivative
continue
print("Did NOT find base expression " + str(d.expr) + " in provided expression list!")
derive_done |= derive_todo
if len(newly_derived) == 0:
return derive_done
return derive_recursively(expression_list, derive_done, newly_derived)
incRot_c = symbols('aX aY aZ')
incRot_s = Matrix(3,1,incRot_c)
theta_s = Function("theta")(*incRot_c)
theta_e = sqrt((incRot_s.T @ incRot_s)[0,0])
incQuat_c = [ Function(f"i{i}")(*incRot_c) for i in "WXYZ" ]
incQuat_s = Quaternion(*incQuat_c)
incQuat_e = Quaternion.from_axis_angle(incRot_s/theta_s, theta_s*2)
baseQuat_c = symbols('qX qY qZ qW')
baseQuat_s = Quaternion(*baseQuat_c)
poseQuat_c = [ Function(f"p{i}")(*incRot_c, *baseQuat_c) for i in "WXYZ" ]
poseQuat_s = Quaternion(*poseQuat_c)
# Could also do it like this and in expressions just refer poseQuat_s to poseQuat_e, but output is less readable
#poseQuat_s = Function(f"pq")(*incRot_c, *baseQuat_c)
poseQuat_e = incQuat_s * baseQuat_s
expressions = { theta_s: theta_e } | \
{ incQuat_c[i]: incQuat_e.to_Matrix()[i] for i in range(4) } | \
{ poseQuat_c[i]: poseQuat_e.to_Matrix()[i] for i in range(4) }
derivatives = derive_recursively(expressions, {}, { symbols('res'): diff(poseQuat_s, incRot_c[0]) })
print(derivatives)
elements = cse(list(expressions.values()) + list(derivatives.values()))
pprint(elements)