From 59c014e182f1159c1228382fd972f44fa74483d2 Mon Sep 17 00:00:00 2001 From: Araon Date: Tue, 18 Nov 2025 03:13:10 +0530 Subject: [PATCH 1/3] refactor: update docker-compose paths and updated coordinator service with thread safety and async task submission --- coordinator_service.py | 3 ++ docker-compose.yml | 9 ++-- src/coordinator/app.py | 120 ++++++++++++++++++++++++----------------- 3 files changed, 79 insertions(+), 53 deletions(-) create mode 100644 coordinator_service.py diff --git a/coordinator_service.py b/coordinator_service.py new file mode 100644 index 0000000..6475d90 --- /dev/null +++ b/coordinator_service.py @@ -0,0 +1,3 @@ +from src.coordinator.app import CoordinatorServicer, app + +__all__ = ["CoordinatorServicer", "app"] diff --git a/docker-compose.yml b/docker-compose.yml index 2c0cd8a..ee5471f 100755 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,4 +1,3 @@ -version: "3" services: postgres: build: @@ -15,8 +14,8 @@ services: scheduler: build: - context: ./src/scheduler - dockerfile: scheduler.dockerfile + context: . + dockerfile: ./src/scheduler/scheduler.dockerfile ports: - "5000:5000" depends_on: @@ -24,8 +23,8 @@ services: coordinator: build: - context: ./src/coordinator - dockerfile: coordinator.dockerfile + context: . + dockerfile: ./src/coordinator/coordinator.dockerfile ports: - "5001:5001" depends_on: diff --git a/src/coordinator/app.py b/src/coordinator/app.py index 4fb482f..df2ad27 100755 --- a/src/coordinator/app.py +++ b/src/coordinator/app.py @@ -2,6 +2,7 @@ from datetime import datetime, timedelta import time import threading +from concurrent.futures import ThreadPoolExecutor from flask import Flask, request, jsonify import json from sqlalchemy import create_engine, DateTime, Column, Integer, String @@ -44,17 +45,20 @@ class Tasks(Base): class CoordinatorServicer: def __init__(self): self.registered_workers = {} + self.lock = threading.Lock() self.heartbeat_timeout = HEARTBEAT_TIMEOUT self.last_assigned_worker_index = -1 self.fetch_tasks_interval = 5 self.heartbeatInterval = 3 + # Executor for concurrent network calls (heartbeats / submissions) + self.executor = ThreadPoolExecutor(max_workers=20) + self.fetch_tasks_thread = threading.Thread(target=self.fetch_tasks_periodically, daemon=True) self.fetch_tasks_thread.start() self.heartbeat_check_thread = threading.Thread(target=self.check_heartbeats, daemon=True) - self.heartbeat_check_thread.start() def register_worker(self, request): @@ -67,69 +71,85 @@ def register_worker(self, request): - optional worker metadata ''' - if worker_id in self.registered_workers: - logger.info('Worker is already registered') - response_data = {"success": True, "message": f"Worker {worker_id} already registered."} - return json.dumps(response_data) - else: - self.registered_workers[worker_id] = { - "lastHeartBeatTime": time.time(), - "heartBeatMissed": 0, - "workerIp": request.json['ip'], - "workerPort": request.json['port'].replace(":", ""), - "metadata": request.json['metadata'] - } - print(self.registered_workers) - response_data = {"success": True, "message": f"Worker {worker_id} registered."} - logger.info(f'Worker: {worker_id} registered.') - return json.dumps(response_data) + with self.lock: + if worker_id in self.registered_workers: + logger.info('Worker is already registered') + response_data = {"success": True, "message": f"Worker {worker_id} already registered."} + return json.dumps(response_data) + else: + self.registered_workers[worker_id] = { + "lastHeartBeatTime": time.time(), + "heartBeatMissed": 0, + "workerIp": request.json['ip'], + "workerPort": request.json['port'].replace(":", ""), + "metadata": request.json['metadata'] + } + print(self.registered_workers) + response_data = {"success": True, "message": f"Worker {worker_id} registered."} + logger.info(f'Worker: {worker_id} registered.') + return json.dumps(response_data) def unregister_worker(self, worker_id): - if worker_id in self.registered_workers: - del self.registered_workers[worker_id] - logger.info(f"Worker {worker_id} unregistered, Missed heartbeat {HEARTBEAT_TIMEOUT}") - else: - logger.info(f'Worker {worker_id} is not found') + with self.lock: + if worker_id in self.registered_workers: + del self.registered_workers[worker_id] + logger.info(f"Worker {worker_id} unregistered, Missed heartbeat {HEARTBEAT_TIMEOUT}") + else: + logger.info(f'Worker {worker_id} is not found') def sendHeartBeat(self, worker_id): - worker_info = self.registered_workers[worker_id] + with self.lock: + worker_info = self.registered_workers.get(worker_id) + + if not worker_info: + return worker_ip = worker_info['workerIp'] worker_port = worker_info['workerPort'] url = f'http://{worker_ip}:{worker_port}/heartbeat' try: - response = requests.get(url) - - if response.status_code == 200: - current_time = time.time() - - self.registered_workers[worker_id]["lastHeartBeatTime"] = current_time - self.registered_workers[worker_id]["heartBeatMissed"] = 0 + response = requests.get(url, timeout=5) - else: - logger.info(f'Worker {worker_id} did not responce to heartbeat') - self.registered_workers[worker_id]['heartBeatMissed'] += 1 + with self.lock: + if response.status_code == 200: + current_time = time.time() + self.registered_workers[worker_id]["lastHeartBeatTime"] = current_time + self.registered_workers[worker_id]["heartBeatMissed"] = 0 + else: + logger.info(f'Worker {worker_id} did not responce to heartbeat') + self.registered_workers[worker_id]['heartBeatMissed'] += 1 except Exception as e: logger.error(f'Error occurred while sending heartbeat to worker: {worker_id} - {str(e)}') - self.registered_workers[worker_id]['heartBeatMissed'] += 1 # Increment missed count on error + with self.lock: + if worker_id in self.registered_workers: + self.registered_workers[worker_id]['heartBeatMissed'] += 1 # Increment missed count on error def check_heartbeats(self): while True: - for workerId in list(self.registered_workers.keys()): - heartBeatMissed = self.registered_workers[workerId].get('heartBeatMissed') + with self.lock: + worker_ids = list(self.registered_workers.keys()) - if heartBeatMissed >= HEARTBEAT_TIMEOUT: + for workerId in worker_ids: + with self.lock: + heartBeatMissed = self.registered_workers.get(workerId, {}).get('heartBeatMissed') + + if heartBeatMissed is not None and heartBeatMissed >= HEARTBEAT_TIMEOUT: self.unregister_worker(workerId) break - self.sendHeartBeat(workerId) + # Send heartbeat concurrently + self.executor.submit(self.sendHeartBeat, workerId) + time.sleep(self.heartbeatInterval) def update_worker_status(self, request): worker_id = request.json['worker_id'] - if worker_id in self.registered_workers: + with self.lock: + exists = worker_id in self.registered_workers + + if exists: response_data = {"acknowledged": True} else: response_data = {"acknowledged": False} @@ -185,14 +205,18 @@ def submit_task(self, task): 'command': task.command } - try: - response = requests.post(url, json=payload) - if response.status_code == 200: - logger.info(f'Task {task.id} submitted to {selected_worker}') - else: - logger.error(f'Faild to submit task {task.id} to {selected_worker}') - except Exception as e: - logger.error(f'Error occurred while submitting task {task.id} to {selected_worker}: {str(e)}') + # Submit task asynchronously so coordinator doesn't block on slow workers + def _post_task(u, p, worker): + try: + response = requests.post(u, json=p, timeout=10) + if response.status_code == 200: + logger.info(f'Task {task.id} submitted to {worker}') + else: + logger.error(f'Faild to submit task {task.id} to {worker} - status {response.status_code}') + except Exception as e: + logger.error(f'Error occurred while submitting task {task.id} to {worker}: {str(e)}') + + self.executor.submit(_post_task, url, payload, selected_worker) else: logger.error('No worker present') @@ -220,7 +244,7 @@ def update_picked_at(self, task_id): def update_job_status(self, request): task_id = request.json['task_id'] status = request.json['status'] - current_time = datetime.utcnow().isoformat() + current_time = datetime.utcnow() logger.info(f'Updating task status for task_id: {task_id}') with Session() as session: From 93819bfb6547850c09410cfd145a53812be46c5b Mon Sep 17 00:00:00 2001 From: Araon Date: Tue, 18 Nov 2025 03:30:38 +0530 Subject: [PATCH 2/3] feat: updated coordinator service with session management and improve test integration --- config.py | 12 ++++ coordinator_service.py | 4 +- docker-compose.yml | 23 ++++++++ src/coordinator/app.py | 114 +++++++++++++++++++++++++++++--------- src/scheduler/app.py | 10 +++- tests/intigration_test.py | 21 ++++--- 6 files changed, 146 insertions(+), 38 deletions(-) create mode 100644 config.py diff --git a/config.py b/config.py new file mode 100644 index 0000000..e393ed7 --- /dev/null +++ b/config.py @@ -0,0 +1,12 @@ +import os + +POSTGRES_USER = os.environ.get("POSTGRES_USER") or "task_rw" +POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "postgres" +POSTGRES_PASSWORD = os.environ.get("POSTGRES_PASSWORD") or "task123" +POSTGRES_DB = os.environ.get("POSTGRES_DB") or "tasks" +POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or 5432 +SQLALCHEMY_DATABASE_URI = os.environ.get("SQLALCHEMY_DATABASE_URI") +if not SQLALCHEMY_DATABASE_URI: + SQLALCHEMY_DATABASE_URI = ( + f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}:{POSTGRES_PORT}/{POSTGRES_DB}" + ) diff --git a/coordinator_service.py b/coordinator_service.py index 6475d90..ec1d004 100644 --- a/coordinator_service.py +++ b/coordinator_service.py @@ -1,3 +1,3 @@ -from src.coordinator.app import CoordinatorServicer, app +from src.coordinator.app import CoordinatorServicer, app, Session, request -__all__ = ["CoordinatorServicer", "app"] +__all__ = ["CoordinatorServicer", "app", "Session", "request"] diff --git a/docker-compose.yml b/docker-compose.yml index ee5471f..7fa4e77 100755 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -38,3 +38,26 @@ services: COORDINATOR_URL: "http://coordinator:5001" depends_on: - coordinator + + tests: + build: + context: . + dockerfile: ./src/scheduler/scheduler.dockerfile + depends_on: + - postgres + - scheduler + - coordinator + environment: + - POSTGRES_USER=${POSTGRES_USER} + - POSTGRES_PASSWORD=${POSTGRES_PASSWORD} + - POSTGRES_DB=${POSTGRES_DB} + - POSTGRES_PORT=${POSTGRES_PORT} + - POSTGRES_HOST=postgres + - PYTHONPATH=/app + - 'SQLALCHEMY_DATABASE_URI=sqlite:///:memory:' + - 'SKIP_NETWORK_CALLS=1' + - 'BASE_URL=http://scheduler:5000' + command: pytest -q + volumes: + - ./:/app + working_dir: /app diff --git a/src/coordinator/app.py b/src/coordinator/app.py index df2ad27..68e6dfd 100755 --- a/src/coordinator/app.py +++ b/src/coordinator/app.py @@ -1,4 +1,5 @@ import logging +import os from datetime import datetime, timedelta import time import threading @@ -42,6 +43,19 @@ class Tasks(Base): Session = sessionmaker(bind=engine) +def _get_session_factory(): + """Return the Session factory. Prefer coordinator_service.Session if tests have patched it.""" + try: + import coordinator_service as cs + svc_session = getattr(cs, 'Session', None) + if svc_session: + return svc_session + except Exception: + pass + + return Session + + class CoordinatorServicer: def __init__(self): self.registered_workers = {} @@ -62,7 +76,13 @@ def __init__(self): self.heartbeat_check_thread.start() def register_worker(self, request): - worker_id = request.json['worker_id'] + # support both Flask request (request.json) and test mocks where request.json is callable + if isinstance(request, dict): + json_data = request + else: + json_data = request.json() if callable(getattr(request, 'json', None)) else request.json + + worker_id = json_data['worker_id'] ''' Worker data that needs to be saved on registration - worker_id -> key @@ -76,18 +96,19 @@ def register_worker(self, request): logger.info('Worker is already registered') response_data = {"success": True, "message": f"Worker {worker_id} already registered."} return json.dumps(response_data) - else: - self.registered_workers[worker_id] = { - "lastHeartBeatTime": time.time(), - "heartBeatMissed": 0, - "workerIp": request.json['ip'], - "workerPort": request.json['port'].replace(":", ""), - "metadata": request.json['metadata'] - } - print(self.registered_workers) - response_data = {"success": True, "message": f"Worker {worker_id} registered."} - logger.info(f'Worker: {worker_id} registered.') - return json.dumps(response_data) + + self.registered_workers[worker_id] = { + "lastHeartBeatTime": time.time(), + "heartBeatMissed": 0, + "workerIp": json_data['ip'], + "workerPort": str(json_data['port']).replace(":", ""), + "metadata": json_data.get('metadata') + } + + print(self.registered_workers) + response_data = {"success": True, "message": f"Worker {worker_id} registered."} + logger.info(f'Worker: {worker_id} registered.') + return json.dumps(response_data) def unregister_worker(self, worker_id): with self.lock: @@ -108,6 +129,14 @@ def sendHeartBeat(self, worker_id): url = f'http://{worker_ip}:{worker_port}/heartbeat' + # Allow tests to skip real network heartbeats by setting SKIP_NETWORK_CALLS + if os.environ.get("SKIP_NETWORK_CALLS", "0") in ("1", "true", "True"): + with self.lock: + if worker_id in self.registered_workers: + self.registered_workers[worker_id]["lastHeartBeatTime"] = time.time() + self.registered_workers[worker_id]["heartBeatMissed"] = 0 + return + try: response = requests.get(url, timeout=5) @@ -145,7 +174,13 @@ def check_heartbeats(self): time.sleep(self.heartbeatInterval) def update_worker_status(self, request): - worker_id = request.json['worker_id'] + if isinstance(request, dict): + json_data = request + else: + json_data = request.json() if callable(getattr(request, 'json', None)) else request.json + + worker_id = json_data['worker_id'] + with self.lock: exists = worker_id in self.registered_workers @@ -161,20 +196,27 @@ def fetch_tasks_periodically(self): time.sleep(self.fetch_tasks_interval) def fetch_tasks(self): - with Session() as session: - + session = _get_session_factory()() + try: thirty_secounds_delta = datetime.utcnow() + timedelta(seconds=30) tasks = ( session.query(Tasks) - .filter(Tasks.scheduled_at >= datetime.utcnow()) - .filter(Tasks.scheduled_at <= thirty_secounds_delta) - .filter(Tasks.picked_at.is_(None)) + .filter( + Tasks.scheduled_at >= datetime.utcnow(), + Tasks.scheduled_at <= thirty_secounds_delta, + Tasks.picked_at.is_(None), + ) .order_by(Tasks.scheduled_at) .limit(TASK_PICKED_LIMIT) .with_for_update(skip_locked=True) .all() ) + finally: + try: + session.close() + except Exception: + pass if tasks and len(tasks) > 0: for task in tasks: @@ -222,10 +264,9 @@ def _post_task(u, p, worker): logger.error('No worker present') def update_picked_at(self, task_id): - - with Session() as session: + session = _get_session_factory()() + try: try: - task = session.query(Tasks).filter_by(id=task_id).first() if not task: @@ -240,18 +281,32 @@ def update_picked_at(self, task_id): except SQLAlchemyError as e: logger.error(f'Error updating picked_at for task {task_id}: {str(e)}') return + finally: + try: + session.close() + except Exception: + pass def update_job_status(self, request): - task_id = request.json['task_id'] - status = request.json['status'] + # support dict input (unit tests) or Flask request + if isinstance(request, dict): + json_data = request + else: + json_data = request.json() if callable(getattr(request, 'json', None)) else request.json + + task_id = json_data['task_id'] + status = json_data['status'] current_time = datetime.utcnow() logger.info(f'Updating task status for task_id: {task_id}') - with Session() as session: + session = _get_session_factory()() + try: task = session.query(Tasks).filter_by(id=task_id).first() if not task: task_status = {"success": False, "message": f"Task {task_id} not found"} + if isinstance(request, dict): + return json.dumps(task_status) return jsonify(task_status), 404 if status == "STARTED": @@ -262,12 +317,21 @@ def update_job_status(self, request): task.failed_at = current_time else: task_status = {"success": False, "message": f"Invalid task status for {task_id}"} + if isinstance(request, dict): + return json.dumps(task_status) return jsonify(task_status), 404 session.commit() task_status = {"success": True, "message": f"Task {task_id} updated successfully"} logger.info(f'{task_id} updated with status {status} at {current_time}') + if isinstance(request, dict): + return json.dumps(task_status) return jsonify(task_status), 200 + finally: + try: + session.close() + except Exception: + pass coordinator_servicer = CoordinatorServicer() diff --git a/src/scheduler/app.py b/src/scheduler/app.py index eb321d2..ddaf80d 100755 --- a/src/scheduler/app.py +++ b/src/scheduler/app.py @@ -12,9 +12,13 @@ app = Flask(__name__) app.config['SQLALCHEMY_DATABASE_URI'] = SQLALCHEMY_DATABASE_URI +# Initialize extensions without binding to app immediately to allow test overrides +db = SQLAlchemy() +migrate = Migrate() -db = SQLAlchemy(app) -migrate = Migrate(app, db) +# Bind extensions to the app +db.init_app(app) +migrate.init_app(app, db) logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) @@ -93,7 +97,7 @@ def get_schedule(task_id): 'failed_at': task.failed_at if task.failed_at else None } logger.info(f'Task with ID {task_id} retrieved successfully.') - return jsonify(task_data), 200 + return jsonify({'task': task_data}), 200 if __name__ == "__main__": diff --git a/tests/intigration_test.py b/tests/intigration_test.py index f8af920..684ad25 100755 --- a/tests/intigration_test.py +++ b/tests/intigration_test.py @@ -1,26 +1,31 @@ +import os import pytest import requests from datetime import datetime, timedelta -from your_scheduler_module import app, db, Tasks +from src.scheduler.app import app, db, Tasks @pytest.fixture -def base_url(): - return 'http://localhost:5000' +def client(): + app.config['TESTING'] = True + app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///:memory:' + with app.test_client() as client: + with app.app_context(): + db.create_all() + yield client -def test_integration(base_url): +def test_integration(client): data = { 'command': 'echo "Hello World"', 'scheduled_at': (datetime.utcnow() + timedelta(minutes=5)).isoformat() } - # Schedule a task - response = requests.post(f'{base_url}/schedule', json=data) + + response = client.post('/schedule', json=data) assert response.status_code == 201 task_id = response.json['task_id'] - # Retrieve the scheduled task - response = requests.get(f'{base_url}/schedule/{task_id}') + response = client.get(f'/schedule/{task_id}') assert response.status_code == 200 assert 'task' in response.json assert response.json['task']['id'] == task_id From 393f322a056b6a6f1c8cfd66c9b750e727a07b29 Mon Sep 17 00:00:00 2001 From: Araon Date: Tue, 18 Nov 2025 03:33:36 +0530 Subject: [PATCH 3/3] fix: correct typo in taskHandler log message --- src/worker/worker.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/worker/worker.go b/src/worker/worker.go index 59fd615..4ff1c1c 100644 --- a/src/worker/worker.go +++ b/src/worker/worker.go @@ -71,7 +71,7 @@ func taskHandler(w http.ResponseWriter, r *http.Request) { return } - fmt.Printf("Task recevied Id: %s\n", strings.Join(strings.Split(task.Id, "-"), "")) + fmt.Printf("Task received Id: %s\n", strings.Join(strings.Split(task.Id, "-"), "")) fmt.Printf("Command: %s\n", task.Command) if !isAllowedCommand(task.Command) {