0

I have following use case:

I am trying to create an Airflow DAG which will be used to automate historical data load. I am passing 5 arguments to this DAG - jar_path, main_class_name, start_param, end_param, max_run_range. Based on start_param, end_param and max_run_range, a list of dates is created. For each individual date of this list, I want to trigger a DatabricksSubmitRunOperator task. I have below code but when I execute it, I am getting this error for the dynamic tasks that it creates -

Airflow DAG code:

@dag(default_args=runParams,
     schedule_interval=None,
     start_date=pendulum.datetime(2024, 4, 22, tz="US/Pacific"),
     tags=["historical_load"]
     )
def CAN_HIST_LOAD_TEST():

    @task
    def get_args():
        context = get_current_context()
        jar_path = context['dag_run'].conf['jar_path']
        main_class_name = context['dag_run'].conf['main_class_name']
        start_param = context['dag_run'].conf['start_param']
        end_param = context['dag_run'].conf['end_param']
        max_run_range = context['dag_run'].conf['max_run_range']

        return {
            'jar_path': jar_path,
            'main_class_name': main_class_name,
            'start_param': start_param,
            'end_param': end_param,
            'max_run_range': max_run_range
        }

    @task
    def process_args(args) -> list[dict]:
        start_date = args['start_param']
        end_date = args['end_param']
        interval = args['max_run_range']

        date_params = []
        idx = 0
        while start_date <= end_date:
            dict_item = {idx: start_date}
            date_params.append(dict_item)
            idx += 1
            start_date = datetime.strptime(start_date, '%Y-%m-%d').date() + timedelta(days=interval)
            start_date = str(start_date)

        args['date_params_dict'] = date_params

        return date_params

    @task
    def spark_task(args, dt_dict):
        dt = ""
        for key, val in dt_dict.items():
            dt = val
        DatabricksSubmitRunOperator(
            task_id=f"process_task_{dt}",
            new_cluster=jobCluster,
            dag=dag,
            databricks_conn_id="${DBxConnID}",
            access_control_list=accessControlList,
            spark_jar_task={"main_class_name": args['main_class_name'], "parameters": dt},
            libraries=[{"jar": args['jar_path']}],
        )

    spark_task.partial(args=get_args()).expand(dt_dict=process_args(get_args()))

CAN_HIST_LOAD_TEST_DAG = CAN_HIST_LOAD_TEST()

Error:

Traceback (most recent call last):
  File "/usr/local/lib/python3.9/site-packages/airflow/decorators/base.py", line 188, in execute
    return_value = super().execute(context)
  File "/usr/local/lib/python3.9/site-packages/airflow/operators/python.py", line 175, in execute
    return_value = self.execute_callable()
  File "/usr/local/lib/python3.9/site-packages/airflow/operators/python.py", line 193, in execute_callable
    return self.python_callable(*self.op_args, **self.op_kwargs)
  File "/usr/local/airflow/dags/can/CAN_HIST_LOAD_TEST.py", line 122, in spark_task
    DatabricksSubmitRunOperator(
  File "/usr/local/lib/python3.9/site-packages/airflow/models/baseoperator.py", line 376, in apply_defaults
    task_group = TaskGroupContext.get_current_task_group(dag)
  File "/usr/local/lib/python3.9/site-packages/airflow/utils/task_group.py", line 489, in get_current_task_group
    return dag.task_group
AttributeError: 'function' object has no attribute 'task_group'

Even though I am not using TaskGroup, I am not sure why is it giving this error.

Env used: Azure Databricks, Astronomer for Airflow(airflow version 2.4.3), jar is created on spark/scala code.

I have also tried passing the dates list directly into .expand() method but it also resulted in the same error. sample run output for 3 dates For a sample argument of 3 dates (start_param='2024-03-01',end_param='2024-03-03',max_run_range=1), the code is able to create 3 mapped instances but it fails in each instance with the above error. (Ignore start and end tasks)

For other variations that I tried, the code failed without even creating dynamic tasks.

1 Answer 1

0

I was able to get a successful run with some changes. Posting the working code for other's reference.

@dag(default_args=runParams,
     schedule_interval=None,
     start_date=pendulum.datetime(2024, 4, 22, tz="US/Pacific"),
     tags=["CAN", "data_layer", "com-can", "historical_load"]
     )
def CAN_HIST_LOAD_TEST():

    @task
    def get_args():
        context = get_current_context()
        jar_path = context['dag_run'].conf['jar_path']
        main_class_name = context['dag_run'].conf['main_class_name']
        start_param = context['dag_run'].conf['start_param']
        end_param = context['dag_run'].conf['end_param']
        max_run_range = context['dag_run'].conf['max_run_range']

        return {
            'jar_path': jar_path,
            'main_class_name': main_class_name,
            'start_param': start_param,
            'end_param': end_param,
            'max_run_range': max_run_range
        }

    @task
    def process_args(args) -> list:
        start_date = args['start_param']
        end_date = args['end_param']
        interval = args['max_run_range']
        class_name = args['main_class_name']

        date_params = []
        while start_date <= end_date:
            list_item = {
                "main_class_name": class_name,
                "parameters": [start_date]
            }
            date_params.append(list_item)
            start_date = datetime.strptime(start_date, '%Y-%m-%d').date() + timedelta(days=interval)
            start_date = str(start_date)

        return date_params

    args = get_args()
    date_params = process_args(args)

    spark_task = DatabricksSubmitRunOperator.partial(
        task_id=f"process_task",
        new_cluster=jobCluster,
        databricks_conn_id="allyuen",
        access_control_list=accessControlList,
        libraries=[{"jar": "{{ task_instance.xcom_pull(task_ids='get_args')['jar_path'] }}"}],
    ).expand(spark_jar_task=date_params)

CAN_HIST_LOAD_TEST_DAG = CAN_HIST_LOAD_TEST()
Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.