DRF serializer validation of inputs across multiple nested fields

46 views Asked by At

I'm using DRF serializers for models with measurements with units. Here's my model:

class Cargo(models.Model):
    length = models.FloatField()
    height = models.FloatField()
    weight = models.FloatField()
    input_length_unit = models.IntegerField(choices=LengthUnits.choices)
    input_weight_unit = models.IntegerField(choices=WeightUnits.choices)

I have the following serializers to convert data like {"length": {"value": 10, "unit": 1}, ...} to my model schema:

class FloatWithUnitSerializer(serializers.Serializer):
    value = serializers.FloatField()
    unit = serializers.IntegerField()

    def __init__(self, unit_type, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.unit_type = unit_type

    def to_representation(self, instance):
        value = getattr(instance, self.field_name)
        unit = getattr(instance, f"input_{self.unit_type}_unit")
        # convert from database unit to input unit
        value = value * UNITS[self.unit_type][unit]

        return {"value": value, "unit": unit}

    def to_internal_value(self, data):
        # convert from input unit to database unit
        value = data["value"] / UNITS[self.unit_type][data["unit"]]

        return {self.field_name: value, f"input_{self.unit_type}_unit": data["unit"]}

class CargoSerializer(serializers.ModelSerializer):
    length = FloatWithUnitSerializer("length", required=False, source="*")
    height = FloatWithUnitSerializer("length", required=False, source="*")
    weight = FloatWithUnitSerializer("weight", required=False, source="*")

    class Meta:
        model = models.Cargo
        fields = ["length", "height", "weight", "input_length_unit", "input_weight_unit"]

This works, but now I want to prevent unit-mixing. For example given the unit type "length", which includes the "length" and "height" fields (units are meters, feet, etc), they need to post the same unit for both, e.g. {"length": {"value": 10, "unit": 1}, "height": {"value:15", "unit": 1}}. If they pass 2 fields of the same unit type with different units, e.g. {"length": {"value": 10, "unit": 1}, "height": {"value:15", "unit": 2}}, I want to raise a ValidationError. I will be adding other unit types like area and volume too. I can't validate this inside the FloatWithUnitSerializer since I only have data for a single field there, and I'm not sure how I would validate this in the CargoSerializer - by the time my validate function is called, FloatWithUnitSerializer.to_internal_value is called so I just have the length, height, input_length_unit, etc fields, and I don't know if there were multiple length units passed. How can I validate this, or is there a simpler way I could structure this? Thanks.

1

There are 1 answers

0
Luciano On

I ended up overriding to_internal_value like so:

class CargoSerializer(serializers.ModelSerializer):
    length = FloatWithUnitSerializer("length", required=False, source="*")
    height = FloatWithUnitSerializer("length", required=False, source="*")
    weight = FloatWithUnitSerializer("weight", required=False, source="*")

    class Meta:
        model = models.Cargo
        fields = ["length", "height", "weight", "input_length_unit", "input_weight_unit"]

    def to_internal_value(self, data):
        # Note the only change from Django's default to_internal_value is called out below.
        if not isinstance(data, Mapping):
            message = self.error_messages["invalid"].format(datatype=type(data).__name__)
            raise ValidationError({api_settings.NON_FIELD_ERRORS_KEY: [message]}, code="invalid")

        ret = OrderedDict()
        errors = OrderedDict()
        fields = self._writable_fields

        for field in fields:
            validate_method = getattr(self, "validate_" + field.field_name, None)
            primitive_value = field.get_value(data)
            try:
                validated_value = field.run_validation(primitive_value)
                if validate_method is not None:
                    validated_value = validate_method(validated_value)
            except ValidationError as exc:
                errors[field.field_name] = exc.detail
            except DjangoValidationError as exc:
                errors[field.field_name] = get_error_detail(exc)
            except SkipField:
                pass
            else:
                # The only change: Don't allow multiple units per unit type
                if isinstance(field, FloatWithUnitSerializer):
                    unit_field_name, unit = next(
                        (k, v) for k, v in validated_value.items() if k.startswith("input_")
                    )

                    current_unit = ret.get(unit_field_name)

                    if current_unit is not None and current_unit != unit:
                        errors[api_settings.NON_FIELD_ERRORS_KEY] = "Received mixed units."
                set_value(ret, field.source_attrs, validated_value)

        if errors:
            raise ValidationError(errors)

        return ret