Skip to content

Commit d92444c

Browse files
committed
feat: add POST /submit_tasks endpoint + fix HuggingFace dataset compat
- Add POST /submit_tasks: validators send task_ids array, executor fetches from HF and runs - Fix FAIL_TO_PASS/PASS_TO_PASS serde aliases (HF uses uppercase keys) - Add language, difficulty, difficulty_score, quality_score to DatasetEntry - Fix /dataset default split to 'train' (HF dataset only has train split) - Fix difficulty filter to actually filter on the difficulty field - Add Clone derive on SweForgeTask for task registry operations
1 parent b8848c4 commit d92444c

File tree

4 files changed

+243
-10
lines changed

4 files changed

+243
-10
lines changed

src/handlers.rs

Lines changed: 202 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ pub fn router(state: Arc<AppState>) -> Router {
4646
.route("/verify/{batch_id}", get(verify_batch))
4747
.route("/instance", get(instance_info))
4848
.route("/dataset", get(fetch_dataset))
49+
.route("/submit_tasks", post(submit_tasks))
4950
.route("/ws", get(ws::ws_handler))
5051
.with_state(state)
5152
}
@@ -536,7 +537,7 @@ async fn fetch_dataset(
536537
)
537538
})?;
538539

539-
let split = query.split.unwrap_or_else(|| "test".to_string());
540+
let split = query.split.unwrap_or_else(|| "train".to_string());
540541
let limit = query.limit.unwrap_or(10).min(100);
541542
let offset = query.offset.unwrap_or(0);
542543

