Serve Tensorflow models in parallel with Ray

881 views Asked by At

I was looking at this StackOverflow thread on using ray.serve to have a saved TF model predict in parallel: https://stackoverflow.com/a/62459372

I tried something similar with the following:

import ray
from ray import serve; serve.init()
import tensorflow as tf

class A:
    def __init__(self):
        self.model = tf.constant(1.0) # dummy example

   @serve.accept_batch
    def __call__(self, *, input_data=None):
        print(input_data) # test if method is entered
        # do stuff, serve model

if __name__ == '__main__':
    serve.create_backend("tf", A,
        # configure resources
        ray_actor_options={"num_cpus": 2},
        # configure replicas
        config={
            "num_replicas": 2, 
            "max_batch_size": 24,
            "batch_wait_timeout": 0.1
        }
    )
    serve.create_endpoint("tf", backend="tf")
    handle = serve.get_handle("tf")

    args = [1,2,3]

    futures = [handle.remote(input_data=i) for i in args]
    result = ray.get(futures)

However, I get the following error: TypeError: __call__() takes 1 positional argument but 2 positional arguments (and 1 keyword-only argument) were given. There's something wrong with the arguments passed into __call__.

This seems like a simple mistake, how should I change the args array so that the __call__ method is actually entered?

1

There are 1 answers

1
Simon Mo On BEST ANSWER

The API for Ray 1.0 is updated. Please see the migration guide https://gist.github.com/simon-mo/6d23dfed729457313137aef6cfbc7b54

For the specific code sample you posted, you can updated it to:

import ray
from ray import serve
import tensorflow as tf

class A:
    def __init__(self):
        self.model = tf.Constant(1.0) # dummy example

   @serve.accept_batch
    def __call__(self, requests):
        for req in requests:
            print(req.data) # test if method is entered
        
        # do stuff, serve model

if __name__ == '__main__':
    client = serve.start()
    client.create_backend("tf", A,
        # configure resources
        ray_actor_options={"num_cpus": 2},
        # configure replicas
        config={
            "num_replicas": 2, 
            "max_batch_size": 24,
            "batch_wait_timeout": 0.1
        }
    )
    client.create_endpoint("tf", backend="tf")
    handle = client.get_handle("tf")

    args = [1,2,3]

    futures = [handle.remote(i) for i in args]
    result = ray.get(futures)