Skip to content

Commit db3ba95

Browse files
committed
feat: integrate HuggingFace dataset handler with task/evaluation system
- Add src/swe_forge/ module with HuggingFace dataset client and types - types.rs: DatasetEntry, HuggingFaceDataset, DatasetConfig for CortexLM/swe-forge dataset schema (repo, instance_id, base_commit, patch, test_patch, problem_statement, hints_text, version, etc.) - client.rs: HuggingFaceClient with async fetch_dataset() and fetch_entry() using HuggingFace Dataset Viewer REST API - Restructure src/task.rs into src/task/ module directory - types.rs: SweForgeTaskFields struct for SWE-forge-specific fields - config.rs: TaskSource enum (Local/HuggingFace) and TaskConfig - registry.rs: TaskRegistry with load_from_huggingface() that converts DatasetEntry items to SweForgeTask format, including test script generation from test_patch and fail_to_pass fields - Add swe_forge_fields: Option<SweForgeTaskFields> to SweForgeTask - Add swe_forge module declaration to main.rs - All existing tests pass, 27 new tests added
1 parent f6d165c commit db3ba95

File tree

9 files changed

+791
-1
lines changed

9 files changed

+791
-1
lines changed

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/main.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ mod executor;
66
mod handlers;
77
mod metrics;
88
mod session;
9+
mod swe_forge;
910
mod task;
1011
mod validator_whitelist;
1112
mod ws;

