How to return AggregateException from async method

2.1k views Asked by At

I got an async method working like an enhanced Task.WhenAll. It takes a bunch of tasks and returns when all are completed.

public async Task MyWhenAll(Task[] tasks) {
    ...
    await Something();
    ...

    // all tasks are completed
    if (someTasksFailed)
        throw ??
}

My question is how do I get the method to return a Task looking like the one returned from Task.WhenAll when one or more tasks has failed?

If I collect the exceptions and throw an AggregateException it will be wrapped in another AggregateException.

Edit: Full Example

async Task Main() {
    try {
        Task.WhenAll(Throw(1), Throw(2)).Wait();
    }
    catch (Exception ex) {
        ex.Dump();
    }

    try {
        MyWhenAll(Throw(1), Throw(2)).Wait();
    }
    catch (Exception ex) {
        ex.Dump();
    }
}

public async Task MyWhenAll(Task t1, Task t2) {
    await Task.Delay(TimeSpan.FromMilliseconds(100));
    try {
        await Task.WhenAll(t1, t2);
    }
    catch {
        throw new AggregateException(new[] { t1.Exception, t2.Exception });
    }
}
public async Task Throw(int id) {
    await Task.Delay(TimeSpan.FromMilliseconds(100));
    throw new InvalidOperationException("Inner" + id);
}

For Task.WhenAll the exception is AggregateException with 2 inner exceptions.

For MyWhenAll the exception is AggregateException with one inner AggregateException with 2 inner exceptions.

Edit: Why I am doing this

I often need to call paging API:s and want to limit number of simultaneous connections.

The actual method signatures are

public static async Task<TResult[]> AsParallelAsync<TResult>(this IEnumerable<Task<TResult>> source, int maxParallel)
public static async Task<TResult[]> AsParallelUntilAsync<TResult>(this IEnumerable<Task<TResult>> source, int maxParallel, Func<Task<TResult>, bool> predicate)

It means I can do paging like this

