Raising an exception in a custom ufunc in a NumPy ndarray subclass

49 views Asked by At

I'm implementing a NumPy ndarray subclass with a custom equal() universal function.

This is a distilled example where the custom equal() ufunc just calls the original:

class MyArray(np.ndarray):

    def __new__(cls, data):
        return np.array(data).view(MyArray)

    def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
        inputs = [i.view(np.ndarray) for i in inputs]
        if ufunc == np.equal and method == "__call__":
            return self._custom_equal(*inputs, **kwargs)
        return super().__array_ufunc__(ufunc, method, *inputs, **kwargs)

    @staticmethod
    def _custom_equal(a, b, **kwargs):
        return np.equal(a, b, **kwargs)

This works:

>>> a = MyArray([(1,2,3), (4,5,6)])
>>> b = MyArray([(1,2,3), (4,5,6)])
>>> a == b
array([[ True,  True,  True],
       [ True,  True,  True]])

But when the arrays have different shapes a DeprecationWarning is issued and the fuction returns False:

>>> c = MyArray([(1,2,3,4,5,6)])
>>> a == c
DeprecationWarning: elementwise comparison failed; this will raise an error in the future.
  a == c
False

When _custom_equal() is updated:

    @staticmethod
    def _custom_equal(a, b, **kwargs):
        try:
            eq = np.equal(a, b, **kwargs)
        except Exception as e:
            print(e)
            raise
        return eq

It can be seen that the exception is raised, then the comparison is retried:

>>> a == c
operands could not be broadcast together with shapes (2,3) (1,6) 
operands could not be broadcast together with shapes (1,6) (2,3) 
False
DeprecationWarning: elementwise comparison failed; this will raise an error in the future.
  a == c

Comparing array shapes and explicitly raising an exception produces the same result.

Is there a way to have the exception raised instead of getting a False result value?

0

There are 0 answers