How to get all descendants of a node with Django treebeard?

523 views Asked by At

Let's say I have these Models:

class Category(MP_Node):
    name = models.CharField(max_length=30)

class Item(models.Model):
    category = models.ForeignKey(Category)

and I would like to find all Items belonging to any descendant of a given Category.

Usually I would write category.item_set but this is just Items belonging to the given level of the hierarchy.

Using the example tree in the treebeard tutorial, if an Item belongs to "Laptop Memory", how would I find all Items belonging to descendants "Computer Hardware" where "Laptop Memory" is one of those descendants?

3

There are 3 answers

0
Super Kai - Kazuya Ito On

For example, you can get all descendants of a category with get_descendants() in views.py as shown below. *You can see my answer explaining how to get all descendants of a category including the model instance itself with get_descendants(include_self=True):

# "views.py"

from django.http import HttpResponse
from .models import Category, Product

def test(request):
    categories = Category.objects.get(name="Food").get_descendants()
    print(categories) 
    # <TreeQuerySet [<Category: Meat>, <Category: Fish>]>

    products = Product.objects.filter(category__in=categories)
    print(products)
    # <QuerySet [<Product: Beef>, <Product: Pork>, <Product: Salmon>

    return HttpResponse("Test")
0
Simone On

I looked at the treebeard code to see how it gets descendants of a node. We can apply the same filters as a related field lookup.

paramcat = Category.objects.get(id=1) # how you actually get the category will depend on your application
#all items associated with this category OR its descendants:
items = Item.objects.filter(category__tree_id=paramcat.tree_id, category__lft__range=(paramcat.lft,paramcat.rgt-1))

I think using intermediate calls like get_descendants will result in one query per descendant, plus load all descendants into memory. It defeats the purpose of using treebeard in the first place

I'd be interested in seeing a custom lookup based on this code, I'm not sure how to do it...

0
Massa On

I just had the same problem and figured out how to do it (consider it inside the function get_queryset of a ListView):

category = Category.objects.filter(slug=self.kwargs['category']).get()
descendants = list(category.get_descendants().all())
return self.model.objects.select_related('category').filter(category__in=descendants+[category, ])

Another option that I came up with was using a filter with 'OR':

from django.db.models import Q

category = Category.objects.filter(slug=self.kwargs['category']).get()
descendants = list(category.get_descendants().all())
return self.model.objects.select_related('category').filter(Q(category__in=category.get_descendants()) | Q(category=category))