diff --git a/agent/celery.py b/agent/celery.py index c01ec88..8aa5872 100644 --- a/agent/celery.py +++ b/agent/celery.py @@ -4,9 +4,24 @@ import os from celery import Celery, shared_task +from celery.signals import worker_shutting_down, worker_shutdown logger = logging.getLogger(__name__) + +@worker_shutting_down.connect +def handle_worker_shutting_down(sig, how, exitcode, **kwargs): + """Called when worker receives shutdown signal.""" + from agent.shutdown import set_shutdown + logger.info(f"Celery worker shutting down (signal={sig}, how={how})") + set_shutdown() + + +@worker_shutdown.connect +def handle_worker_shutdown(sender, **kwargs): + """Called after worker has shut down.""" + logger.info("Celery worker shutdown complete") + # Set the default Django settings mode for the 'celery' program. os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'agent.settings') diff --git a/agent/shutdown.py b/agent/shutdown.py new file mode 100644 index 0000000..6297d75 --- /dev/null +++ b/agent/shutdown.py @@ -0,0 +1,28 @@ +""" +Graceful shutdown management for Celery workers. + +Provides: +- Global shutdown state tracking +- Utility for tasks to check shutdown status +""" +import logging +import threading + +logger = logging.getLogger(__name__) + +# Thread-safe shutdown state +_shutdown_event = threading.Event() +_shutdown_lock = threading.Lock() + + +def is_shutting_down(): + """Check if worker is in shutdown mode. Call this in long-running tasks.""" + return _shutdown_event.is_set() + + +def set_shutdown(): + """Mark the worker as shutting down.""" + with _shutdown_lock: + if not _shutdown_event.is_set(): + logger.info("Graceful shutdown initiated - completing in-flight tasks") + _shutdown_event.set() diff --git a/asset_manager/tasks.py b/asset_manager/tasks.py index 153709c..5b29da6 100644 --- a/asset_manager/tasks.py +++ b/asset_manager/tasks.py @@ -6,6 +6,7 @@ from drdroid_debug_toolkit.core.integrations.source_metadata_extractor import SourceMetadataExtractor from drdroid_debug_toolkit.core.integrations.source_metadata_extractor_facade import source_metadata_extractor_facade +from agent.shutdown import is_shutting_down logger = logging.getLogger(__name__) @@ -32,6 +33,11 @@ def populate_connector_metadata(request_id, connector_name, connector_type, conn callable(getattr(extractor, method)) and method not in dir(SourceMetadataExtractor) and method.startswith('extract_')] for extractor_method in extractor_methods: + # Check for shutdown between extraction methods + if is_shutting_down(): + logger.info(f"Shutdown in progress - stopping asset extraction for {connector_name} at method {extractor_method}") + break + logger.info(f"Running method: {extractor_method} for connector: {connector_name}") try: extractor_async_method_call(request_id, connector_name, connector_type, connector_credentials_dict, diff --git a/helm/charts/celery_beat/templates/deployment.yaml b/helm/charts/celery_beat/templates/deployment.yaml index 7aa9094..92a3709 100644 --- a/helm/charts/celery_beat/templates/deployment.yaml +++ b/helm/charts/celery_beat/templates/deployment.yaml @@ -25,6 +25,7 @@ spec: tolerations: {{- toYaml .Values.global.tolerations | nindent 8 }} {{- end }} + terminationGracePeriodSeconds: 30 initContainers: - name: wait-for-redis image: busybox:1.36 @@ -100,6 +101,10 @@ spec: periodSeconds: 5 timeoutSeconds: 3 failureThreshold: 12 + lifecycle: + preStop: + exec: + command: ["/bin/sh", "-c", "sleep 5"] volumes: - name: credentials-volume configMap: diff --git a/helm/charts/celery_worker/templates/deployment.yaml b/helm/charts/celery_worker/templates/deployment.yaml index 32e84c3..a03c092 100644 --- a/helm/charts/celery_worker/templates/deployment.yaml +++ b/helm/charts/celery_worker/templates/deployment.yaml @@ -25,6 +25,7 @@ spec: tolerations: {{- toYaml .Values.global.tolerations | nindent 8 }} {{- end }} + terminationGracePeriodSeconds: 30 initContainers: - name: wait-for-redis image: busybox:1.36 @@ -104,6 +105,10 @@ spec: periodSeconds: 30 timeoutSeconds: 10 failureThreshold: 3 + lifecycle: + preStop: + exec: + command: ["/bin/sh", "-c", "sleep 5"] - name: celery-worker-task-executor # Task executor for high-priority tasks image: {{ .Values.image.repository }}:{{ .Values.image.tag }} @@ -163,6 +168,10 @@ spec: periodSeconds: 30 timeoutSeconds: 10 failureThreshold: 3 + lifecycle: + preStop: + exec: + command: ["/bin/sh", "-c", "sleep 5"] - name: celery-worker-asset-extractor # Task executor for asset extraction tasks, which run rarely and are long-running image: {{ .Values.image.repository }}:{{ .Values.image.tag }} @@ -222,6 +231,10 @@ spec: periodSeconds: 30 timeoutSeconds: 10 failureThreshold: 3 + lifecycle: + preStop: + exec: + command: ["/bin/sh", "-c", "sleep 5"] volumes: - name: credentials-volume diff --git a/playbooks_engine/tasks.py b/playbooks_engine/tasks.py index 82a73cf..300f73b 100644 --- a/playbooks_engine/tasks.py +++ b/playbooks_engine/tasks.py @@ -15,6 +15,7 @@ from utils.credentilal_utils import credential_yaml_to_connector_proto from drdroid_debug_toolkit.core.integrations.utils.executor_utils import check_multiple_task_results from utils.credentilal_utils import credential_yaml_to_connector_proto, generate_credentials_dict +from agent.shutdown import is_shutting_down logger = logging.getLogger(__name__) @@ -230,6 +231,12 @@ def fetch_playbook_execution_tasks(): @shared_task(max_retries=3, default_retry_delay=10) def execute_task_and_send_result(playbook_task_execution_log): + # Check if worker is shutting down + if is_shutting_down(): + request_id = playbook_task_execution_log.get('proxy_execution_request_id', 'unknown') + logger.info(f"Worker shutting down - task {request_id} will be requeued") + return False # Task will be requeued due to CELERY_TASK_REJECT_ON_WORKER_LOST + try: # Check if this is an asset refresh task task = playbook_task_execution_log.get('task', {})