Typing a class decorator

111 views Asked by At

I'm writing some micropython code (micropython doesn't have dataclasses) to have some generic class that can be serialized and deserialized. I wrote a decorator function that decorates a class with these 2 functions:

import pickle
from typing import Any
from abc import ABCMeta, abstractmethod

class Serializable(ABCMeta):
    @abstractmethod
    def serialize(self) -> str:
        pass

    @staticmethod
    @abstractmethod
    def deserialize(data: str) -> 'Serializable':
        pass


def serializable(cls):
    if not hasattr(cls, 'fields'):
        raise ValueError('fields is required')
    def __init__(self, **kwargs: Any) -> None:
        for fname, _ in cls.fields:
            if fname not in kwargs:
                raise ValueError(f'{fname} is required')
            setattr(self, fname, kwargs.get(fname))
    cls.__init__ = __init__

    def serialize(self: Serializable) -> str:
        return f"{cls.__name__} {' '.join(str(getattr(self, fname)) for fname, _ in cls.fields)}"
    cls.serialize = serialize

    def deserialize(data: str) -> Serializable:
        parts = data.split(' ')
        if parts[0] != cls.__name__:
            raise ValueError(f'invalid data not a {cls.__name__}')
        if len(parts) - 1 != len(cls.fields):
            raise ValueError('invalid data not enough fields')
        return cls(**{fname: ftype(parts[i+1]) for i, (fname, ftype) in enumerate(cls.fields)})
    cls.deserialize = deserialize

    return cls

Since we use strict type checking using mypy in the project, it will currently complain that some functions (like the decorator) are not typechecked. Furthermore it will also complain when using the class like so:

@serializable
class Test:
    fields = [
        ('a', int),
        ('b', str)
    ]

x = Test(a=1, b='hello')
print(x.a)

How do I get mypy to typecheck this correctly?

1

There are 1 answers

2
Mark On

As noted in the comments, dataclass_transform is what you're looking for, but you'll need to use __annotations__ instead of the custom fields in your example. The good news is, the syntax is actually easier:

from typing import Any, dataclass_transform
from abc import ABCMeta, abstractmethod

class Serializable(ABCMeta):
    @abstractmethod
    def serialize(self) -> str:
        pass

    @staticmethod
    @abstractmethod
    def deserialize(data: str) -> 'Serializable':
        pass

@dataclass_transform()
def serializable(cls):
    def __init__(self, **kwargs: Any) -> None:
        for fname, _ in cls.__annotations__.items():
            if fname not in kwargs:
                raise ValueError(f'{fname} is required')
            setattr(self, fname, kwargs[fname])
    cls.__init__ = __init__

    def serialize(self: Serializable) -> str:
        return f"{cls.__name__} {' '.join(str(getattr(self, fname)) for fname, _ in cls.fields)}"
    cls.serialize = serialize

    def deserialize(data: str) -> Serializable:
        parts = data.split(' ')
        if parts[0] != cls.__name__:
            raise ValueError(f'invalid data not a {cls.__name__}')
        if len(parts) - 1 != len(cls.fields):
            raise ValueError('invalid data not enough fields')
        return cls(**{fname: ftype(parts[i+1]) for i, (fname, ftype) in enumerate(cls.fields)})
    cls.deserialize = deserialize

    return cls

@serializable
class Test:
    a: int
    b: str

x = Test(a=1, b='hello')
print(x.a)