Expand a function of an arbitrary number of arguments if it is linear in each argument

78 views Asked by At

Is there a way to write a replacement rule for a function f with an arbitrary number of arguments that makes it linear in all its arguments? An example for when f has three arguments:

  1. f( x1+x4 , x2 , x3 ) = f(x4,x2,x3) + f(x1,x2,x3)
  2. f( x1 , x2+x4 , x3 ) = f(x1,x2,x3) + f(x1,x4,x3)
  3. f( x1 , x2 , x3+x4 ) = f(x1,x2,x3) + f(x1,x2,x4)

Using "Wild" works partially:

from sympy import *
f=Function('f')
var("x1:5")
a=Wild("a")
b=Wild("b")
A=Wild('A', exclude=[0])
B=Wild('B', exclude=[0])
expr=f(x1,x2+x4,x3);
print("This one works")
print expr , '->' , expr.replace(f(a,Add(A,B),b),f(a,A,b)+f(a,B,b))
# f(x1, x2 + x4, x3) -> f(x1, x2, x3) + f(x1, x4, x3)
print("This one doesn't on the last entry")
expr=f(x1,x2,x3+x4);
print f(x1,x2,x3+x4) , '->' , expr.replace(f(a,Add(A,B),b),f(a,A,b)+f(a,B,b))
# f(x1, x2, x3 + x4) -> f(x1, x2, x3 + x4)

I know I could iterate in a variety of ways over the arguments of the function while altering the replacement, but I was hoping the functionality was built into "Wild" or "replace" already. Mathematica, for example, has "wildcards" like "a___,b___,A___,B___" which mean that "a___" could be an empty sequence, or a single argument, or a sequence of multiple arguments. For example, in Mathematica,

expr /. f[a__,A_Plus,b__] :> f[a,A[[1]],b]+f[a,A[[2;;]],b]

would correctly simplify both test cases, and for f's with any number of arguments.

Is there something similar, or is this is close as sympy gets?

Alternatively, might this be possible to do with argument unpacking on a recursive definition starting from something like def f(*args):?

1

There are 1 answers

0
AudioBubble On

Instead of Wild matching, I would detect which arguments of f are Add and expand those using itertools.product

import itertools
term_groups = [term.args if term.func is Add else (term,) for term in expr.args]
expanded = Add(*[expr.func(*args) for args in itertools.product(*term_groups)])

For example, if expr is f(x1+x2+x4, x2+x4, x3*x1), then term_groups is [(x1, x2, x4), (x2, x4), (x1*x3,)] where the last argument yields a 1-element tuple since it's not an Add. And expanded is

f(x1, x2, x1*x3) + f(x1, x4, x1*x3) + f(x2, x2, x1*x3) + f(x2, x4, x1*x3) + f(x4, x2, x1*x3) + f(x4, x4, x1*x3)