@@ -559,11 +560,8 @@ async fn fetch_dataset(
559560
dataset
560561
.entries
561562
.iter()
562-
.filter(|_e| {
563-
// swe-forge puts difficulty in separate splits (easy, medium, hard)
564-
// so typically the split itself is the filter
565-
let _ = diff;
566-
true
563+
.filter(|e| {
564+
e.difficulty.as_deref().map(|d| d.eq_ignore_ascii_case(diff)).unwrap_or(false)
567565
})
568566
.collect()
569567
} else {
@@ -583,7 +581,10 @@ async fn fetch_dataset(
583581
"fail_to_pass": e.fail_to_pass,
584582
"pass_to_pass": e.pass_to_pass,
585583
"version": e.version,
586-
"language": e.hints_text.as_deref().unwrap_or("python"),
584+
"language": e.language,
585+
"difficulty": e.difficulty,
586+
"difficulty_score": e.difficulty_score,
587+
"quality_score": e.quality_score,
587588
})).collect::<Vec<_>>(),
588589
})))
589590
}
@@ -595,3 +596,197 @@ struct DatasetQuery {
595596
offset: Option<usize>,
596597
difficulty: Option<String>,
597598
}
599+
600+
/// Request body for /submit_tasks: validators provide task IDs to execute.
601+
/// The executor fetches matching tasks from HuggingFace CortexLM/swe-forge,
602+
/// pairs them with the uploaded agent archive, and runs them.
603+
#[derive(serde::Deserialize)]
604+
struct SubmitTasksRequest {
605+
/// List of instance_ids from the swe-forge dataset to execute
606+
task_ids: Vec<String>,
607+
/// HuggingFace dataset split (default: "train")
608+
#[serde(default = "default_train_split")]
609+
split: String,
610+
}
611+
612+
fn default_train_split() -> String {
613+
"train".to_string()
614+
}
615+
616+
/// Accept a list of task_ids from validators, fetch them from HuggingFace,
617+
/// and execute them with the agent code from the uploaded archive.
618+
async fn submit_tasks(
619+
State(state): State<Arc<AppState>>,
620+
headers: axum::http::HeaderMap,
621+
mut multipart: Multipart,
622+
) -> Result<impl IntoResponse, (StatusCode, Json<serde_json::Value>)> {
623+
// Auth check
624+
let auth_headers = auth::extract_auth_headers(&headers).ok_or_else(|| {
625+
(
626+
StatusCode::UNAUTHORIZED,
627+
Json(serde_json::json!({
628+
"error": "missing_auth",
629+
"message": "Missing required headers: X-Hotkey, X-Nonce, X-Signature"
630+
})),
631+
)
632+
})?;
633+
634+
if state.validator_whitelist.validator_count() > 0 {
635+
if let Err(e) = auth::verify_request(
636+
&auth_headers,
637+
&state.nonce_store,
638+
&state.validator_whitelist,
639+
) {
640+
return Err((
641+
StatusCode::UNAUTHORIZED,
642+
Json(serde_json::json!({
643+
"error": e.code(),
644+
"message": e.message(),
645+
})),
646+
));
647+
}
648+
}
649+
650+
// Parse multipart: expect "task_ids" (JSON) and "archive" (file)
651+
let mut task_ids: Option<Vec<String>> = None;
652+
let mut split = "train".to_string();
653+
let mut archive_data: Option<Vec<u8>> = None;
654+
655+
while let Ok(Some(field)) = multipart.next_field().await {
656+
let name = field.name().unwrap_or("").to_string();
657+
match name.as_str() {
658+
"task_ids" => {
659+
let text = field.text().await.map_err(|e| {
660+
(StatusCode::BAD_REQUEST, Json(serde_json::json!({"error": format!("Failed to read task_ids: {}", e)})))
661+
})?;
662+
task_ids = Some(serde_json::from_str::<Vec<String>>(&text).map_err(|e| {
663+
(StatusCode::BAD_REQUEST, Json(serde_json::json!({"error": format!("Invalid task_ids JSON: {}", e)})))
664+
})?);
665+
}
666+
"split" => {
667+
split = field.text().await.unwrap_or_else(|_| "train".to_string());
668+
}
669+
"archive" | "file" => {
670+
let mut buf = Vec::new();
671+
use futures::TryStreamExt;
672+
let mut stream = field;
673+
while let Some(chunk) = stream.try_next().await.map_err(|e| {
674+
(StatusCode::BAD_REQUEST, Json(serde_json::json!({"error": format!("Upload failed: {}", e)})))
675+
})? {
676+
buf.extend_from_slice(&chunk);
677+
}
678+
archive_data = Some(buf);
679+
}
680+
_ => {}
681+
}
682+
}
683+
684+
let task_ids = task_ids.ok_or_else(|| {
685+
(StatusCode::BAD_REQUEST, Json(serde_json::json!({"error": "Missing task_ids field"})))
686+
})?;
687+
688+
let archive_bytes = archive_data.ok_or_else(|| {
689+
(StatusCode::BAD_REQUEST, Json(serde_json::json!({"error": "Missing archive file with agent code"})))
690+
})?;
691+
692+
if task_ids.is_empty() || task_ids.len() > 50 {
693+
return Err((
694+
StatusCode::BAD_REQUEST,
695+
Json(serde_json::json!({"error": "task_ids must have 1-50 entries"})),
696+
));
697+
}
698+
699+
// Fetch full dataset from HuggingFace to find matching tasks
700+
let hf_client = crate::swe_forge::client::HuggingFaceClient::new().map_err(|e| {
701+
(StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": format!("HF client error: {}", e)})))
702+
})?;
703+
704+
let config = crate::swe_forge::types::DatasetConfig {
705+
dataset_id: "CortexLM/swe-forge".to_string(),
706+
split,
707+
limit: 100, // fetch all (dataset has 66 rows currently)
708+
offset: 0,
709+
};
710+
711+
let dataset = hf_client.fetch_dataset(&config).await.map_err(|e| {
712+
(StatusCode::BAD_GATEWAY, Json(serde_json::json!({"error": format!("Failed to fetch HF dataset: {}", e)})))
713+
})?;
714+
715+
// Match requested task_ids
716+
let matched: Vec<&crate::swe_forge::types::DatasetEntry> = dataset
717+
.entries
718+
.iter()
719+
.filter(|e| task_ids.contains(&e.instance_id))
720+
.collect();
721+
722+
let not_found: Vec<&String> = task_ids
723+
.iter()
724+
.filter(|id| !matched.iter().any(|e| &e.instance_id == *id))
725+
.collect();
726+
727+
if matched.is_empty() {
728+
return Err((
729+
StatusCode::NOT_FOUND,
730+
Json(serde_json::json!({
731+
"error": "No matching tasks found in dataset",
732+
"requested": task_ids,
733+
"available_count": dataset.entries.len(),
734+
})),
735+
));
736+
}
737+
738+
// Convert HF entries to SweForgeTask + build archive with tasks/ dirs
739+
let mut registry = crate::task::registry::TaskRegistry::new();
740+
let hf_dataset = crate::swe_forge::types::HuggingFaceDataset {
741+
dataset_id: dataset.dataset_id.clone(),
742+
split: dataset.split.clone(),
743+
entries: matched.into_iter().cloned().collect(),
744+
total_count: dataset.total_count,
745+
};
746+
registry.load_from_huggingface(&hf_dataset).map_err(|e| {
747+
(StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": format!("Failed to load tasks: {}", e)})))
748+
})?;
749+
750+
// Extract agent code from uploaded archive
751+
let extract_dir = state.config.workspace_base.join("_extract_submit_tasks");
752+
let _ = tokio::fs::remove_dir_all(&extract_dir).await;
753+
let extracted = crate::task::extract_uploaded_archive(&archive_bytes, &extract_dir)
754+
.await
755+
.map_err(|e| {
756+
(StatusCode::BAD_REQUEST, Json(serde_json::json!({"error": format!("Failed to extract agent archive: {}", e)})))
757+
})?;
758+
let _ = tokio::fs::remove_dir_all(&extract_dir).await;
759+
760+
// Replace the tasks from archive with the HF tasks, but keep the agent code
761+
let hf_tasks: Vec<crate::task::SweForgeTask> = registry.get_tasks().to_vec();
762+
let final_archive = crate::task::ExtractedArchive {
763+
tasks: hf_tasks,
764+
agent_code: extracted.agent_code,
765+
agent_language: extracted.agent_language,
766+
};
767+
768+
if state.sessions.has_active_batch() {
769+
return Err((
770+
StatusCode::SERVICE_UNAVAILABLE,
771+
Json(serde_json::json!({"error": "A batch is already running"})),
772+
));
773+
}
774+
775+
let total_tasks = final_archive.tasks.len();
776+
let batch = state.sessions.create_batch(total_tasks);
777+
let batch_id = batch.id.clone();
778+
let concurrent = state.config.max_concurrent_tasks;
779+
780+
state.executor.spawn_batch(batch, final_archive, concurrent);
781+
782+
Ok((
783+
StatusCode::ACCEPTED,
784+
Json(serde_json::json!({
785+
"batch_id": batch_id,
786+
"total_tasks": total_tasks,
787+
"matched_task_ids": task_ids.iter().filter(|id| !not_found.contains(id)).collect::<Vec<_>>(),
788+
"not_found": not_found,
789+
"ws_url": format!("/ws?batch_id={}", batch_id),
790+
})),
791+
))
792+
}

src/swe_forge/types.rs

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,20 @@ pub struct DatasetEntry {
1414
pub created_at: Option<String>,
1515
#[serde(default)]
1616
pub version: Option<String>,
17-
#[serde(default)]
17+
#[serde(default, alias = "FAIL_TO_PASS")]
1818
pub fail_to_pass: Option<String>,
19-
#[serde(default)]
19+
#[serde(default, alias = "PASS_TO_PASS")]
2020
pub pass_to_pass: Option<String>,
2121
#[serde(default)]
2222
pub environment_setup_commit: Option<String>,
23+
#[serde(default)]
24+
pub language: Option<String>,
25+
#[serde(default)]
26+
pub difficulty: Option<String>,
27+
#[serde(default)]
28+
pub difficulty_score: Option<u8>,
29+
#[serde(default)]
30+
pub quality_score: Option<f64>,
2331
}
2432

2533
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -135,4 +143,30 @@ mod tests {
135143
assert_eq!(back.dataset_id, "CortexLM/swe-forge");
136144
assert_eq!(back.total_count, 0);
137145
}
146+
147+
#[test]
148+
fn test_dataset_entry_huggingface_uppercase_keys() {
149+
// HuggingFace API returns FAIL_TO_PASS and PASS_TO_PASS in uppercase
150+
let json = r#"{
151+
"repo": "django/django",
152+
"instance_id": "django__django-12345",
153+
"base_commit": "abc123",
154+
"patch": "diff",
155+
"test_patch": "",
156+
"problem_statement": "Fix bug",
157+
"FAIL_TO_PASS": "[\"pytest tests/test_fix.py\"]",
158+
"PASS_TO_PASS": "[\"pytest tests/test_basic.py\"]",
159+
"language": "python",
160+
"difficulty": "hard",
161+
"difficulty_score": 3,
162+
"quality_score": 0.82
163+
}"#;
164+
let entry: DatasetEntry = serde_json::from_str(json).expect("should deserialize HF format");
165+
assert_eq!(entry.fail_to_pass.as_deref(), Some("[\"pytest tests/test_fix.py\"]"));
166+
assert_eq!(entry.pass_to_pass.as_deref(), Some("[\"pytest tests/test_basic.py\"]"));
167+
assert_eq!(entry.language.as_deref(), Some("python"));
168+
assert_eq!(entry.difficulty.as_deref(), Some("hard"));
169+
assert_eq!(entry.difficulty_score, Some(3));
170+
assert!((entry.quality_score.unwrap() - 0.82).abs() < 0.001);
171+
}
138172
}

src/task/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ pub struct WorkspaceConfig {
4242
pub prompt: Option<String>,
4343
}
4444

45-
#[derive(Debug)]
45+
#[derive(Debug, Clone)]
4646
pub struct SweForgeTask {
4747
pub id: String,
4848
pub workspace: WorkspaceConfig,

src/task/registry.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,10 @@ mod tests {
353353
fail_to_pass: Some(r#"["tests/test_orm.py::test_query"]"#.to_string()),
354354
pass_to_pass: None,
355355
environment_setup_commit: None,
356+
language: Some("python".to_string()),
357+
difficulty: Some("medium".to_string()),
358+
difficulty_score: Some(2),
359+
quality_score: Some(0.75),
356360
}
357361
}
358362
}

0 commit comments

Comments
 (0)