How to distinguish between Base Class and Derived Class in generics using Typing and Mypy

35 views Asked by At

Consider the following code:

from typing import TypeVar
import dataclasses


@dataclasses.dataclass
class A:
    pass


@dataclasses.dataclass
class B(A):
    pass


T = TypeVar("T", A, B)


def fun(
    x1: T,
    x2: T,
) -> int:
    if type(x1) != type(x2):
        raise TypeError("must be same type!")

    if type(x1) == A:
        return 5

    elif type(x1) == B:
        return 10
    else:
        raise TypeError("Type not handled")


fun(x1=A(), x2=A())  # OK
fun(x1=B(), x2=B())  # OK
fun(x1=B(), x2=A())  # Will throw TypeError, how can I get mypy to say this is an error?
fun(x1=A(), x2=B())  # Will throw TypeError, how can I get mypy to say this is an error?

Mypy is not seeing any problem here. It seems like it is always interpreting the passed object as a base class object of type A.

Is there a way to make the generic even more strict in the sense that it is sensitive to the exact class type? Such that if x1 is of type B, then also x2 must be exactly of type B? If x1 is of type A then also x2 must be exactly of type A?

1

There are 1 answers

1
Mark On BEST ANSWER

This was a fun question - at first I considered solving it in the following way:

from typing import overload
import dataclasses


@dataclasses.dataclass
class A:
    pass


@dataclasses.dataclass
class B(A):
    pass


@overload
def fun(x1: B, x2: B) -> int:
    ...


@overload
def fun(x1: A, x2: A) -> int:
    ...


def fun(
    x1: A | B,
    x2: A | B,
) -> int:
    if type(x1) != type(x2):
        raise TypeError("must be same type!")

    if type(x1) == A:
        return 5

    elif type(x1) == B:
        return 10
    else:
        raise TypeError("Type not handled")


fun(x1=A(), x2=A())
fun(x1=B(), x2=B())
fun(x1=B(), x2=A())
fun(x1=A(), x2=B())

I thought this might be a quirk of the way TypeVar works initially, but I discovered that even if we specify in an overload, that it must be A, A or B, B then it still won't raise an error on the final two lines. The last two just use the overload A, A, because A, B is still a subtype of A, A. Python does not distinguish at all between direct instances and subtypes - you could enforce a structural type with a Protocol, so long as there was a structural difference between A and B.

Even if you made the function arg a list[A], B's would still be valid in the list for this reason.

If you're trying to get B to pick to up all the attributes of A, and have them be distinct types, I would instead do it this way, with A being now a hidden base class, and A2 exposed to an end user:

from typing import TypeVar
import dataclasses


@dataclasses.dataclass
class A:
    pass


@dataclasses.dataclass
class B(A):
    pass

@dataclasses.dataclass
class A2(A):
    # Note, this class would be empty in practice as well
    pass



T = TypeVar("T", A2, B)

def fun(
    x1: T,
    x2: T,
) -> int:
    if type(x1) != type(x2):
        raise TypeError("must be same type!")

    if type(x1) == A2:
        return 5

    elif type(x1) == B:
        return 10
    else:
        raise TypeError("Type not handled")


fun(x1=A2(), x2=A2())  # OK
fun(x1=B(), x2=B())  # OK
fun(x1=B(), x2=A2())  # Will throw TypeError, how can I get mypy to say this is an error?
fun(x1=A2(), x2=B())  # Will throw TypeError, how can I get mypy to say this is an error?

Hope this is useful!