numpy einsum for y = (x*A)@A.T

52 views Asked by At

np.einsum seems nice, but it is complicated (for me) to use on more difficult expressions. What would be the equivalent expression here?

import numpy as np
x = np.array([1.,2.,3.])
A = np.array([[3.,4.,5],
               [2.,2.,2.]])
y = (x*A)@A.T

# y2 = np.einsum('what to write here', x, A)

I can recreate each basic step as below, but I struggle to do everything in one np.einsum().

xA = np.einsum('i, ji->ji', x,A) # x*A
At = np.einsum('ij->ji', A) #A.T
y2 = np.einsum('ij, jk', xA, At) # xA @ At
assert (y==y2).all() #evaluates to true

Note that in the example above, A is 2x3 matrix but in my actual code it can be 1000x3 matrix or larger.

0

There are 0 answers