Select an item from a list of object of any type when using tensorflow 2.x

532 views Asked by At

Given a list of instances of class A, [A() for _ in range(5)], I want to randomly select one of them (see the following code for an example)

class A:
    def __init__(self, a):
        self.a = a
    def __call__(self):
        return self.a
def f():
    a_list = [A(i) for i in range(5)]
    a = a_list[random.randint(0, 5)]()
    return a

f()

Is there is a way to decorate f with @tf.function without changing what f does and without calling all items in a_list?

Note that directly decorating f with @tf.function without any other changing to the above code is infeasible as it will always return the same result. Also, I know that this can be achieved by calling all elements in a_list first and then index them using tf.gather_nd. But this will incur a large amount of overhead if calling an object of type A involves a deep neural network.

1

There are 1 answers

1
Michael Potter On BEST ANSWER

I'm working on the same thing at the moment. Here's what I've got so far. If anyone knows a better way I'd be interested to hear it too. When I run it on an expensive call it is appropriately faster than if I compute and return all of the values.

@tf.function
def f2():
    a_list = [A(i) for i in range(5)]
    idx = tf.cast(tf.random.uniform(shape=[], maxval=4), tf.int32)
    return tf.switch_case(idx, a_list)

For a speed comparison I made the call method of A expensive matrix algebra. Then consider an alternate function which invokes every function:

@tf.function
def f3():
    a_list = [A(i) for i in range(40)]
    results = [a() for a in a_list]
    return results

Running f2 with 40 elements: 0.42643 seconds

Running f3 with 40 elements: 14.9153 seconds

So that looks to be right about exactly the expected 40x speedup for only choosing one branch.