Given a list of instances of class A
, [A() for _ in range(5)]
, I want to randomly select one of them (see the following code for an example)
class A:
def __init__(self, a):
self.a = a
def __call__(self):
return self.a
def f():
a_list = [A(i) for i in range(5)]
a = a_list[random.randint(0, 5)]()
return a
f()
Is there is a way to decorate f
with @tf.function
without changing what f
does and without calling all items in a_list
?
Note that directly decorating f
with @tf.function
without any other changing to the above code is infeasible as it will always return the same result. Also, I know that this can be achieved by calling all elements in a_list
first and then index them using tf.gather_nd
. But this will incur a large amount of overhead if calling an object of type A
involves a deep neural network.
I'm working on the same thing at the moment. Here's what I've got so far. If anyone knows a better way I'd be interested to hear it too. When I run it on an expensive call it is appropriately faster than if I compute and return all of the values.
For a speed comparison I made the call method of A expensive matrix algebra. Then consider an alternate function which invokes every function:
Running f2 with 40 elements: 0.42643 seconds
Running f3 with 40 elements: 14.9153 seconds
So that looks to be right about exactly the expected 40x speedup for only choosing one branch.