From cd1c90e3840074863b8c5648353b38d980b67632 Mon Sep 17 00:00:00 2001 From: claudiazi Date: Thu, 7 Aug 2025 14:38:29 +0200 Subject: [PATCH 1/9] feat: fetch the logs from cloudwatch before BatchOperator.execute_complete --- .../airflow/operators/awsbatch_operator.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/dagger/dag_creator/airflow/operators/awsbatch_operator.py b/dagger/dag_creator/airflow/operators/awsbatch_operator.py index 1b56ca7..2b6950f 100644 --- a/dagger/dag_creator/airflow/operators/awsbatch_operator.py +++ b/dagger/dag_creator/airflow/operators/awsbatch_operator.py @@ -92,16 +92,13 @@ def monitor_job(self, context: Context): def execute_complete(self, context: Context, event: Optional[dict[str, Any]] = None) -> str: """Execute when the trigger fires - fetch logs and complete the task.""" - # Call parent's execute_complete first - job_id = super().execute_complete(context, event) - - # Only fetch logs if we're in deferrable mode and awslogs are enabled - # In non-deferrable mode, logs are already fetched by monitor_job() - if self.deferrable and self.awslogs_enabled and job_id: + # Fetch logs before calling parent's execute_complete for both success and failure cases + if self.deferrable and self.awslogs_enabled and event and event.get("job_id"): + job_id = event["job_id"] # Set job_id for our log fetching methods self.job_id = job_id - # Get job logs and display them + # Get job logs and display them for both successful and failed jobs try: # Use the log fetcher to display container logs log_fetcher = self._get_batch_log_fetcher(job_id) @@ -133,7 +130,6 @@ def execute_complete(self, context: Context, event: Optional[dict[str, Any]] = N aws_partition=self.hook.conn_partition, **awslogs[0], ) - - self.log.info("AWS Batch job (%s) succeeded", self.job_id) - - return job_id + + # Call parent's execute_complete which will handle success/failure logic + return super().execute_complete(context, event) From 6733f4a703da64a05d7edb86aac8947b76893bfb Mon Sep 17 00:00:00 2001 From: claudiazi Date: Thu, 7 Aug 2025 16:37:42 +0200 Subject: [PATCH 2/9] feat: add fetch logs if the execute fails --- .../airflow/operators/awsbatch_operator.py | 48 ++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/dagger/dag_creator/airflow/operators/awsbatch_operator.py b/dagger/dag_creator/airflow/operators/awsbatch_operator.py index 2b6950f..9eca269 100644 --- a/dagger/dag_creator/airflow/operators/awsbatch_operator.py +++ b/dagger/dag_creator/airflow/operators/awsbatch_operator.py @@ -1,6 +1,6 @@ from typing import Any, Optional -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, TaskDeferralError from airflow.providers.amazon.aws.links.batch import ( BatchJobDefinitionLink, BatchJobQueueLink, @@ -133,3 +133,49 @@ def execute_complete(self, context: Context, event: Optional[dict[str, Any]] = N # Call parent's execute_complete which will handle success/failure logic return super().execute_complete(context, event) + + def _fetch_batch_logs(self): + """Fetch and display batch job logs for debugging failed jobs.""" + if not self.job_id or not self.awslogs_enabled: + return + + try: + # Use the log fetcher to display container logs + log_fetcher = self._get_batch_log_fetcher(self.job_id) + if log_fetcher: + # Get the last 50 log messages + self.log.info("Fetch the latest 50 messages from cloudwatch:") + log_messages = log_fetcher.get_last_log_messages(50) + for message in log_messages: + self.log.info(message) + except Exception as e: + self.log.warning("Could not fetch batch job logs: %s", e) + + # Get CloudWatch log links + awslogs = [] + try: + awslogs = self.hook.get_job_all_awslogs_info(self.job_id) + except AirflowException as ae: + self.log.warning("Cannot determine where to find the AWS logs for this Batch job: %s", ae) + + if awslogs: + self.log.info("AWS Batch job (%s) CloudWatch Events details found. Links to logs:", self.job_id) + for log in awslogs: + self.log.info(self._format_cloudwatch_link(**log)) + + def execute(self, context: Context): + """Override execute to handle failures and fetch logs.""" + try: + return super().execute(context) + except (TaskDeferralError, AirflowException) as e: + # When deferred task fails or other batch-related errors occur, fetch logs if we have a job_id + if self.deferrable and self.job_id and self.awslogs_enabled: + self.log.info("Task failed (deferrable mode), attempting to fetch batch job logs...") + self._fetch_batch_logs() + raise + except Exception as e: + # For any other unexpected exception, still try to fetch logs if we have job info + if self.deferrable and self.job_id and self.awslogs_enabled: + self.log.info("Unexpected error in deferrable batch task, attempting to fetch logs...") + self._fetch_batch_logs() + raise From abfeac2a01d557c622ea0345cc0802ac4c09c1a2 Mon Sep 17 00:00:00 2001 From: claudiazi Date: Thu, 7 Aug 2025 14:38:29 +0200 Subject: [PATCH 3/9] feat: fetch the logs from cloudwatch before BatchOperator.execute_complete --- .../airflow/operators/awsbatch_operator.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/dagger/dag_creator/airflow/operators/awsbatch_operator.py b/dagger/dag_creator/airflow/operators/awsbatch_operator.py index 1b56ca7..2b6950f 100644 --- a/dagger/dag_creator/airflow/operators/awsbatch_operator.py +++ b/dagger/dag_creator/airflow/operators/awsbatch_operator.py @@ -92,16 +92,13 @@ def monitor_job(self, context: Context): def execute_complete(self, context: Context, event: Optional[dict[str, Any]] = None) -> str: """Execute when the trigger fires - fetch logs and complete the task.""" - # Call parent's execute_complete first - job_id = super().execute_complete(context, event) - - # Only fetch logs if we're in deferrable mode and awslogs are enabled - # In non-deferrable mode, logs are already fetched by monitor_job() - if self.deferrable and self.awslogs_enabled and job_id: + # Fetch logs before calling parent's execute_complete for both success and failure cases + if self.deferrable and self.awslogs_enabled and event and event.get("job_id"): + job_id = event["job_id"] # Set job_id for our log fetching methods self.job_id = job_id - # Get job logs and display them + # Get job logs and display them for both successful and failed jobs try: # Use the log fetcher to display container logs log_fetcher = self._get_batch_log_fetcher(job_id) @@ -133,7 +130,6 @@ def execute_complete(self, context: Context, event: Optional[dict[str, Any]] = N aws_partition=self.hook.conn_partition, **awslogs[0], ) - - self.log.info("AWS Batch job (%s) succeeded", self.job_id) - - return job_id + + # Call parent's execute_complete which will handle success/failure logic + return super().execute_complete(context, event) From a65e25e0bb1f2968da65ced8b13a1a5ca64fa0c7 Mon Sep 17 00:00:00 2001 From: claudiazi Date: Thu, 7 Aug 2025 16:37:42 +0200 Subject: [PATCH 4/9] feat: add fetch logs if the execute fails --- .../airflow/operators/awsbatch_operator.py | 48 ++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/dagger/dag_creator/airflow/operators/awsbatch_operator.py b/dagger/dag_creator/airflow/operators/awsbatch_operator.py index 2b6950f..9eca269 100644 --- a/dagger/dag_creator/airflow/operators/awsbatch_operator.py +++ b/dagger/dag_creator/airflow/operators/awsbatch_operator.py @@ -1,6 +1,6 @@ from typing import Any, Optional -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, TaskDeferralError from airflow.providers.amazon.aws.links.batch import ( BatchJobDefinitionLink, BatchJobQueueLink, @@ -133,3 +133,49 @@ def execute_complete(self, context: Context, event: Optional[dict[str, Any]] = N # Call parent's execute_complete which will handle success/failure logic return super().execute_complete(context, event) + + def _fetch_batch_logs(self): + """Fetch and display batch job logs for debugging failed jobs.""" + if not self.job_id or not self.awslogs_enabled: + return + + try: + # Use the log fetcher to display container logs + log_fetcher = self._get_batch_log_fetcher(self.job_id) + if log_fetcher: + # Get the last 50 log messages + self.log.info("Fetch the latest 50 messages from cloudwatch:") + log_messages = log_fetcher.get_last_log_messages(50) + for message in log_messages: + self.log.info(message) + except Exception as e: + self.log.warning("Could not fetch batch job logs: %s", e) + + # Get CloudWatch log links + awslogs = [] + try: + awslogs = self.hook.get_job_all_awslogs_info(self.job_id) + except AirflowException as ae: + self.log.warning("Cannot determine where to find the AWS logs for this Batch job: %s", ae) + + if awslogs: + self.log.info("AWS Batch job (%s) CloudWatch Events details found. Links to logs:", self.job_id) + for log in awslogs: + self.log.info(self._format_cloudwatch_link(**log)) + + def execute(self, context: Context): + """Override execute to handle failures and fetch logs.""" + try: + return super().execute(context) + except (TaskDeferralError, AirflowException) as e: + # When deferred task fails or other batch-related errors occur, fetch logs if we have a job_id + if self.deferrable and self.job_id and self.awslogs_enabled: + self.log.info("Task failed (deferrable mode), attempting to fetch batch job logs...") + self._fetch_batch_logs() + raise + except Exception as e: + # For any other unexpected exception, still try to fetch logs if we have job info + if self.deferrable and self.job_id and self.awslogs_enabled: + self.log.info("Unexpected error in deferrable batch task, attempting to fetch logs...") + self._fetch_batch_logs() + raise From 007dcbac0655054c061805765988b4bc1164c51d Mon Sep 17 00:00:00 2001 From: claudiazi Date: Fri, 8 Aug 2025 12:51:15 +0200 Subject: [PATCH 5/9] fix: handle all failure scenarios --- .../airflow/operators/awsbatch_operator.py | 155 +++++++++--------- 1 file changed, 79 insertions(+), 76 deletions(-) diff --git a/dagger/dag_creator/airflow/operators/awsbatch_operator.py b/dagger/dag_creator/airflow/operators/awsbatch_operator.py index 9eca269..b4173e2 100644 --- a/dagger/dag_creator/airflow/operators/awsbatch_operator.py +++ b/dagger/dag_creator/airflow/operators/awsbatch_operator.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any, Optional, Union from airflow.exceptions import AirflowException, TaskDeferralError from airflow.providers.amazon.aws.links.batch import ( @@ -10,6 +10,18 @@ from airflow.utils.context import Context +def _format_extra_info(error_msg: str, last_logs: list[str], cloudwatch_link: Optional[str]) -> str: + """Format the enhanced error message with logs and link.""" + extra_info = [] + if cloudwatch_link: + extra_info.append(f"CloudWatch Logs: {cloudwatch_link}") + if last_logs: + extra_info.append("Last log lines:\n" + "\n".join(last_logs[-5:])) + if extra_info: + return f"{error_msg}\n\n" + "\n".join(extra_info) + return error_msg + + class AWSBatchOperator(BatchOperator): @staticmethod def _format_cloudwatch_link(awslogs_region: str, awslogs_group: str, awslogs_stream_name: str): @@ -90,92 +102,83 @@ def monitor_job(self, context: Context): self.hook.check_job_success(self.job_id) self.log.info("AWS Batch job (%s) succeeded", self.job_id) - def execute_complete(self, context: Context, event: Optional[dict[str, Any]] = None) -> str: - """Execute when the trigger fires - fetch logs and complete the task.""" - # Fetch logs before calling parent's execute_complete for both success and failure cases - if self.deferrable and self.awslogs_enabled and event and event.get("job_id"): - job_id = event["job_id"] - # Set job_id for our log fetching methods - self.job_id = job_id - - # Get job logs and display them for both successful and failed jobs + def _fetch_and_log_cloudwatch(self, context: Context, job_id: str) -> tuple[list[str], Optional[str]]: + """ + Fetch CloudWatch logs for the given job_id, log them to Airflow, + and return (last_logs, cloudwatch_link). + """ + last_logs: list[str] = [] + cloudwatch_link: Optional[str] = None + + if self.awslogs_enabled: + # Fetch last 50 log messages try: - # Use the log fetcher to display container logs log_fetcher = self._get_batch_log_fetcher(job_id) if log_fetcher: - # Get the last 50 log messages - self.log.info("Fetch the latest 50 messages from cloudwatch:") - log_messages = log_fetcher.get_last_log_messages(50) - for message in log_messages: + self.log.info("Fetching the latest 50 messages from CloudWatch:") + last_logs = log_fetcher.get_last_log_messages(50) + for message in last_logs: self.log.info(message) except Exception as e: self.log.warning("Could not fetch batch job logs: %s", e) - - # Get CloudWatch log links - awslogs = [] + + # Fetch CloudWatch log link try: - awslogs = self.hook.get_job_all_awslogs_info(self.job_id) + awslogs = self.hook.get_job_all_awslogs_info(job_id) except AirflowException as ae: - self.log.warning("Cannot determine where to find the AWS logs for this Batch job: %s", ae) - - if awslogs: - self.log.info("AWS Batch job (%s) CloudWatch Events details found. Links to logs:", self.job_id) - for log in awslogs: - self.log.info(self._format_cloudwatch_link(**log)) - - CloudWatchEventsLink.persist( - context=context, - operator=self, - region_name=self.hook.conn_region_name, - aws_partition=self.hook.conn_partition, - **awslogs[0], + self.log.warning("Cannot determine where to find the AWS logs: %s", ae) + awslogs = [] + else: + if awslogs: + cloudwatch_link = self._format_cloudwatch_link(**awslogs[0]) + self.log.info("AWS Batch job (%s) CloudWatch Events details found:", job_id) + for log in awslogs: + self.log.info(self._format_cloudwatch_link(**log)) + CloudWatchEventsLink.persist( + context=context, + operator=self, + region_name=self.hook.conn_region_name, + aws_partition=self.hook.conn_partition, + **awslogs[0], + ) + + return last_logs, cloudwatch_link + + def execute(self, context: Context) -> Union[str, None]: + """Submit and monitor an AWS Batch job, including early failures.""" + try: + result = super().execute(context) + return result + except TaskDeferralError as e: + # Trigger itself failed — try to fetch logs if job_id is available + if self.deferrable and self.job_id: + last_logs, cloudwatch_link = self._fetch_and_log_cloudwatch(context, self.job_id) + raise AirflowException( + _format_extra_info(f"Trigger failed for job {self.job_id}: {e}", last_logs, cloudwatch_link) ) + raise + except AirflowException as e: + # Covers immediate failures before deferral (job already FAILED) + if self.job_id: + last_logs, cloudwatch_link = self._fetch_and_log_cloudwatch(context, self.job_id) + raise AirflowException(_format_extra_info(str(e), last_logs, cloudwatch_link)) + raise - # Call parent's execute_complete which will handle success/failure logic - return super().execute_complete(context, event) + def execute_complete(self, context: Context, event: Optional[dict[str, Any]] = None) -> str: + """Execute when the trigger fires - fetch logs first, then check job status.""" + job_id = event.get("job_id") if event else None + if not job_id: + raise AirflowException("No job_id found in event data from trigger.") - def _fetch_batch_logs(self): - """Fetch and display batch job logs for debugging failed jobs.""" - if not self.job_id or not self.awslogs_enabled: - return + self.job_id = job_id - try: - # Use the log fetcher to display container logs - log_fetcher = self._get_batch_log_fetcher(self.job_id) - if log_fetcher: - # Get the last 50 log messages - self.log.info("Fetch the latest 50 messages from cloudwatch:") - log_messages = log_fetcher.get_last_log_messages(50) - for message in log_messages: - self.log.info(message) - except Exception as e: - self.log.warning("Could not fetch batch job logs: %s", e) - - # Get CloudWatch log links - awslogs = [] - try: - awslogs = self.hook.get_job_all_awslogs_info(self.job_id) - except AirflowException as ae: - self.log.warning("Cannot determine where to find the AWS logs for this Batch job: %s", ae) + # Always fetch logs before checking status + last_logs, cloudwatch_link = self._fetch_and_log_cloudwatch(context, job_id) - if awslogs: - self.log.info("AWS Batch job (%s) CloudWatch Events details found. Links to logs:", self.job_id) - for log in awslogs: - self.log.info(self._format_cloudwatch_link(**log)) - - def execute(self, context: Context): - """Override execute to handle failures and fetch logs.""" try: - return super().execute(context) - except (TaskDeferralError, AirflowException) as e: - # When deferred task fails or other batch-related errors occur, fetch logs if we have a job_id - if self.deferrable and self.job_id and self.awslogs_enabled: - self.log.info("Task failed (deferrable mode), attempting to fetch batch job logs...") - self._fetch_batch_logs() - raise - except Exception as e: - # For any other unexpected exception, still try to fetch logs if we have job info - if self.deferrable and self.job_id and self.awslogs_enabled: - self.log.info("Unexpected error in deferrable batch task, attempting to fetch logs...") - self._fetch_batch_logs() - raise + self.hook.check_job_success(job_id) + except AirflowException as e: + raise AirflowException(_format_extra_info(str(e), last_logs, cloudwatch_link)) + + self.log.info("AWS Batch job (%s) succeeded", job_id) + return job_id From 78b47726ff1bc54171f42fe37a6f47ad9245a99a Mon Sep 17 00:00:00 2001 From: claudiazi Date: Fri, 8 Aug 2025 14:16:43 +0200 Subject: [PATCH 6/9] fix: overwrite resume_execution --- .../airflow/operators/awsbatch_operator.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/dagger/dag_creator/airflow/operators/awsbatch_operator.py b/dagger/dag_creator/airflow/operators/awsbatch_operator.py index b4173e2..9f80dec 100644 --- a/dagger/dag_creator/airflow/operators/awsbatch_operator.py +++ b/dagger/dag_creator/airflow/operators/awsbatch_operator.py @@ -182,3 +182,24 @@ def execute_complete(self, context: Context, event: Optional[dict[str, Any]] = N self.log.info("AWS Batch job (%s) succeeded", job_id) return job_id + + def resume_execution(self, next_method: str, next_kwargs: Optional[dict[str, Any]], context: Context): + """Override resume_execution to handle trigger failures and fetch logs.""" + self.log.info(f"AWSBatchOperator.resume_execution called with next_method='{next_method}'") + self.log.info(f"job_id available: {hasattr(self, 'job_id') and bool(self.job_id)}") + self.log.info(f"awslogs_enabled: {getattr(self, 'awslogs_enabled', False)}") + + try: + return super().resume_execution(next_method, next_kwargs, context) + except TaskDeferralError as e: + # When trigger fails, try to fetch logs if job_id is available + if hasattr(self, 'job_id') and self.job_id and self.awslogs_enabled: + self.log.info("Trigger failed - attempting to fetch batch job logs...") + last_logs, cloudwatch_link = self._fetch_and_log_cloudwatch(context, self.job_id) + # Re-raise with enhanced error message + raise AirflowException( + _format_extra_info(f"Trigger failed for job {self.job_id}: {e}", last_logs, cloudwatch_link) + ) + else: + self.log.warning(f"Cannot fetch logs: job_id={getattr(self, 'job_id', None)}, awslogs_enabled={getattr(self, 'awslogs_enabled', False)}") + raise From 892e6189cba22a36216af8032f5cc2b0d3fee464 Mon Sep 17 00:00:00 2001 From: claudiazi Date: Fri, 8 Aug 2025 14:51:05 +0200 Subject: [PATCH 7/9] fix: save job_id in xcom --- .../airflow/operators/awsbatch_operator.py | 60 +++++++++++-------- 1 file changed, 36 insertions(+), 24 deletions(-) diff --git a/dagger/dag_creator/airflow/operators/awsbatch_operator.py b/dagger/dag_creator/airflow/operators/awsbatch_operator.py index 9f80dec..9374813 100644 --- a/dagger/dag_creator/airflow/operators/awsbatch_operator.py +++ b/dagger/dag_creator/airflow/operators/awsbatch_operator.py @@ -146,23 +146,27 @@ def _fetch_and_log_cloudwatch(self, context: Context, job_id: str) -> tuple[list def execute(self, context: Context) -> Union[str, None]: """Submit and monitor an AWS Batch job, including early failures.""" - try: - result = super().execute(context) - return result - except TaskDeferralError as e: - # Trigger itself failed — try to fetch logs if job_id is available - if self.deferrable and self.job_id: - last_logs, cloudwatch_link = self._fetch_and_log_cloudwatch(context, self.job_id) - raise AirflowException( - _format_extra_info(f"Trigger failed for job {self.job_id}: {e}", last_logs, cloudwatch_link) - ) - raise - except AirflowException as e: - # Covers immediate failures before deferral (job already FAILED) - if self.job_id: - last_logs, cloudwatch_link = self._fetch_and_log_cloudwatch(context, self.job_id) - raise AirflowException(_format_extra_info(str(e), last_logs, cloudwatch_link)) - raise + # First call parent execute, which will submit the job and possibly defer + result = super().execute(context) + + # If we reach here without exception, the task completed (didn't defer) + return result + + def defer(self, *, trigger, method_name: str = "execute_complete", kwargs=None, timeout=None): + """Override defer to store job_id in XCom before deferring.""" + # Store job_id in XCom so it's available when the task resumes + if hasattr(self, 'job_id') and self.job_id: + # Get task instance from current context + from airflow.operators.python import get_current_context + try: + context = get_current_context() + context['task_instance'].xcom_push(key='batch_job_id', value=self.job_id) + self.log.info(f"Stored job_id in XCom before deferring: {self.job_id}") + except Exception as e: + self.log.warning(f"Could not store job_id in XCom: {e}") + + # Call parent defer method + super().defer(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout) def execute_complete(self, context: Context, event: Optional[dict[str, Any]] = None) -> str: """Execute when the trigger fires - fetch logs first, then check job status.""" @@ -185,21 +189,29 @@ def execute_complete(self, context: Context, event: Optional[dict[str, Any]] = N def resume_execution(self, next_method: str, next_kwargs: Optional[dict[str, Any]], context: Context): """Override resume_execution to handle trigger failures and fetch logs.""" - self.log.info(f"AWSBatchOperator.resume_execution called with next_method='{next_method}'") - self.log.info(f"job_id available: {hasattr(self, 'job_id') and bool(self.job_id)}") - self.log.info(f"awslogs_enabled: {getattr(self, 'awslogs_enabled', False)}") + # Retrieve job_id from XCom if not available on the instance + if not hasattr(self, 'job_id') or not self.job_id: + task_instance = context.get('task_instance') + if task_instance: + try: + stored_job_id = task_instance.xcom_pull(task_ids=task_instance.task_id, key='batch_job_id') + if stored_job_id: + self.job_id = stored_job_id + self.log.info(f"Retrieved job_id from XCom: {stored_job_id}") + except Exception as e: + self.log.debug(f"Could not retrieve job_id from XCom: {e}") try: return super().resume_execution(next_method, next_kwargs, context) except TaskDeferralError as e: # When trigger fails, try to fetch logs if job_id is available if hasattr(self, 'job_id') and self.job_id and self.awslogs_enabled: - self.log.info("Trigger failed - attempting to fetch batch job logs...") + self.log.info("Batch job trigger failed - fetching CloudWatch logs...") last_logs, cloudwatch_link = self._fetch_and_log_cloudwatch(context, self.job_id) - # Re-raise with enhanced error message + # Re-raise with enhanced error message including logs raise AirflowException( - _format_extra_info(f"Trigger failed for job {self.job_id}: {e}", last_logs, cloudwatch_link) + _format_extra_info(f"Batch job {self.job_id} failed: {e}", last_logs, cloudwatch_link) ) else: - self.log.warning(f"Cannot fetch logs: job_id={getattr(self, 'job_id', None)}, awslogs_enabled={getattr(self, 'awslogs_enabled', False)}") + self.log.warning("Cannot fetch logs for failed batch job - job_id or awslogs_enabled not available") raise From 3e6fe6d61cc5105ed6da9f9fe0b40649c771e3ed Mon Sep 17 00:00:00 2001 From: claudiazi Date: Fri, 8 Aug 2025 15:12:53 +0200 Subject: [PATCH 8/9] chore: simplify the logic --- .../airflow/operators/awsbatch_operator.py | 28 ++++--------------- 1 file changed, 6 insertions(+), 22 deletions(-) diff --git a/dagger/dag_creator/airflow/operators/awsbatch_operator.py b/dagger/dag_creator/airflow/operators/awsbatch_operator.py index 9374813..e052724 100644 --- a/dagger/dag_creator/airflow/operators/awsbatch_operator.py +++ b/dagger/dag_creator/airflow/operators/awsbatch_operator.py @@ -152,22 +152,6 @@ def execute(self, context: Context) -> Union[str, None]: # If we reach here without exception, the task completed (didn't defer) return result - def defer(self, *, trigger, method_name: str = "execute_complete", kwargs=None, timeout=None): - """Override defer to store job_id in XCom before deferring.""" - # Store job_id in XCom so it's available when the task resumes - if hasattr(self, 'job_id') and self.job_id: - # Get task instance from current context - from airflow.operators.python import get_current_context - try: - context = get_current_context() - context['task_instance'].xcom_push(key='batch_job_id', value=self.job_id) - self.log.info(f"Stored job_id in XCom before deferring: {self.job_id}") - except Exception as e: - self.log.warning(f"Could not store job_id in XCom: {e}") - - # Call parent defer method - super().defer(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout) - def execute_complete(self, context: Context, event: Optional[dict[str, Any]] = None) -> str: """Execute when the trigger fires - fetch logs first, then check job status.""" job_id = event.get("job_id") if event else None @@ -189,17 +173,17 @@ def execute_complete(self, context: Context, event: Optional[dict[str, Any]] = N def resume_execution(self, next_method: str, next_kwargs: Optional[dict[str, Any]], context: Context): """Override resume_execution to handle trigger failures and fetch logs.""" - # Retrieve job_id from XCom if not available on the instance + # Retrieve job_id from batch_job_details XCom if not available on the instance if not hasattr(self, 'job_id') or not self.job_id: task_instance = context.get('task_instance') if task_instance: try: - stored_job_id = task_instance.xcom_pull(task_ids=task_instance.task_id, key='batch_job_id') - if stored_job_id: - self.job_id = stored_job_id - self.log.info(f"Retrieved job_id from XCom: {stored_job_id}") + batch_job_details = task_instance.xcom_pull(task_ids=task_instance.task_id, key='batch_job_details') + if batch_job_details and 'job_id' in batch_job_details: + self.job_id = batch_job_details['job_id'] + self.log.info(f"Retrieved job_id from batch_job_details XCom: {self.job_id}") except Exception as e: - self.log.debug(f"Could not retrieve job_id from XCom: {e}") + self.log.debug(f"Could not retrieve job_id from batch_job_details XCom: {e}") try: return super().resume_execution(next_method, next_kwargs, context) From 8db638a24d2fc1f83ecefb2e215402d6f0aa6a1b Mon Sep 17 00:00:00 2001 From: claudiazi Date: Fri, 8 Aug 2025 15:42:11 +0200 Subject: [PATCH 9/9] chore: clean the logic --- .../airflow/operators/awsbatch_operator.py | 71 +++++++------------ 1 file changed, 26 insertions(+), 45 deletions(-) diff --git a/dagger/dag_creator/airflow/operators/awsbatch_operator.py b/dagger/dag_creator/airflow/operators/awsbatch_operator.py index e052724..3297e85 100644 --- a/dagger/dag_creator/airflow/operators/awsbatch_operator.py +++ b/dagger/dag_creator/airflow/operators/awsbatch_operator.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union +from typing import Any, Optional from airflow.exceptions import AirflowException, TaskDeferralError from airflow.providers.amazon.aws.links.batch import ( @@ -102,55 +102,36 @@ def monitor_job(self, context: Context): self.hook.check_job_success(self.job_id) self.log.info("AWS Batch job (%s) succeeded", self.job_id) - def _fetch_and_log_cloudwatch(self, context: Context, job_id: str) -> tuple[list[str], Optional[str]]: - """ - Fetch CloudWatch logs for the given job_id, log them to Airflow, - and return (last_logs, cloudwatch_link). - """ + def _fetch_and_log_cloudwatch(self, job_id: str) -> tuple[list[str], Optional[str]]: + """Fetch CloudWatch logs for the given job_id and return (last_logs, cloudwatch_link).""" last_logs: list[str] = [] cloudwatch_link: Optional[str] = None - if self.awslogs_enabled: - # Fetch last 50 log messages - try: - log_fetcher = self._get_batch_log_fetcher(job_id) - if log_fetcher: - self.log.info("Fetching the latest 50 messages from CloudWatch:") - last_logs = log_fetcher.get_last_log_messages(50) + if not self.awslogs_enabled: + return last_logs, cloudwatch_link + + # Fetch last log messages + try: + log_fetcher = self._get_batch_log_fetcher(job_id) + if log_fetcher: + last_logs = log_fetcher.get_last_log_messages(50) + if last_logs: + self.log.info("CloudWatch logs (last 50 messages):") for message in last_logs: self.log.info(message) - except Exception as e: - self.log.warning("Could not fetch batch job logs: %s", e) - - # Fetch CloudWatch log link - try: - awslogs = self.hook.get_job_all_awslogs_info(job_id) - except AirflowException as ae: - self.log.warning("Cannot determine where to find the AWS logs: %s", ae) - awslogs = [] - else: - if awslogs: - cloudwatch_link = self._format_cloudwatch_link(**awslogs[0]) - self.log.info("AWS Batch job (%s) CloudWatch Events details found:", job_id) - for log in awslogs: - self.log.info(self._format_cloudwatch_link(**log)) - CloudWatchEventsLink.persist( - context=context, - operator=self, - region_name=self.hook.conn_region_name, - aws_partition=self.hook.conn_partition, - **awslogs[0], - ) + except Exception as e: + self.log.warning("Could not fetch batch job logs: %s", e) - return last_logs, cloudwatch_link + # Get CloudWatch log link + try: + awslogs = self.hook.get_job_all_awslogs_info(job_id) + if awslogs: + cloudwatch_link = self._format_cloudwatch_link(**awslogs[0]) + self.log.info("CloudWatch link: %s", cloudwatch_link) + except AirflowException as e: + self.log.warning("Cannot determine CloudWatch log link: %s", e) - def execute(self, context: Context) -> Union[str, None]: - """Submit and monitor an AWS Batch job, including early failures.""" - # First call parent execute, which will submit the job and possibly defer - result = super().execute(context) - - # If we reach here without exception, the task completed (didn't defer) - return result + return last_logs, cloudwatch_link def execute_complete(self, context: Context, event: Optional[dict[str, Any]] = None) -> str: """Execute when the trigger fires - fetch logs first, then check job status.""" @@ -161,7 +142,7 @@ def execute_complete(self, context: Context, event: Optional[dict[str, Any]] = N self.job_id = job_id # Always fetch logs before checking status - last_logs, cloudwatch_link = self._fetch_and_log_cloudwatch(context, job_id) + last_logs, cloudwatch_link = self._fetch_and_log_cloudwatch(job_id) try: self.hook.check_job_success(job_id) @@ -191,7 +172,7 @@ def resume_execution(self, next_method: str, next_kwargs: Optional[dict[str, Any # When trigger fails, try to fetch logs if job_id is available if hasattr(self, 'job_id') and self.job_id and self.awslogs_enabled: self.log.info("Batch job trigger failed - fetching CloudWatch logs...") - last_logs, cloudwatch_link = self._fetch_and_log_cloudwatch(context, self.job_id) + last_logs, cloudwatch_link = self._fetch_and_log_cloudwatch(self.job_id) # Re-raise with enhanced error message including logs raise AirflowException( _format_extra_info(f"Batch job {self.job_id} failed: {e}", last_logs, cloudwatch_link)