diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d8b1b4f..d0041f4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -48,7 +48,6 @@ jobs: run: mypy py_src/taskito/ rust-test: - needs: lint runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 @@ -63,6 +62,8 @@ jobs: - name: Rust cache uses: Swatinem/rust-cache@v2.8.2 + with: + save-if: false - name: Run Rust tests run: cargo test --workspace @@ -70,7 +71,7 @@ jobs: LD_LIBRARY_PATH: ${{ env.pythonLocation }}/lib test: - needs: [lint, rust-test] + needs: lint runs-on: ${{ matrix.os }} strategy: matrix: @@ -98,6 +99,8 @@ jobs: - name: Rust cache uses: Swatinem/rust-cache@v2.8.2 + with: + save-if: ${{ matrix.os != 'ubuntu-latest' }} - name: Create virtualenv (Unix) if: runner.os != 'Windows' diff --git a/.github/workflows/cleanup.yml b/.github/workflows/cleanup.yml new file mode 100644 index 0000000..4f883d0 --- /dev/null +++ b/.github/workflows/cleanup.yml @@ -0,0 +1,23 @@ +name: Cleanup PR caches + +on: + pull_request: + types: [closed] + +jobs: + cleanup: + runs-on: ubuntu-latest + permissions: + actions: write + steps: + - uses: actions/checkout@v6 + + - name: Delete PR branch caches + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + BRANCH: refs/pull/${{ github.event.pull_request.number }}/merge + run: | + echo "Deleting caches for branch: $BRANCH" + gh cache list --ref "$BRANCH" --json id --jq '.[].id' | + xargs -I {} gh cache delete {} --repo "${{ github.repository }}" || true + echo "Done" diff --git a/crates/taskito-async/Cargo.toml b/crates/taskito-async/Cargo.toml index 7a20210..e6782b0 100644 --- a/crates/taskito-async/Cargo.toml +++ b/crates/taskito-async/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "taskito-async" -version = "0.5.0" +version = "0.6.0" edition = "2021" [dependencies] diff --git a/crates/taskito-async/src/pool.rs b/crates/taskito-async/src/pool.rs index 1f566c2..7eb3aa5 100644 --- a/crates/taskito-async/src/pool.rs +++ b/crates/taskito-async/src/pool.rs @@ -96,6 +96,7 @@ impl WorkerDispatcher for NativeAsyncPool { task_name: job.task_name.clone(), wall_time_ns: 0, should_retry: true, + timed_out: false, }); } }); diff --git a/crates/taskito-async/src/result_sender.rs b/crates/taskito-async/src/result_sender.rs index 874dc42..7275356 100644 --- a/crates/taskito-async/src/result_sender.rs +++ b/crates/taskito-async/src/result_sender.rs @@ -59,6 +59,7 @@ impl PyResultSender { task_name, wall_time_ns, should_retry, + timed_out: false, }); } diff --git a/crates/taskito-async/src/task_executor.rs b/crates/taskito-async/src/task_executor.rs index 6cea5b2..b07fac6 100644 --- a/crates/taskito-async/src/task_executor.rs +++ b/crates/taskito-async/src/task_executor.rs @@ -68,6 +68,7 @@ pub fn execute_sync_task( task_name, wall_time_ns, should_retry, + timed_out: false, } } } @@ -115,8 +116,14 @@ fn run_task(py: Python<'_>, task_registry: &PyObject, job: &Job) -> PyResult PyResult> { + // Deserialize arguments using per-task or queue-level serializer let payload_bytes = PyBytes::new_bound(py, &job.payload); - let unpickled = cloudpickle.call_method1("loads", (payload_bytes,))?; + let queue_ref = context_mod.getattr("_queue_ref")?; + let unpickled = if !queue_ref.is_none() { + queue_ref.call_method1("_deserialize_payload", (&job.task_name, &payload_bytes))? + } else { + cloudpickle.call_method1("loads", (&payload_bytes,))? + }; let args_tuple: Bound<'_, PyTuple> = unpickled.downcast_into()?; if args_tuple.len() != 2 { diff --git a/crates/taskito-core/Cargo.toml b/crates/taskito-core/Cargo.toml index 5a67ce6..6ec526f 100644 --- a/crates/taskito-core/Cargo.toml +++ b/crates/taskito-core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "taskito-core" -version = "0.5.0" +version = "0.6.0" edition = "2021" [features] diff --git a/crates/taskito-core/src/scheduler/maintenance.rs b/crates/taskito-core/src/scheduler/maintenance.rs index 647af94..94725d1 100644 --- a/crates/taskito-core/src/scheduler/maintenance.rs +++ b/crates/taskito-core/src/scheduler/maintenance.rs @@ -33,7 +33,7 @@ impl Scheduler { for job in stale_jobs { let error = format!("job timed out after {}ms", job.timeout_ms); - self.handle_result(JobResult::Failure { + let _ = self.handle_result(JobResult::Failure { job_id: job.id.clone(), error, retry_count: job.retry_count, @@ -41,6 +41,7 @@ impl Scheduler { task_name: job.task_name.clone(), wall_time_ns: 0, should_retry: true, + timed_out: true, })?; } diff --git a/crates/taskito-core/src/scheduler/mod.rs b/crates/taskito-core/src/scheduler/mod.rs index 9f06afc..9a5dfed 100644 --- a/crates/taskito-core/src/scheduler/mod.rs +++ b/crates/taskito-core/src/scheduler/mod.rs @@ -60,6 +60,7 @@ pub enum JobResult { task_name: String, wall_time_ns: i64, should_retry: bool, + timed_out: bool, }, Cancelled { job_id: String, @@ -68,12 +69,45 @@ pub enum JobResult { }, } +/// Outcome of processing a job result, returned to the caller for +/// Python-side middleware hook dispatch. +#[derive(Debug, Clone)] +pub enum ResultOutcome { + /// Task completed successfully. + Success { job_id: String, task_name: String }, + /// Task failed and will be retried. + Retry { + job_id: String, + task_name: String, + error: String, + retry_count: i32, + timed_out: bool, + }, + /// Task exhausted retries and moved to the dead-letter queue. + DeadLettered { + job_id: String, + task_name: String, + error: String, + timed_out: bool, + }, + /// Task was cancelled during execution. + Cancelled { job_id: String, task_name: String }, +} + /// Per-task configuration for retry, rate limiting, and circuit breaker. #[derive(Debug, Clone)] pub struct TaskConfig { pub retry_policy: RetryPolicy, pub rate_limit: Option, pub circuit_breaker: Option, + pub max_concurrent: Option, +} + +/// Per-queue configuration for rate limiting and concurrency caps. +#[derive(Debug, Clone)] +pub struct QueueConfig { + pub rate_limit: Option, + pub max_concurrent: Option, } /// The central scheduler that coordinates job dispatch, retries, rate limiting, and circuit breakers. @@ -83,6 +117,7 @@ pub struct Scheduler { dlq: DeadLetterQueue, circuit_breaker: CircuitBreaker, task_configs: HashMap, + queue_configs: HashMap, queues: Vec, config: SchedulerConfig, shutdown: Arc, @@ -109,6 +144,7 @@ impl Scheduler { dlq, circuit_breaker, task_configs: HashMap::new(), + queue_configs: HashMap::new(), queues, config, shutdown: Arc::new(Notify::new()), @@ -124,6 +160,10 @@ impl Scheduler { self.shutdown.clone() } + pub fn register_queue_config(&mut self, queue_name: String, config: QueueConfig) { + self.queue_configs.insert(queue_name, config); + } + pub fn register_task(&mut self, task_name: String, config: TaskConfig) { if let Some(ref cb_config) = config.circuit_breaker { if let Err(e) = self.circuit_breaker.register(&task_name, cb_config) { @@ -276,6 +316,7 @@ mod tests { }, rate_limit: None, circuit_breaker: None, + max_concurrent: None, }, ); @@ -290,6 +331,7 @@ mod tests { task_name: "retry_task".to_string(), wall_time_ns: 500_000, should_retry: true, + timed_out: false, }) .unwrap(); @@ -312,6 +354,7 @@ mod tests { task_name: "exhausted_task".to_string(), wall_time_ns: 100, should_retry: true, + timed_out: false, }) .unwrap(); @@ -337,6 +380,7 @@ mod tests { task_name: "no_retry_task".to_string(), wall_time_ns: 100, should_retry: false, + timed_out: false, }) .unwrap(); @@ -373,6 +417,7 @@ mod tests { refill_rate: 0.0, }), circuit_breaker: None, + max_concurrent: None, }, ); @@ -414,6 +459,7 @@ mod tests { retry_policy: RetryPolicy::default(), rate_limit: None, circuit_breaker: Some(cb_config), + max_concurrent: None, }, ); @@ -452,6 +498,7 @@ mod tests { }, rate_limit: None, circuit_breaker: None, + max_concurrent: None, }, ); diff --git a/crates/taskito-core/src/scheduler/poller.rs b/crates/taskito-core/src/scheduler/poller.rs index 385a977..3463cd7 100644 --- a/crates/taskito-core/src/scheduler/poller.rs +++ b/crates/taskito-core/src/scheduler/poller.rs @@ -14,6 +14,9 @@ const CIRCUIT_BREAKER_RETRY_DELAY_MS: i64 = 5000; /// Delay before re-scheduling a rate-limited job (ms). const RATE_LIMIT_RETRY_DELAY_MS: i64 = 1000; +/// Delay before re-scheduling a concurrency-limited job (ms). +const CONCURRENCY_RETRY_DELAY_MS: i64 = 500; + impl Scheduler { pub(super) fn try_dispatch(&self, job_tx: &tokio::sync::mpsc::Sender) -> Result { let now = now_millis(); @@ -53,6 +56,26 @@ impl Scheduler { None => return Ok(false), }; + // Check queue-level limits + if let Some(qcfg) = self.queue_configs.get(&job.queue) { + if let Some(ref rl_config) = qcfg.rate_limit { + let key = format!("queue:{}", job.queue); + if !self.rate_limiter.try_acquire(&key, rl_config)? { + self.storage + .retry(&job.id, now + RATE_LIMIT_RETRY_DELAY_MS)?; + return Ok(false); + } + } + if let Some(max_conc) = qcfg.max_concurrent { + let stats = self.storage.stats_by_queue(&job.queue)?; + if stats.running >= max_conc as i64 { + self.storage + .retry(&job.id, now + CONCURRENCY_RETRY_DELAY_MS)?; + return Ok(false); + } + } + } + // Check circuit breaker for this task if let Some(config) = self.task_configs.get(&job.task_name) { if config.circuit_breaker.is_some() && !self.circuit_breaker.allow(&job.task_name)? { @@ -68,6 +91,16 @@ impl Scheduler { return Ok(false); } } + + // Check per-task concurrency limit + if let Some(max_conc) = config.max_concurrent { + let running = self.storage.count_running_by_task(&job.task_name)?; + if running >= max_conc as i64 { + self.storage + .retry(&job.id, now + CONCURRENCY_RETRY_DELAY_MS)?; + return Ok(false); + } + } } // Claim exactly-once execution diff --git a/crates/taskito-core/src/scheduler/result_handler.rs b/crates/taskito-core/src/scheduler/result_handler.rs index ad0ac91..daad119 100644 --- a/crates/taskito-core/src/scheduler/result_handler.rs +++ b/crates/taskito-core/src/scheduler/result_handler.rs @@ -3,11 +3,14 @@ use log::{error, warn}; use crate::error::Result; use crate::storage::Storage; -use super::{JobResult, Scheduler}; +use super::{JobResult, ResultOutcome, Scheduler}; impl Scheduler { /// Handle a completed or failed job result from a worker. - pub fn handle_result(&self, result: JobResult) -> Result<()> { + /// + /// Returns a [`ResultOutcome`] describing the action taken, so the + /// caller can dispatch Python-side middleware hooks and events. + pub fn handle_result(&self, result: JobResult) -> Result { match result { JobResult::Success { job_id, @@ -32,6 +35,11 @@ impl Scheduler { if let Err(e) = self.circuit_breaker.record_success(task_name) { error!("circuit breaker error for {task_name}: {e}"); } + + Ok(ResultOutcome::Success { + job_id, + task_name: task_name.clone(), + }) } JobResult::Failure { job_id, @@ -41,6 +49,7 @@ impl Scheduler { task_name, wall_time_ns, should_retry, + timed_out, } => { // Clear execution claim so it can be retried if let Err(e) = self.storage.complete_execution(&job_id) { @@ -68,7 +77,12 @@ impl Scheduler { Some(job) => self.dlq.move_to_dlq(&job, &error, None)?, None => warn!("job {job_id} disappeared before DLQ move"), } - return Ok(()); + return Ok(ResultOutcome::DeadLettered { + job_id, + task_name, + error, + timed_out, + }); } let policy = self @@ -86,12 +100,25 @@ impl Scheduler { if retry_count < effective_max { let next_at = policy.next_retry_at(retry_count); self.storage.retry(&job_id, next_at)?; + Ok(ResultOutcome::Retry { + job_id, + task_name, + error, + retry_count, + timed_out, + }) } else { // Move to DLQ match self.storage.get_job(&job_id)? { Some(job) => self.dlq.move_to_dlq(&job, &error, None)?, None => warn!("job {job_id} disappeared before DLQ move"), } + Ok(ResultOutcome::DeadLettered { + job_id, + task_name, + error, + timed_out, + }) } } JobResult::Cancelled { @@ -113,8 +140,8 @@ impl Scheduler { { error!("failed to record metric for cancelled job {job_id}: {e}"); } + Ok(ResultOutcome::Cancelled { job_id, task_name }) } } - Ok(()) } } diff --git a/crates/taskito-core/src/storage/diesel_common/jobs.rs b/crates/taskito-core/src/storage/diesel_common/jobs.rs index 9541812..e808376 100644 --- a/crates/taskito-core/src/storage/diesel_common/jobs.rs +++ b/crates/taskito-core/src/storage/diesel_common/jobs.rs @@ -916,6 +916,19 @@ macro_rules! impl_diesel_job_ops { Ok(affected as u64) } + /// Count running jobs for a specific task name (for per-task concurrency limiting). + pub fn count_running_by_task(&self, task_name: &str) -> Result { + let mut conn = self.conn()?; + + let count: i64 = jobs::table + .filter(jobs::task_name.eq(task_name)) + .filter(jobs::status.eq(JobStatus::Running as i32)) + .count() + .get_result(&mut conn)?; + + Ok(count) + } + /// Purge job errors older than the given timestamp. pub fn purge_job_errors(&self, older_than_ms: i64) -> Result { let mut conn = self.conn()?; diff --git a/crates/taskito-core/src/storage/mod.rs b/crates/taskito-core/src/storage/mod.rs index a5146de..55a81f2 100644 --- a/crates/taskito-core/src/storage/mod.rs +++ b/crates/taskito-core/src/storage/mod.rs @@ -444,6 +444,12 @@ macro_rules! impl_storage { ) -> $crate::error::Result { self.purge_execution_claims(older_than_ms) } + fn count_running_by_task( + &self, + task_name: &str, + ) -> $crate::error::Result { + self.count_running_by_task(task_name) + } fn stats_by_queue( &self, queue_name: &str, @@ -792,6 +798,9 @@ impl Storage for StorageBackend { fn purge_execution_claims(&self, older_than_ms: i64) -> Result { delegate!(self, purge_execution_claims, older_than_ms) } + fn count_running_by_task(&self, task_name: &str) -> Result { + delegate!(self, count_running_by_task, task_name) + } fn stats_by_queue(&self, queue_name: &str) -> Result { delegate!(self, stats_by_queue, queue_name) } diff --git a/crates/taskito-core/src/storage/redis_backend/jobs.rs b/crates/taskito-core/src/storage/redis_backend/jobs.rs index 3377bf4..3d989f1 100644 --- a/crates/taskito-core/src/storage/redis_backend/jobs.rs +++ b/crates/taskito-core/src/storage/redis_backend/jobs.rs @@ -725,6 +725,24 @@ impl RedisStorage { Ok(stats) } + /// Count running jobs for a specific task name (for per-task concurrency limiting). + pub fn count_running_by_task(&self, task_name: &str) -> Result { + let mut conn = self.conn()?; + let by_task_key = self.key(&["jobs", "by_task", task_name]); + let job_ids: Vec = conn.smembers(&by_task_key).map_err(map_err)?; + + let mut count: i64 = 0; + for id in &job_ids { + if let Some(job) = self.load_job(&mut conn, id)? { + if job.status == JobStatus::Running { + count += 1; + } + } + } + + Ok(count) + } + pub fn stats_by_queue(&self, queue_name: &str) -> Result { let mut conn = self.conn()?; let by_queue_key = self.key(&["jobs", "by_queue", queue_name]); diff --git a/crates/taskito-core/src/storage/sqlite/tests.rs b/crates/taskito-core/src/storage/sqlite/tests.rs index 9c77ce3..5ae4cc3 100644 --- a/crates/taskito-core/src/storage/sqlite/tests.rs +++ b/crates/taskito-core/src/storage/sqlite/tests.rs @@ -362,6 +362,31 @@ fn test_cascade_cancel_on_dlq() { assert!(b.error.unwrap().contains("dependency failed")); } +#[test] +fn test_count_running_by_task() { + let storage = test_storage(); + storage.enqueue(make_job("task_a")).unwrap(); + storage.enqueue(make_job("task_a")).unwrap(); + storage.enqueue(make_job("task_b")).unwrap(); + + // No running jobs yet + assert_eq!(storage.count_running_by_task("task_a").unwrap(), 0); + + let now = now_millis() + 1000; + // Dequeue one task_a (becomes running) + storage.dequeue("default", now).unwrap().unwrap(); + + assert_eq!(storage.count_running_by_task("task_a").unwrap(), 1); + assert_eq!(storage.count_running_by_task("task_b").unwrap(), 0); + + // Dequeue second task_a + storage.dequeue("default", now).unwrap().unwrap(); + assert_eq!(storage.count_running_by_task("task_a").unwrap(), 2); + + // Nonexistent task should return 0 + assert_eq!(storage.count_running_by_task("no_such_task").unwrap(), 0); +} + #[test] fn test_enqueue_rejects_missing_dependency() { let storage = test_storage(); diff --git a/crates/taskito-core/src/storage/traits.rs b/crates/taskito-core/src/storage/traits.rs index 407ce66..0fff94d 100644 --- a/crates/taskito-core/src/storage/traits.rs +++ b/crates/taskito-core/src/storage/traits.rs @@ -161,6 +161,10 @@ pub trait Storage: Send + Sync + Clone { fn complete_execution(&self, job_id: &str) -> Result<()>; fn purge_execution_claims(&self, older_than_ms: i64) -> Result; + // ── Per-task concurrency ────────────────────────────────────── + + fn count_running_by_task(&self, task_name: &str) -> Result; + // ── Per-queue stats ────────────────────────────────────────── fn stats_by_queue(&self, queue_name: &str) -> Result; diff --git a/crates/taskito-python/Cargo.toml b/crates/taskito-python/Cargo.toml index 39e14ca..bfeaf46 100644 --- a/crates/taskito-python/Cargo.toml +++ b/crates/taskito-python/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "taskito-python" -version = "0.5.0" +version = "0.6.0" edition = "2021" [features] @@ -22,3 +22,4 @@ crossbeam-channel = { workspace = true } uuid = { workspace = true } async-trait = { workspace = true } taskito-async = { path = "../taskito-async", optional = true } +serde_json = { workspace = true } diff --git a/crates/taskito-python/src/async_worker.rs b/crates/taskito-python/src/async_worker.rs index 0c6443f..19bce48 100644 --- a/crates/taskito-python/src/async_worker.rs +++ b/crates/taskito-python/src/async_worker.rs @@ -118,6 +118,7 @@ impl WorkerDispatcher for AsyncWorkerPool { task_name, wall_time_ns, should_retry, + timed_out: false, } } } diff --git a/crates/taskito-python/src/py_config.rs b/crates/taskito-python/src/py_config.rs index 7b6f994..ec911f3 100644 --- a/crates/taskito-python/src/py_config.rs +++ b/crates/taskito-python/src/py_config.rs @@ -26,13 +26,17 @@ pub struct PyTaskConfig { pub circuit_breaker_cooldown: Option, #[pyo3(get, set)] pub retry_delays: Option>, + #[pyo3(get, set)] + pub max_retry_delay: Option, + #[pyo3(get, set)] + pub max_concurrent: Option, } #[pymethods] #[allow(clippy::too_many_arguments)] impl PyTaskConfig { #[new] - #[pyo3(signature = (name, max_retries=3, retry_backoff=1.0, timeout=300, priority=0, rate_limit=None, queue="default".to_string(), circuit_breaker_threshold=None, circuit_breaker_window=None, circuit_breaker_cooldown=None, retry_delays=None))] + #[pyo3(signature = (name, max_retries=3, retry_backoff=1.0, timeout=300, priority=0, rate_limit=None, queue="default".to_string(), circuit_breaker_threshold=None, circuit_breaker_window=None, circuit_breaker_cooldown=None, retry_delays=None, max_retry_delay=None, max_concurrent=None))] pub fn new( name: String, max_retries: i32, @@ -45,6 +49,8 @@ impl PyTaskConfig { circuit_breaker_window: Option, circuit_breaker_cooldown: Option, retry_delays: Option>, + max_retry_delay: Option, + max_concurrent: Option, ) -> Self { Self { name, @@ -58,6 +64,8 @@ impl PyTaskConfig { circuit_breaker_window, circuit_breaker_cooldown, retry_delays, + max_retry_delay, + max_concurrent, } } } diff --git a/crates/taskito-python/src/py_queue/mod.rs b/crates/taskito-python/src/py_queue/mod.rs index 43660e2..b84bfda 100644 --- a/crates/taskito-python/src/py_queue/mod.rs +++ b/crates/taskito-python/src/py_queue/mod.rs @@ -33,6 +33,9 @@ pub struct PyQueue { pub(crate) default_priority: i32, pub(crate) shutdown_flag: Arc, pub(crate) result_ttl_ms: Option, + pub(crate) scheduler_poll_interval_ms: u64, + pub(crate) scheduler_reap_interval: u32, + pub(crate) scheduler_cleanup_interval: u32, } #[pymethods] @@ -43,7 +46,8 @@ pub struct PyQueue { )] impl PyQueue { #[new] - #[pyo3(signature = (db_path=".taskito/taskito.db", workers=0, default_retry=3, default_timeout=300, default_priority=0, result_ttl=None, backend="sqlite", db_url=None, schema="taskito", pool_size=None))] + #[pyo3(signature = (db_path=".taskito/taskito.db", workers=0, default_retry=3, default_timeout=300, default_priority=0, result_ttl=None, backend="sqlite", db_url=None, schema="taskito", pool_size=None, scheduler_poll_interval_ms=50, scheduler_reap_interval=100, scheduler_cleanup_interval=1200))] + #[allow(clippy::too_many_arguments)] pub fn new( db_path: &str, workers: usize, @@ -55,6 +59,9 @@ impl PyQueue { db_url: Option<&str>, schema: &str, pool_size: Option, + scheduler_poll_interval_ms: u64, + scheduler_reap_interval: u32, + scheduler_cleanup_interval: u32, ) -> PyResult { let storage = match backend { "sqlite" => { @@ -117,6 +124,9 @@ impl PyQueue { default_priority, shutdown_flag: Arc::new(AtomicBool::new(false)), result_ttl_ms: result_ttl.map(|s| s * 1000), + scheduler_poll_interval_ms, + scheduler_reap_interval, + scheduler_cleanup_interval, }) } diff --git a/crates/taskito-python/src/py_queue/worker.rs b/crates/taskito-python/src/py_queue/worker.rs index ecc511a..2e8c066 100644 --- a/crates/taskito-python/src/py_queue/worker.rs +++ b/crates/taskito-python/src/py_queue/worker.rs @@ -7,7 +7,7 @@ use pyo3::types::PyDict; use taskito_core::resilience::circuit_breaker::CircuitBreakerConfig; use taskito_core::resilience::rate_limiter::RateLimitConfig; use taskito_core::resilience::retry::RetryPolicy; -use taskito_core::scheduler::{JobResult, Scheduler, SchedulerConfig, TaskConfig}; +use taskito_core::scheduler::{JobResult, ResultOutcome, Scheduler, SchedulerConfig, TaskConfig}; use taskito_core::storage::Storage; use super::PyQueue; @@ -15,6 +15,149 @@ use super::PyQueue; use crate::async_worker::AsyncWorkerPool; use crate::py_config::PyTaskConfig; +/// Dispatch a ResultOutcome to Python middleware hooks and events. +/// +/// Called with the GIL held after `handle_result()` returns. +fn dispatch_outcome(py: Python<'_>, outcome: &ResultOutcome) { + let result = (|| -> PyResult<()> { + let context_mod = py.import_bound("taskito.context")?; + let queue_ref = context_mod.getattr("_queue_ref")?; + if queue_ref.is_none() { + return Ok(()); + } + + match outcome { + ResultOutcome::Retry { + job_id, + task_name, + error, + retry_count, + timed_out, + } => { + // Emit JOB_RETRYING event + let events_mod = py.import_bound("taskito.events")?; + let event_type = events_mod.getattr("EventType")?.getattr("JOB_RETRYING")?; + let payload = PyDict::new_bound(py); + payload.set_item("job_id", job_id)?; + payload.set_item("task_name", task_name)?; + payload.set_item("error", error)?; + payload.set_item("retry_count", retry_count)?; + queue_ref.call_method1("_emit_event", (event_type, payload))?; + + // Call on_timeout middleware if this was a timeout + if *timed_out { + let ctx = build_lightweight_ctx(py, job_id, task_name)?; + call_middleware_hook(py, &queue_ref, task_name, "on_timeout", (ctx,))?; + } + + // Call on_retry middleware + let ctx = build_lightweight_ctx(py, job_id, task_name)?; + let error_obj = + pyo3::exceptions::PyRuntimeError::new_err(error.clone()).into_py(py); + call_middleware_hook( + py, + &queue_ref, + task_name, + "on_retry", + (ctx, error_obj, *retry_count), + )?; + } + ResultOutcome::DeadLettered { + job_id, + task_name, + error, + timed_out, + } => { + // Emit JOB_DEAD event + let events_mod = py.import_bound("taskito.events")?; + let event_type = events_mod.getattr("EventType")?.getattr("JOB_DEAD")?; + let payload = PyDict::new_bound(py); + payload.set_item("job_id", job_id)?; + payload.set_item("task_name", task_name)?; + payload.set_item("error", error)?; + queue_ref.call_method1("_emit_event", (event_type, payload))?; + + // Call on_timeout middleware if this was a timeout + if *timed_out { + let ctx = build_lightweight_ctx(py, job_id, task_name)?; + call_middleware_hook(py, &queue_ref, task_name, "on_timeout", (ctx,))?; + } + + // Call on_dead_letter middleware + let ctx = build_lightweight_ctx(py, job_id, task_name)?; + let error_obj = + pyo3::exceptions::PyRuntimeError::new_err(error.clone()).into_py(py); + call_middleware_hook( + py, + &queue_ref, + task_name, + "on_dead_letter", + (ctx, error_obj), + )?; + } + ResultOutcome::Cancelled { job_id, task_name } => { + // Emit JOB_CANCELLED event + let events_mod = py.import_bound("taskito.events")?; + let event_type = events_mod.getattr("EventType")?.getattr("JOB_CANCELLED")?; + let payload = PyDict::new_bound(py); + payload.set_item("job_id", job_id)?; + payload.set_item("task_name", task_name)?; + queue_ref.call_method1("_emit_event", (event_type, payload))?; + + // Call on_cancel middleware + let ctx = build_lightweight_ctx(py, job_id, task_name)?; + call_middleware_hook(py, &queue_ref, task_name, "on_cancel", (ctx,))?; + } + ResultOutcome::Success { .. } => { + // Success events are already emitted in _wrap_task + } + } + Ok(()) + })(); + + if let Err(e) = result { + eprintln!("[taskito] middleware dispatch error: {e}"); + } +} + +/// Build a lightweight JobContext-like object for middleware hooks called +/// outside of task execution (retry/dlq/cancel outcomes). +fn build_lightweight_ctx<'py>( + py: Python<'py>, + job_id: &str, + task_name: &str, +) -> PyResult> { + let types_mod = py.import_bound("types")?; + let ns = types_mod.call_method1("SimpleNamespace", ())?; + ns.setattr("id", job_id)?; + ns.setattr("task_name", task_name)?; + ns.setattr("queue_name", "unknown")?; + ns.setattr("retry_count", 0)?; + Ok(ns) +} + +/// Call a middleware hook on all middleware in the chain for a given task. +fn call_middleware_hook( + py: Python<'_>, + queue_ref: &Bound<'_, pyo3::PyAny>, + task_name: &str, + hook_name: &str, + args: impl pyo3::IntoPy>, +) -> PyResult<()> { + let chain = queue_ref.call_method1("_get_middleware_chain", (task_name,))?; + let args_tuple = args.into_py(py); + let args_bound = args_tuple.bind(py); + for mw in chain.iter()? { + let mw = mw?; + if let Err(e) = mw.call_method(hook_name, args_bound, None) { + let logging = py.import_bound("logging")?; + let logger = logging.call_method1("getLogger", ("taskito",))?; + logger.call_method1("warning", (format!("middleware {hook_name}() error: {e}"),))?; + } + } + Ok(()) +} + #[pymethods] #[allow(clippy::useless_conversion)] impl PyQueue { @@ -32,6 +175,7 @@ impl PyQueue { resources=None, threads=1, async_concurrency=100, + queue_configs=None, ))] #[allow(clippy::too_many_arguments)] pub fn run_worker( @@ -46,6 +190,7 @@ impl PyQueue { resources: Option, threads: i32, #[allow(unused_variables)] async_concurrency: i32, + queue_configs: Option, ) -> PyResult<()> { // Reset shutdown flag for this run self.shutdown_flag.store(false, Ordering::SeqCst); @@ -54,6 +199,9 @@ impl PyQueue { let queues_str = queues.join(","); let scheduler_config = SchedulerConfig { + poll_interval: std::time::Duration::from_millis(self.scheduler_poll_interval_ms), + reap_interval: self.scheduler_reap_interval, + cleanup_interval: self.scheduler_cleanup_interval, result_ttl_ms: self.result_ttl_ms, ..SchedulerConfig::default() }; @@ -94,10 +242,14 @@ impl PyQueue { } else { (tc.retry_backoff.min(i64::MAX as f64 / 1000.0) * 1000.0) as i64 }; + let max_delay_ms = tc + .max_retry_delay + .map(|s| s.saturating_mul(1000)) + .unwrap_or(300_000); let retry_policy = RetryPolicy { max_retries: tc.max_retries, base_delay_ms, - max_delay_ms: 300_000, + max_delay_ms, custom_delays_ms, }; let rate_limit = tc @@ -117,10 +269,37 @@ impl PyQueue { retry_policy, rate_limit, circuit_breaker, + max_concurrent: tc.max_concurrent, }, ); } + // Register queue-level rate limits and concurrency caps + if let Some(ref qc_json) = queue_configs { + if let Ok(map) = serde_json::from_str::< + std::collections::HashMap, + >(qc_json) + { + for (queue_name, cfg) in map { + let rate_limit = cfg + .get("rate_limit") + .and_then(|v| v.as_str()) + .and_then(RateLimitConfig::parse); + let max_concurrent = cfg + .get("max_concurrent") + .and_then(|v| v.as_i64()) + .map(|v| v as i32); + scheduler.register_queue_config( + queue_name, + taskito_core::scheduler::QueueConfig { + rate_limit, + max_concurrent, + }, + ); + } + } + } + let shutdown = scheduler.shutdown_handle(); let (job_tx, job_rx) = tokio::sync::mpsc::channel(self.num_workers * 2); @@ -266,11 +445,14 @@ impl PyQueue { match drain_action { PollAction::Result(result) => { - py.allow_threads(|| { - if let Err(e) = scheduler_for_results.handle_result(result) { - eprintln!("[taskito] result handling error: {e}"); + let outcome = py + .allow_threads(|| scheduler_for_results.handle_result(result)); + match outcome { + Ok(ref o) => dispatch_outcome(py, o), + Err(e) => { + eprintln!("[taskito] result handling error: {e}") } - }); + } } PollAction::Continue => continue, PollAction::Done => break, @@ -280,11 +462,11 @@ impl PyQueue { break; } PollAction::Result(result) => { - py.allow_threads(|| { - if let Err(e) = scheduler_for_results.handle_result(result) { - eprintln!("[taskito] result handling error: {e}"); - } - }); + let outcome = py.allow_threads(|| scheduler_for_results.handle_result(result)); + match outcome { + Ok(ref o) => dispatch_outcome(py, o), + Err(e) => eprintln!("[taskito] result handling error: {e}"), + } } PollAction::Continue => continue, PollAction::Done => break, diff --git a/crates/taskito-python/src/py_worker.rs b/crates/taskito-python/src/py_worker.rs index d6ac3ad..465e031 100644 --- a/crates/taskito-python/src/py_worker.rs +++ b/crates/taskito-python/src/py_worker.rs @@ -113,6 +113,7 @@ fn worker_loop( task_name, wall_time_ns, should_retry, + timed_out: false, } } } @@ -168,9 +169,14 @@ pub fn execute_task( // Wrap deserialization + call so _clear_context is always called let result = (|| -> PyResult> { - // Deserialize arguments: (args, kwargs) + // Deserialize arguments using per-task or queue-level serializer let payload_bytes = PyBytes::new_bound(py, &job.payload); - let unpickled = cloudpickle.call_method1("loads", (payload_bytes,))?; + let queue_ref = context_mod.getattr("_queue_ref")?; + let unpickled = if !queue_ref.is_none() { + queue_ref.call_method1("_deserialize_payload", (&job.task_name, &payload_bytes))? + } else { + cloudpickle.call_method1("loads", (&payload_bytes,))? + }; let args_tuple: Bound<'_, PyTuple> = unpickled.downcast_into()?; if args_tuple.len() != 2 { diff --git a/docs/api/queue.md b/docs/api/queue.md index af10c6f..a4872d8 100644 --- a/docs/api/queue.md +++ b/docs/api/queue.md @@ -23,6 +23,10 @@ Queue( file_path_allowlist: list[str] | None = None, disabled_proxies: list[str] | None = None, async_concurrency: int = 100, + event_workers: int = 4, + scheduler_poll_interval_ms: int = 50, + scheduler_reap_interval: int = 100, + scheduler_cleanup_interval: int = 1200, ) ``` @@ -43,6 +47,10 @@ Queue( | `file_path_allowlist` | `list[str] \| None` | `None` | Allowed file path prefixes for the file proxy handler. | | `disabled_proxies` | `list[str] \| None` | `None` | Handler names to skip when registering built-in proxy handlers. | | `async_concurrency` | `int` | `100` | Maximum number of `async def` tasks running concurrently on the native async executor. | +| `event_workers` | `int` | `4` | Thread pool size for the event bus. Increase for high event volume. | +| `scheduler_poll_interval_ms` | `int` | `50` | Milliseconds between scheduler poll cycles. Lower values improve scheduling precision at the cost of CPU. | +| `scheduler_reap_interval` | `int` | `100` | Reap stale/timed-out jobs every N poll cycles. | +| `scheduler_cleanup_interval` | `int` | `1200` | Clean up old completed jobs every N poll cycles. | ## Task Registration @@ -54,6 +62,7 @@ Queue( max_retries: int = 3, retry_backoff: float = 1.0, retry_delays: list[float] | None = None, + max_retry_delay: int | None = None, timeout: int = 300, soft_timeout: float | None = None, priority: int = 0, @@ -63,6 +72,8 @@ Queue( middleware: list[TaskMiddleware] | None = None, expires: float | None = None, inject: list[str] | None = None, + serializer: Serializer | None = None, + max_concurrent: int | None = None, ) -> TaskWrapper ``` @@ -74,6 +85,7 @@ Register a function as a background task. Returns a [`TaskWrapper`](task.md). | `max_retries` | `int` | `3` | Max retry attempts before moving to DLQ. | | `retry_backoff` | `float` | `1.0` | Base delay in seconds for exponential backoff. | | `retry_delays` | `list[float] \| None` | `None` | Per-attempt delays in seconds, overrides backoff. e.g. `[1, 5, 30]`. | +| `max_retry_delay` | `int \| None` | `None` | Cap on backoff delay in seconds. Defaults to 300 s. | | `timeout` | `int` | `300` | Hard execution time limit in seconds. | | `soft_timeout` | `float \| None` | `None` | Cooperative time limit checked via `current_job.check_timeout()`. | | `priority` | `int` | `0` | Default priority (higher = more urgent). | @@ -83,6 +95,8 @@ Register a function as a background task. Returns a [`TaskWrapper`](task.md). | `middleware` | `list[TaskMiddleware] \| None` | `None` | Per-task middleware, applied in addition to queue-level middleware. | | `expires` | `float \| None` | `None` | Seconds until the job expires if not started. | | `inject` | `list[str] \| None` | `None` | Resource names to inject as keyword arguments. See [Resource System](../guide/resources.md). | +| `serializer` | `Serializer \| None` | `None` | Per-task serializer override. Falls back to queue-level serializer. | +| `max_concurrent` | `int \| None` | `None` | Max concurrent running instances. `None` = no limit. | ### `@queue.periodic()` @@ -234,6 +248,32 @@ Return a dependency graph for a job, including all ancestors and descendants. Us ## Queue Management +### `queue.set_queue_rate_limit()` + +```python +queue.set_queue_rate_limit(queue_name: str, rate_limit: str) -> None +``` + +Set a rate limit for all jobs in a queue. Checked by the scheduler before per-task rate limits. + +| Parameter | Type | Description | +|---|---|---| +| `queue_name` | `str` | Queue name (e.g. `"default"`). | +| `rate_limit` | `str` | Rate limit string: `"N/s"`, `"N/m"`, or `"N/h"`. | + +### `queue.set_queue_concurrency()` + +```python +queue.set_queue_concurrency(queue_name: str, max_concurrent: int) -> None +``` + +Set a maximum number of concurrently running jobs for a queue across all workers. Checked by the scheduler before per-task `max_concurrent` limits. + +| Parameter | Type | Description | +|---|---|---| +| `queue_name` | `str` | Queue name (e.g. `"default"`). | +| `max_concurrent` | `int` | Maximum simultaneous running jobs from this queue. | + ### `queue.pause()` ```python @@ -646,13 +686,26 @@ def handle_failure(job_id: str, task_name: str, error: str) -> None: ```python queue.add_webhook( url: str, - events: list[str], + events: list[EventType] | None = None, headers: dict[str, str] | None = None, secret: str | None = None, -) -> str + max_retries: int = 3, + timeout: float = 10.0, + retry_backoff: float = 2.0, +) -> None ``` -Register a webhook URL for one or more events. Returns the webhook ID. 4xx responses are not retried; 5xx responses are retried with backoff. +Register a webhook URL for one or more events. 4xx responses are not retried; 5xx responses are retried with exponential backoff. + +| Parameter | Type | Default | Description | +|---|---|---|---| +| `url` | `str` | — | URL to POST to. Must be `http://` or `https://`. | +| `events` | `list[EventType] \| None` | `None` | Event types to subscribe to. `None` means all events. | +| `headers` | `dict[str, str] \| None` | `None` | Extra HTTP headers to include. | +| `secret` | `str \| None` | `None` | HMAC-SHA256 signing secret for `X-Taskito-Signature`. | +| `max_retries` | `int` | `3` | Maximum delivery attempts. | +| `timeout` | `float` | `10.0` | HTTP request timeout in seconds. | +| `retry_backoff` | `float` | `2.0` | Base for exponential backoff between retries. | ## Worker diff --git a/docs/changelog.md b/docs/changelog.md index 5cd264a..4d57905 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -2,6 +2,35 @@ All notable changes to taskito are documented here. +## 0.6.0 + +### Features + +- **Middleware lifecycle hooks wired** -- `on_retry(ctx, error, retry_count)`, `on_dead_letter(ctx, error)`, and `on_cancel(ctx)` are now dispatched from the Rust result handler; they fire for every matching outcome across all registered middleware +- **Expanded middleware hooks** -- `TaskMiddleware` gains four new hooks: `on_enqueue`, `on_dead_letter`, `on_timeout`, `on_cancel`; `on_enqueue` receives a mutable `options` dict that can modify priority, delay, queue, and other enqueue parameters before the job is written +- **`JOB_RETRYING`, `JOB_DEAD`, `JOB_CANCELLED` events now emitted** -- these three event types were previously defined but never fired; they are now emitted from the Rust result handler with payloads `{job_id, task_name, error, retry_count}`, `{job_id, task_name, error}`, and `{job_id, task_name}` respectively +- **Queue-level rate limits** -- `queue.set_queue_rate_limit("name", "100/m")` applies a token-bucket rate limit to an entire queue, checked in the scheduler before per-task limits +- **Queue-level concurrency caps** -- `queue.set_queue_concurrency("name", 10)` limits how many jobs from a queue run simultaneously across all workers, checked before per-task `max_concurrent` +- **Worker lifecycle events** -- `EventType.WORKER_STARTED` and `EventType.WORKER_STOPPED` fired when a worker thread comes online or exits; subscribe via `queue.on_event(EventType.WORKER_STARTED, cb)` +- **Queue pause/resume events** -- `EventType.QUEUE_PAUSED` and `EventType.QUEUE_RESUMED` fired by `queue.pause()` and `queue.resume()` +- **`event_workers` parameter** -- `Queue(event_workers=N)` configures the event bus thread pool size (default 4); raise for high event volume +- **Per-webhook delivery options** -- `queue.add_webhook()` now accepts `max_retries`, `timeout`, and `retry_backoff` per endpoint, replacing the previous hardcoded values +- **OTel customization** -- `OpenTelemetryMiddleware` adds `span_name_fn`, `attribute_prefix`, `extra_attributes_fn`, and `task_filter` parameters +- **Sentry customization** -- `SentryMiddleware` adds `tag_prefix`, `transaction_name_fn`, `task_filter`, and `extra_tags_fn` parameters +- **Prometheus customization** -- `PrometheusMiddleware` and `PrometheusStatsCollector` add `namespace`, `extra_labels_fn`, and `disabled_metrics` parameters; metrics grouped by category (`"jobs"`, `"queue"`, `"resource"`, `"proxy"`, `"intercept"`) +- **FastAPI route selection** -- `TaskitoRouter` adds `include_routes`/`exclude_routes`, `dependencies`, `sse_poll_interval`, `result_timeout`, `default_page_size`, `max_page_size`, and `result_serializer` parameters; new endpoints: `/health`, `/readiness`, `/resources`, `/stats/queues` +- **Flask CLI group** -- `Taskito(app, cli_group="tasks")` renames the CLI command group; `flask taskito info --format json` outputs machine-readable stats +- **Django settings** -- `TASKITO_AUTODISCOVER_MODULE`, `TASKITO_ADMIN_PER_PAGE`, `TASKITO_ADMIN_TITLE`, `TASKITO_ADMIN_HEADER`, `TASKITO_DASHBOARD_HOST`, `TASKITO_DASHBOARD_PORT` control autodiscovery, admin pagination, branding, and dashboard bind address +- **`max_retry_delay` on `@queue.task()`** -- caps exponential backoff at a configurable ceiling in seconds (defaults to 300 s) +- **`max_concurrent` on `@queue.task()`** -- limits how many instances of a task run simultaneously across all workers +- **`serializer` on `@queue.task()`** -- per-task serializer override; falls back to queue-level serializer +- **Per-task serializer full round-trip** -- deserialization now also uses the per-task serializer; previously only enqueue (serialization) did; both the sync and native-async worker paths call `_deserialize_payload(task_name, payload)` instead of cloudpickle directly +- **`on_timeout` middleware hook wired** -- `on_timeout(ctx)` now fires when the Rust maintenance reaper detects a stale job that exceeded its hard timeout; fires before `on_retry` (if retrying) or `on_dead_letter` (if retries exhausted); previously the hook existed in `TaskMiddleware` but was never called +- **`QUEUE_PAUSED` / `QUEUE_RESUMED` events emitted** -- `queue.pause()` and `queue.resume()` now emit these events with payload `{"queue": "..."}` after updating storage; previously the event types were defined but never fired +- **Scheduler tuning** -- `Queue(scheduler_poll_interval_ms=N, scheduler_reap_interval=N, scheduler_cleanup_interval=N)` exposes the three Rust scheduler timing knobs to Python + +--- + ## 0.5.0 ### New Features diff --git a/docs/guide/events-webhooks.md b/docs/guide/events-webhooks.md index 2e50f40..dbd9605 100644 --- a/docs/guide/events-webhooks.md +++ b/docs/guide/events-webhooks.md @@ -6,14 +6,22 @@ taskito includes an in-process event bus and webhook delivery system for reactin The `EventType` enum defines all available lifecycle events: -| Event | Fired when | -|-------|------------| -| `JOB_ENQUEUED` | A job is added to the queue | -| `JOB_COMPLETED` | A job finishes successfully | -| `JOB_FAILED` | A job raises an exception (before retry) | -| `JOB_RETRYING` | A failed job is being retried | -| `JOB_DEAD` | A job exhausts all retries and enters the DLQ | -| `JOB_CANCELLED` | A job is cancelled | +| Event | Fired when | Payload fields | +|-------|------------|----------------| +| `JOB_ENQUEUED` | A job is added to the queue | `job_id`, `task_name`, `queue` | +| `JOB_COMPLETED` | A job finishes successfully | `job_id`, `task_name`, `queue` | +| `JOB_FAILED` | A job raises an exception (before retry) | `job_id`, `task_name`, `queue`, `error` | +| `JOB_RETRYING` | A failed job will be retried | `job_id`, `task_name`, `error`, `retry_count` | +| `JOB_DEAD` | A job exhausts all retries and enters the DLQ | `job_id`, `task_name`, `error` | +| `JOB_CANCELLED` | A job is cancelled | `job_id`, `task_name` | +| `WORKER_STARTED` | A worker process/thread comes online | `worker_id`, `hostname` | +| `WORKER_STOPPED` | A worker process/thread shuts down | `worker_id`, `hostname` | +| `QUEUE_PAUSED` | A named queue is paused | `queue` | +| `QUEUE_RESUMED` | A paused queue is resumed | `queue` | + +`JOB_RETRYING`, `JOB_DEAD`, and `JOB_CANCELLED` are emitted by the Rust result handler immediately after the scheduler records the outcome. Middleware hooks (`on_retry`, `on_dead_letter`, `on_cancel`) are called in the same result-handling pass, after the event fires. + +`QUEUE_PAUSED` and `QUEUE_RESUMED` are emitted synchronously by `queue.pause()` and `queue.resume()` after the queue state is written to storage. ## Registering Listeners @@ -40,7 +48,7 @@ All callbacks receive two arguments: ### Async Delivery -Callbacks are dispatched asynchronously in a `ThreadPoolExecutor` (4 threads by default). This means: +Callbacks are dispatched asynchronously in a `ThreadPoolExecutor`. The thread pool size defaults to 4 and can be configured via `Queue(event_workers=N)`. This means: - Callbacks never block the worker - Exceptions in callbacks are logged but do not affect job processing @@ -61,12 +69,15 @@ queue.add_webhook( ) ``` -| Parameter | Type | Description | -|-----------|------|-------------| -| `url` | `str` | URL to POST event payloads to (must be `http://` or `https://`) | -| `events` | `list[EventType] | None` | Event types to subscribe to. `None` means all events | -| `headers` | `dict[str, str] | None` | Extra HTTP headers to include in requests | -| `secret` | `str | None` | HMAC-SHA256 signing secret | +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `url` | `str` | — | URL to POST event payloads to (must be `http://` or `https://`) | +| `events` | `list[EventType] \| None` | `None` | Event types to subscribe to. `None` means all events | +| `headers` | `dict[str, str] \| None` | `None` | Extra HTTP headers to include in requests | +| `secret` | `str \| None` | `None` | HMAC-SHA256 signing secret | +| `max_retries` | `int` | `3` | Maximum delivery attempts | +| `timeout` | `float` | `10.0` | HTTP request timeout in seconds | +| `retry_backoff` | `float` | `2.0` | Base for exponential backoff between retries | ### HMAC Signing @@ -89,15 +100,15 @@ def verify_signature(body: bytes, signature: str, secret: str) -> bool: ### Retry Behavior -Failed webhook deliveries are retried up to 3 times with exponential backoff: +Failed webhook deliveries are retried with exponential backoff. The number of attempts, request timeout, and backoff base are configurable per webhook via `max_retries`, `timeout`, and `retry_backoff`. With the defaults (`max_retries=3`, `retry_backoff=2.0`): -| Attempt | Delay | -|---------|-------| -| 1st retry | 1 second | -| 2nd retry | 2 seconds | -| 3rd retry | 4 seconds | +| Attempt | Delay before next retry | +|---------|------------------------| +| 1st retry | 1 second (`2.0 ** 0`) | +| 2nd retry | 2 seconds (`2.0 ** 1`) | +| 3rd retry | — (final) | -If all attempts fail, a warning is logged and the event is dropped. +4xx responses are not retried. If all attempts fail, a warning is logged and the event is dropped. ### Event Filtering diff --git a/docs/guide/middleware.md b/docs/guide/middleware.md index 70abff5..656c850 100644 --- a/docs/guide/middleware.md +++ b/docs/guide/middleware.md @@ -1,6 +1,6 @@ # Per-Task Middleware -taskito supports a middleware system that lets you run code before, after, and on retry of task executions. Middleware can be applied globally (to all tasks) or per-task. +taskito supports a middleware system that lets you run code at key points in the task lifecycle. Middleware can be applied globally (to all tasks) or per-task. ## TaskMiddleware Base Class @@ -23,14 +23,61 @@ class LoggingMiddleware(TaskMiddleware): ### Hook Signatures -| Hook | Signature | Called when | -|---|---|---| -| `before(ctx)` | `ctx: JobContext` | Before task execution | -| `after(ctx, result, error)` | `ctx: JobContext`, `result: Any`, `error: Exception \| None` | After task execution (success or failure) | -| `on_retry(ctx, error, retry_count)` | `ctx: JobContext`, `error: Exception`, `retry_count: int` | When a task is about to be retried | +| Hook | Called when | +|---|---| +| `before(ctx)` | Before task execution | +| `after(ctx, result, error)` | After task execution (success or failure) | +| `on_retry(ctx, error, retry_count)` | A job fails and will be retried | +| `on_enqueue(task_name, args, kwargs, options)` | A job is about to be enqueued | +| `on_dead_letter(ctx, error)` | A job exhausts all retries and moves to the DLQ | +| `on_timeout(ctx)` | A job hits its timeout limit | +| `on_cancel(ctx)` | A job is cancelled during execution | The `ctx` parameter is a `JobContext` — the same object as `current_job` — providing `ctx.id`, `ctx.task_name`, `ctx.retry_count`, and `ctx.queue_name`. +!!! note "Lifecycle hooks dispatched from Rust" + `on_retry`, `on_dead_letter`, `on_timeout`, and `on_cancel` are called by the Rust result handler after the scheduler records the outcome. They fire after `after()` and after the corresponding event is emitted on the event bus. Exceptions raised inside these hooks are logged and do not affect job processing. + +### `on_timeout` — Handling Timed-Out Jobs + +`on_timeout` fires when the Rust scheduler detects a stale job that exceeded its hard `timeout`. Detection happens in the maintenance reaper, which periodically scans for jobs still marked as running past their deadline. + +When a job times out, `on_timeout` is called **before** `on_retry` (if the job will be retried) or `on_dead_letter` (if retries are exhausted). This lets you react to the timeout itself independently of whether the job will be retried: + +```python +class TimeoutAlerter(TaskMiddleware): + def on_timeout(self, ctx): + # Fires for every timed-out job, regardless of retry/DLQ outcome + logger.error("Job %s (%s) timed out", ctx.id, ctx.task_name) + + def on_retry(self, ctx, error, retry_count): + # Fires after on_timeout when the job will be retried + logger.warning("Retrying %s (attempt %d)", ctx.task_name, retry_count) + + def on_dead_letter(self, ctx, error): + # Fires after on_timeout when retries are exhausted + logger.critical("Job %s exhausted retries after timeout", ctx.id) +``` + +!!! tip + Use `on_timeout` to update dashboards, fire alerts, or record SLA violations. Combine with `on_retry` and `on_dead_letter` for full visibility into the job's fate after the timeout. + +### `on_enqueue` — Modifying Enqueue Parameters + +`on_enqueue` is unique: it fires before the job is written to the database, and the `options` dict is **mutable**. Modify it to change how the job is enqueued: + +```python +class PriorityBoostMiddleware(TaskMiddleware): + def on_enqueue(self, task_name, args, kwargs, options): + # Bump priority for urgent tasks during business hours + import datetime + hour = datetime.datetime.now().hour + if 9 <= hour < 18 and task_name.startswith("alerts."): + options["priority"] = max(options.get("priority", 0), 50) +``` + +Keys present in `options`: `priority`, `delay`, `queue`, `max_retries`, `timeout`, `unique_key`, `metadata`. + ## Queue-Level Middleware Apply middleware to **all tasks** by passing it to the `Queue` constructor: @@ -79,9 +126,11 @@ taskito has two systems for running code around tasks: | | Hooks (`@queue.on_failure`, etc.) | Middleware (`TaskMiddleware`) | |---|---|---| | **Scope** | Queue-level only | Queue-level or per-task | -| **Interface** | Decorated functions | Class with `before`/`after`/`on_retry` | +| **Interface** | Decorated functions | Class with up to 7 hooks | | **Context** | Receives `task_name, args, kwargs` | Receives `JobContext` | +| **Enqueue hook** | No | Yes (`on_enqueue`, can mutate options) | | **Retry hook** | No | Yes (`on_retry`) | +| **DLQ / timeout / cancel hooks** | No | Yes | | **Execution order** | After middleware | Before hooks | Middleware runs **inside** the task wrapper (closer to the task function), while hooks run **outside**. In practice, middleware `before()` fires first, then `before_task` hooks. On completion, `on_success`/`on_failure` hooks fire, then middleware `after()`. diff --git a/docs/guide/queues.md b/docs/guide/queues.md index 4b61b5d..bf8699e 100644 --- a/docs/guide/queues.md +++ b/docs/guide/queues.md @@ -84,6 +84,33 @@ high = task.apply_async(args=(3,), priority=10) # Processing order: high (10), mid (5), low (1) ``` +## Queue-Level Limits + +Apply a rate limit or concurrency cap to an entire queue, independently of per-task settings. These limits are checked in the scheduler before any per-task limits. + +### Rate limiting a queue + +```python +queue.set_queue_rate_limit("default", "100/m") # Max 100 jobs per minute +queue.set_queue_rate_limit("emails", "20/s") # Max 20 emails per second +``` + +The format is the same as `rate_limit` on `@queue.task()`: `"N/s"`, `"N/m"`, or `"N/h"`. + +### Capping concurrency per queue + +```python +queue.set_queue_concurrency("default", 10) # Max 10 jobs running at once +queue.set_queue_concurrency("reports", 2) # Heavy tasks: max 2 at a time +``` + +`set_queue_concurrency` limits how many jobs from that queue run simultaneously across all workers. + +!!! tip "Queue limits vs task limits" + Queue-level limits apply to all tasks in the queue regardless of their individual settings. Per-task `rate_limit` and `max_concurrent` are checked afterwards and may impose stricter caps. Set queue limits to protect shared downstream resources (APIs, databases) and per-task limits to manage individual task capacity. + +Both methods can be called at any point before or after `run_worker()` starts. + ## Default Queue Settings Configure defaults at the Queue level: diff --git a/docs/guide/tasks.md b/docs/guide/tasks.md index d0c790d..371ef46 100644 --- a/docs/guide/tasks.md +++ b/docs/guide/tasks.md @@ -23,6 +23,7 @@ def process_data(data: dict) -> str: | `max_retries` | `int` | `3` | Max retry attempts before moving to DLQ. | | `retry_backoff` | `float` | `1.0` | Base delay in seconds for exponential backoff. | | `retry_delays` | `list[float] \| None` | `None` | Per-attempt delays in seconds, overrides backoff. e.g. `[1, 5, 30]`. | +| `max_retry_delay` | `int \| None` | `None` | Cap on backoff delay in seconds (default 300 s). | | `timeout` | `int` | `300` | Max execution time in seconds (hard timeout). | | `soft_timeout` | `float \| None` | `None` | Cooperative time limit in seconds; checked via `current_job.check_timeout()`. | | `priority` | `int` | `0` | Default priority (higher = more urgent). | @@ -32,16 +33,20 @@ def process_data(data: dict) -> str: | `middleware` | `list[TaskMiddleware] \| None` | `None` | Per-task middleware, applied in addition to queue-level middleware. | | `expires` | `float \| None` | `None` | Seconds until the job expires if not started. | | `inject` | `list[str] \| None` | `None` | Worker resource names to inject as keyword arguments. See [Resource System](resources.md). | +| `serializer` | `Serializer \| None` | `None` | Per-task serializer override. Falls back to the queue-level serializer. | +| `max_concurrent` | `int \| None` | `None` | Max concurrent running instances of this task. `None` means no limit. | ```python @queue.task( name="emails.send", max_retries=5, retry_backoff=2.0, + max_retry_delay=60, # cap backoff at 60 s timeout=60, priority=10, rate_limit="100/m", queue="emails", + max_concurrent=10, ) def send_email(to: str, subject: str, body: str): ... @@ -107,6 +112,44 @@ def time_sensitive(): ... ``` +### Max Retry Delay + +Cap the exponential backoff so waits don't grow unbounded: + +```python +@queue.task(retry_backoff=2.0, max_retries=10, max_retry_delay=120) +def flaky_service(): + ... +# Delays: 2, 4, 8, 16, 32, 64, 120, 120, 120 s (capped at 2 min) +``` + +### Per-Task Concurrency Limit + +Prevent a single task type from consuming all workers: + +```python +@queue.task(max_concurrent=3) +def expensive_render(): + ... +# At most 3 instances of expensive_render run simultaneously across all workers. +``` + +### Per-Task Serializer + +Override the queue-level serializer for a specific task: + +```python +from taskito.serializers import JSONSerializer + +@queue.task(serializer=JSONSerializer()) +def api_event(payload: dict) -> dict: + ... +``` + +The per-task serializer is used for the full round-trip: arguments are serialized with it at enqueue time and deserialized with it on the worker before the task function is called. Both the sync worker and the native async worker honour the per-task serializer, falling back to the queue-level serializer for tasks that have none registered. + +Useful when a task needs a different format (e.g., human-readable JSON for audit tasks) or when the payload is not picklable. + ## Task Naming By default, tasks are named using `module.qualname`: diff --git a/docs/integrations/django.md b/docs/integrations/django.md index ffed512..bf51075 100644 --- a/docs/integrations/django.md +++ b/docs/integrations/django.md @@ -32,9 +32,35 @@ from taskito.contrib.django.admin import TaskitoAdminSite admin_site = TaskitoAdminSite(name="taskito_admin") ``` -## Configuration +## Django Settings -Create a `taskito` queue configuration in your Django settings or a dedicated module. The `get_queue()` function in `taskito.contrib.django.settings` is used to retrieve the queue instance. +The following settings can be defined in your Django `settings.py`: + +| Setting | Default | Description | +|---------|---------|-------------| +| `TASKITO_AUTODISCOVER_MODULE` | `"tasks"` | Module name auto-discovered in each installed app on startup. | +| `TASKITO_ADMIN_PER_PAGE` | `50` | Rows per page in the admin jobs and dead letters views. | +| `TASKITO_ADMIN_TITLE` | `"Taskito"` | Browser tab title for `TaskitoAdminSite`. | +| `TASKITO_ADMIN_HEADER` | `"Taskito Admin"` | Site header shown in `TaskitoAdminSite`. | +| `TASKITO_WATCH_INTERVAL` | `2` | Polling interval in seconds for `manage.py taskito_info --watch`. | +| `TASKITO_DASHBOARD_HOST` | `"127.0.0.1"` | Default bind host for `manage.py taskito_dashboard`. | +| `TASKITO_DASHBOARD_PORT` | `8080` | Default bind port for `manage.py taskito_dashboard`. | + +Example: + +```python +# settings.py +TASKITO_AUTODISCOVER_MODULE = "jobs" # import myapp.jobs instead of myapp.tasks +TASKITO_ADMIN_PER_PAGE = 25 +TASKITO_ADMIN_TITLE = "MyApp Tasks" +TASKITO_ADMIN_HEADER = "MyApp Task Queue" +TASKITO_DASHBOARD_HOST = "0.0.0.0" +TASKITO_DASHBOARD_PORT = 9000 +``` + +## Queue Configuration + +Create a `taskito` queue instance in your Django project. The `get_queue()` function in `taskito.contrib.django.settings` is used to retrieve the queue instance. ```python # myproject/tasks.py diff --git a/docs/integrations/fastapi.md b/docs/integrations/fastapi.md index 00308db..bffa52e 100644 --- a/docs/integrations/fastapi.md +++ b/docs/integrations/fastapi.md @@ -38,6 +38,7 @@ uvicorn myapp:app --reload | Method | Path | Description | |--------|------|-------------| | `GET` | `/stats` | Queue statistics | +| `GET` | `/stats/queues` | Per-queue statistics | | `GET` | `/jobs/{job_id}` | Job status, progress, and metadata | | `GET` | `/jobs/{job_id}/errors` | Error history for a job | | `GET` | `/jobs/{job_id}/result` | Job result (optional `?timeout=N` for blocking) | @@ -45,5 +46,48 @@ uvicorn myapp:app --reload | `POST` | `/jobs/{job_id}/cancel` | Cancel a pending job | | `GET` | `/dead-letters` | List dead letter entries (paginated) | | `POST` | `/dead-letters/{dead_id}/retry` | Re-enqueue a dead letter | +| `GET` | `/health` | Liveness check | +| `GET` | `/readiness` | Readiness check | +| `GET` | `/resources` | Worker resource status | + +## Configuration + +`TaskitoRouter` accepts options to control which routes are registered, how results are serialized, and page sizes: + +```python +from fastapi import Depends, HTTPException +from taskito.contrib.fastapi import TaskitoRouter + +def require_api_key(x_api_key: str = Header(...)): + if x_api_key != "secret": + raise HTTPException(status_code=403) + +app.include_router( + TaskitoRouter( + queue, + include_routes={"stats", "jobs", "dead-letters", "retry-dead"}, + dependencies=[Depends(require_api_key)], + sse_poll_interval=1.0, + result_timeout=5.0, + default_page_size=25, + max_page_size=200, + result_serializer=lambda v: v if isinstance(v, (str, int, float, bool, None)) else str(v), + ), + prefix="/tasks", +) +``` + +| Parameter | Type | Default | Description | +|---|---|---|---| +| `include_routes` | `set[str] \| None` | `None` | If set, only register these route names. Cannot be combined with `exclude_routes`. | +| `exclude_routes` | `set[str] \| None` | `None` | If set, skip these route names. Cannot be combined with `include_routes`. | +| `dependencies` | `Sequence[Depends] \| None` | `None` | FastAPI dependencies applied to every route (e.g. auth). | +| `sse_poll_interval` | `float` | `0.5` | Seconds between SSE progress polls. | +| `result_timeout` | `float` | `1.0` | Default timeout for non-blocking result fetch. | +| `default_page_size` | `int` | `20` | Default page size for paginated endpoints. | +| `max_page_size` | `int` | `100` | Maximum allowed page size. | +| `result_serializer` | `Callable[[Any], Any] \| None` | `None` | Custom result serializer. Receives any value, must return a JSON-serializable value. | + +Valid route names: `stats`, `jobs`, `job-errors`, `job-result`, `job-progress`, `cancel`, `dead-letters`, `retry-dead`, `health`, `readiness`, `resources`, `queue-stats`. For full details on SSE streaming, blocking result fetch, Pydantic response models, and authentication, see the [Advanced guide](../guide/advanced.md#fastapi-integration). diff --git a/docs/integrations/flask.md b/docs/integrations/flask.md index 3690699..4454ae9 100644 --- a/docs/integrations/flask.md +++ b/docs/integrations/flask.md @@ -55,9 +55,18 @@ All configuration is read from `app.config`: | `TASKITO_RESULT_TTL` | `None` | Auto-purge completed jobs after N seconds | | `TASKITO_DRAIN_TIMEOUT` | `30` | Seconds to wait for running tasks on shutdown | +## Extension Options + +The `Taskito` constructor accepts a `cli_group` parameter to rename the CLI command group: + +```python +# Commands will be under `flask tasks worker`, `flask tasks info`, etc. +taskito = Taskito(app, cli_group="tasks") +``` + ## CLI Commands -The extension registers commands under the `flask taskito` group: +The extension registers commands under the `flask taskito` group (configurable via `cli_group`): ### `flask taskito worker` @@ -70,10 +79,11 @@ flask taskito worker --queues default,emails ### `flask taskito info` -Show queue statistics: +Show queue statistics. Supports `--format table` (default) and `--format json`: ```bash flask taskito info +flask taskito info --format json ``` ``` diff --git a/docs/integrations/otel.md b/docs/integrations/otel.md index ee5a779..3a56ec0 100644 --- a/docs/integrations/otel.md +++ b/docs/integrations/otel.md @@ -27,7 +27,7 @@ queue = Queue(middleware=[OpenTelemetryMiddleware()]) Each task execution produces a span with: -- **Span name**: `taskito.execute.` +- **Span name**: `taskito.execute.` (customizable) - **Attributes**: - `taskito.job_id` — the job ID - `taskito.task_name` — the registered task name @@ -58,14 +58,28 @@ from taskito.contrib.otel import OpenTelemetryMiddleware queue = Queue(middleware=[OpenTelemetryMiddleware()]) ``` -### Custom Tracer Name +## Configuration -By default, spans are created under the `"taskito"` tracer. Override with: +`OpenTelemetryMiddleware` accepts several options to customize how spans are created: ```python -OpenTelemetryMiddleware(tracer_name="my-service") +OpenTelemetryMiddleware( + tracer_name="my-service", + span_name_fn=lambda ctx: f"task/{ctx.task_name}", + attribute_prefix="myapp", + extra_attributes_fn=lambda ctx: {"deployment.env": "prod"}, + task_filter=lambda name: not name.startswith("internal."), +) ``` +| Parameter | Type | Default | Description | +|---|---|---|---| +| `tracer_name` | `str` | `"taskito"` | OpenTelemetry tracer name. | +| `span_name_fn` | `Callable[[JobContext], str] \| None` | `None` | Custom span name builder. Receives `JobContext`, returns a string. Defaults to `.execute.`. | +| `attribute_prefix` | `str` | `"taskito"` | Prefix for all span attribute keys. | +| `extra_attributes_fn` | `Callable[[JobContext], dict] \| None` | `None` | Returns extra attributes to add to each span. Receives `JobContext`. | +| `task_filter` | `Callable[[str], bool] \| None` | `None` | Predicate that receives a task name. Return `True` to trace, `False` to skip. `None` traces all tasks. | + ## Combining with Other Middleware `OpenTelemetryMiddleware` is a standard `TaskMiddleware`, so it composes with other middleware: diff --git a/docs/integrations/prometheus.md b/docs/integrations/prometheus.md index fb7d17d..860e7b8 100644 --- a/docs/integrations/prometheus.md +++ b/docs/integrations/prometheus.md @@ -21,6 +21,22 @@ from taskito.contrib.prometheus import PrometheusMiddleware queue = Queue(db_path="myapp.db", middleware=[PrometheusMiddleware()]) ``` +### Configuration + +```python +PrometheusMiddleware( + namespace="myapp", + extra_labels_fn=lambda ctx: {"env": "prod", "region": "us-east-1"}, + disabled_metrics={"resource", "proxy"}, +) +``` + +| Parameter | Type | Default | Description | +|---|---|---|---| +| `namespace` | `str` | `"taskito"` | Prefix for all metric names. | +| `extra_labels_fn` | `Callable[[JobContext], dict[str, str]] \| None` | `None` | Returns extra labels to add to job metrics. Receives `JobContext`. | +| `disabled_metrics` | `set[str] \| None` | `None` | Metric groups or individual names to skip. Groups: `"jobs"`, `"queue"`, `"resource"`, `"proxy"`, `"intercept"`. | + ### Metrics Tracked | Metric | Type | Labels | Description | @@ -41,6 +57,24 @@ collector = PrometheusStatsCollector(queue, interval=10) collector.start() ``` +### Configuration + +```python +PrometheusStatsCollector( + queue, + interval=10, + namespace="myapp", + disabled_metrics={"intercept"}, +) +``` + +| Parameter | Type | Default | Description | +|---|---|---|---| +| `queue` | `Queue` | — | The Queue instance to poll. | +| `interval` | `float` | `10.0` | Seconds between polls. | +| `namespace` | `str` | `"taskito"` | Prefix for metric names. Must match `PrometheusMiddleware` namespace to share metric objects. | +| `disabled_metrics` | `set[str] \| None` | `None` | Metric groups or names to skip. Same groups as `PrometheusMiddleware`. | + ### Metrics Tracked | Metric | Type | Labels | Description | diff --git a/docs/integrations/sentry.md b/docs/integrations/sentry.md index 1b34022..fdcc8c7 100644 --- a/docs/integrations/sentry.md +++ b/docs/integrations/sentry.md @@ -28,7 +28,7 @@ queue = Queue(db_path="myapp.db", middleware=[SentryMiddleware()]) ### Scope Tags -Each task execution gets a Sentry scope with the following tags: +Each task execution gets a Sentry scope with the following tags (prefix customizable via `tag_prefix`): | Tag | Value | |-----|-------| @@ -39,7 +39,7 @@ Each task execution gets a Sentry scope with the following tags: ### Transaction Name -The Sentry transaction is set to `taskito:`, making it easy to filter and group task performance data in the Sentry dashboard. +The Sentry transaction is set to `taskito:` by default. Customizable via `transaction_name_fn`. ### Automatic Error Capture @@ -49,12 +49,30 @@ When a task raises an exception, `SentryMiddleware` calls `sentry_sdk.capture_ex When a task is retried, a breadcrumb is added with: -- **Category**: `taskito` +- **Category**: `taskito` (matches `tag_prefix`) - **Level**: `warning` - **Message**: `Retrying (attempt ): ` This gives you a trail of retry attempts leading up to a final failure. +## Configuration + +```python +SentryMiddleware( + tag_prefix="myapp", + transaction_name_fn=lambda ctx: f"task-{ctx.task_name}", + task_filter=lambda name: not name.startswith("internal."), + extra_tags_fn=lambda ctx: {"worker.host": socket.gethostname()}, +) +``` + +| Parameter | Type | Default | Description | +|---|---|---|---| +| `tag_prefix` | `str` | `"taskito"` | Prefix for Sentry tag keys and breadcrumb category. | +| `transaction_name_fn` | `Callable[[JobContext], str] \| None` | `None` | Custom transaction name builder. Receives `JobContext`. Defaults to `:`. | +| `task_filter` | `Callable[[str], bool] \| None` | `None` | Predicate on task name. Return `True` to report, `False` to skip. `None` reports all tasks. | +| `extra_tags_fn` | `Callable[[JobContext], dict[str, str]] \| None` | `None` | Returns extra Sentry tags to set. Receives `JobContext`. | + ## Combining with Other Middleware `SentryMiddleware` composes with other observability middleware: diff --git a/py_src/taskito/__init__.py b/py_src/taskito/__init__.py index be173d5..28ca42b 100644 --- a/py_src/taskito/__init__.py +++ b/py_src/taskito/__init__.py @@ -88,4 +88,4 @@ __version__ = _get_version("taskito") except PackageNotFoundError: - __version__ = "0.3.0" + __version__ = "0.6.0" diff --git a/py_src/taskito/_taskito.pyi b/py_src/taskito/_taskito.pyi index ac8e0dd..3b9bd1f 100644 --- a/py_src/taskito/_taskito.pyi +++ b/py_src/taskito/_taskito.pyi @@ -19,6 +19,10 @@ class PyTaskConfig: circuit_breaker_window: int | None circuit_breaker_cooldown: int | None + retry_delays: list[float] | None + max_retry_delay: int | None + max_concurrent: int | None + def __init__( self, name: str, @@ -32,6 +36,8 @@ class PyTaskConfig: circuit_breaker_window: int | None = None, circuit_breaker_cooldown: int | None = None, retry_delays: list[float] | None = None, + max_retry_delay: int | None = None, + max_concurrent: int | None = None, ) -> None: ... class PyJob: @@ -74,6 +80,9 @@ class PyQueue: db_url: str | None = None, schema: str = "taskito", pool_size: int | None = None, + scheduler_poll_interval_ms: int = 50, + scheduler_reap_interval: int = 100, + scheduler_cleanup_interval: int = 1200, ) -> None: ... def request_shutdown(self) -> None: ... def enqueue( @@ -156,6 +165,7 @@ class PyQueue: resources: str | None = None, threads: int = 1, async_concurrency: int = 100, + queue_configs: str | None = None, ) -> None: ... def worker_heartbeat( self, diff --git a/py_src/taskito/app.py b/py_src/taskito/app.py index 3eed4ea..6b4a2ee 100644 --- a/py_src/taskito/app.py +++ b/py_src/taskito/app.py @@ -104,6 +104,10 @@ def __init__( file_path_allowlist: list[str] | None = None, disabled_proxies: list[str] | None = None, async_concurrency: int = 100, + event_workers: int = 4, + scheduler_poll_interval_ms: int = 50, + scheduler_reap_interval: int = 100, + scheduler_cleanup_interval: int = 1200, ): """Initialize a new task queue. @@ -142,6 +146,13 @@ def __init__( proxy handlers. async_concurrency: Maximum number of async tasks running concurrently on the native async executor. Defaults to 100. + event_workers: Thread pool size for the event bus (default 4). + scheduler_poll_interval_ms: Milliseconds between scheduler poll + cycles (default 50). + scheduler_reap_interval: Reap stale jobs every N poll iterations + (default 100). + scheduler_cleanup_interval: Cleanup old jobs every N poll iterations + (default 1200). """ if backend == "sqlite": # Ensure parent directory exists for SQLite @@ -160,6 +171,9 @@ def __init__( db_url=db_url, schema=schema, pool_size=pool_size, + scheduler_poll_interval_ms=scheduler_poll_interval_ms, + scheduler_reap_interval=scheduler_reap_interval, + scheduler_cleanup_interval=scheduler_cleanup_interval, ) self._backend = backend self._db_url = db_url @@ -177,11 +191,13 @@ def __init__( "on_failure": [], } self._serializer: Serializer = serializer or CloudpickleSerializer() + self._task_serializers: dict[str, Serializer] = {} self._global_middleware: list[TaskMiddleware] = middleware or [] self._task_middleware: dict[str, list[TaskMiddleware]] = {} self._task_retry_filters: dict[str, dict[str, list[type[Exception]]]] = {} self._drain_timeout = drain_timeout - self._event_bus = EventBus() + self._queue_configs: dict[str, dict[str, Any]] = {} + self._event_bus = EventBus(max_workers=event_workers) self._webhook_manager = WebhookManager() # Proxy handlers @@ -239,6 +255,9 @@ def task( middleware: list[TaskMiddleware] | None = None, retry_delays: list[float] | None = None, inject: list[str] | None = None, + serializer: Serializer | None = None, + max_retry_delay: int | None = None, + max_concurrent: int | None = None, ) -> Callable[[Callable[..., Any]], TaskWrapper]: """Decorator to register a function as a background task. @@ -258,6 +277,11 @@ def task( soft_timeout: Soft timeout in seconds. Checked via ``current_job.check_timeout()``. middleware: Per-task middleware instances (in addition to global middleware). inject: List of resource names to inject as keyword arguments. + serializer: Per-task serializer. Falls back to the queue-level serializer. + max_retry_delay: Maximum backoff delay in seconds. Defaults to 300 + (5 minutes) if not set. + max_concurrent: Maximum number of concurrent running instances of + this task. ``None`` means no limit. """ def decorator(fn: Callable) -> TaskWrapper: @@ -303,6 +327,10 @@ def decorator(fn: Callable) -> TaskWrapper: if middleware: self._task_middleware[task_name] = middleware + # Store per-task serializer + if serializer is not None: + self._task_serializers[task_name] = serializer + # Store inject map for resource injection if final_inject: self._task_inject_map[task_name] = final_inject @@ -332,6 +360,8 @@ def decorator(fn: Callable) -> TaskWrapper: circuit_breaker_window=cb_window, circuit_breaker_cooldown=cb_cooldown, retry_delays=retry_delays, + max_retry_delay=max_retry_delay, + max_concurrent=max_concurrent, ) self._task_configs.append(config) @@ -385,7 +415,7 @@ def decorator(fn: Callable) -> TaskWrapper: wrapper = self.task(name=name, queue=queue)(fn) # Store periodic config for registration at worker startup - payload = self._serializer.dumps((args, kwargs or {})) + payload = self._get_serializer(wrapper.name).dumps((args, kwargs or {})) self._periodic_configs.append( { "name": name or f"{_resolve_module_name(fn.__module__)}.{fn.__qualname__}", @@ -610,6 +640,33 @@ def register_type( proxy_handler=proxy_handler, ) + def set_queue_rate_limit(self, queue_name: str, rate_limit: str) -> None: + """Set a rate limit for an entire queue. + + Args: + queue_name: Queue name (e.g. ``"default"``). + rate_limit: Rate limit string, e.g. ``"100/m"``, ``"10/s"``. + """ + self._queue_configs.setdefault(queue_name, {})["rate_limit"] = rate_limit + + def set_queue_concurrency(self, queue_name: str, max_concurrent: int) -> None: + """Set a maximum number of concurrent jobs for a queue. + + Args: + queue_name: Queue name (e.g. ``"default"``). + max_concurrent: Maximum number of jobs running simultaneously + from this queue. + """ + self._queue_configs.setdefault(queue_name, {})["max_concurrent"] = max_concurrent + + def _get_serializer(self, task_name: str) -> Serializer: + """Get the serializer for a task (per-task or queue-level fallback).""" + return self._task_serializers.get(task_name, self._serializer) + + def _deserialize_payload(self, task_name: str, payload: bytes) -> tuple: + """Deserialize a job payload using the per-task or queue-level serializer.""" + return self._get_serializer(task_name).loads(payload) # type: ignore[no-any-return] + def _get_middleware_chain(self, task_name: str) -> list[TaskMiddleware]: """Get the combined global + per-task middleware list.""" per_task = self._task_middleware.get(task_name, []) @@ -761,9 +818,42 @@ def enqueue( """ final_args = args final_kwargs = kwargs or {} + + # Run on_enqueue middleware hook (options dict is mutable) + enqueue_options: dict[str, Any] = { + "priority": priority, + "delay": delay, + "queue": queue, + "max_retries": max_retries, + "timeout": timeout, + "unique_key": unique_key, + "metadata": metadata, + "depends_on": depends_on, + "expires": expires, + "result_ttl": result_ttl, + } + for mw in self._global_middleware: + try: + mw.on_enqueue(task_name, final_args, final_kwargs, enqueue_options) + except Exception: + logger.exception("middleware on_enqueue() error") + + # Apply any middleware mutations back + priority = enqueue_options.get("priority") + delay = enqueue_options.get("delay") + queue = enqueue_options.get("queue") + max_retries = enqueue_options.get("max_retries") + timeout = enqueue_options.get("timeout") + unique_key = enqueue_options.get("unique_key") + metadata = enqueue_options.get("metadata") + depends_on = enqueue_options.get("depends_on") + expires = enqueue_options.get("expires") + result_ttl = enqueue_options.get("result_ttl") + if self._interceptor is not None and not self._test_mode_active: final_args, final_kwargs = self._interceptor.intercept(final_args, final_kwargs) - payload = self._serializer.dumps((final_args, final_kwargs)) + task_serializer = self._get_serializer(task_name) + payload = task_serializer.dumps((final_args, final_kwargs)) dep_ids = None if depends_on is not None: @@ -826,11 +916,12 @@ def enqueue_many( f"args_list length ({len(args_list)})" ) kw_list = kwargs_list or [{}] * count + task_serializer = self._get_serializer(task_name) if self._interceptor is not None: pairs = [self._interceptor.intercept(a, kw) for a, kw in zip(args_list, kw_list)] - payloads = [self._serializer.dumps((a, kw)) for a, kw in pairs] + payloads = [task_serializer.dumps((a, kw)) for a, kw in pairs] else: - payloads = [self._serializer.dumps((a, kw)) for a, kw in zip(args_list, kw_list)] + payloads = [task_serializer.dumps((a, kw)) for a, kw in zip(args_list, kw_list)] task_names = [task_name] * count queues_list = [queue or "default"] * count if queue else None @@ -871,6 +962,9 @@ def add_webhook( events: list[EventType] | None = None, headers: dict[str, str] | None = None, secret: str | None = None, + max_retries: int = 3, + timeout: float = 10.0, + retry_backoff: float = 2.0, ) -> None: """Register a webhook endpoint for job events. @@ -879,8 +973,19 @@ def add_webhook( events: Event types to subscribe to (None = all). headers: Extra HTTP headers. secret: HMAC-SHA256 signing secret. + max_retries: Maximum delivery attempts (default 3). + timeout: HTTP request timeout in seconds (default 10.0). + retry_backoff: Base for exponential backoff between retries (default 2.0). """ - self._webhook_manager.add_webhook(url, events, headers, secret) + self._webhook_manager.add_webhook( + url, + events, + headers, + secret, + max_retries=max_retries, + timeout=timeout, + retry_backoff=retry_backoff, + ) # -- Worker startup -- @@ -1043,7 +1148,13 @@ def sighup_handler(signum: int, frame: Any) -> None: ) heartbeat_thread.start() + self._emit_event( + EventType.WORKER_STARTED, + {"worker_id": worker_id, "queues": worker_queues}, + ) + try: + queue_configs_json = json.dumps(self._queue_configs) if self._queue_configs else None self._inner.run_worker( task_registry=self._task_registry, task_configs=self._task_configs, @@ -1054,10 +1165,15 @@ def sighup_handler(signum: int, frame: Any) -> None: resources=resources_json, threads=self._workers, async_concurrency=self._async_concurrency, + queue_configs=queue_configs_json, ) except KeyboardInterrupt: logger.info("Cold shutdown (terminating immediately)") finally: + self._emit_event( + EventType.WORKER_STOPPED, + {"worker_id": worker_id}, + ) stop_heartbeat.set() heartbeat_thread.join(timeout=6) # Tear down resources before stopping async loop diff --git a/py_src/taskito/contrib/django/admin.py b/py_src/taskito/contrib/django/admin.py index ce0f8b3..6d0bc0b 100644 --- a/py_src/taskito/contrib/django/admin.py +++ b/py_src/taskito/contrib/django/admin.py @@ -40,7 +40,9 @@ def _jobs_view(request: HttpRequest, site: Any) -> HttpResponse: except (ValueError, TypeError): page = 1 page = max(page, 1) - per_page = 50 + from django.conf import settings as django_settings + + per_page = getattr(django_settings, "TASKITO_ADMIN_PER_PAGE", 50) try: jobs = queue.list_jobs( @@ -96,7 +98,9 @@ def _dead_letters_view(request: HttpRequest, site: Any) -> HttpResponse: except (ValueError, TypeError): page = 1 page = max(page, 1) - per_page = 50 + from django.conf import settings as django_settings + + per_page = getattr(django_settings, "TASKITO_ADMIN_PER_PAGE", 50) dead = queue.dead_letters(limit=per_page, offset=(page - 1) * per_page) context = { **site.each_context(request), @@ -107,11 +111,26 @@ def _dead_letters_view(request: HttpRequest, site: Any) -> HttpResponse: return TemplateResponse(request, "taskito/admin/dead_letters.html", context) +def _get_admin_setting(name: str, default: str) -> str: + from django.conf import settings as django_settings + + return str(getattr(django_settings, name, default)) + + class TaskitoAdminSite(admin.AdminSite): - """Custom admin site with taskito queue views.""" + """Custom admin site with taskito queue views. + + Reads ``TASKITO_ADMIN_TITLE`` and ``TASKITO_ADMIN_HEADER`` from Django + settings to customize the admin site branding. + """ + + @property + def site_header(self) -> str: + return _get_admin_setting("TASKITO_ADMIN_HEADER", "Taskito Admin") - site_header = "Taskito Admin" - site_title = "Taskito" + @property + def site_title(self) -> str: + return _get_admin_setting("TASKITO_ADMIN_TITLE", "Taskito") def get_urls(self) -> list: urls = super().get_urls() diff --git a/py_src/taskito/contrib/django/apps.py b/py_src/taskito/contrib/django/apps.py index 0de7b16..7b71178 100644 --- a/py_src/taskito/contrib/django/apps.py +++ b/py_src/taskito/contrib/django/apps.py @@ -18,7 +18,13 @@ class TaskitoConfig(AppConfig): default_auto_field = "django.db.models.BigAutoField" def ready(self) -> None: - """Auto-discover ``tasks.py`` modules in all installed apps.""" + """Auto-discover task modules in all installed apps. + + The module name defaults to ``"tasks"`` but can be overridden via the + ``TASKITO_AUTODISCOVER_MODULE`` Django setting. + """ + from django.conf import settings from django.utils.module_loading import autodiscover_modules - autodiscover_modules("tasks") + module_name = getattr(settings, "TASKITO_AUTODISCOVER_MODULE", "tasks") + autodiscover_modules(module_name) diff --git a/py_src/taskito/contrib/django/management/commands/taskito_dashboard.py b/py_src/taskito/contrib/django/management/commands/taskito_dashboard.py index ab70da2..6117f14 100644 --- a/py_src/taskito/contrib/django/management/commands/taskito_dashboard.py +++ b/py_src/taskito/contrib/django/management/commands/taskito_dashboard.py @@ -14,10 +14,22 @@ class Command(BaseCommand): help = "Start the taskito web dashboard" def add_arguments(self, parser): # type: ignore[no-untyped-def] + from django.conf import settings + + default_host = getattr(settings, "TASKITO_DASHBOARD_HOST", "127.0.0.1") + default_port = getattr(settings, "TASKITO_DASHBOARD_PORT", 8080) + + parser.add_argument( + "--host", + default=default_host, + help=f"Bind address (default: {default_host})", + ) parser.add_argument( - "--host", default="127.0.0.1", help="Bind address (default: 127.0.0.1)" + "--port", + type=int, + default=default_port, + help=f"Bind port (default: {default_port})", ) - parser.add_argument("--port", type=int, default=8080, help="Bind port (default: 8080)") def handle(self, **options): # type: ignore[no-untyped-def] from taskito.contrib.django.settings import get_queue diff --git a/py_src/taskito/contrib/django/management/commands/taskito_info.py b/py_src/taskito/contrib/django/management/commands/taskito_info.py index 4f28f19..d3dfce9 100644 --- a/py_src/taskito/contrib/django/management/commands/taskito_info.py +++ b/py_src/taskito/contrib/django/management/commands/taskito_info.py @@ -44,19 +44,23 @@ def _print(self, queue): # type: ignore[no-untyped-def] def _watch(self, queue): # type: ignore[no-untyped-def] import time + from django.conf import settings + + interval = getattr(settings, "TASKITO_WATCH_INTERVAL", 2) + prev_completed = 0 try: while True: self.stdout.write("\033[2J\033[H", ending="") stats = queue.stats() completed = stats.get("completed", 0) - throughput = (completed - prev_completed) / 2.0 + throughput = (completed - prev_completed) / float(interval) prev_completed = completed self._print(queue) if throughput > 0: self.stdout.write(f"\n throughput {throughput:.1f} jobs/s") - self.stdout.write("\nRefreshing every 2s... (Ctrl+C to stop)") - time.sleep(2) + self.stdout.write(f"\nRefreshing every {interval}s... (Ctrl+C to stop)") + time.sleep(interval) except KeyboardInterrupt: pass diff --git a/py_src/taskito/contrib/fastapi.py b/py_src/taskito/contrib/fastapi.py index 0d0c35b..5d8a147 100644 --- a/py_src/taskito/contrib/fastapi.py +++ b/py_src/taskito/contrib/fastapi.py @@ -23,8 +23,8 @@ import asyncio import json import logging -from collections.abc import AsyncGenerator -from typing import TYPE_CHECKING, Any +from collections.abc import AsyncGenerator, Sequence +from typing import TYPE_CHECKING, Any, Callable logger = logging.getLogger(__name__) @@ -134,6 +134,24 @@ class ReadinessResponse(BaseModel): checks: dict[str, Any] +# ── All known route names ──────────────────────────────── + +_ALL_ROUTES: set[str] = { + "stats", + "jobs", + "job-errors", + "job-result", + "job-progress", + "cancel", + "dead-letters", + "retry-dead", + "health", + "readiness", + "resources", + "queue-stats", +} + + # ── Router factory ─────────────────────────────────────── @@ -142,8 +160,19 @@ class TaskitoRouter(APIRouter): Args: queue: The taskito Queue instance to expose. + include_routes: If set, only register these route names. Cannot be + used together with ``exclude_routes``. + exclude_routes: If set, skip these route names when registering. + dependencies: FastAPI dependency list applied to every route. + sse_poll_interval: Seconds between SSE progress polls (default 0.5). + result_timeout: Default timeout for blocking result fetch (default 1.0). + default_page_size: Default page size for paginated endpoints (default 20). + max_page_size: Maximum allowed page size (default 100). + result_serializer: Custom result serializer. Receives any value and + must return a JSON-serializable value. Falls back to + :func:`_safe_serialize`. **kwargs: Passed to ``APIRouter.__init__()`` (e.g. ``prefix``, - ``tags``, ``dependencies``). + ``tags``). Example:: @@ -153,160 +182,234 @@ class TaskitoRouter(APIRouter): ) """ - def __init__(self, queue: Queue, **kwargs: Any) -> None: + def __init__( + self, + queue: Queue, + *, + include_routes: set[str] | None = None, + exclude_routes: set[str] | None = None, + dependencies: Sequence[Any] | None = None, + sse_poll_interval: float = 0.5, + result_timeout: float = 1.0, + default_page_size: int = 20, + max_page_size: int = 100, + result_serializer: Callable[[Any], Any] | None = None, + **kwargs: Any, + ) -> None: + if include_routes is not None and exclude_routes is not None: + raise ValueError("Cannot specify both include_routes and exclude_routes") + kwargs.setdefault("tags", ["taskito"]) + if dependencies is not None: + kwargs.setdefault("dependencies", list(dependencies)) super().__init__(**kwargs) self._queue = queue + self._sse_poll_interval = sse_poll_interval + self._result_timeout = result_timeout + self._default_page_size = default_page_size + self._max_page_size = max_page_size + self._result_serializer = result_serializer or _safe_serialize + + # Compute active route set + if include_routes is not None: + self._active_routes = include_routes & _ALL_ROUTES + elif exclude_routes is not None: + self._active_routes = _ALL_ROUTES - exclude_routes + else: + self._active_routes = _ALL_ROUTES + self._register_routes() + def _should_register(self, name: str) -> bool: + return name in self._active_routes + def _register_routes(self) -> None: queue = self._queue - - @self.get("/stats", response_model=StatsResponse) - async def get_stats() -> StatsResponse: - """Get queue statistics.""" - stats = await queue.astats() - return StatsResponse(**stats) - - @self.get("/jobs/{job_id}", response_model=JobResponse) - def get_job(job_id: str) -> JobResponse: - """Get a job by ID.""" - job = queue.get_job(job_id) - if job is None: - raise HTTPException(status_code=404, detail="Job not found") - return JobResponse(**job.to_dict()) - - @self.get("/jobs/{job_id}/errors", response_model=list[JobErrorResponse]) - def get_job_errors(job_id: str) -> list[JobErrorResponse]: - """Get error history for a job.""" - errors = queue.job_errors(job_id) - return [JobErrorResponse(**e) for e in errors] - - @self.get("/jobs/{job_id}/result", response_model=JobResultResponse) - async def get_job_result( - job_id: str, - timeout: float = Query(default=0, ge=0, le=300), - ) -> JobResultResponse: - """Get job result. Set timeout > 0 for blocking wait.""" - job = queue.get_job(job_id) - if job is None: - raise HTTPException(status_code=404, detail="Job not found") - - if timeout > 0 and job.status not in ("complete", "failed", "dead", "cancelled"): - try: - result = await job.aresult(timeout=timeout) - return JobResultResponse( - id=job_id, - status="complete", - result=_safe_serialize(result), - ) - except TimeoutError: - job.refresh() - return JobResultResponse( - id=job_id, - status=job.status, - ) - except RuntimeError as e: - job.refresh() - return JobResultResponse( - id=job_id, - status=job.status, - error=str(e), - ) - - d = job.to_dict() - result = None - if d["status"] == "complete": - try: - result = _safe_serialize(job.result(timeout=1)) - except Exception: - logger.exception("Failed to deserialize result for job %s", job_id) - - return JobResultResponse( - id=job_id, - status=d["status"], - result=result, - error=d.get("error"), - ) - - @self.get("/jobs/{job_id}/progress") - async def stream_progress(job_id: str) -> StreamingResponse: - """SSE stream of progress updates until job reaches terminal state.""" - job = queue.get_job(job_id) - if job is None: - raise HTTPException(status_code=404, detail="Job not found") - - async def event_stream() -> AsyncGenerator[str, None]: - terminal = {"complete", "failed", "dead", "cancelled"} - while True: - refreshed = queue.get_job(job_id) - if refreshed is None: - yield f"data: {json.dumps({'status': 'not_found'})}\n\n" - return - - d = refreshed.to_dict() - payload = json.dumps({"status": d["status"], "progress": d["progress"]}) - yield f"data: {payload}\n\n" - - if d["status"] in terminal: - return - - await asyncio.sleep(0.5) - - return StreamingResponse( - event_stream(), - media_type="text/event-stream", - headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, - ) - - @self.post("/jobs/{job_id}/cancel", response_model=CancelResponse) - async def cancel_job(job_id: str) -> CancelResponse: - """Cancel a pending job.""" - ok = await queue.acancel_job(job_id) - return CancelResponse(cancelled=ok) - - @self.get("/dead-letters", response_model=list[DeadLetterResponse]) - async def list_dead_letters( - limit: int = Query(default=20, ge=1, le=100), - offset: int = Query(default=0, ge=0), - ) -> list[DeadLetterResponse]: - """List dead letter queue entries.""" - dead = await queue.adead_letters(limit=limit, offset=offset) - return [DeadLetterResponse(**d) for d in dead] - - @self.post("/dead-letters/{dead_id}/retry", response_model=RetryResponse) - async def retry_dead_letter(dead_id: str) -> RetryResponse: - """Re-enqueue a dead letter job.""" - new_id = await queue.aretry_dead(dead_id) - return RetryResponse(new_job_id=new_id) - - @self.get("/health", response_model=HealthResponse) - async def health() -> HealthResponse: - """Liveness check.""" - from taskito.health import check_health - - return HealthResponse(**check_health()) - - @self.get("/readiness", response_model=ReadinessResponse) - async def readiness() -> ReadinessResponse: - """Readiness check.""" - from taskito.health import check_readiness - - return ReadinessResponse(**check_readiness(queue)) - - @self.get("/resources") - async def get_resources() -> list[dict[str, Any]]: - """Get resource status for all registered worker resources.""" - return await queue.aresource_status() - - @self.get("/stats/queues") - async def get_queue_stats( - queue_name: str | None = Query(default=None, alias="queue"), - ) -> dict[str, Any]: - """Get per-queue stats. If queue is specified, returns stats for that queue only.""" - if queue_name: - return await queue.astats_by_queue(queue_name) - return await queue.astats_all_queues() + serialize_result = self._result_serializer + result_timeout = self._result_timeout + sse_interval = self._sse_poll_interval + default_page = self._default_page_size + max_page = self._max_page_size + + if self._should_register("stats"): + + @self.get("/stats", response_model=StatsResponse) + async def get_stats() -> StatsResponse: + """Get queue statistics.""" + stats = await queue.astats() + return StatsResponse(**stats) + + if self._should_register("jobs"): + + @self.get("/jobs/{job_id}", response_model=JobResponse) + def get_job(job_id: str) -> JobResponse: + """Get a job by ID.""" + job = queue.get_job(job_id) + if job is None: + raise HTTPException(status_code=404, detail="Job not found") + return JobResponse(**job.to_dict()) + + if self._should_register("job-errors"): + + @self.get("/jobs/{job_id}/errors", response_model=list[JobErrorResponse]) + def get_job_errors(job_id: str) -> list[JobErrorResponse]: + """Get error history for a job.""" + errors = queue.job_errors(job_id) + return [JobErrorResponse(**e) for e in errors] + + if self._should_register("job-result"): + + @self.get("/jobs/{job_id}/result", response_model=JobResultResponse) + async def get_job_result( + job_id: str, + timeout: float = Query(default=0, ge=0, le=300), + ) -> JobResultResponse: + """Get job result. Set timeout > 0 for blocking wait.""" + job = queue.get_job(job_id) + if job is None: + raise HTTPException(status_code=404, detail="Job not found") + + if timeout > 0 and job.status not in ( + "complete", + "failed", + "dead", + "cancelled", + ): + try: + result = await job.aresult(timeout=timeout) + return JobResultResponse( + id=job_id, + status="complete", + result=serialize_result(result), + ) + except TimeoutError: + job.refresh() + return JobResultResponse( + id=job_id, + status=job.status, + ) + except RuntimeError as e: + job.refresh() + return JobResultResponse( + id=job_id, + status=job.status, + error=str(e), + ) + + d = job.to_dict() + result = None + if d["status"] == "complete": + try: + result = serialize_result(job.result(timeout=result_timeout)) + except Exception: + logger.exception("Failed to deserialize result for job %s", job_id) + + return JobResultResponse( + id=job_id, + status=d["status"], + result=result, + error=d.get("error"), + ) + + if self._should_register("job-progress"): + + @self.get("/jobs/{job_id}/progress") + async def stream_progress(job_id: str) -> StreamingResponse: + """SSE stream of progress updates until job reaches terminal state.""" + job = queue.get_job(job_id) + if job is None: + raise HTTPException(status_code=404, detail="Job not found") + + poll_interval = sse_interval + + async def event_stream() -> AsyncGenerator[str, None]: + terminal = {"complete", "failed", "dead", "cancelled"} + while True: + refreshed = queue.get_job(job_id) + if refreshed is None: + yield f"data: {json.dumps({'status': 'not_found'})}\n\n" + return + + d = refreshed.to_dict() + payload = json.dumps({"status": d["status"], "progress": d["progress"]}) + yield f"data: {payload}\n\n" + + if d["status"] in terminal: + return + + await asyncio.sleep(poll_interval) + + return StreamingResponse( + event_stream(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "X-Accel-Buffering": "no", + }, + ) + + if self._should_register("cancel"): + + @self.post("/jobs/{job_id}/cancel", response_model=CancelResponse) + async def cancel_job(job_id: str) -> CancelResponse: + """Cancel a pending job.""" + ok = await queue.acancel_job(job_id) + return CancelResponse(cancelled=ok) + + if self._should_register("dead-letters"): + + @self.get("/dead-letters", response_model=list[DeadLetterResponse]) + async def list_dead_letters( + limit: int = Query(default=default_page, ge=1, le=max_page), + offset: int = Query(default=0, ge=0), + ) -> list[DeadLetterResponse]: + """List dead letter queue entries.""" + dead = await queue.adead_letters(limit=limit, offset=offset) + return [DeadLetterResponse(**d) for d in dead] + + if self._should_register("retry-dead"): + + @self.post("/dead-letters/{dead_id}/retry", response_model=RetryResponse) + async def retry_dead_letter(dead_id: str) -> RetryResponse: + """Re-enqueue a dead letter job.""" + new_id = await queue.aretry_dead(dead_id) + return RetryResponse(new_job_id=new_id) + + if self._should_register("health"): + + @self.get("/health", response_model=HealthResponse) + async def health() -> HealthResponse: + """Liveness check.""" + from taskito.health import check_health + + return HealthResponse(**check_health()) + + if self._should_register("readiness"): + + @self.get("/readiness", response_model=ReadinessResponse) + async def readiness() -> ReadinessResponse: + """Readiness check.""" + from taskito.health import check_readiness + + return ReadinessResponse(**check_readiness(queue)) + + if self._should_register("resources"): + + @self.get("/resources") + async def get_resources() -> list[dict[str, Any]]: + """Get resource status for all registered worker resources.""" + return await queue.aresource_status() + + if self._should_register("queue-stats"): + + @self.get("/stats/queues") + async def get_queue_stats( + queue_name: str | None = Query(default=None, alias="queue"), + ) -> dict[str, Any]: + """Get per-queue stats. Filter by queue name if provided.""" + if queue_name: + return await queue.astats_by_queue(queue_name) + return await queue.astats_all_queues() def _safe_serialize(value: Any) -> Any: diff --git a/py_src/taskito/contrib/flask.py b/py_src/taskito/contrib/flask.py index 333a44d..f7ec535 100644 --- a/py_src/taskito/contrib/flask.py +++ b/py_src/taskito/contrib/flask.py @@ -44,10 +44,15 @@ class Taskito: - ``TASKITO_DEFAULT_PRIORITY`` — Default priority (default: 0) - ``TASKITO_RESULT_TTL`` — Result TTL in seconds (default: None) - ``TASKITO_DRAIN_TIMEOUT`` — Drain timeout in seconds (default: 30) + + Args: + app: Optional Flask application instance. + cli_group: Name for the CLI command group (default ``"taskito"``). """ - def __init__(self, app: flask.Flask | None = None): + def __init__(self, app: flask.Flask | None = None, cli_group: str = "taskito"): self.queue: Any = None + self._cli_group = cli_group if app is not None: self.init_app(app) @@ -76,7 +81,7 @@ def _register_cli(self, app: flask.Flask) -> None: """Register Flask CLI commands.""" import click - @app.cli.group("taskito") + @app.cli.group(self._cli_group) def taskito_cli() -> None: """Taskito task queue commands.""" @@ -88,13 +93,25 @@ def worker_cmd(queues: str | None) -> None: self.queue.run_worker(queues=queue_list) @taskito_cli.command("info") - def info_cmd() -> None: + @click.option( + "--format", + "output_format", + type=click.Choice(["table", "json"]), + default="table", + help="Output format (default: table)", + ) + def info_cmd(output_format: str) -> None: """Show queue statistics.""" + import json + stats = self.queue.stats() - click.echo("taskito queue statistics") - click.echo("-" * 30) - for key in ("pending", "running", "completed", "failed", "dead", "cancelled"): - click.echo(f" {key:<12} {stats.get(key, 0)}") - total = sum(stats.values()) - click.echo("-" * 30) - click.echo(f" {'total':<12} {total}") + if output_format == "json": + click.echo(json.dumps(stats, indent=2)) + else: + click.echo("taskito queue statistics") + click.echo("-" * 30) + for key in ("pending", "running", "completed", "failed", "dead", "cancelled"): + click.echo(f" {key:<12} {stats.get(key, 0)}") + total = sum(stats.values()) + click.echo("-" * 30) + click.echo(f" {'total':<12} {total}") diff --git a/py_src/taskito/contrib/otel.py b/py_src/taskito/contrib/otel.py index 2472368..210cc21 100644 --- a/py_src/taskito/contrib/otel.py +++ b/py_src/taskito/contrib/otel.py @@ -13,7 +13,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Callable from taskito.middleware import TaskMiddleware @@ -34,13 +34,32 @@ class OpenTelemetryMiddleware(TaskMiddleware): """Middleware that creates OpenTelemetry spans for task execution. Each task execution produces a span with: - - Span name: ``taskito.execute.`` + - Span name: ``taskito.execute.`` (customizable via ``span_name_fn``) - Attributes: ``taskito.job_id``, ``taskito.task_name``, - ``taskito.queue``, ``taskito.retry_count`` + ``taskito.queue``, ``taskito.retry_count`` (prefix customizable via + ``attribute_prefix``) - Status: OK on success, ERROR on failure with exception recorded + + Args: + tracer_name: OpenTelemetry tracer name. + span_name_fn: Custom span name builder. Receives a + :class:`~taskito.context.JobContext` and returns a string. + attribute_prefix: Prefix for span attribute keys (default ``"taskito"``). + extra_attributes_fn: Callable that returns extra attributes to add to + each span. Receives a :class:`~taskito.context.JobContext`. + task_filter: Predicate that receives a task name and returns ``True`` + to trace the task. ``None`` traces all tasks. """ - def __init__(self, tracer_name: str = _TRACER_NAME): + def __init__( + self, + tracer_name: str = _TRACER_NAME, + *, + span_name_fn: Callable[[JobContext], str] | None = None, + attribute_prefix: str = "taskito", + extra_attributes_fn: Callable[[JobContext], dict[str, Any]] | None = None, + task_filter: Callable[[str], bool] | None = None, + ): if trace is None: raise ImportError( "opentelemetry-api is required for OpenTelemetryMiddleware. " @@ -49,19 +68,36 @@ def __init__(self, tracer_name: str = _TRACER_NAME): import threading self._tracer = trace.get_tracer(tracer_name) + self._span_name_fn = span_name_fn + self._attr_prefix = attribute_prefix + self._extra_attributes_fn = extra_attributes_fn + self._task_filter = task_filter self._spans: dict[str, Any] = {} self._lock = threading.Lock() + def _should_trace(self, task_name: str) -> bool: + return self._task_filter is None or self._task_filter(task_name) + + def _span_name(self, ctx: JobContext) -> str: + if self._span_name_fn is not None: + return self._span_name_fn(ctx) + return f"{self._attr_prefix}.execute.{ctx.task_name}" + def before(self, ctx: JobContext) -> None: - span = self._tracer.start_span( - f"taskito.execute.{ctx.task_name}", - attributes={ - "taskito.job_id": ctx.id, - "taskito.task_name": ctx.task_name, - "taskito.queue": ctx.queue_name, - "taskito.retry_count": ctx.retry_count, - }, - ) + if not self._should_trace(ctx.task_name): + return + + prefix = self._attr_prefix + attributes: dict[str, Any] = { + f"{prefix}.job_id": ctx.id, + f"{prefix}.task_name": ctx.task_name, + f"{prefix}.queue": ctx.queue_name, + f"{prefix}.retry_count": ctx.retry_count, + } + if self._extra_attributes_fn is not None: + attributes.update(self._extra_attributes_fn(ctx)) + + span = self._tracer.start_span(self._span_name(ctx), attributes=attributes) with self._lock: self._spans[ctx.id] = span @@ -84,10 +120,11 @@ def on_retry(self, ctx: JobContext, error: Exception, retry_count: int) -> None: with self._lock: span = self._spans.get(ctx.id) if span is not None: + prefix = self._attr_prefix span.add_event( "retry", attributes={ - "taskito.retry_count": retry_count, - "taskito.error": str(error), + f"{prefix}.retry_count": retry_count, + f"{prefix}.error": str(error), }, ) diff --git a/py_src/taskito/contrib/prometheus.py b/py_src/taskito/contrib/prometheus.py index cfa7e52..870bc47 100644 --- a/py_src/taskito/contrib/prometheus.py +++ b/py_src/taskito/contrib/prometheus.py @@ -17,7 +17,7 @@ import logging import threading import time -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Callable from taskito.middleware import TaskMiddleware @@ -35,199 +35,223 @@ Histogram = None start_http_server = None -# Module-level metric singletons (created once on first middleware init) -_jobs_total: Any = None -_job_duration: Any = None -_active_workers: Any = None -_retries_total: Any = None -_queue_depth: Any = None -_dlq_size: Any = None -_worker_utilization: Any = None -_resource_health: Any = None -_resource_recreations: Any = None -_resource_init_duration: Any = None -_proxy_reconstruct_duration: Any = None -_proxy_reconstruct_total: Any = None -_proxy_reconstruct_errors: Any = None -_intercept_duration: Any = None -_intercept_strategy_total: Any = None -_pool_size: Any = None -_pool_active: Any = None -_pool_idle: Any = None -_pool_timeouts: Any = None -_metrics_initialized = False - - -_init_lock = threading.Lock() - - -def _init_metrics() -> None: - global _jobs_total, _job_duration, _active_workers, _retries_total - global _queue_depth, _dlq_size, _worker_utilization, _metrics_initialized - global _resource_health, _resource_recreations, _resource_init_duration - global _proxy_reconstruct_duration, _proxy_reconstruct_total - global _proxy_reconstruct_errors, _intercept_duration - global _intercept_strategy_total - global _pool_size, _pool_active, _pool_idle, _pool_timeouts - - if _metrics_initialized: - return - - with _init_lock: - if _metrics_initialized: - return - - _jobs_total = Counter( - "taskito_jobs_total", - "Total number of jobs processed", - ["task", "status"], +# ── Metric categories ───────────────────────────────────── +# Used by `disabled_metrics` to skip groups of metrics. +_METRIC_GROUPS: dict[str, list[str]] = { + "jobs": ["jobs_total", "job_duration_seconds", "active_workers", "retries_total"], + "queue": ["queue_depth", "dlq_size", "worker_utilization"], + "resource": [ + "resource_health_status", + "resource_recreation_total", + "resource_init_duration_seconds", + "resource_pool_size", + "resource_pool_active", + "resource_pool_idle", + "resource_pool_timeout_total", + ], + "proxy": [ + "proxy_reconstruct_duration_seconds", + "proxy_reconstruct_total", + "proxy_reconstruct_errors_total", + ], + "intercept": ["intercept_duration_seconds", "intercept_strategy_total"], +} + + +# ── Per-namespace metric store ──────────────────────────── + +_store_lock = threading.Lock() +_metric_stores: dict[str, dict[str, Any]] = {} + + +def _get_or_create_metrics( + namespace: str, disabled_metrics: set[str] | None = None +) -> dict[str, Any]: + """Return a metric store for the given namespace, creating if needed.""" + if namespace in _metric_stores: + return _metric_stores[namespace] + + with _store_lock: + if namespace in _metric_stores: + return _metric_stores[namespace] + + disabled_names: set[str] = set() + if disabled_metrics: + for group in disabled_metrics: + disabled_names.update(_METRIC_GROUPS.get(group, [group])) + + store: dict[str, Any] = {} + + def _make(cls: Any, suffix: str, description: str, labels: list[str] | None = None) -> Any: + if suffix in disabled_names: + return None + name = f"{namespace}_{suffix}" + if labels: + return cls(name, description, labels) + return cls(name, description) + + store["jobs_total"] = _make( + Counter, "jobs_total", "Total number of jobs processed", ["task", "status"] ) - _job_duration = Histogram( - "taskito_job_duration_seconds", - "Job execution duration in seconds", - ["task"], + store["job_duration"] = _make( + Histogram, "job_duration_seconds", "Job execution duration in seconds", ["task"] ) - _active_workers = Gauge( - "taskito_active_workers", - "Number of currently active workers", + store["active_workers"] = _make( + Gauge, "active_workers", "Number of currently active workers" ) - _retries_total = Counter( - "taskito_retries_total", - "Total number of job retries", - ["task"], + store["retries_total"] = _make( + Counter, "retries_total", "Total number of job retries", ["task"] ) - _queue_depth = Gauge( - "taskito_queue_depth", - "Number of pending jobs per queue", - ["queue"], + store["queue_depth"] = _make( + Gauge, "queue_depth", "Number of pending jobs per queue", ["queue"] ) - _dlq_size = Gauge( - "taskito_dlq_size", - "Number of dead-letter jobs", + store["dlq_size"] = _make(Gauge, "dlq_size", "Number of dead-letter jobs") + store["worker_utilization"] = _make( + Gauge, "worker_utilization", "Worker utilization ratio (0.0-1.0)", ["queue"] ) - _worker_utilization = Gauge( - "taskito_worker_utilization", - "Worker utilization ratio (0.0-1.0)", - ["queue"], - ) - _resource_health = Gauge( - "taskito_resource_health_status", + store["resource_health"] = _make( + Gauge, + "resource_health_status", "Resource health (1=healthy, 0=unhealthy)", ["resource"], ) - _resource_recreations = Gauge( - "taskito_resource_recreation_total", - "Total recreations per resource", - ["resource"], + store["resource_recreations"] = _make( + Gauge, "resource_recreation_total", "Total recreations per resource", ["resource"] ) - _resource_init_duration = Gauge( - "taskito_resource_init_duration_seconds", + store["resource_init_duration"] = _make( + Gauge, + "resource_init_duration_seconds", "Time to initialize each resource", ["resource"], ) - _proxy_reconstruct_duration = Histogram( - "taskito_proxy_reconstruct_duration_seconds", + store["proxy_reconstruct_duration"] = _make( + Histogram, + "proxy_reconstruct_duration_seconds", "Proxy reconstruction duration", ["handler"], ) - _proxy_reconstruct_total = Counter( - "taskito_proxy_reconstruct_total", - "Total proxy reconstructions", - ["handler"], + store["proxy_reconstruct_total"] = _make( + Counter, "proxy_reconstruct_total", "Total proxy reconstructions", ["handler"] ) - _proxy_reconstruct_errors = Counter( - "taskito_proxy_reconstruct_errors_total", + store["proxy_reconstruct_errors"] = _make( + Counter, + "proxy_reconstruct_errors_total", "Total proxy reconstruction errors", ["handler"], ) - _intercept_duration = Histogram( - "taskito_intercept_duration_seconds", - "Argument interception duration", + store["intercept_duration"] = _make( + Histogram, "intercept_duration_seconds", "Argument interception duration" ) - _intercept_strategy_total = Counter( - "taskito_intercept_strategy_total", - "Interception strategy counts", - ["strategy"], + store["intercept_strategy_total"] = _make( + Counter, "intercept_strategy_total", "Interception strategy counts", ["strategy"] ) - _pool_size = Gauge( - "taskito_resource_pool_size", - "Resource pool max size", - ["resource"], + store["pool_size"] = _make( + Gauge, "resource_pool_size", "Resource pool max size", ["resource"] ) - _pool_active = Gauge( - "taskito_resource_pool_active", - "Active pool instances", - ["resource"], + store["pool_active"] = _make( + Gauge, "resource_pool_active", "Active pool instances", ["resource"] ) - _pool_idle = Gauge( - "taskito_resource_pool_idle", - "Idle pool instances", - ["resource"], + store["pool_idle"] = _make( + Gauge, "resource_pool_idle", "Idle pool instances", ["resource"] ) - _pool_timeouts = Counter( - "taskito_resource_pool_timeout_total", - "Pool acquisition timeouts", - ["resource"], + store["pool_timeouts"] = _make( + Counter, "resource_pool_timeout_total", "Pool acquisition timeouts", ["resource"] ) - _metrics_initialized = True + + _metric_stores[namespace] = store + return store class PrometheusMiddleware(TaskMiddleware): """Middleware that exports Prometheus metrics for task execution. - Tracks: - - ``taskito_jobs_total{task,status}`` — counter of completed/failed jobs - - ``taskito_job_duration_seconds{task}`` — histogram of execution times - - ``taskito_active_workers`` — gauge of currently executing workers - - ``taskito_retries_total{task}`` — counter of retry attempts + Args: + namespace: Prefix for all metric names (default ``"taskito"``). + extra_labels_fn: Callable that returns extra labels to add to job + metrics. Receives a :class:`~taskito.context.JobContext` and + returns a dict of label key-value pairs. + disabled_metrics: Metric groups or individual metric names to skip. + Groups: ``"jobs"``, ``"queue"``, ``"resource"``, ``"proxy"``, + ``"intercept"``. """ - def __init__(self) -> None: + def __init__( + self, + *, + namespace: str = "taskito", + extra_labels_fn: Callable[[JobContext], dict[str, str]] | None = None, + disabled_metrics: set[str] | None = None, + ) -> None: if Counter is None: raise ImportError( "prometheus-client is required for PrometheusMiddleware. " "Install it with: pip install taskito[prometheus]" ) - _init_metrics() + self._metrics = _get_or_create_metrics(namespace, disabled_metrics) + self._extra_labels_fn = extra_labels_fn self._start_times: dict[str, float] = {} self._lock = threading.Lock() def before(self, ctx: JobContext) -> None: with self._lock: self._start_times[ctx.id] = time.monotonic() - _active_workers.inc() + m = self._metrics["active_workers"] + if m is not None: + m.inc() def after(self, ctx: JobContext, result: Any, error: Exception | None) -> None: - _active_workers.dec() + m = self._metrics["active_workers"] + if m is not None: + m.dec() + status = "failed" if error is not None else "completed" - _jobs_total.labels(task=ctx.task_name, status=status).inc() + m = self._metrics["jobs_total"] + if m is not None: + m.labels(task=ctx.task_name, status=status).inc() with self._lock: start = self._start_times.pop(ctx.id, None) if start is not None: duration = time.monotonic() - start - _job_duration.labels(task=ctx.task_name).observe(duration) + m = self._metrics["job_duration"] + if m is not None: + m.labels(task=ctx.task_name).observe(duration) def on_retry(self, ctx: JobContext, error: Exception, retry_count: int) -> None: - _retries_total.labels(task=ctx.task_name).inc() + m = self._metrics["retries_total"] + if m is not None: + m.labels(task=ctx.task_name).inc() class PrometheusStatsCollector: """Daemon thread that polls queue stats and updates Prometheus gauges. + Args: + queue: The Queue instance to poll. + interval: Seconds between polls (default 10.0). + namespace: Prefix for metric names (default ``"taskito"``). + disabled_metrics: Metric groups or names to skip. + Usage:: collector = PrometheusStatsCollector(queue, interval=10) collector.start() """ - def __init__(self, queue: Queue, interval: float = 10.0): + def __init__( + self, + queue: Queue, + interval: float = 10.0, + *, + namespace: str = "taskito", + disabled_metrics: set[str] | None = None, + ): if Counter is None: raise ImportError( "prometheus-client is required for PrometheusStatsCollector. " "Install it with: pip install taskito[prometheus]" ) - _init_metrics() + self._metrics = _get_or_create_metrics(namespace, disabled_metrics) self._queue = queue self._interval = interval self._thread: threading.Thread | None = None @@ -244,10 +268,12 @@ def stop(self) -> None: self._thread.join(timeout=5) def _poll(self) -> None: + m = self._metrics while not self._stop_event.is_set(): try: stats = self._queue.stats() - _dlq_size.set(stats.get("dead", 0)) + if m["dlq_size"] is not None: + m["dlq_size"].set(stats.get("dead", 0)) running = stats.get("running", 0) total_workers = self._queue._workers @@ -256,14 +282,20 @@ def _poll(self) -> None: try: all_q = self._queue.stats_all_queues() for q_name, q_stats in all_q.items(): - _queue_depth.labels(queue=q_name).set(q_stats.get("pending", 0)) - if total_workers > 0: + if m["queue_depth"] is not None: + m["queue_depth"].labels(queue=q_name).set(q_stats.get("pending", 0)) + if m["worker_utilization"] is not None and total_workers > 0: q_running = q_stats.get("running", 0) - _worker_utilization.labels(queue=q_name).set(q_running / total_workers) + m["worker_utilization"].labels(queue=q_name).set( + q_running / total_workers + ) except Exception: - _queue_depth.labels(queue="default").set(stats.get("pending", 0)) - if total_workers > 0: - _worker_utilization.labels(queue="default").set(running / total_workers) + if m["queue_depth"] is not None: + m["queue_depth"].labels(queue="default").set(stats.get("pending", 0)) + if m["worker_utilization"] is not None and total_workers > 0: + m["worker_utilization"].labels(queue="default").set( + running / total_workers + ) except Exception: logger.debug("Stats collection failed", exc_info=True) @@ -271,19 +303,26 @@ def _poll(self) -> None: try: for res in self._queue.resource_status(): name = res["name"] - _resource_health.labels(resource=name).set( - 1.0 if res["health"] == "healthy" else 0.0 - ) - _resource_recreations.labels(resource=name).set(res["recreations"]) - _resource_init_duration.labels(resource=name).set( - res["init_duration_ms"] / 1000.0 - ) + if m["resource_health"] is not None: + m["resource_health"].labels(resource=name).set( + 1.0 if res["health"] == "healthy" else 0.0 + ) + if m["resource_recreations"] is not None: + m["resource_recreations"].labels(resource=name).set(res["recreations"]) + if m["resource_init_duration"] is not None: + m["resource_init_duration"].labels(resource=name).set( + res["init_duration_ms"] / 1000.0 + ) pool = res.get("pool") if pool: - _pool_size.labels(resource=name).set(pool["size"]) - _pool_active.labels(resource=name).set(pool["active"]) - _pool_idle.labels(resource=name).set(pool["idle"]) - _pool_timeouts.labels(resource=name).set(pool["total_timeouts"]) + if m["pool_size"] is not None: + m["pool_size"].labels(resource=name).set(pool["size"]) + if m["pool_active"] is not None: + m["pool_active"].labels(resource=name).set(pool["active"]) + if m["pool_idle"] is not None: + m["pool_idle"].labels(resource=name).set(pool["idle"]) + if m["pool_timeouts"] is not None: + m["pool_timeouts"].labels(resource=name).set(pool["total_timeouts"]) except Exception: logger.debug("Resource metrics collection failed", exc_info=True) @@ -291,21 +330,23 @@ def _poll(self) -> None: try: for pstat in self._queue.proxy_stats(): handler = pstat["handler"] - _proxy_reconstruct_total.labels(handler=handler)._value.set( - pstat["total_reconstructions"] - ) - _proxy_reconstruct_errors.labels(handler=handler)._value.set( - pstat["total_errors"] - ) + if m["proxy_reconstruct_total"] is not None: + m["proxy_reconstruct_total"].labels(handler=handler)._value.set( + pstat["total_reconstructions"] + ) + if m["proxy_reconstruct_errors"] is not None: + m["proxy_reconstruct_errors"].labels(handler=handler)._value.set( + pstat["total_errors"] + ) except Exception: logger.debug("Proxy metrics collection failed", exc_info=True) # Interception metrics try: istats = self._queue.interception_stats() - if istats: + if istats and m["intercept_strategy_total"] is not None: for strategy, count in istats.get("strategy_counts", {}).items(): - _intercept_strategy_total.labels(strategy=strategy)._value.set(count) + m["intercept_strategy_total"].labels(strategy=strategy)._value.set(count) except Exception: logger.debug("Interception metrics collection failed", exc_info=True) diff --git a/py_src/taskito/contrib/sentry.py b/py_src/taskito/contrib/sentry.py index ce2363e..50468bc 100644 --- a/py_src/taskito/contrib/sentry.py +++ b/py_src/taskito/contrib/sentry.py @@ -13,7 +13,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Callable from taskito.middleware import TaskMiddleware @@ -30,39 +30,74 @@ class SentryMiddleware(TaskMiddleware): """Middleware that reports task errors to Sentry and sets transaction context. Each task execution gets: - - A Sentry scope with tags: ``task_name``, ``job_id``, ``queue``, ``retry_count`` + - A Sentry scope with tags (prefix customizable via ``tag_prefix``) - Exceptions automatically captured via ``capture_exception`` - Retries recorded as breadcrumbs + + Args: + tag_prefix: Prefix for Sentry tag keys (default ``"taskito"``). + transaction_name_fn: Custom transaction name builder. Receives a + :class:`~taskito.context.JobContext` and returns a string. + task_filter: Predicate that receives a task name and returns ``True`` + to report the task. ``None`` reports all tasks. + extra_tags_fn: Callable that returns extra tags to set on the scope. + Receives a :class:`~taskito.context.JobContext`. """ - def __init__(self) -> None: + def __init__( + self, + *, + tag_prefix: str = "taskito", + transaction_name_fn: Callable[[JobContext], str] | None = None, + task_filter: Callable[[str], bool] | None = None, + extra_tags_fn: Callable[[JobContext], dict[str, str]] | None = None, + ) -> None: if sentry_sdk is None: raise ImportError( "sentry-sdk is required for SentryMiddleware. " "Install it with: pip install taskito[sentry]" ) + self._tag_prefix = tag_prefix + self._transaction_name_fn = transaction_name_fn + self._task_filter = task_filter + self._extra_tags_fn = extra_tags_fn + + def _should_report(self, task_name: str) -> bool: + return self._task_filter is None or self._task_filter(task_name) def before(self, ctx: JobContext) -> None: + if not self._should_report(ctx.task_name): + return + sentry_sdk.push_scope() try: scope = sentry_sdk.get_current_scope() - scope.set_tag("taskito.task_name", ctx.task_name) - scope.set_tag("taskito.job_id", ctx.id) - scope.set_tag("taskito.queue", ctx.queue_name) - scope.set_tag("taskito.retry_count", str(ctx.retry_count)) - scope.set_transaction_name(f"taskito:{ctx.task_name}") + prefix = self._tag_prefix + scope.set_tag(f"{prefix}.task_name", ctx.task_name) + scope.set_tag(f"{prefix}.job_id", ctx.id) + scope.set_tag(f"{prefix}.queue", ctx.queue_name) + scope.set_tag(f"{prefix}.retry_count", str(ctx.retry_count)) + if self._extra_tags_fn is not None: + for key, value in self._extra_tags_fn(ctx).items(): + scope.set_tag(key, value) + if self._transaction_name_fn is not None: + scope.set_transaction_name(self._transaction_name_fn(ctx)) + else: + scope.set_transaction_name(f"{prefix}:{ctx.task_name}") except Exception: sentry_sdk.pop_scope_unsafe() raise def after(self, ctx: JobContext, result: Any, error: Exception | None) -> None: + if not self._should_report(ctx.task_name): + return if error is not None: sentry_sdk.capture_exception(error) sentry_sdk.pop_scope_unsafe() def on_retry(self, ctx: JobContext, error: Exception, retry_count: int) -> None: sentry_sdk.add_breadcrumb( - category="taskito", + category=self._tag_prefix, message=f"Retrying {ctx.task_name} (attempt {retry_count}): {error}", level="warning", ) diff --git a/py_src/taskito/events.py b/py_src/taskito/events.py index ecd22cb..a6cf4d8 100644 --- a/py_src/taskito/events.py +++ b/py_src/taskito/events.py @@ -12,7 +12,7 @@ class EventType(enum.Enum): - """Types of job lifecycle events.""" + """Types of job and worker lifecycle events.""" JOB_ENQUEUED = "job.enqueued" JOB_COMPLETED = "job.completed" @@ -20,6 +20,10 @@ class EventType(enum.Enum): JOB_RETRYING = "job.retrying" JOB_DEAD = "job.dead" JOB_CANCELLED = "job.cancelled" + WORKER_STARTED = "worker.started" + WORKER_STOPPED = "worker.stopped" + QUEUE_PAUSED = "queue.paused" + QUEUE_RESUMED = "queue.resumed" class EventBus: diff --git a/py_src/taskito/middleware.py b/py_src/taskito/middleware.py index f1111fe..1502fb4 100644 --- a/py_src/taskito/middleware.py +++ b/py_src/taskito/middleware.py @@ -33,3 +33,19 @@ def after(self, ctx: JobContext, result: Any, error: Exception | None) -> None: def on_retry(self, ctx: JobContext, error: Exception, retry_count: int) -> None: """Called when a task is about to be retried.""" + + def on_enqueue(self, task_name: str, args: tuple, kwargs: dict, options: dict) -> None: + """Called when a job is about to be enqueued. + + The ``options`` dict contains enqueue parameters (priority, delay, + queue, etc.) and may be mutated to modify the enqueue call. + """ + + def on_dead_letter(self, ctx: JobContext, error: Exception) -> None: + """Called when a job exhausts retries and moves to the dead-letter queue.""" + + def on_timeout(self, ctx: JobContext) -> None: + """Called when a job hits its timeout limit.""" + + def on_cancel(self, ctx: JobContext) -> None: + """Called when a job is cancelled during execution.""" diff --git a/py_src/taskito/mixins.py b/py_src/taskito/mixins.py index 1c9ea9b..5456981 100644 --- a/py_src/taskito/mixins.py +++ b/py_src/taskito/mixins.py @@ -259,10 +259,18 @@ def workers(self) -> list[dict]: def pause(self, queue_name: str = "default") -> None: """Pause a queue so no new jobs are dispatched from it.""" self._inner.pause_queue(queue_name) + if hasattr(self, "_emit_event"): + from taskito.events import EventType + + self._emit_event(EventType.QUEUE_PAUSED, {"queue": queue_name}) def resume(self, queue_name: str = "default") -> None: """Resume a paused queue.""" self._inner.resume_queue(queue_name) + if hasattr(self, "_emit_event"): + from taskito.events import EventType + + self._emit_event(EventType.QUEUE_RESUMED, {"queue": queue_name}) def paused_queues(self) -> list[str]: """List currently paused queues.""" diff --git a/py_src/taskito/webhooks.py b/py_src/taskito/webhooks.py index ef7c1c6..9eb9109 100644 --- a/py_src/taskito/webhooks.py +++ b/py_src/taskito/webhooks.py @@ -37,6 +37,9 @@ def add_webhook( events: list[EventType] | None = None, headers: dict[str, str] | None = None, secret: str | None = None, + max_retries: int = 3, + timeout: float = 10.0, + retry_backoff: float = 2.0, ) -> None: """Register a webhook endpoint. @@ -45,6 +48,9 @@ def add_webhook( events: List of event types to subscribe to. None means all events. headers: Extra HTTP headers to include. secret: HMAC-SHA256 signing secret for the ``X-Taskito-Signature`` header. + max_retries: Maximum delivery attempts (default 3). + timeout: HTTP request timeout in seconds (default 10.0). + retry_backoff: Base for exponential backoff between retries (default 2.0). """ parsed = urllib.parse.urlparse(url) if parsed.scheme not in ("http", "https"): @@ -56,6 +62,9 @@ def add_webhook( "events": {e.value for e in events} if events else None, "headers": headers or {}, "secret": secret.encode() if secret else None, + "max_retries": max_retries, + "timeout": timeout, + "retry_backoff": retry_backoff, } ) self._ensure_thread() @@ -98,10 +107,14 @@ def _send(self, wh: dict[str, Any], payload: dict[str, Any]) -> None: sig = hmac.new(wh["secret"], body, hashlib.sha256).hexdigest() headers["X-Taskito-Signature"] = f"sha256={sig}" - for attempt in range(3): + max_retries: int = wh.get("max_retries", 3) + timeout: float = wh.get("timeout", 10.0) + retry_backoff: float = wh.get("retry_backoff", 2.0) + + for attempt in range(max_retries): try: req = urllib.request.Request(wh["url"], data=body, headers=headers, method="POST") - with urllib.request.urlopen(req, timeout=10) as resp: + with urllib.request.urlopen(req, timeout=timeout) as resp: if resp.status < 400: return if resp.status < 500: @@ -114,7 +127,9 @@ def _send(self, wh: dict[str, Any], payload: dict[str, Any]) -> None: logger.warning("Webhook %s returned server error %d", wh["url"], resp.status) except Exception: logger.debug("Webhook %s attempt %d failed", wh["url"], attempt + 1, exc_info=True) - if attempt == 2: - logger.warning("Webhook delivery failed after 3 attempts: %s", wh["url"]) + if attempt == max_retries - 1: + logger.warning( + "Webhook delivery failed after %d attempts: %s", max_retries, wh["url"] + ) else: - time.sleep(2**attempt) + time.sleep(retry_backoff**attempt) diff --git a/pyproject.toml b/pyproject.toml index 2b88280..e15dcff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "maturin" [project] name = "taskito" -version = "0.5.0" +version = "0.6.0" description = "Rust-powered task queue for Python. No broker required." requires-python = ">=3.10" license = { file = "LICENSE" } diff --git a/tests/python/test_contrib.py b/tests/python/test_contrib.py index 8da5c76..bbc607f 100644 --- a/tests/python/test_contrib.py +++ b/tests/python/test_contrib.py @@ -38,6 +38,10 @@ def test_before_starts_span(self) -> None: import threading mw._tracer = mock_tracer + mw._span_name_fn = None + mw._attr_prefix = "taskito" + mw._extra_attributes_fn = None + mw._task_filter = None mw._spans = {} mw._lock = threading.Lock() @@ -57,6 +61,10 @@ def test_after_ends_span_success(self) -> None: import threading mw._tracer = MagicMock() + mw._span_name_fn = None + mw._attr_prefix = "taskito" + mw._extra_attributes_fn = None + mw._task_filter = None mw._spans = {"job-1": mock_span} mw._lock = threading.Lock() @@ -77,6 +85,10 @@ def test_after_records_exception_on_error(self) -> None: import threading mw._tracer = MagicMock() + mw._span_name_fn = None + mw._attr_prefix = "taskito" + mw._extra_attributes_fn = None + mw._task_filter = None mw._spans = {"job-1": mock_span} mw._lock = threading.Lock() @@ -130,6 +142,10 @@ def test_before_pushes_scope(self) -> None: ctx = _make_ctx() mw = sentry_mod.SentryMiddleware.__new__(sentry_mod.SentryMiddleware) + mw._tag_prefix = "taskito" + mw._transaction_name_fn = None + mw._task_filter = None + mw._extra_tags_fn = None mw.before(ctx) mock_sdk.push_scope.assert_called_once() @@ -143,6 +159,10 @@ def test_after_pops_scope(self) -> None: ctx = _make_ctx() mw = sentry_mod.SentryMiddleware.__new__(sentry_mod.SentryMiddleware) + mw._tag_prefix = "taskito" + mw._transaction_name_fn = None + mw._task_filter = None + mw._extra_tags_fn = None mw.after(ctx, result="ok", error=None) mock_sdk.pop_scope_unsafe.assert_called_once() @@ -157,6 +177,10 @@ def test_after_captures_exception_on_error(self) -> None: exc = RuntimeError("oops") mw = sentry_mod.SentryMiddleware.__new__(sentry_mod.SentryMiddleware) + mw._tag_prefix = "taskito" + mw._transaction_name_fn = None + mw._task_filter = None + mw._extra_tags_fn = None mw.after(ctx, result=None, error=exc) mock_sdk.capture_exception.assert_called_once_with(exc) @@ -181,22 +205,50 @@ def _try_import_sentry(): # type: ignore[no-untyped-def] # ── Prometheus ─────────────────────────────────────────────────────── +def _make_mock_metrics() -> dict: + """Create a mock metrics dict matching the instance-based store format.""" + return { + "jobs_total": MagicMock(), + "job_duration": MagicMock(), + "active_workers": MagicMock(), + "retries_total": MagicMock(), + "queue_depth": MagicMock(), + "dlq_size": MagicMock(), + "worker_utilization": MagicMock(), + "resource_health": MagicMock(), + "resource_recreations": MagicMock(), + "resource_init_duration": MagicMock(), + "proxy_reconstruct_duration": MagicMock(), + "proxy_reconstruct_total": MagicMock(), + "proxy_reconstruct_errors": MagicMock(), + "intercept_duration": MagicMock(), + "intercept_strategy_total": MagicMock(), + "pool_size": MagicMock(), + "pool_active": MagicMock(), + "pool_idle": MagicMock(), + "pool_timeouts": MagicMock(), + } + + class TestPrometheusMiddleware: def test_before_increments_active_workers(self) -> None: prom = _try_import_prometheus() if prom is None: return - mw = prom.PrometheusMiddleware.__new__(prom.PrometheusMiddleware) import threading + metrics = _make_mock_metrics() + mw = prom.PrometheusMiddleware.__new__(prom.PrometheusMiddleware) + mw._metrics = metrics + mw._extra_labels_fn = None mw._start_times = {} mw._lock = threading.Lock() ctx = _make_ctx() mw.before(ctx) - prom._active_workers.inc.assert_called() + metrics["active_workers"].inc.assert_called() def test_after_tracks_counter_and_histogram(self) -> None: prom = _try_import_prometheus() @@ -205,16 +257,19 @@ def test_after_tracks_counter_and_histogram(self) -> None: import threading + metrics = _make_mock_metrics() mw = prom.PrometheusMiddleware.__new__(prom.PrometheusMiddleware) + mw._metrics = metrics + mw._extra_labels_fn = None mw._start_times = {"job-1": 0.0} mw._lock = threading.Lock() ctx = _make_ctx() mw.after(ctx, result="ok", error=None) - prom._active_workers.dec.assert_called() - prom._jobs_total.labels.assert_called_with(task="my_task", status="completed") - prom._jobs_total.labels().inc.assert_called() + metrics["active_workers"].dec.assert_called() + metrics["jobs_total"].labels.assert_called_with(task="my_task", status="completed") + metrics["jobs_total"].labels().inc.assert_called() def test_after_tracks_failure(self) -> None: prom = _try_import_prometheus() @@ -223,7 +278,10 @@ def test_after_tracks_failure(self) -> None: import threading + metrics = _make_mock_metrics() mw = prom.PrometheusMiddleware.__new__(prom.PrometheusMiddleware) + mw._metrics = metrics + mw._extra_labels_fn = None mw._start_times = {"job-1": 0.0} mw._lock = threading.Lock() @@ -231,7 +289,7 @@ def test_after_tracks_failure(self) -> None: exc = ValueError("fail") mw.after(ctx, result=None, error=exc) - prom._jobs_total.labels.assert_called_with(task="my_task", status="failed") + metrics["jobs_total"].labels.assert_called_with(task="my_task", status="failed") def _try_import_prometheus(): # type: ignore[no-untyped-def] @@ -257,14 +315,6 @@ def _try_import_prometheus(): # type: ignore[no-untyped-def] del sys.modules["taskito.contrib.prometheus"] from taskito.contrib import prometheus - # Replace module-level metric singletons with mocks - prometheus._jobs_total = MagicMock() - prometheus._job_duration = MagicMock() - prometheus._active_workers = MagicMock() - prometheus._retries_total = MagicMock() - prometheus._queue_depth = MagicMock() - prometheus._dlq_size = MagicMock() - prometheus._worker_utilization = MagicMock() return prometheus except Exception: return None diff --git a/tests/python/test_customizability.py b/tests/python/test_customizability.py new file mode 100644 index 0000000..057e919 --- /dev/null +++ b/tests/python/test_customizability.py @@ -0,0 +1,298 @@ +"""Tests for customizability configuration options.""" + +from __future__ import annotations + +import time +from typing import Any +from unittest.mock import MagicMock + +from taskito.app import Queue +from taskito.events import EventType +from taskito.middleware import TaskMiddleware +from taskito.webhooks import WebhookManager + +# ── Middleware Hooks ────────────────────────────────────────────────── + + +class RecordingMiddleware(TaskMiddleware): + """Middleware that records all hook calls for testing.""" + + def __init__(self) -> None: + self.calls: list[tuple[str, Any]] = [] + + def before(self, ctx: Any) -> None: + self.calls.append(("before", ctx.task_name)) + + def after(self, ctx: Any, result: Any, error: Any) -> None: + self.calls.append(("after", ctx.task_name)) + + def on_enqueue(self, task_name: str, args: tuple, kwargs: dict, options: dict) -> None: + self.calls.append(("on_enqueue", task_name)) + + def on_dead_letter(self, ctx: Any, error: Exception) -> None: + self.calls.append(("on_dead_letter", ctx.task_name)) + + def on_timeout(self, ctx: Any) -> None: + self.calls.append(("on_timeout", ctx.task_name)) + + def on_cancel(self, ctx: Any) -> None: + self.calls.append(("on_cancel", ctx.task_name)) + + +class TestMiddlewareHooks: + def test_on_enqueue_called(self, tmp_path: Any) -> None: + mw = RecordingMiddleware() + q = Queue(db_path=str(tmp_path / "test.db"), middleware=[mw]) + + @q.task() + def my_task() -> None: + pass + + my_task.delay() + assert ("on_enqueue", my_task.name) in mw.calls + + def test_on_enqueue_can_mutate_options(self, tmp_path: Any) -> None: + """on_enqueue can modify the options dict to change enqueue params.""" + + class PriorityBoostMiddleware(TaskMiddleware): + def on_enqueue(self, task_name: str, args: tuple, kwargs: dict, options: dict) -> None: + options["priority"] = 99 + + mw = PriorityBoostMiddleware() + q = Queue(db_path=str(tmp_path / "test.db"), middleware=[mw]) + + @q.task() + def my_task() -> None: + pass + + result = my_task.delay() + job = q.get_job(result.id) + assert job is not None + assert job.to_dict()["priority"] == 99 + + def test_default_hooks_are_noop(self) -> None: + """Base TaskMiddleware hooks should not raise.""" + mw = TaskMiddleware() + mw.on_enqueue("test", (), {}, {}) + mw.on_dead_letter(MagicMock(), Exception("test")) + mw.on_timeout(MagicMock()) + mw.on_cancel(MagicMock()) + + +# ── Event System ───────────────────────────────────────────────────── + + +class TestEventSystem: + def test_new_event_types_exist(self) -> None: + assert EventType.WORKER_STARTED.value == "worker.started" + assert EventType.WORKER_STOPPED.value == "worker.stopped" + assert EventType.QUEUE_PAUSED.value == "queue.paused" + assert EventType.QUEUE_RESUMED.value == "queue.resumed" + + def test_event_workers_param(self, tmp_path: Any) -> None: + q = Queue(db_path=str(tmp_path / "test.db"), event_workers=2) + assert q._event_bus._executor._max_workers == 2 + + def test_on_event_public_api(self, tmp_path: Any) -> None: + q = Queue(db_path=str(tmp_path / "test.db")) + received: list[Any] = [] + + def callback(event_type: EventType, payload: dict) -> None: + received.append((event_type, payload)) + + q.on_event(EventType.JOB_ENQUEUED, callback) + + @q.task() + def my_task() -> None: + pass + + my_task.delay() + time.sleep(0.2) + assert len(received) == 1 + assert received[0][0] == EventType.JOB_ENQUEUED + + +# ── Webhook Configuration ──────────────────────────────────────────── + + +class TestWebhookConfig: + def test_add_webhook_with_retry_params(self) -> None: + mgr = WebhookManager() + mgr.add_webhook( + "https://example.com/hook", + max_retries=5, + timeout=30.0, + retry_backoff=3.0, + ) + wh = mgr._webhooks[0] + assert wh["max_retries"] == 5 + assert wh["timeout"] == 30.0 + assert wh["retry_backoff"] == 3.0 + + def test_add_webhook_defaults(self) -> None: + mgr = WebhookManager() + mgr.add_webhook("https://example.com/hook") + wh = mgr._webhooks[0] + assert wh["max_retries"] == 3 + assert wh["timeout"] == 10.0 + assert wh["retry_backoff"] == 2.0 + + def test_queue_add_webhook_passes_params(self, tmp_path: Any) -> None: + q = Queue(db_path=str(tmp_path / "test.db")) + q.add_webhook( + "https://example.com/hook", + max_retries=1, + timeout=5.0, + retry_backoff=1.5, + ) + wh = q._webhook_manager._webhooks[0] + assert wh["max_retries"] == 1 + assert wh["timeout"] == 5.0 + + +# ── Queue Configuration ────────────────────────────────────────────── + + +class TestQueueConfig: + def test_scheduler_timing_params(self, tmp_path: Any) -> None: + q = Queue( + db_path=str(tmp_path / "test.db"), + scheduler_poll_interval_ms=100, + scheduler_reap_interval=50, + scheduler_cleanup_interval=600, + ) + # These are passed to the Rust side — just verify they don't error + assert q._inner is not None + + def test_scheduler_timing_defaults(self, tmp_path: Any) -> None: + q = Queue(db_path=str(tmp_path / "test.db")) + # Defaults should work fine + assert q._inner is not None + + +# ── Per-Task Configuration ─────────────────────────────────────────── + + +class TestPerTaskConfig: + def test_max_retry_delay_param(self, tmp_path: Any) -> None: + q = Queue(db_path=str(tmp_path / "test.db")) + + @q.task(max_retry_delay=60) + def my_task() -> None: + pass + + config = q._task_configs[-1] + assert config.max_retry_delay == 60 + + def test_max_retry_delay_default(self, tmp_path: Any) -> None: + q = Queue(db_path=str(tmp_path / "test.db")) + + @q.task() + def my_task() -> None: + pass + + config = q._task_configs[-1] + assert config.max_retry_delay is None + + def test_max_concurrent_param(self, tmp_path: Any) -> None: + q = Queue(db_path=str(tmp_path / "test.db")) + + @q.task(max_concurrent=5) + def my_task() -> None: + pass + + config = q._task_configs[-1] + assert config.max_concurrent == 5 + + def test_max_concurrent_default(self, tmp_path: Any) -> None: + q = Queue(db_path=str(tmp_path / "test.db")) + + @q.task() + def my_task() -> None: + pass + + config = q._task_configs[-1] + assert config.max_concurrent is None + + +# ── Per-Task Serializer ────────────────────────────────────────────── + + +class TestPerTaskSerializer: + def test_task_level_serializer_used_for_enqueue(self, tmp_path: Any) -> None: + """Per-task serializer is used instead of queue-level serializer.""" + mock_serializer = MagicMock() + mock_serializer.dumps.return_value = b"\x80\x04\x95" + + q = Queue(db_path=str(tmp_path / "test.db")) + + @q.task(serializer=mock_serializer) + def my_task(x: int) -> None: + pass + + my_task.delay(42) + mock_serializer.dumps.assert_called_once() + + def test_queue_serializer_used_when_no_task_serializer(self, tmp_path: Any) -> None: + q = Queue(db_path=str(tmp_path / "test.db")) + + @q.task() + def my_task() -> None: + pass + + # Should use the default CloudpickleSerializer without error + my_task.delay() + assert my_task.name not in q._task_serializers + + +# ── Flask CLI ───────────────────────────────────────────────────────── + + +# ── Queue-Level Limits ──────────────────────────────────────────────── + + +class TestQueueLevelLimits: + def test_set_queue_rate_limit(self, tmp_path: Any) -> None: + q = Queue(db_path=str(tmp_path / "test.db")) + q.set_queue_rate_limit("default", "100/m") + assert q._queue_configs["default"]["rate_limit"] == "100/m" + + def test_set_queue_concurrency(self, tmp_path: Any) -> None: + q = Queue(db_path=str(tmp_path / "test.db")) + q.set_queue_concurrency("default", 10) + assert q._queue_configs["default"]["max_concurrent"] == 10 + + def test_set_both_on_same_queue(self, tmp_path: Any) -> None: + q = Queue(db_path=str(tmp_path / "test.db")) + q.set_queue_rate_limit("emails", "50/m") + q.set_queue_concurrency("emails", 5) + assert q._queue_configs["emails"]["rate_limit"] == "50/m" + assert q._queue_configs["emails"]["max_concurrent"] == 5 + + def test_queue_configs_serialized_to_json(self, tmp_path: Any) -> None: + import json + + q = Queue(db_path=str(tmp_path / "test.db")) + q.set_queue_rate_limit("default", "10/s") + q.set_queue_concurrency("default", 3) + serialized = json.dumps(q._queue_configs) + parsed = json.loads(serialized) + assert parsed["default"]["rate_limit"] == "10/s" + assert parsed["default"]["max_concurrent"] == 3 + + +# ── Flask CLI ───────────────────────────────────────────────────────── + + +class TestFlaskConfig: + def test_cli_group_param(self) -> None: + from taskito.contrib.flask import Taskito + + ext = Taskito(cli_group="jobs") + assert ext._cli_group == "jobs" + + def test_cli_group_default(self) -> None: + from taskito.contrib.flask import Taskito + + ext = Taskito() + assert ext._cli_group == "taskito" diff --git a/tests/python/test_events.py b/tests/python/test_events.py index ec4f264..f1ab61e 100644 --- a/tests/python/test_events.py +++ b/tests/python/test_events.py @@ -80,5 +80,9 @@ def test_all_event_types_exist(): "job.retrying", "job.dead", "job.cancelled", + "worker.started", + "worker.stopped", + "queue.paused", + "queue.resumed", } assert {e.value for e in EventType} == expected