Based on ti7's very helpful answer I've slightly modified their solution so that the replacement Pow(a,b) -> a
happens for any type Binary
symbol, regardless of a
being inside an IndexedBase
or just a Symbol
. After all, the identity x**n==x
holds for either.
from sympy import *
class Binary(Symbol):
'''Empty class for tagging variables as binary'''
pass
def simplify_binary_powers(expr):
'''
Remove exponents of binary variables by replacing Pow(type(Binary),b)
for type(Binary).
'''
a = Wild("a", properties=[lambda a: a.atoms(Binary)])
b = Wild("b", properties=[lambda b: isinstance(b, Number)])
return expr.replace(Pow(a, b), lambda a, b: a)
Works for IndexedBase
:
x = IndexedBase(Binary("x"), integer=True)
expr = x[0,0]**3 + x[1,0]**2 + x[2,0]
print(simplify_binary_powers(expr))
Output: x[0, 0] + x[1, 0] + x[2, 0]
and also works for a single Binary
variable:
y = Binary('y')
expr2 = y**3
print(simplify_binary_powers(expr2))
Output: y