Are the ways to speed up ZMQ's recv_multipart()?

352 views Asked by At

I have multiple clients that send dicts of Numpy arrays to a ZMQ server. I managed to pack the dicts of Numpy arrays into a multi part message to avoid memcpy's during deserialization, which doubled the throughput.

However, the vast majority of the time is now spent in ZMQ's recv_multipart() function, which presumably also copies the data from the network interface to RAM. I'm wondering if there are any ways to further remove this second bottleneck?

For example, is the time spent for malloc of the new buffer to then copy the message into? In that case, is there a way to reuse buffers for receiving messages in ZMQ? Or is this just a fundamental limitation of going through TCP that cannot be optimized much further?

Total Samples 30400
GIL: 73.00%, Active: 73.00%, Threads: 1

  %Own   %Total  OwnTime  TotalTime  Function (filename:line)
 70.00%  70.00%   203.7s    203.7s   recv_multipart (zmq/sugar/socket.py:808)
  1.00%   1.00%    3.01s     4.13s   recv_multipart (zmq/sugar/socket.py:807)
  0.00%   0.00%    2.62s     2.62s   <listcomp> (zmq_gbs_dict_seq.py:37)
  0.00%   0.00%    2.49s     2.49s   send (zmq/sugar/socket.py:696)
  0.00%   0.00%    1.32s     1.32s   unpack (zmq_gbs_dict_seq.py:35)
  0.00%   0.00%   0.690s     1.22s   __call__ (enum.py:717)
  0.00%  72.00%   0.520s    209.9s   server (zmq_gbs_dict_seq.py:82)
  1.00%   1.00%   0.500s    0.840s   inner (typing.py:341)
  0.00%   0.00%   0.500s     5.32s   server (zmq_gbs_dict_seq.py:83)
  0.00%   1.00%   0.400s     1.33s   recv_multipart (zmq/sugar/socket.py:812)
  1.00%   1.00%   0.360s     3.07s   send_multipart (zmq/sugar/socket.py:751)
  0.00%   0.00%   0.350s    0.350s   __new__ (enum.py:1106)
  0.00%   0.00%   0.300s    0.300s   __hash__ (typing.py:1352)
  0.00%   0.00%   0.270s    0.270s   <genexpr> (zmq_gbs_dict_seq.py:93)
  0.00%   0.00%   0.260s    0.260s   server (zmq_gbs_dict_seq.py:101)
  0.00%   0.00%   0.250s    0.660s   server (zmq_gbs_dict_seq.py:92)
  0.00%   0.00%   0.250s     3.04s   unpack (zmq_gbs_dict_seq.py:36)
  0.00%   0.00%   0.210s    0.210s   unpack (zmq_gbs_dict_seq.py:38)
  0.00%   0.00%   0.210s    0.210s   server (zmq_gbs_dict_seq.py:91)
  0.00%   0.00%   0.200s    0.200s   unpack (zmq_gbs_dict_seq.py:39)
  0.00%   1.00%   0.200s     4.04s   server (zmq_gbs_dict_seq.py:99)
import multiprocessing
import pickle
import time

import numpy as np
import zmq


def client(port):
  socket = zmq.Context.instance().socket(zmq.DEALER)
  socket.set_hwm(0)
  socket.connect(f'tcp://localhost:{port}')
  data = {
      'foo': np.zeros((1024, 64, 64, 3), np.uint8),
      'bar': np.zeros((1024, 1024), np.float32),
      'baz': np.zeros((1024,), np.float32),
  }
  parts = pack(data)
  while True:
    socket.send_multipart(parts)
    msg = socket.recv()
    assert msg == b'done'
  socket.close()


