Searching a SortedList of a class

40 views Asked by At

I have a SortedList of a class, such as...

from typing import List
from sortedcontainers import SortedList

class Supplier:
    def __init__(self, name: str, id: int, sap_id: int):
        self.Name = name
        self.Id = id
        self.SapId = sap_id

class Data:
    def __init__(self):
        self.suppliers = SortedList(key=lambda x: x.Name.lower())

I want to be able to search through the SortedList based on the supplier name

Such as...

    # Part of the Data class
    def find_supplier(self, name:str):
        index = self.suppliers.bisect_left(name)
        if index != len(self.suppliers) and self.suppliers[index].Name.lower() == name.lower():
            return self.suppliers[index]
        
        return None

However this does not work as the bisect_left passes a str, and it is expected a Supplier. I can fix this by creating a temporary supplier and adding the name to it, and then searching that way, such as...

    # Part of the Data class
    def find_supplier(self, name:str):
        temporary_supplier = Supplier(name, 0, 0)
        index = self.suppliers.bisect_left(temporary_supplier)
        if index != len(self.suppliers) and self.suppliers[index].Name.lower() == name.lower():
            return self.suppliers[index]
        
        return None

However it feels like an ugly way of doing it. Is there another option that does not rely on me creating my own binary search function?

1

There are 1 answers

0
blhsing On BEST ANSWER

Instead of using a custom key function for SortedList or creating a temporary Supplier when calling bisect_left, a cleaner approach may be to make Supplier comparable by defining one of the rich comparison methods and making it accept a string as the other operand:

class Supplier:
    def __init__(self, name: str):
        self.Name = name

    def __repr__(self):
        return self.Name

    def __lt__(self, other):
        if isinstance(other, str):
            return self.Name < other
        return self.Name < other.Name

so that:

class Data:
    def __init__(self):
        self.suppliers = SortedList()

    def find_supplier(self, name: str):
        index = self.suppliers.bisect_left(name)
        if index != len(self.suppliers) and self.suppliers[index].Name.lower() == name.lower():
            return self.suppliers[index]

d = Data()
d.suppliers.update([Supplier('B'), Supplier('A')])
print(d.suppliers)
print(d.find_supplier('A'))

outputs:

SortedList([A, B])
A