Correctly type hint different callback functions for the class

122 views Asked by At

I'm trying to implement a class that would be able to use two types of callbacks:

  • callback with integer parameter named exactly as I want it to be named (to force API users to use certain parameter name while implementing these callbacks)
  • callback without any parameters.

The needed callback will be chosen by my class based on criteria passed to one of the instance methods.

For this sake, my first intention was to use protocols. Here is what I came to:

from typing import NotRequired, Protocol, TypedDict, Unpack, runtime_checkable


class ActionPerformer(TypedDict):
    action_performer_id: NotRequired[int]


@runtime_checkable
class PerformableAction(Protocol):
    def __call__(self, action_performer_id: int) -> list[str]: ...


class NonPerformableAction(Protocol):
    def __call__(self) -> list[str]: ...


class MyClass:
    def __init__(self, data_loader_callback: PerformableAction | NonPerformableAction):
        self._data_loader_callback = data_loader_callback

    def run(self, **kwargs: Unpack[ActionPerformer]):
        data = self._load_data(**kwargs)
        print(f"{data=}")

    def _load_data(self, **kwargs: Unpack[ActionPerformer]) -> list[str]:
        if isinstance(self._data_loader_callback, PerformableAction):
            print("calling PerformableAction...")
            return self._data_loader_callback(**kwargs)
        else:
            print("calling NonPerformableAction...")
            return self._data_loader_callback()


def parametrized_data_loader(action_performer_id: int) -> list[str]:
    print(f"Parametrized call with param {action_performer_id=}")
    return ["data", "from", "parametrized", "call"]


def non_parametrized_data_loader() -> list[str]:
    print("Non-parametrized call")
    return ["data", "from", "non-parametrized", "call"]


my_class_instance_1 = MyClass(data_loader_callback=parametrized_data_loader)
my_class_instance_1.run(action_performer_id=123)

print()

my_class_instance_2 = MyClass(data_loader_callback=non_parametrized_data_loader)
my_class_instance_2.run()

Looks like I considered everything:

  • I specified my intention to use different types of callbacks using protocols
  • data loader callback type hinted using these protocols
  • I pass optional kwarg to the callback
  • I try to satisfy static type checkers (mypy/pyright) using protocol instance check and calling specific callback for the specific situation

The truth though is that at runtime the first branch will always be used:

calling PerformableAction...
Parametrized call with param action_performer_id=123
data=['data', 'from', 'parametrized', 'call']

calling PerformableAction...
Non-parametrized call
data=['data', 'from', 'non-parametrized', 'call']

so this part seems a necessary hack just to make static type checkers happy:

        if isinstance(self._data_loader_callback, PerformableAction):
            print("calling PerformableAction...")
            return self._data_loader_callback(**kwargs)
        else:
            print("calling NonPerformableAction...")
            return self._data_loader_callback()

If I change this check to anything else, e.g.:

return self._data_loader_callback(**kwargs)  # this it the 26th line

then static type checkers start blaming:

mypy proto.py --enable-incomplete-feature=Unpack
proto.py: note: In member "_load_data" of class "MyClass":
proto.py:26: error: Extra argument "action_performer_id" from **args for "__call__" of "NonPerformableAction"  [misc]
Found 1 error in 1 file (checked 1 source file)

pyright proto.py
proto.py
  proto.py:26:45 - error: Unable to match unpacked TypedDict argument to parameters
    No parameter named "action_performer_id" (reportGeneralTypeIssues)
1 error, 0 warnings, 0 informations

I believe that there is a solution for my purpose, which will be simple and straightforward, but I still can't figure it out. Can you help me, please?

1

There are 1 answers

0
dbzix On

At the moment, I have a lot more useful solution:

from typing import (
    Callable,
    Optional,
    TypeAlias,
    overload,
)

PerformableAction: TypeAlias = Callable[[int], list[str]]
NonPerformableAction: TypeAlias = Callable[[], list[str]]


class MyClass:
    def __init__(self, data_loader_callback: PerformableAction | NonPerformableAction):
        self._data_loader_callback = data_loader_callback

    @overload
    def run(self) -> None: ...

    @overload
    def run(self, action_performer_id: int) -> None: ...

    def run(self, action_performer_id: Optional[int] = None) -> None:
        data = self._load_data(action_performer_id)
        print(f"{data=}")

    def _load_data(self, action_performer_id: Optional[int] = None) -> list[str]:
        if action_performer_id is not None:
            print(f"calling ParametrizedAction with param {action_performer_id}")
            return self._data_loader_callback(action_performer_id)  # type: ignore
        else:
            print("calling NonParametrizedAction...")
            return self._data_loader_callback()  # type: ignore


def parametrized_data_loader(action_performer_id: int) -> list[str]:
    print(f"Parametrized call with param {action_performer_id=}")
    return ["data", "from", "parametrized", "call"]


def non_parametrized_data_loader() -> list[str]:
    print("Non-parametrized call")
    return ["data", "from", "non-parametrized", "call"]


my_class_instance_1 = MyClass(data_loader_callback=parametrized_data_loader)
my_class_instance_1.run(action_performer_id=123)

print()

my_class_instance_2 = MyClass(data_loader_callback=non_parametrized_data_loader)
my_class_instance_2.run()

Pros:

  • API user sees possible signatures for callback functions
  • API user knows possible ways to call run(...)
  • I expose the desired parameter name to run(...)

Cons:

  • I can't force user to use desired parameter name for callback with parameter
  • I have to silence type checkers with # type: ignore in places where they're unable to infer the necessity of the parameter.