Airflow : Complete all tasks in a TaskGroup before running to the next one and avoid dependancies between TaskGroup

62 views Asked by At

I would like to set up a DAG where all tasks in a single TaskGroup are done before running to the next one. Meaning that in the example (cf screenshot), the Workflow_FRA has to be done with the tasks run_task_FRA and run_next_task_FRA then the Workflow_BEL would run and on and on.

Below the DAG script but it's running the tasks in parallel regardless the TaskGroup.

enter image description here

import time
from datetime import datetime
from airflow.utils.task_group import TaskGroup
from airflow.decorators import task
from airflow import DAG

with DAG(
        dag_id="dev_dag",
        concurrency=1,
        start_date=datetime(2024, 2, 27),
        schedule_interval='*/1 * * * *',
        catchup=False
) as dag:

    @task(task_id="start_task")
    def start_task():
        print("start")

    start_task = start_task()

    @task(task_id="end_task")
    def end_task():
        print("end")

    end_task = end_task()

    for country in ["FR", "BE", "SP", "EN"]:
        with TaskGroup(group_id=f"workflow_{country}") as workflow:
            @task(task_id=f"run_task_{country}")
            def run_task():
                time.sleep(5)
                print("run task")


            @task(task_id=f"run_next_task_{country}")
            def run_next_task():
                time.sleep(5)
                print("run next task")

            start_task >> run_task() >> run_next_task() >> end_task
    start_task >> workflow >> end_task

What I want to achieve it's, if the TaskGroup workflow_BE fails, the next TaskGroup worflows are able to run and I'd like to clear the tasks from the one that has failed without running the next taskgroups again

1

There are 1 answers

5
ARCrow On BEST ANSWER

You can use the chain operator:

import time
from datetime import datetime
from airflow.utils.task_group import TaskGroup
from airflow.decorators import task
from airflow.models.baseoperator import chain
from airflow import DAG

with DAG(
        dag_id="dev_dag",
        concurrency=1,
        start_date=datetime(2024, 2, 27),
        schedule_interval=None,
        catchup=False
) as dag:

    @task(task_id="start_task")
    def start_task():
        print("start")

    start_task = start_task()

    @task(task_id="end_task")
    def end_task():
        print("end")

    end_task = end_task()

    tasks = [start_task]
    for country in ["FR", "BE", "SP", "EN"]:
        with TaskGroup(group_id=f"workflow_{country}") as workflow:
            @task(task_id=f"run_task_{country}")
            def run_task():
                time.sleep(5)
                print("run task")


            @task(
                task_id=f"run_next_task_{country}",
                trigger_rule="all_done"
            )
            def run_next_task():
                time.sleep(5)
                print("run next task")

            start_task >> run_task() >> run_next_task() >> end_task
        tasks.append(workflow)
    tasks.append(end_task)
    chain(*tasks)

The dependency graph will look like this: enter image description here