Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions examples/pydantic_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ def image_task(params: dict):
# Simulate image processing
return "Image processed successfully"

# job = add(a=3, b=4) # ✨ validated on the spot
job = add(a=3, b=4) # ✨ validated on the spot

# job2 = sub(a=10, b=5) # ✨ no schema validation, just a simple task
job2 = sub(a=10, b=5) # ✨ no schema validation, just a simple task

# task = image_task({"image": "example.png"}) # ✨ no schema validation, just a simple task
# task2 = image_task(params={"image": "example.png"})
task = image_task({"image": "example.png"}) # ✨ no schema validation, just a simple task
task2 = image_task(params={"image": "example.png"})
import time

if __name__ == "__main__":
Expand All @@ -44,17 +44,17 @@ def image_task(params: dict):
# Keep the worker running indefinitely
try:
while True:
# output = job.get_result(mq.redis_client,returns=AddOut)
output = job.get_result(mq.redis_client,returns=AddOut)

# print(f"Result of addition: {output}")
# print(type(output))
# print(f"Result of addition (total): {output.total}")
print(f"Result of addition: {output}")
print(type(output))
print(f"Result of addition (total): {output.total}")

# output2 = job2.get_result(mq.redis_client)
# print(f"Result of subtraction: {output2}")
output2 = job2.get_result(mq.redis_client)
print(f"Result of subtraction: {output2}")

# output3 = task.get_result(mq.redis_client)
# print(f"Result of image task: {output3}")
output3 = task.get_result(mq.redis_client)
print(f"Result of image task: {output3}")
time.sleep(1)
except KeyboardInterrupt:
print("\nGracefully shutting down...")
14 changes: 13 additions & 1 deletion modelq/app/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from modelq.app.tasks import Task
from modelq.exceptions import TaskProcessingError, TaskTimeoutError,RetryTaskException
from modelq.app.middleware import Middleware
from modelq.app.redis_retry import _RedisWithRetry

from pydantic import BaseModel, ValidationError
from typing import Optional, Dict, Any, Type
Expand Down Expand Up @@ -47,10 +48,21 @@ def __init__(
webhook_url: Optional[str] = None, # Optional webhook for error logging
requeue_threshold : Optional[int] = None ,
delay_seconds: int = 30,
redis_retry_attempts: int = 5,
redis_retry_base_delay: float = 0.5,
redis_retry_backoff: float = 2.0,
redis_retry_jitter: float = 0.3,
**kwargs,
):
if redis_client:
self.redis_client = redis_client
self.redis_client = _RedisWithRetry(
redis_client,
max_attempts=redis_retry_attempts,
base_delay=redis_retry_base_delay,
backoff=redis_retry_backoff,
jitter=redis_retry_jitter,
)

else:
self.redis_client = self._connect_to_redis(
host=host,
Expand Down
56 changes: 56 additions & 0 deletions modelq/app/redis_retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import time
import redis
from redis.exceptions import ConnectionError, TimeoutError
import logging

logger = logging.getLogger(__name__)
import random

class _RedisWithRetry:
"""Lightweight proxy that wraps a redis.Redis instance.

Any callable attribute (e.g. get, set, blpop, xadd …) is executed with a
retry loop that catches *ConnectionError* and *TimeoutError* from redis‑py
and re‑issues the call after an exponential back‑off (base_delay × backoff^n)
plus a small random jitter.
"""

RETRYABLE = (ConnectionError, TimeoutError)

def __init__(self, client: redis.Redis, *,
max_attempts: int = 5,
base_delay: float = 0.5,
backoff: float = 2.0,
jitter: float = 0.3):
self._client = client
self._max_attempts = max_attempts
self._base_delay = base_delay
self._backoff = backoff
self._jitter = jitter

# Forward non‑callable attrs (e.g. "connection_pool") directly ──────────
def __getattr__(self, name):
attr = getattr(self._client, name)
if not callable(attr):
return attr

# Wrap callable with retry loop
def _wrapped(*args, **kwargs):
attempt = 0
delay = self._base_delay
while True:
try:
return attr(*args, **kwargs)
except self.RETRYABLE as exc:
attempt += 1
if attempt >= self._max_attempts:
logger.error(
f"Redis command '{name}' failed after {attempt} attempts: {exc}")
raise
sleep_for = delay + random.uniform(0, self._jitter)
logger.warning(
f"Redis '{name}' failed ({exc.__class__.__name__}: {exc}). "
f"Retrying in {sleep_for:.2f}s (attempt {attempt}/{self._max_attempts})")
time.sleep(sleep_for)
delay *= self._backoff
return _wrapped