Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 168 additions & 1 deletion modelq/app/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from modelq.exceptions import TaskProcessingError, TaskTimeoutError,RetryTaskException
from modelq.app.middleware import Middleware
from modelq.app.redis_retry import _RedisWithRetry
from modelq.app.utils.compression import compress_base64, decompress_base64

from pydantic import BaseModel, ValidationError
from typing import Optional, Dict, Any, Type
Expand All @@ -33,6 +34,9 @@ class ModelQ:
PRUNE_CHECK_INTERVAL = 60 # seconds: how often to check for stale servers
TASK_RESULT_RETENTION = 86400

# Thread-local storage for tracking current task context
_current_task = threading.local()

def __init__(
self,
host: str = "localhost",
Expand Down Expand Up @@ -609,6 +613,9 @@ def process_task(self, task: Task) -> None:
f"with args: {call_args}, kwargs: {call_kwargs}"
)

# Set current task context for this thread
self._current_task.task_id = task.task_id

if stream:
# Stream results
for result in task_function(*call_args, **call_kwargs):
Expand Down Expand Up @@ -683,17 +690,35 @@ def process_task(self, task: Task) -> None:

finally:
self.redis_client.srem("processing_tasks", task.task_id)
# Clear current task context for this thread
if hasattr(self._current_task, 'task_id'):
delattr(self._current_task, 'task_id')


def _store_final_task_state(self, task: Task, success: bool):
"""
Persists the final status/result of the task in Redis, adding finished_at.
Preserves any base64_output that was stored during task execution.
"""
# Get existing task data to preserve base64_output if it exists
existing_data = self.redis_client.get(f"task:{task.task_id}")
existing_base64_output = None
if existing_data:
try:
existing_dict = json.loads(existing_data)
existing_base64_output = existing_dict.get("base64_output")
except:
pass

task_dict = task.to_dict()

# Preserve base64_output if it was stored during task execution
if existing_base64_output is not None:
task_dict["base64_output"] = existing_base64_output

# Mark finished_at
task_dict["finished_at"] = time.time()

self.redis_client.set(
f"task_result:{task.task_id}",
json.dumps(task_dict),
Expand Down Expand Up @@ -740,6 +765,148 @@ def get_task_status(self, task_id: str) -> Optional[str]:
return json.loads(task_data).get("status")
return None

def store_base64_output(
self,
base64_output: str,
task_id: Optional[str] = None,
compress: bool = True,
compression_method: str = "zlib",
compression_level: int = 6
) -> bool:
"""
Store base64 output for a task with optional compression.
Automatically detects the current task_id if called from within a task function.

Args:
base64_output: The base64 encoded output (image, video, etc.)
task_id: Optional task ID. If not provided, uses the current task being processed
compress: Whether to compress the base64 output (default: True)
compression_method: Compression algorithm to use (default: "zlib")
Options: "zlib", "gzip", "bz2", "brotli", "lz4"
compression_level: Compression level 0-9 (default: 6)
Higher = better compression but slower

Returns:
True if storage was successful, False otherwise

Example:
@modelq.task()
def generate_image(params):
# ... generate image and encode to base64
base64_image = "data:image/png;base64,..."

# Store with default compression (zlib, level 6)
modelq.store_base64_output(base64_image, compress=True)

# Store with maximum compression using brotli
modelq.store_base64_output(base64_image, compress=True,
compression_method="brotli",
compression_level=11)

# Return regular result
return {"status": "success"}
"""
try:
# Auto-detect task_id from current thread context if not provided
if task_id is None:
if hasattr(self._current_task, 'task_id'):
task_id = self._current_task.task_id
else:
logger.error("store_base64_output called without task_id and no task context found")
return False

# Get the existing task data
task_data = self.redis_client.get(f"task:{task_id}")
if not task_data:
logger.warning(f"Task {task_id} not found when trying to store base64 output")
return False

task_dict = json.loads(task_data)

# Compress if requested
if compress:
stored_output = compress_base64(
base64_output,
compression_level=compression_level,
method=compression_method
)
else:
stored_output = base64_output

# Update the task dict with base64_output
task_dict["base64_output"] = stored_output

# Store back to Redis with same expiry times
self.redis_client.set(
f"task:{task_id}",
json.dumps(task_dict),
ex=86400 # 24 hours
)

# Also update task_result if it exists
task_result_data = self.redis_client.get(f"task_result:{task_id}")
if task_result_data:
result_dict = json.loads(task_result_data)
result_dict["base64_output"] = stored_output
self.redis_client.set(
f"task_result:{task_id}",
json.dumps(result_dict),
ex=3600 # 1 hour
)

logger.info(f"Stored base64 output for task {task_id} (compressed: {compress}, method: {compression_method})")
return True

except Exception as e:
logger.error(f"Failed to store base64 output for task {task_id}: {e}")
return False

def get_task_base64(self, task_id: str, decompress: bool = True) -> Optional[str]:
"""
Retrieve the base64 output for a task.

Args:
task_id: The task ID to retrieve the output for
decompress: Whether to decompress the output if it was compressed (default: True)

Returns:
The base64 output string, or None if not found

Example:
# Get the base64 output for a completed task
base64_image = modelq.get_task_base64(task.task_id)
if base64_image:
# Use the base64 image
pass
"""
try:
# Try to get from task_result first (most recent)
task_data = self.redis_client.get(f"task_result:{task_id}")

# Fall back to task key if not in result
if not task_data:
task_data = self.redis_client.get(f"task:{task_id}")

if not task_data:
logger.warning(f"Task {task_id} not found when trying to retrieve base64 output")
return None

task_dict = json.loads(task_data)
base64_output = task_dict.get("base64_output")

if base64_output is None:
return None

# Decompress if requested
if decompress:
return decompress_base64(base64_output)
else:
return base64_output

except Exception as e:
logger.error(f"Failed to retrieve base64 output for task {task_id}: {e}")
return None

def log_task_error_to_file(self, task: Task, exc: Exception, file_path="modelq_errors.log"):
"""
Logs detailed error info to a specified file, with dashes before and after.
Expand Down
5 changes: 4 additions & 1 deletion modelq/app/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ def __init__(self, task_name: str, payload: dict, timeout: int = 15):
self.original_payload = copy.deepcopy(payload)
self.status = "queued"
self.result = None

self.base64_output = None # New field for storing compressed base64 outputs

# New timestamps:
self.created_at = time.time() # When Task object is instantiated
self.queued_at = None # When task is enqueued in Redis
Expand All @@ -36,6 +37,7 @@ def to_dict(self):
"payload": self.payload,
"status": self.status,
"result": self.result,
"base64_output": self.base64_output,
"created_at": self.created_at,
"queued_at": self.queued_at,
"started_at": self.started_at,
Expand All @@ -49,6 +51,7 @@ def from_dict(data: dict) -> "Task":
task.task_id = data["task_id"]
task.status = data["status"]
task.result = data.get("result")
task.base64_output = data.get("base64_output")

# Load timestamps if present
task.created_at = data.get("created_at")
Expand Down
Loading