Is there a simple way to subclass python's set without redefining all operators?

60 views Asked by At

Is there a way to subclass set, with the binary operator returning the subclassed type, without redefining them ?

example :

class A(set):
    pass

a = A([1,2,3]) & A([1,2,4])


a.__class__ == A # it's False, and I would like it to be true without redefining all operators

Note that this question : What is the correct (or best) way to subclass the Python set class, adding a new instance variable? is 10 years old and the provided answers are only relevant for python 2.x. This is why I asked another question concerning python 3.x (especially, python ≥ 3.8).

2

There are 2 answers

0
JonSG On

I believe the answer is "No".

As I skim the reference implementation it looks like all paths for set_and() seem to ultimately return either NULL or the result of calling make_new_set_basetype().

That method is currently implemented as:

make_new_set_basetype(PyTypeObject *type, PyObject *iterable)
{
    if (type != &PySet_Type && type != &PyFrozenSet_Type) {
        if (PyType_IsSubtype(type, &PySet_Type))
            type = &PySet_Type;
        else
            type = &PyFrozenSet_Type;
    }
    return make_new_set(type, iterable);
}

Thus without overriding __and__() in your sub-class, what you are going to get back is a set() or a frozenset().

1
James On

It looks like you do need to redefine the set methods, but only the ones that return set object, which is only 13 method, and all of it is boilerplate code.

Here is an example:

class ASet(set):
    def __repr__(self):
        return f'ASet({set(self)})'

    def __and__(self, other):
        return self.__class__(super().__and__(other))

    def __or__(self, other):
        return self.__class__(super().__or__(other))

    def __rand__(self, other):
        return self.__class__(super().__rand__(other))

    def __ror__(self, other):
        return self.__class__(super().__ror__(other))

    def __rsub__(self, other):
        return self.__class__(super().__rsub__(other))

    def __rxor__(self, other):
        return self.__class__(super().__rxor__(other))

    def __sub__(self, other):
        return self.__class__(super().__sub__(other))

    def __xor__(self, other):
        return self.__class__(super().__xor__(other))

    def copy(self, other):
        return self.__class__(super().copy(other))

    def difference(self, other):
        return self.__class__(super().difference(other))

    def intersection(self, other):
        return self.__class__(super().intersection(other))

    def symmetric_difference(self, other):
        return self.__class__(super().symmetric_difference(other))

    def union(self, other):
        return self.__class__(super().union(other))