Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import os

POSTGRES_USER = os.environ.get("POSTGRES_USER") or "task_rw"
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "postgres"
POSTGRES_PASSWORD = os.environ.get("POSTGRES_PASSWORD") or "task123"
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "tasks"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or 5432
SQLALCHEMY_DATABASE_URI = os.environ.get("SQLALCHEMY_DATABASE_URI")
if not SQLALCHEMY_DATABASE_URI:
SQLALCHEMY_DATABASE_URI = (
f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}:{POSTGRES_PORT}/{POSTGRES_DB}"
)
3 changes: 3 additions & 0 deletions coordinator_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from src.coordinator.app import CoordinatorServicer, app, Session, request

__all__ = ["CoordinatorServicer", "app", "Session", "request"]
32 changes: 27 additions & 5 deletions docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
version: "3"
services:
postgres:
build:
Expand All @@ -15,17 +14,17 @@ services:

scheduler:
build:
context: ./src/scheduler
dockerfile: scheduler.dockerfile
context: .
dockerfile: ./src/scheduler/scheduler.dockerfile
ports:
- "5000:5000"
depends_on:
- postgres

coordinator:
build:
context: ./src/coordinator
dockerfile: coordinator.dockerfile
context: .
dockerfile: ./src/coordinator/coordinator.dockerfile
ports:
- "5001:5001"
depends_on:
Expand All @@ -39,3 +38,26 @@ services:
COORDINATOR_URL: "http://coordinator:5001"
depends_on:
- coordinator

tests:
build:
context: .
dockerfile: ./src/scheduler/scheduler.dockerfile
depends_on:
- postgres
- scheduler
- coordinator
environment:
- POSTGRES_USER=${POSTGRES_USER}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD}
- POSTGRES_DB=${POSTGRES_DB}
- POSTGRES_PORT=${POSTGRES_PORT}
- POSTGRES_HOST=postgres
- PYTHONPATH=/app
- 'SQLALCHEMY_DATABASE_URI=sqlite:///:memory:'
- 'SKIP_NETWORK_CALLS=1'
- 'BASE_URL=http://scheduler:5000'
command: pytest -q
volumes:
- ./:/app
working_dir: /app
202 changes: 145 additions & 57 deletions src/coordinator/app.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import logging
import os
from datetime import datetime, timedelta
import time
import threading
from concurrent.futures import ThreadPoolExecutor
from flask import Flask, request, jsonify
import json
from sqlalchemy import create_engine, DateTime, Column, Integer, String
Expand Down Expand Up @@ -41,24 +43,46 @@ class Tasks(Base):
Session = sessionmaker(bind=engine)


def _get_session_factory():
"""Return the Session factory. Prefer coordinator_service.Session if tests have patched it."""
try:
import coordinator_service as cs
svc_session = getattr(cs, 'Session', None)
if svc_session:
return svc_session
except Exception:
pass

return Session


class CoordinatorServicer:
def __init__(self):
self.registered_workers = {}
self.lock = threading.Lock()
self.heartbeat_timeout = HEARTBEAT_TIMEOUT
self.last_assigned_worker_index = -1

self.fetch_tasks_interval = 5
self.heartbeatInterval = 3

# Executor for concurrent network calls (heartbeats / submissions)
self.executor = ThreadPoolExecutor(max_workers=20)

self.fetch_tasks_thread = threading.Thread(target=self.fetch_tasks_periodically, daemon=True)
self.fetch_tasks_thread.start()

self.heartbeat_check_thread = threading.Thread(target=self.check_heartbeats, daemon=True)

self.heartbeat_check_thread.start()

def register_worker(self, request):
worker_id = request.json['worker_id']
# support both Flask request (request.json) and test mocks where request.json is callable
if isinstance(request, dict):
json_data = request
else:
json_data = request.json() if callable(getattr(request, 'json', None)) else request.json

worker_id = json_data['worker_id']
'''
Worker data that needs to be saved on registration
- worker_id -> key
Expand All @@ -67,69 +91,100 @@ def register_worker(self, request):
- optional worker metadata
'''

if worker_id in self.registered_workers:
logger.info('Worker is already registered')
response_data = {"success": True, "message": f"Worker {worker_id} already registered."}
return json.dumps(response_data)
else:
with self.lock:
if worker_id in self.registered_workers:
logger.info('Worker is already registered')
response_data = {"success": True, "message": f"Worker {worker_id} already registered."}
return json.dumps(response_data)

