`Pydantic` `self.model_dump()` doesn't pick up subclass instances passed to a field

23 views Asked by At

I am trying to use a custom repr method to make a readable output of the object. I want to type a field of the super class, so all subclass instances passed to the field are validated correctly. Although, the self.model_dump() in the repr method doesnt seem to work unless I have correctly and explicitly typed the same class in the definition of the field.

WORKING CODE (explicitly typed field)

from pydantic import BaseModel, Field
from pydantic.config import ConfigDict


class QueryParams(BaseModel):
    pass


class subQueryParams(QueryParams):
    test: str = "test"


class YourModel(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)
    command_params: subQueryParams = Field()

    def __repr__(self) -> str:
        """Human readable representation of the object."""
        items = [
            f"{k}: {v}"[:83] + ("..." if len(f"{k}: {v}") > 83 else "")
            for k, v in self.model_dump().items()
        ]
        return f"{self.__class__.__name__}\n\n" + "\n".join(items)

YourModel(command_params=subQueryParams())

Returns:

YourModel

command_params: {'test': 'test'}

BUG CODE (changed the field type to a super class)

from pydantic import BaseModel, Field
from pydantic.config import ConfigDict


class QueryParams(BaseModel):
    pass


class subQueryParams(QueryParams):
    test: str = "test"


class YourModel(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)
    command_params: QueryParams = Field()

    def __repr__(self) -> str:
        """Human readable representation of the object."""
        items = [
            f"{k}: {v}"[:83] + ("..." if len(f"{k}: {v}") > 83 else "")
            for k, v in self.model_dump().items()
        ]
        return f"{self.__class__.__name__}\n\n" + "\n".join(items)


YourModel(command_params=subQueryParams())

Returns

YourModel

command_params: {}

How can I make use of typing a super class, while having the desired outut form the first code example?

Hacky Solution:

    def __repr__(self) -> str:
        """Human readable representation of the object."""
        items = [
            f"{k}: {v}"[:83] + ("..." if len(f"{k}: {v}") > 83 else "")
            for k, v in self.model_dump().items()
        ]

        # Needed to extract subclass items
        if self.command_params:
            add_item = self.command_params.model_dump()
        for i, item in enumerate(items):
            if item.startswith('command_params:'):
                items[i] = f'command_params: {add_item}'
                break  # Assuming only one item with 'command_params:', stop after updating


        return f"{self.__class__.__name__}\n\n" + "\n".join(items)
1

There are 1 answers

0
Axel Donath On

This does not work in the simple way you suggested. How should Pydantic know which class to instantiate if only the base class is annotated? However you can work with a Union:

from pydantic import BaseModel, Field
from pydantic.config import ConfigDict
from typing import Union


class QueryParams(BaseModel):
    pass


class subQueryParams(QueryParams):
    test: str = "test"


class YourModel(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)
    command_params: Union[QueryParams, subQueryParams] = Field()

    def __repr__(self) -> str:
        """Human readable representation of the object."""
        items = [
            f"{k}: {v}"[:83] + ("..." if len(f"{k}: {v}") > 83 else "")
            for k, v in self.model_dump().items()
        ]
        return f"{self.__class__.__name__}\n\n" + "\n".join(items)


YourModel(command_params=subQueryParams())

Which prints:

YourModel

command_params: {'test': 'test'}

Or even better you could work with a discriminated union. See https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-str-discriminators

I hope this helps!