src/swe_forge/client.rs

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
use anyhow::{Context, Result};
2+
use tracing::{debug, info};
3+
4+
use super::types::{DatasetConfig, DatasetEntry, HfRowsResponse, HuggingFaceDataset};
5+
6+
const HF_DATASET_VIEWER_BASE: &str = "https://datasets-server.huggingface.co/rows";
7+
const DEFAULT_TIMEOUT_SECS: u64 = 30;
8+
const MAX_PAGE_SIZE: usize = 100;
9+
10+
pub struct HuggingFaceClient {
11+
client: reqwest::Client,
12+
}
13+
14+
impl HuggingFaceClient {
15+
pub fn new() -> Result<Self> {
16+
let client = reqwest::Client::builder()
17+
.timeout(std::time::Duration::from_secs(DEFAULT_TIMEOUT_SECS))
18+
.build()
19+
.context("Failed to build HTTP client for HuggingFace")?;
20+
Ok(Self { client })
21+
}
22+
23+
pub async fn fetch_dataset(&self, config: &DatasetConfig) -> Result<HuggingFaceDataset> {
24+
info!(
25+
"Fetching HuggingFace dataset: {} (split={}, offset={}, limit={})",
26+
config.dataset_id, config.split, config.offset, config.limit
27+
);
28+
29+
let mut all_entries = Vec::new();
30+
let mut offset = config.offset;
31+
let mut total_count = 0;
32+
let remaining = config.limit;
33+
34+
while all_entries.len() < remaining {
35+
let page_size = MAX_PAGE_SIZE.min(remaining - all_entries.len());
36+
37+
let response = self
38+
.fetch_page(&config.dataset_id, &config.split, offset, page_size)
39+
.await?;
40+
41+
if let Some(total) = response.num_rows_total {
42+
total_count = total;
43+
}
44+
45+
let row_count = response.rows.len();
46+
if row_count == 0 {
47+
break;
48+
}
49+
50+
for wrapper in response.rows {
51+
all_entries.push(wrapper.row);
52+
}
53+
54+
offset += row_count;
55+
56+
if row_count < page_size {
57+
break;
58+
}
59+
}
60+
61+
info!(
62+
"Fetched {} entries from {} (total available: {})",
63+
all_entries.len(),
64+
config.dataset_id,
65+
total_count
66+
);
67+
68+
Ok(HuggingFaceDataset {
69+
dataset_id: config.dataset_id.clone(),
70+
split: config.split.clone(),
71+
entries: all_entries,
72+
total_count,
73+
})
74+
}
75+
76+
pub async fn fetch_entry(
77+
&self,
78+
dataset_id: &str,
79+
split: &str,
80+
index: usize,
81+
) -> Result<DatasetEntry> {
82+
debug!(
83+
"Fetching single entry from {} at index {}",
84+
dataset_id, index
85+
);
86+
87+
let response = self.fetch_page(dataset_id, split, index, 1).await?;
88+
89+
response
90+
.rows
91+
.into_iter()
92+
.next()
93+
.map(|w| w.row)
94+
.context(format!(
95+
"No entry found at index {} in dataset {}",
96+
index, dataset_id
97+
))
98+
}
99+
100+
async fn fetch_page(
101+
&self,
102+
dataset_id: &str,
103+
split: &str,
104+
offset: usize,
105+
length: usize,
106+
) -> Result<HfRowsResponse> {
107+
let url = format!(
108+
"{}?dataset={}&config=default&split={}&offset={}&length={}",
109+
HF_DATASET_VIEWER_BASE, dataset_id, split, offset, length
110+
);
111+
112+
debug!("Requesting HuggingFace API: {}", url);
113+
114+
let resp = self
115+
.client
116+
.get(&url)
117+
.send()
118+
.await
119+
.context("Failed to send request to HuggingFace dataset viewer")?;
120+
121+
let status = resp.status();
122+
if !status.is_success() {
123+
let body = resp.text().await.unwrap_or_default();
124+
anyhow::bail!(
125+
"HuggingFace API returned HTTP {}: {}",
126+
status.as_u16(),
127+
&body[..body.len().min(500)]
128+
);
129+
}
130+
131+
let response: HfRowsResponse = resp
132+
.json()
133+
.await
134+
.context("Failed to parse HuggingFace API response")?;
135+
136+
Ok(response)
137+
}
138+
}
139+
140+
#[cfg(test)]
141+
mod tests {
142+
use super::*;
143+
144+
#[test]
145+
fn test_client_creation() {
146+
let client = HuggingFaceClient::new();
147+
assert!(client.is_ok());
148+
}
149+
150+
#[test]
151+
fn test_hf_rows_response_deserialize() {
152+
let json = r#"{
153+
"rows": [
154+
{
155+
"row": {
156+
"repo": "psf/requests",
157+
"instance_id": "psf__requests-1234",
158+
"base_commit": "abc123",
159+
"patch": "diff --git a/f.py b/f.py",
160+
"test_patch": "diff --git a/t.py b/t.py",
161+
"problem_statement": "Fix bug"
162+
}
163+
}
164+
],
165+
"num_rows_total": 2294
166+
}"#;
167+
let resp: HfRowsResponse = serde_json::from_str(json).expect("should parse");
168+
assert_eq!(resp.rows.len(), 1);
169+
assert_eq!(resp.num_rows_total, Some(2294));
170+
assert_eq!(resp.rows[0].row.repo, "psf/requests");
171+
}
172+
173+
#[test]
174+
fn test_hf_rows_response_empty() {
175+
let json = r#"{"rows": []}"#;
176+
let resp: HfRowsResponse = serde_json::from_str(json).expect("should parse");
177+
assert!(resp.rows.is_empty());
178+
assert!(resp.num_rows_total.is_none());
179+
}
180+
}

src/swe_forge/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#[allow(dead_code)]
2+
pub mod client;
3+
#[allow(dead_code)]
4+
pub mod types;

