Skip to content
131 changes: 87 additions & 44 deletions dagger/dag_creator/airflow/operators/awsbatch_operator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Optional

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, TaskDeferralError
from airflow.providers.amazon.aws.links.batch import (
BatchJobDefinitionLink,
BatchJobQueueLink,
Expand All @@ -10,6 +10,18 @@
from airflow.utils.context import Context


def _format_extra_info(error_msg: str, last_logs: list[str], cloudwatch_link: Optional[str]) -> str:
"""Format the enhanced error message with logs and link."""
extra_info = []
if cloudwatch_link:
extra_info.append(f"CloudWatch Logs: {cloudwatch_link}")
if last_logs:
extra_info.append("Last log lines:\n" + "\n".join(last_logs[-5:]))
if extra_info:
return f"{error_msg}\n\n" + "\n".join(extra_info)
return error_msg


class AWSBatchOperator(BatchOperator):
@staticmethod
def _format_cloudwatch_link(awslogs_region: str, awslogs_group: str, awslogs_stream_name: str):
Expand Down Expand Up @@ -90,50 +102,81 @@ def monitor_job(self, context: Context):
self.hook.check_job_success(self.job_id)
self.log.info("AWS Batch job (%s) succeeded", self.job_id)

def execute_complete(self, context: Context, event: Optional[dict[str, Any]] = None) -> str:
"""Execute when the trigger fires - fetch logs and complete the task."""
# Call parent's execute_complete first
job_id = super().execute_complete(context, event)

# Only fetch logs if we're in deferrable mode and awslogs are enabled
# In non-deferrable mode, logs are already fetched by monitor_job()
if self.deferrable and self.awslogs_enabled and job_id:
# Set job_id for our log fetching methods
self.job_id = job_id

# Get job logs and display them
try:
# Use the log fetcher to display container logs
log_fetcher = self._get_batch_log_fetcher(job_id)
if log_fetcher:
# Get the last 50 log messages
self.log.info("Fetch the latest 50 messages from cloudwatch:")
log_messages = log_fetcher.get_last_log_messages(50)
for message in log_messages:
def _fetch_and_log_cloudwatch(self, job_id: str) -> tuple[list[str], Optional[str]]:
"""Fetch CloudWatch logs for the given job_id and return (last_logs, cloudwatch_link)."""
last_logs: list[str] = []
cloudwatch_link: Optional[str] = None

if not self.awslogs_enabled:
return last_logs, cloudwatch_link

# Fetch last log messages
try:
log_fetcher = self._get_batch_log_fetcher(job_id)
if log_fetcher:
last_logs = log_fetcher.get_last_log_messages(50)
if last_logs:
self.log.info("CloudWatch logs (last 50 messages):")
for message in last_logs:
self.log.info(message)
except Exception as e:
self.log.warning("Could not fetch batch job logs: %s", e)

# Get CloudWatch log links
awslogs = []
try:
awslogs = self.hook.get_job_all_awslogs_info(self.job_id)
except AirflowException as ae:
self.log.warning("Cannot determine where to find the AWS logs for this Batch job: %s", ae)
except Exception as e:
self.log.warning("Could not fetch batch job logs: %s", e)

# Get CloudWatch log link
try:
awslogs = self.hook.get_job_all_awslogs_info(job_id)
if awslogs:
self.log.info("AWS Batch job (%s) CloudWatch Events details found. Links to logs:", self.job_id)
for log in awslogs:
self.log.info(self._format_cloudwatch_link(**log))

CloudWatchEventsLink.persist(
context=context,
operator=self,
region_name=self.hook.conn_region_name,
aws_partition=self.hook.conn_partition,
**awslogs[0],
)

self.log.info("AWS Batch job (%s) succeeded", self.job_id)

cloudwatch_link = self._format_cloudwatch_link(**awslogs[0])
self.log.info("CloudWatch link: %s", cloudwatch_link)
except AirflowException as e:
self.log.warning("Cannot determine CloudWatch log link: %s", e)

return last_logs, cloudwatch_link

def execute_complete(self, context: Context, event: Optional[dict[str, Any]] = None) -> str:
"""Execute when the trigger fires - fetch logs first, then check job status."""
job_id = event.get("job_id") if event else None
if not job_id:
raise AirflowException("No job_id found in event data from trigger.")

self.job_id = job_id

# Always fetch logs before checking status
last_logs, cloudwatch_link = self._fetch_and_log_cloudwatch(job_id)

try:
self.hook.check_job_success(job_id)
except AirflowException as e:
raise AirflowException(_format_extra_info(str(e), last_logs, cloudwatch_link))

self.log.info("AWS Batch job (%s) succeeded", job_id)
return job_id

def resume_execution(self, next_method: str, next_kwargs: Optional[dict[str, Any]], context: Context):
"""Override resume_execution to handle trigger failures and fetch logs."""
# Retrieve job_id from batch_job_details XCom if not available on the instance
if not hasattr(self, 'job_id') or not self.job_id:
task_instance = context.get('task_instance')
if task_instance:
try:
batch_job_details = task_instance.xcom_pull(task_ids=task_instance.task_id, key='batch_job_details')
if batch_job_details and 'job_id' in batch_job_details:
self.job_id = batch_job_details['job_id']
self.log.info(f"Retrieved job_id from batch_job_details XCom: {self.job_id}")
except Exception as e:
self.log.debug(f"Could not retrieve job_id from batch_job_details XCom: {e}")

try:
return super().resume_execution(next_method, next_kwargs, context)
except TaskDeferralError as e:
# When trigger fails, try to fetch logs if job_id is available
if hasattr(self, 'job_id') and self.job_id and self.awslogs_enabled:
self.log.info("Batch job trigger failed - fetching CloudWatch logs...")
last_logs, cloudwatch_link = self._fetch_and_log_cloudwatch(self.job_id)
# Re-raise with enhanced error message including logs
raise AirflowException(
_format_extra_info(f"Batch job {self.job_id} failed: {e}", last_logs, cloudwatch_link)
)
else:
self.log.warning("Cannot fetch logs for failed batch job - job_id or awslogs_enabled not available")
raise