I have some structures that are very expensive to compare. (They are actually trees with distinct branches.) Computing hash values for them is also expensive.
I want to create a decorator for the eq operator that will cache some results to speed things up. This is somewhat similar to memoization.
In particular, I want something like this to happen. Suppose we have 3 objects: A, B and C. We compare A with B. The eq operator gets called, returns True, the result gets stored. We compare B with C. The eq operator gets called as before. Now we compare A with C. Now the algorithm should detect that A is equal to B and B is equal to C, so it should return that A is equal to C without invoking the costly eq operator.
I wanted to use the union-find algorithm, but it only allows for caching equalities, and doesn't allow to cache inequalities.
Suppose that we have 2 objects equal to each other: A and B. Suppose also that we have another pair of equal objects: C and D. The union-find algorithm will correctly group them into two categories (A, B) and (C, D). Now suppose A is not equal to C. My algorithm should cache it somehow and prevent the eq operator further from running on pairs (A, C), (B, C), (A, D), (B, D), since we can deduce all these pairs are unequal. Union-find does not allow that. It only saves the positive equalities, failing miserably when we have to compare many unequal objects.
My current solution is something like that:
def optimize(original_eq):
def optimized_eq(first, second):
if first is second: return True
if hash(first) != hash(second): return False
if cache.find(first) == cache.find(second): return True
result = original_eq(first, second)
if result:
cache.union(first, second)
else:
pass # no idea how to save the negative result
return result
return optimized_eq
This solution would be OK if the hash function was easy to compute, but it isn't. We would be invoking cache.find on objects that are highly likely to be equal, so we would rarely need to call the original equality operator. But, as I said, the hash function is very slow on my trees (it basically needs to traverse all the tree, comparing branches on each node to remove duplicates), so I want to remove it. I want to cache the negative results instead.
Does anyone know a good solution to this problem? I need to cache not only the positive comparison results, but also negative ones.
UPDATE:
My current solutions that works for me follows:
def memoize_hash_and_eq(cls):
"This decorator should be applied to the class."
def union(key1, key2):
nonlocal union_find
if key1 is not key2:
key1_leader = union_find(key1)
key2_leader = union_find(key2)
key1_leader._memoize_hash_and_eq__leader = key2_leader
try:
key2_leader._memoize_hash_and_eq__splits = key1_leader._memoize_hash_and_eq__splits
del key1_leader._memoize_hash_and_eq__splits
except AttributeError:
pass
def union_find(key):
leader = key
while True:
try:
leader = leader._memoize_hash_and_eq__leader
except AttributeError:
break
if leader is not key:
key._memoize_hash_and_eq__leader = leader
try:
leader.__splits = key._memoize_hash_and_eq__splits
del key._memoize_hash_and_eq__splits
except AttributeError:
pass
return leader
def split(key1, key2):
nonlocal union_find
key1_leader = union_find(key1)
key2_leader = union_find(key2)
try:
key1_leader._memoize_hash_and_eq__splits.add(key2_leader)
except AttributeError:
try:
key2_leader._memoize_hash_and_eq__splits.add(key1_leader)
except AttributeError:
try:
key1_leader._memoize_hash_and_eq__splits = set()
key1_leader._memoize_hash_and_eq__splits.add(key2_leader)
except (AttributeError, TypeError):
pass
def split_find(key1, key2):
nonlocal union_find
key1_leader = union_find(key1)
key2_leader = union_find(key2)
try:
split_leaders = key2_leader._memoize_hash_and_eq__splits
for k in [_k for _k in split_leaders]:
split_leaders.add(union_find(k))
if key1_leader in split_leaders:
return True
except (AttributeError, TypeError):
pass
try:
split_leaders = key1_leader._memoize_hash_and_eq__splits
for k in [_k for _k in split_leaders]:
split_leaders.add(union_find(k))
if key2_leader in split_leaders:
return True
except (AttributeError, TypeError):
pass
return False
def memoized_hash(self):
return original_hash(union_find(self))
original_hash = cls.__hash__
cls.__hash__ = memoized_hash
def memoized_equivalence(self, other):
if self is other:
return True
if union_find(self) is union_find(other):
return True
if split_find(self, other):
return False
result = original_equivalence(self, other)
if result is NotImplemented:
return result
elif result:
union(self, other)
else:
split(self, other)
return result
original_equivalence = cls.__eq__
cls.__eq__ = memoized_equivalence
return cls
This speeds up both eq and hash.
This isn't a very pretty solution, but how about you store, for every leader of an equivalence class (ie a root in the Union Find structure), a binary search tree containing at least(see below) all the leaders that it is definitely unequal to.
To query
x ?= y
: as usual, you'd find the leaders of both of them and see if they're equal. If they're not equal, find one of the leaders in the BST of the other. If present,x
andy
are definitely unequal.To union two equivalence classes
x
andy
: merge the BSTs of their leaders and set that as the BST of the new leader of the union ofx
andy
. Nodes that enter one of the BSTs and later become non-leader are never removed from the BSTs, but that's not a huge problem - it won't cause any queries to return the wrong result, it just wastes some space (but never a lot).