dask handle delayed failures

392 views Asked by At

How can I port the following function to dask in order to parallelize it?

from time import sleep
from dask.distributed import Client
from dask import delayed
client = Client(n_workers=4)
from tqdm import tqdm
tqdm.pandas()

# linear
things = [1,2,3]
_x = []
_y = []

def my_slow_function(foo):
    sleep(2)
    x = foo
    y = 2 * foo
    assert y < 5
    return x, y

for foo in tqdm(things):
    try:
        x_v, y_v = my_slow_function(foo)
        _x.append(x_v)
        if y_v is not None: _y.append(y_v)
    except AssertionError:
        print(f'failed: {foo}')

X = _x
y = _y

print(X)
print(y)

I am particularly unsure about handling the state & failure in the delayed futures.

So far I only have:

from dask.diagnostics import ProgressBar
ProgressBar().register()

@delayed(nout=2)
def my_slow_function(foo):
    sleep(2)
    x = foo
    y = 2 * foo
    assert y < 5
    return x, y


for foo in tqdm(things):
    try:
        x_v, y_v = delayed(my_slow_function(foo))
        _x.append(x_v)
        if y_v is not None: _y.append(y_v)
    except AssertionError:
        print(f'failed: {foo}')

X = _x
y = _y

print(X)
print(y)

delayed(sum)(X).compute()

But:

  • the try/except no longer works. I.e. is no longer catching the exceptions
  • I have 2 lists of delayed results but no 2 lists of computed values
    • for these 2 lists I am unsure how to execute compute without computing the result twice

edit

futures = client.map(my_slow_function, things)
results = client.gather(futures)

obviously fails as the exception is no longer handled - but so far I am not really sure what is the right way of catching them from dask.

How to prevent dask client from dying on worker exception? might be similar

1

There are 1 answers

0
Georg Heiler On BEST ANSWER

It is a design goal of dask to cancel the whole task graph in case of failure (). Instead the concurrent futures API should be used (https://docs.dask.org/en/latest/futures.html), which allows to handle the failure on the driver:

futures = client.map(my_slow_function, things)
from dask.distributed import wait, as_completed
wait(futures)

for f in futures:
    print(f)
    try:
        f = f.result()
        print(f)
        # to match 1:1 add unpacking and as well as append to the state result list
    except:
        # implement logging here
        pass