self.registered_workers[worker_id] = {
"lastHeartBeatTime": time.time(),
"heartBeatMissed": 0,
"workerIp": request.json['ip'],
"workerPort": request.json['port'].replace(":", ""),
"metadata": request.json['metadata']
"workerIp": json_data['ip'],
"workerPort": str(json_data['port']).replace(":", ""),
"metadata": json_data.get('metadata')
}
print(self.registered_workers)
response_data = {"success": True, "message": f"Worker {worker_id} registered."}
logger.info(f'Worker: {worker_id} registered.')
return json.dumps(response_data)

print(self.registered_workers)
response_data = {"success": True, "message": f"Worker {worker_id} registered."}
logger.info(f'Worker: {worker_id} registered.')
return json.dumps(response_data)

def unregister_worker(self, worker_id):
if worker_id in self.registered_workers:
del self.registered_workers[worker_id]
logger.info(f"Worker {worker_id} unregistered, Missed heartbeat {HEARTBEAT_TIMEOUT}")
else:
logger.info(f'Worker {worker_id} is not found')
with self.lock:
if worker_id in self.registered_workers:
del self.registered_workers[worker_id]
logger.info(f"Worker {worker_id} unregistered, Missed heartbeat {HEARTBEAT_TIMEOUT}")
else:
logger.info(f'Worker {worker_id} is not found')

def sendHeartBeat(self, worker_id):
worker_info = self.registered_workers[worker_id]
with self.lock:
worker_info = self.registered_workers.get(worker_id)

if not worker_info:
return
worker_ip = worker_info['workerIp']
worker_port = worker_info['workerPort']

url = f'http://{worker_ip}:{worker_port}/heartbeat'

try:
response = requests.get(url)

if response.status_code == 200:
current_time = time.time()
# Allow tests to skip real network heartbeats by setting SKIP_NETWORK_CALLS
if os.environ.get("SKIP_NETWORK_CALLS", "0") in ("1", "true", "True"):
with self.lock:
if worker_id in self.registered_workers:
self.registered_workers[worker_id]["lastHeartBeatTime"] = time.time()
self.registered_workers[worker_id]["heartBeatMissed"] = 0
return

self.registered_workers[worker_id]["lastHeartBeatTime"] = current_time
self.registered_workers[worker_id]["heartBeatMissed"] = 0
try:
response = requests.get(url, timeout=5)

else:
logger.info(f'Worker {worker_id} did not responce to heartbeat')
self.registered_workers[worker_id]['heartBeatMissed'] += 1
with self.lock:
if response.status_code == 200:
current_time = time.time()
self.registered_workers[worker_id]["lastHeartBeatTime"] = current_time
self.registered_workers[worker_id]["heartBeatMissed"] = 0
else:
logger.info(f'Worker {worker_id} did not responce to heartbeat')
self.registered_workers[worker_id]['heartBeatMissed'] += 1
except Exception as e:
logger.error(f'Error occurred while sending heartbeat to worker: {worker_id} - {str(e)}')
self.registered_workers[worker_id]['heartBeatMissed'] += 1 # Increment missed count on error
with self.lock:
if worker_id in self.registered_workers:
self.registered_workers[worker_id]['heartBeatMissed'] += 1 # Increment missed count on error

def check_heartbeats(self):

while True:
for workerId in list(self.registered_workers.keys()):
heartBeatMissed = self.registered_workers[workerId].get('heartBeatMissed')
with self.lock:
worker_ids = list(self.registered_workers.keys())

for workerId in worker_ids:
with self.lock:
heartBeatMissed = self.registered_workers.get(workerId, {}).get('heartBeatMissed')

if heartBeatMissed >= HEARTBEAT_TIMEOUT:
if heartBeatMissed is not None and heartBeatMissed >= HEARTBEAT_TIMEOUT:
self.unregister_worker(workerId)
break

self.sendHeartBeat(workerId)
# Send heartbeat concurrently
self.executor.submit(self.sendHeartBeat, workerId)

time.sleep(self.heartbeatInterval)

