Type checking an optional attribute whose value is related to that of another attribute

75 views Asked by At

I have a function which computes a Result and this computation can either be successful or not. In case it was successful, some data that summarizes the result of the computation will be returned as well. In case it was unsuccessful, this data will be None. Now the problem is that even though I verify the status of the computation (success), the type checker (mypy) cannot infer the coupling between success and data. This is summarized by the following code:

from dataclasses import dataclass
from typing import Optional


@dataclass
class Result:
    success: bool
    data: Optional[int]  # This is not None if `success` is True.


def compute(inputs: str) -> Result:
    if inputs.startswith('!'):  # Oops, some condition that prevents the computation.
        return Result(success=False, data=None)
    return Result(success=True, data=len(inputs))


def check(inputs: str) -> bool:
    return (result := compute(inputs)).success and result.data > 2


assert check('123')
assert not check('12')
assert not check('!123')

Running mypy against this code gives the following error:

test.py:18: error: Unsupported operand types for < ("int" and "None")  [operator]
test.py:18: note: Left operand is of type "Optional[int]"

I considered the following solutions, but I'm not really happy with either of them. So, I'm wondering if there's a better way to solve this.

typing.cast

The function check could be modified to use cast(int, result.data) to enforce the logical relationship between success and data. However, having to resort to cast feels like a sign of something being wrong with the code (in this case, at least). Also, I would have to use cast each time this relationship is used in the code. It would be better to solve it in one place.

Check result.data is not None

In the above example, the relationship between success and data is quite simple: success == data is not None. So, I could remove the attribute success altogether and, instead, check for result.data is not None.

def check(inputs: str) -> bool:
    return (result := compute(inputs)).data is not None and result.data > 2

While this works, the real use case is more complex and there are various data fields, e.g., data_x, data_y, and data_z. In this case, success == all(d is not None for d in [data_x, data_y, data_z]). Using this as a check is too verbose and so I would refactor it into a property of the Result class. For the above example this would be:

@dataclass
class Result:
    data: Optional[int]

    @property
    def success(self) -> bool:
        return self.data is not None

However, when the is not None check is moved into a property, mypy cannot infer anymore that result.data really is not None when result.success is True.

1

There are 1 answers

5
joel On BEST ANSWER

Your initial question is extremely close to this other one, but the overall question is not just about that, so here goes.

Here's one option of several.

The fact that you have success == all(d is not None for d in [data_x, data_y, data_z]) in your general case suggests to me that what you really want is a simple Result type, and to compose them. There's a very well established pattern for this in Haskell/Rust etc, though it's usually called Maybe or Option there. In Python this would look like

@dataclass
class Success[T]:
    data: T

class Fail:
    pass

Result[T] = Success[T] | Fail

and you'd use it like so

def compute(inputs: str) -> Result[int]:
    if inputs.startswith('!'):  # Oops, some condition that prevents the computation.
        return Fail()
    return Success(len(inputs))

now it's not clear to me what role check would play here, as it depends on your needs. The simplest way you could do it is

def check(inputs: str) -> bool:
    match compute(inputs):
        case Success(x):
            return x > 2
        case Fail():
            return False

but you might find functions like

def is_success[T](r: Result[T]) -> bool:
    return isinstance(r, Success)

def map[T, U](result: Result[T], f: Callable[[T], U]) -> Result[U]:
    match result:
        case Success(x):
            return Success(f(x))
        case Fail:
            return Fail()

useful more generally, in which case you might do

def check(inputs: str) -> bool:
    return is_success(map(compute(inputs), lambda data: result.data > 2))

That's not very pretty, partly because I don't really know what check is for.

When you have multiple values, you could combine multiple functions with a combinator like

def map2[T, U, V](r0: Result[T], r1: Result[U], f: Callable[[T, U], V]) -> Result[V]:
    match (r0, r1):
        case (Success(x0), Success(x1)):
            return Success(f(x0, x1))
        case _:
            return Fail()

@dataclass
class TwoThings:
    data0: int
    data1: int

hopefully_two_things: Result[TwoThings] = map2(compute("foo"), compute("bar"), TwoThings)