Django MPTT efficiently serializing relational data with DRF

6.1k views Asked by At

I have a Category model that is a MPTT model. It is m2m to Group and I need to serialize the tree with related counts, imagine my Category tree is this:

Root (related to 1 group)
 - Branch (related to 2 groups) 
    - Leaf (related to 3 groups)
...

So the serialized output would look like this:

{ 
    id: 1, 
    name: 'root1', 
    full_name: 'root1',
    group_count: 6,
    children: [
    {
        id: 2,
        name: 'branch1',
        full_name: 'root1 - branch1',
        group_count: 5,
        children: [
        {
            id: 3,
            name: 'leaf1',
            full_name: 'root1 - branch1 - leaf1',
            group_count: 3,
            children: []
        }]
    }]
}

This is my current super inefficient implementation:

Model

class Category(MPTTModel):
    name = ...
    parent = ... (related_name='children')

    def get_full_name(self):
        names = self.get_ancestors(include_self=True).values('name')
        full_name = ' - '.join(map(lambda x: x['name'], names))
        return full_name

    def get_group_count(self):
        cats = self.get_descendants(include_self=True)
        return Group.objects.filter(categories__in=cats).count()

View

class CategoryViewSet(ModelViewSet):
    def list(self, request):
        tree = cache_tree_children(Category.objects.filter(level=0))
        serializer = CategorySerializer(tree, many=True)
        return Response(serializer.data)

Serializer

class RecursiveField(serializers.Serializer):
    def to_native(self, value):
        return self.parent.to_native(value)


class CategorySerializer(serializers.ModelSerializer):
    children = RecursiveField(many=True, required=False)
    full_name = serializers.Field(source='get_full_name')
    group_count = serializers.Field(source='get_group_count')

    class Meta:
        model = Category
        fields = ('id', 'name', 'children', 'full_name', 'group_count')

This works but also hits the DB with an insane number of queries, also there's additional relations, not just Group. Is there a way to make this efficient? How can I write my own serializer?

2

There are 2 answers

1
Kevin Brown-Silva On

You are definitely running into a N+1 query issue, which I have covered in detail in another Stack Overflow answer. I would recommend reading up on optimizing queries in Django, as this is a very common issue.

Now, Django MPTT also has a few problems that you are going to need to work around as far as N+1 queries. Both the self.get_ancestors and self.get_descendants methods create a new queryset, which in your case happens for every object that you are serializing. You may want to look into a better way to avoid these, I've described possible improvements below.

In your get_full_name method, you are calling self.get_ancestors in order to generate the chain that is being used. Considering you always have the parent when you are generating the output, you may benefit from moving this to a SerializerMethodField that reuses the parent object to generate the name. Something like the following may work:

class RecursiveField(serializers.Serializer):

    def to_native(self, value):
        return CategorySerializer(value, context={"parent": self.parent.object, "parent_serializer": self.parent})

class CategorySerializer(serializers.ModelSerializer):
    children = RecursiveField(many=True, required=False)
    full_name = SerializerMethodField("get_full_name")
    group_count = serializers.Field(source='get_group_count')

    class Meta:
        model = Category
        fields = ('id', 'name', 'children', 'full_name', 'group_count')

    def get_full_name(self, obj):
        name = obj.name

        if "parent" in self.context:
            parent = self.context["parent"]

            parent_name = self.context["parent_serializer"].get_full_name(parent)

            name = "%s - %s" % (parent_name, name, )

        return name

You may need to edit this code slightly, but the general idea is that you don't always need to get the ancestors because you will have the ancestor chain already.

This doesn't solve the Group queries, which you may not be able to optimize, but it should at least reduce your queries. Recursive queries are incredibly difficult to optimize, and they usually take a lot of planning to figure out how you can best get the required data without falling back to N+1 situations.

1
Petr DlouhĂ˝ On

I have found a solution for the counts. Thanks to django-mptts function get_cached_trees, you can do following:

from django.db.models import Count


class CategorySerializer(serializers.ModelSerializer):
    def get_group_count(self, obj, field=field):
        return obj.group_count

    class Meta:
        model = Category
        fields = [
            'name',
            'slug',
            'children',
            'group_count',
        ]

CategorySerializer._declared_fields['children'] = CategorySerializer(
    many=True,
    source='get_children',
)

class CategoryViewSet(ModelViewSet):
    serializer_class = CategorySerializer

    def get_queryset(self, queryset=None):
        queryset = Category.tree.annotate('group_count': Count('group')})
        queryset = queryset.get_cached_trees()
        return queryset

Where tree is mptts TreeManager, as used in django-categories, for which I have written slightly more complicated code to this PR: https://github.com/callowayproject/django-categories/pull/145/files