def update_worker_status(self, request):
worker_id = request.json['worker_id']
if worker_id in self.registered_workers:
if isinstance(request, dict):
json_data = request
else:
json_data = request.json() if callable(getattr(request, 'json', None)) else request.json

worker_id = json_data['worker_id']

with self.lock:
exists = worker_id in self.registered_workers

if exists:
response_data = {"acknowledged": True}
else:
response_data = {"acknowledged": False}
Expand All @@ -141,20 +196,27 @@ def fetch_tasks_periodically(self):
time.sleep(self.fetch_tasks_interval)

def fetch_tasks(self):
with Session() as session:

session = _get_session_factory()()
try:
thirty_secounds_delta = datetime.utcnow() + timedelta(seconds=30)

tasks = (
session.query(Tasks)
.filter(Tasks.scheduled_at >= datetime.utcnow())
.filter(Tasks.scheduled_at <= thirty_secounds_delta)
.filter(Tasks.picked_at.is_(None))
.filter(
Tasks.scheduled_at >= datetime.utcnow(),
Tasks.scheduled_at <= thirty_secounds_delta,
Tasks.picked_at.is_(None),
)
.order_by(Tasks.scheduled_at)
.limit(TASK_PICKED_LIMIT)
.with_for_update(skip_locked=True)
.all()
)
finally:
try:
session.close()
except Exception:
pass

if tasks and len(tasks) > 0:
for task in tasks:
Expand Down Expand Up @@ -185,23 +247,26 @@ def submit_task(self, task):
'command': task.command
}

try:
response = requests.post(url, json=payload)
if response.status_code == 200:
logger.info(f'Task {task.id} submitted to {selected_worker}')
else:
logger.error(f'Faild to submit task {task.id} to {selected_worker}')
except Exception as e:
logger.error(f'Error occurred while submitting task {task.id} to {selected_worker}: {str(e)}')
# Submit task asynchronously so coordinator doesn't block on slow workers
def _post_task(u, p, worker):
try:
response = requests.post(u, json=p, timeout=10)
if response.status_code == 200:
logger.info(f'Task {task.id} submitted to {worker}')
else:
logger.error(f'Faild to submit task {task.id} to {worker} - status {response.status_code}')
except Exception as e:
logger.error(f'Error occurred while submitting task {task.id} to {worker}: {str(e)}')

self.executor.submit(_post_task, url, payload, selected_worker)

else:
logger.error('No worker present')

def update_picked_at(self, task_id):

with Session() as session:
session = _get_session_factory()()
try:
try:

task = session.query(Tasks).filter_by(id=task_id).first()

if not task:
Expand All @@ -216,18 +281,32 @@ def update_picked_at(self, task_id):
except SQLAlchemyError as e:
logger.error(f'Error updating picked_at for task {task_id}: {str(e)}')
return
finally:
try:
session.close()
except Exception:
pass

def update_job_status(self, request):
task_id = request.json['task_id']
status = request.json['status']
current_time = datetime.utcnow().isoformat()
# support dict input (unit tests) or Flask request
if isinstance(request, dict):
json_data = request
else:
json_data = request.json() if callable(getattr(request, 'json', None)) else request.json

task_id = json_data['task_id']
status = json_data['status']
current_time = datetime.utcnow()
logger.info(f'Updating task status for task_id: {task_id}')

with Session() as session:
session = _get_session_factory()()
try:
task = session.query(Tasks).filter_by(id=task_id).first()

if not task:
task_status = {"success": False, "message": f"Task {task_id} not found"}
if isinstance(request, dict):
return json.dumps(task_status)
return jsonify(task_status), 404

if status == "STARTED":
Expand All @@ -238,12 +317,21 @@ def update_job_status(self, request):
task.failed_at = current_time
else:
task_status = {"success": False, "message": f"Invalid task status for {task_id}"}
if isinstance(request, dict):
return json.dumps(task_status)
return jsonify(task_status), 404

session.commit()
task_status = {"success": True, "message": f"Task {task_id} updated successfully"}
logger.info(f'{task_id} updated with status {status} at {current_time}')
if isinstance(request, dict):
return json.dumps(task_status)
return jsonify(task_status), 200
finally:
try:
session.close()
except Exception:
pass


coordinator_servicer = CoordinatorServicer()
Expand Down
Loading