src/swe_forge/types.rs

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
use serde::{Deserialize, Serialize};
2+
3+
#[derive(Debug, Clone, Serialize, Deserialize)]
4+
pub struct DatasetEntry {
5+
pub repo: String,
6+
pub instance_id: String,
7+
pub base_commit: String,
8+
pub patch: String,
9+
pub test_patch: String,
10+
pub problem_statement: String,
11+
#[serde(default)]
12+
pub hints_text: Option<String>,
13+
#[serde(default)]
14+
pub created_at: Option<String>,
15+
#[serde(default)]
16+
pub version: Option<String>,
17+
#[serde(default)]
18+
pub fail_to_pass: Option<String>,
19+
#[serde(default)]
20+
pub pass_to_pass: Option<String>,
21+
#[serde(default)]
22+
pub environment_setup_commit: Option<String>,
23+
}
24+
25+
#[derive(Debug, Clone, Serialize, Deserialize)]
26+
pub struct HuggingFaceDataset {
27+
pub dataset_id: String,
28+
pub split: String,
29+
pub entries: Vec<DatasetEntry>,
30+
pub total_count: usize,
31+
}
32+
33+
#[derive(Debug, Clone, Serialize, Deserialize)]
34+
pub struct DatasetConfig {
35+
pub dataset_id: String,
36+
#[serde(default = "default_split")]
37+
pub split: String,
38+
#[serde(default = "default_limit")]
39+
pub limit: usize,
40+
#[serde(default)]
41+
pub offset: usize,
42+
}
43+
44+
fn default_split() -> String {
45+
"test".to_string()
46+
}
47+
48+
fn default_limit() -> usize {
49+
100
50+
}
51+
52+
impl Default for DatasetConfig {
53+
fn default() -> Self {
54+
Self {
55+
dataset_id: "CortexLM/swe-forge".to_string(),
56+
split: default_split(),
57+
limit: default_limit(),
58+
offset: 0,
59+
}
60+
}
61+
}
62+
63+
#[derive(Debug, Deserialize)]
64+
pub(crate) struct HfRowsResponse {
65+
pub rows: Vec<HfRowWrapper>,
66+
#[serde(default)]
67+
pub num_rows_total: Option<usize>,
68+
}
69+
70+
#[derive(Debug, Deserialize)]
71+
pub(crate) struct HfRowWrapper {
72+
pub row: DatasetEntry,
73+
}
74+
75+
#[cfg(test)]
76+
mod tests {
77+
use super::*;
78+
79+
#[test]
80+
fn test_dataset_config_default() {
81+
let config = DatasetConfig::default();
82+
assert_eq!(config.dataset_id, "CortexLM/swe-forge");
83+
assert_eq!(config.split, "test");
84+
assert_eq!(config.limit, 100);
85+
assert_eq!(config.offset, 0);
86+
}
87+
88+
#[test]
89+
fn test_dataset_entry_deserialize() {
90+
let json = r#"{
91+
"repo": "psf/requests",
92+
"instance_id": "psf__requests-1234",
93+
"base_commit": "abc123def456",
94+
"patch": "diff --git a/file.py b/file.py",
95+
"test_patch": "diff --git a/test_file.py b/test_file.py",
96+
"problem_statement": "Fix the bug in requests library"
97+
}"#;
98+
let entry: DatasetEntry = serde_json::from_str(json).expect("should deserialize");
99+
assert_eq!(entry.repo, "psf/requests");
100+
assert_eq!(entry.instance_id, "psf__requests-1234");
101+
assert!(entry.hints_text.is_none());
102+
assert!(entry.version.is_none());
103+
}
104+
105+
#[test]
106+
fn test_dataset_entry_deserialize_with_optional_fields() {
107+
let json = r#"{
108+
"repo": "psf/requests",
109+
"instance_id": "psf__requests-1234",
110+
"base_commit": "abc123def456",
111+
"patch": "diff --git a/file.py b/file.py",
112+
"test_patch": "diff --git a/test_file.py b/test_file.py",
113+
"problem_statement": "Fix the bug",
114+
"hints_text": "Check the encoding",
115+
"version": "2.31.0",
116+
"fail_to_pass": "[\"test_requests.py::test_encoding\"]",
117+
"pass_to_pass": "[\"test_requests.py::test_basic\"]"
118+
}"#;
119+
let entry: DatasetEntry = serde_json::from_str(json).expect("should deserialize");
120+
assert_eq!(entry.hints_text.as_deref(), Some("Check the encoding"));
121+
assert_eq!(entry.version.as_deref(), Some("2.31.0"));
122+
assert!(entry.fail_to_pass.is_some());
123+
}
124+
125+
#[test]
126+
fn test_huggingface_dataset_serialize_roundtrip() {
127+
let dataset = HuggingFaceDataset {
128+
dataset_id: "CortexLM/swe-forge".to_string(),
129+
split: "test".to_string(),
130+
entries: vec![],
131+
total_count: 0,
132+
};
133+
let json = serde_json::to_string(&dataset).expect("should serialize");
134+
let back: HuggingFaceDataset = serde_json::from_str(&json).expect("should deserialize");
135+
assert_eq!(back.dataset_id, "CortexLM/swe-forge");
136+
assert_eq!(back.total_count, 0);
137+
}
138+
}

0 commit comments

Comments
 (0)