def server(port):
  socket = zmq.Context.instance().socket(zmq.ROUTER)
  socket.set_hwm(0)
  socket.bind(f'tcp://*:{port}')
  time.sleep(3)
  print('Start')
  start = time.time()
  steps = 0
  nbytes = 0
  poller = zmq.Poller()
  poller.register(socket, zmq.POLLIN)
  while True:
    if poller.poll():
      addr, *parts = socket.recv_multipart(zmq.NOBLOCK)
      data = unpack(parts)
      steps += data['foo'].shape[0]
      nbytes += sum(v.nbytes for v in data.values())
      socket.send_multipart([addr, b'done'])
    duration = time.time() - start
    if duration > 1:
      fps = steps / duration
      gbs = (nbytes / 1024 / 1024 / 1024) / duration
      print(f'{fps/1e3:.2f}k fps {gbs:.2f} gb/s')
      start = time.time()
      steps = 0
      nbytes = 0
  socket.close()


def pack(data):
  dtypes, shapes, buffers = [], [], []
  items = sorted(data.items(), key=lambda x: x[0])
  keys, vals = zip(*items)
  dtypes = [v.dtype.name for v in vals]
  shapes = [v.shape for v in vals]
  buffers = [v.tobytes() for v in vals]
  meta = (keys, dtypes, shapes)
  parts = [pickle.dumps(meta), *buffers]
  return parts


def unpack(parts):
  meta, *buffers = parts
  keys, dtypes, shapes = pickle.loads(meta)
  vals = [
      np.frombuffer(b, d).reshape(s)
      for i, (d, s, b) in enumerate(zip(dtypes, shapes, buffers))]
  data = dict(zip(keys, vals))
  return data


def main():
  mp = multiprocessing.get_context('spawn')
  workers = []
  for _ in range(32):
    workers.append(mp.Process(target=client, args=(5555,)))
  workers.append(mp.Process(target=server, args=(5555,)))
  [x.start() for x in workers]
  [x.join() for x in workers]


if __name__ == '__main__':
  main()
1

There are 1 answers

0
Azmisov On

Setting copy=False on both send and receive will help. However, the "zero copy" feature doesn't really behave like you are thinking.

The way the library is setup, a ZMQ message object is constructed to be sent to the output socket (e.g. TCP if that's what you've chosen). Imagine you were working with a lower level language like C. You might serialize your data to a byte buffer; once serialization is complete, you copy your message over to the ZMQ message object to be sent to the socket. In many cases, that byte buffer will not be written to again, so we'd really like to tell ZMQ to just reference that buffer directly to avoid the copy to the message object.

ZMQ's "zero copy" serves this purpose. The zero copy tells ZMQ it can reference your program's raw byte data as the ZMQ message contents, rather than copying over to a separate ZMQ message buffer. Note however that ZMQ still copies data when going over TCP, or crossing between kernel/user OS memory boundaries; so the term can be somewhat misleading. I believe the only case you get a true end-to-end zero copy is when using the in-process (inproc) transport.

When using PyZMQ, you need to set the copy=False option to use the zero-copy feature. Your program is not doing so, so is actually not benefiting. While you didn't share your previous code, I believe the speedup you're seeing is simply because Numpy arrays use a binary representation, so even though they are getting copied, there is less data to send; previously they must have been being serialized as JSON or Pickle objects.

If you are manually managing a byte buffer in Python (not the case for Numpy), you would need to use PyZMQ's MessageTracker feature, which can notify you when ZMQ is finished reading from your buffer, so it can be freed/reused/etc.

PyZMQ also supports "zero copy" on the receiving side. I don't believe this is really standard ZMQ concept though. When you receive data, the message data needs to be copied out of the socket into a buffer. Depending on how the ZMQ receiving loop is designed, a library might pass that raw buffer, or make a transient copy the program can use. If you set copy=False on the receiving side for PyZMQ, it will return that raw first buffer; otherwise, PyZMQ makes a copy of the data. I don't see any advantage to setting copy=True for the receiving side.

Your program isn't using copy=False on the receiving side, so it does do an extra copy. But as noted earlier, it won't be possible to do a zero copy directly from the OS's kernel TCP reading buffers like you might have been imagining.

More reading