Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions agent/celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,24 @@
import os

from celery import Celery, shared_task
from celery.signals import worker_shutting_down, worker_shutdown

logger = logging.getLogger(__name__)


@worker_shutting_down.connect
def handle_worker_shutting_down(sig, how, exitcode, **kwargs):
"""Called when worker receives shutdown signal."""
from agent.shutdown import set_shutdown
logger.info(f"Celery worker shutting down (signal={sig}, how={how})")
set_shutdown()


@worker_shutdown.connect
def handle_worker_shutdown(sender, **kwargs):
"""Called after worker has shut down."""
logger.info("Celery worker shutdown complete")

# Set the default Django settings mode for the 'celery' program.
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'agent.settings')

Expand Down
28 changes: 28 additions & 0 deletions agent/shutdown.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""
Graceful shutdown management for Celery workers.

Provides:
- Global shutdown state tracking
- Utility for tasks to check shutdown status
"""
import logging
import threading

logger = logging.getLogger(__name__)

# Thread-safe shutdown state
_shutdown_event = threading.Event()
_shutdown_lock = threading.Lock()


def is_shutting_down():
"""Check if worker is in shutdown mode. Call this in long-running tasks."""
return _shutdown_event.is_set()


def set_shutdown():
"""Mark the worker as shutting down."""
with _shutdown_lock:
if not _shutdown_event.is_set():
logger.info("Graceful shutdown initiated - completing in-flight tasks")
_shutdown_event.set()
6 changes: 6 additions & 0 deletions asset_manager/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from drdroid_debug_toolkit.core.integrations.source_metadata_extractor import SourceMetadataExtractor
from drdroid_debug_toolkit.core.integrations.source_metadata_extractor_facade import source_metadata_extractor_facade
from agent.shutdown import is_shutting_down

logger = logging.getLogger(__name__)

Expand All @@ -32,6 +33,11 @@ def populate_connector_metadata(request_id, connector_name, connector_type, conn
callable(getattr(extractor, method)) and method not in dir(SourceMetadataExtractor)
and method.startswith('extract_')]
for extractor_method in extractor_methods:
# Check for shutdown between extraction methods
if is_shutting_down():
logger.info(f"Shutdown in progress - stopping asset extraction for {connector_name} at method {extractor_method}")
break

logger.info(f"Running method: {extractor_method} for connector: {connector_name}")
try:
extractor_async_method_call(request_id, connector_name, connector_type, connector_credentials_dict,
Expand Down
5 changes: 5 additions & 0 deletions helm/charts/celery_beat/templates/deployment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ spec:
tolerations:
{{- toYaml .Values.global.tolerations | nindent 8 }}
{{- end }}
terminationGracePeriodSeconds: 30
initContainers:
- name: wait-for-redis
image: busybox:1.36
Expand Down Expand Up @@ -100,6 +101,10 @@ spec:
periodSeconds: 5
timeoutSeconds: 3
failureThreshold: 12
lifecycle:
preStop:
exec:
command: ["/bin/sh", "-c", "sleep 5"]
volumes:
- name: credentials-volume
configMap:
Expand Down
13 changes: 13 additions & 0 deletions helm/charts/celery_worker/templates/deployment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ spec:
tolerations:
{{- toYaml .Values.global.tolerations | nindent 8 }}
{{- end }}
terminationGracePeriodSeconds: 30
initContainers:
- name: wait-for-redis
image: busybox:1.36
Expand Down Expand Up @@ -104,6 +105,10 @@ spec:
periodSeconds: 30
timeoutSeconds: 10
failureThreshold: 3
lifecycle:
preStop:
exec:
command: ["/bin/sh", "-c", "sleep 5"]

- name: celery-worker-task-executor # Task executor for high-priority tasks
image: {{ .Values.image.repository }}:{{ .Values.image.tag }}
Expand Down Expand Up @@ -163,6 +168,10 @@ spec:
periodSeconds: 30
timeoutSeconds: 10
failureThreshold: 3
lifecycle:
preStop:
exec:
command: ["/bin/sh", "-c", "sleep 5"]

- name: celery-worker-asset-extractor # Task executor for asset extraction tasks, which run rarely and are long-running
image: {{ .Values.image.repository }}:{{ .Values.image.tag }}
Expand Down Expand Up @@ -222,6 +231,10 @@ spec:
periodSeconds: 30
timeoutSeconds: 10
failureThreshold: 3
lifecycle:
preStop:
exec:
command: ["/bin/sh", "-c", "sleep 5"]

volumes:
- name: credentials-volume
Expand Down
7 changes: 7 additions & 0 deletions playbooks_engine/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from utils.credentilal_utils import credential_yaml_to_connector_proto
from drdroid_debug_toolkit.core.integrations.utils.executor_utils import check_multiple_task_results
from utils.credentilal_utils import credential_yaml_to_connector_proto, generate_credentials_dict
from agent.shutdown import is_shutting_down

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -230,6 +231,12 @@ def fetch_playbook_execution_tasks():

@shared_task(max_retries=3, default_retry_delay=10)
def execute_task_and_send_result(playbook_task_execution_log):
# Check if worker is shutting down
if is_shutting_down():
request_id = playbook_task_execution_log.get('proxy_execution_request_id', 'unknown')
logger.info(f"Worker shutting down - task {request_id} will be requeued")
return False # Task will be requeued due to CELERY_TASK_REJECT_ON_WORKER_LOST

try:
# Check if this is an asset refresh task
task = playbook_task_execution_log.get('task', {})
Expand Down