var pagedRecords = await Enumerable.Range(1, int.MaxValue)
                                   .Select(x => GetRecordsAsync(pageSize: 1000, pageNumber: x)
                                   .AsParallelUntilAsync(maxParallel: 5, x => x.Result.Count < 1000);
var records = pagedRecords.SelectMany(x => x).ToList();

It all works fine, the aggregate within aggregate is just a minor inconvenience.

3

There are 3 answers

6
Servy On BEST ANSWER

async methods are designed to only every set at most a single exception on the returned task, not multiple.

This leaves you with two options, you can either not use an async method to start with, instead relying on other means of performing your method:

public Task MyWhenAll(Task t1, Task t2)
{
    return Task.Delay(TimeSpan.FromMilliseconds(100))
        .ContinueWith(_ => Task.WhenAll(t1, t2))
        .Unwrap();
}

If you have a more complex method that would be harder to write without using await, then you'll need to unwrap the nested aggregate exceptions, which is tedious, although not overly complex, to do:

    public static Task UnwrapAggregateException(this Task taskToUnwrap)
    {
        var tcs = new TaskCompletionSource<bool>();

        taskToUnwrap.ContinueWith(task =>
        {
            if (task.IsCanceled)
                tcs.SetCanceled();
            else if (task.IsFaulted)
            {
                if (task.Exception is AggregateException aggregateException)
                    tcs.SetException(Flatten(aggregateException));
                else
                    tcs.SetException(task.Exception);
            }
            else //successful
                tcs.SetResult(true);
        });

        IEnumerable<Exception> Flatten(AggregateException exception)
        {
            var stack = new Stack<AggregateException>();
            stack.Push(exception);
            while (stack.Any())
            {
                var next = stack.Pop();
                foreach (Exception inner in next.InnerExceptions)
                {
                    if (inner is AggregateException innerAggregate)
                        stack.Push(innerAggregate);
                    else
                        yield return inner;
                }
            }
        }

        return tcs.Task;
    }
4
canton7 On

Use a TaskCompletionSource.

The outermost exception is created by .Wait() or .Result - this is documented as wrapping the exception stored inside the Task inside an AggregateException (to preserve its stack trace - this was introduced before ExceptionDispatchInfo was created).

However, Task can actually contain many exceptions. When this is the case, .Wait() and .Result will throw an AggregateException which contains multiple InnerExceptions. You can access this functionality through TaskCompletionSource.SetException(IEnumerable<Exception> exceptions).

So you do not want to create your own AggregateException. Set multiple exceptions on the Task, and let .Wait() and .Result create that AggregateException for you.

So:

var tcs = new TaskCompletionSource<object>();
tcs.SetException(new[] { t1.Exception, t2.Exception });
return tcs.Task;

Of course, if you then call await MyWhenAll(..) or MyWhenAll(..).GetAwaiter().GetResult(), then it will only throw the first exception. This matches the behaviour of Task.WhenAll.

This means you need to pass tcs.Task up as your method's return value, which means your method can't be async. You end up doing ugly things like this (adjusting the sample code from your question):

public static Task MyWhenAll(Task t1, Task t2)
{
    var tcs = new TaskCompletionSource<object>();
    var _ = Impl();
    return tcs.Task;

    async Task Impl()
    {
        await Task.Delay(10);
        try
        {
            await Task.WhenAll(t1, t2);
            tcs.SetResult(null);
        }
        catch
        {
            tcs.SetException(new[] { t1.Exception, t2.Exception });
        }
    }
}

At this point, though, I'd start to query why you're trying to do this, and why you can't use the Task returned from Task.WhenAll directly.

0
Theodor Zoulias On

I deleted my previous answer, because I found a simpler solution. This solution does not involve the pesky ContinueWith method or the TaskCompletionSource type. The idea is to return a nested Task<Task> from a local function, and Unwrap() it from the outer container function. Here is a basic outline of this idea:

public Task<T[]> GetAllAsync<T>()
{
    return LocalAsyncFunction().Unwrap();

    async Task<Task<T[]>> LocalAsyncFunction()
    {
        var tasks = new List<Task<T>>();
        // ...
        await SomethingAsync();
        // ...
        Task<T[]> whenAll = Task.WhenAll(tasks);
        return whenAll;
    }
}

The GetAllAsync method is not async. It delegates all the work to the LocalAsyncFunction, which is async, and then Unwraps the resulting nested task and returns it. The unwrapped task contains in its .Exception.InnerExceptions property all the exceptions of the tasks, because it is just a facade of the internal Task.WhenAll task.

Let's demonstrate a more practical realization of this idea. The AsParallelUntilAsync method below enumerates lazily the source sequence and projects the items it contains to Task<TResult>s, until an item satisfies the predicate. It also limits the concurrency of the asynchronous operations. The difficulty is that enumerating the IEnumerable<TSource> could throw an exception too. The correct behavior in this case is to await all the running tasks before propagating the enumeration error, and return an AggregateException that contains both the enumeration error, and all the task errors that may have occurred in the meantime. Here is how it can be done:

public static Task<TResult[]> AsParallelUntilAsync<TSource, TResult>(
    this IEnumerable<TSource> source, Func<TSource, Task<TResult>> action,
    Func<TSource, bool> predicate, int maxConcurrency)
{
    return Implementation().Unwrap();

    async Task<Task<TResult[]>> Implementation()
    {
        var tasks = new List<Task<TResult>>();

        async Task<TResult> EnumerateAsync()
        {
            var semaphore = new SemaphoreSlim(maxConcurrency, maxConcurrency);
            using var enumerator = source.GetEnumerator();
            while (true)
            {
                await semaphore.WaitAsync();
                if (!enumerator.MoveNext()) break;
                var item = enumerator.Current;
                if (predicate(item)) break;

                async Task<TResult> RunAndRelease(TSource item)
                {
                    try { return await action(item); }
                    finally { semaphore.Release(); }
                }

                tasks.Add(RunAndRelease(item));
            }
            return default; // A dummy value that will never be returned
        }

        Task<TResult> enumerateTask = EnumerateAsync();

        try
        {
            await enumerateTask; // Make sure that the enumeration succeeded
            Task<TResult[]> whenAll = Task.WhenAll(tasks);
            await whenAll; // Make sure that all the tasks succeeded
            return whenAll;
        }
        catch
        {
            // Return a faulted task that contains ALL the errors!
            return Task.WhenAll(tasks.Prepend(enumerateTask));
        }
    }
}