-
Notifications
You must be signed in to change notification settings - Fork 3
ONNX MODULE UPDATE #6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
fa806c1
287a7b8
ae4bf61
494b914
d27ee78
afea0f4
f39c914
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| { | ||
| "githubPullRequests.ignoredPullRequestBranches": [ | ||
| "main" | ||
| ] | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| { | ||
| "per_device_eval_batch_size": 8, | ||
| "sequence_length": 64, | ||
| "model_type": "forecasting", | ||
| "evaluation_metrics": ["mae", "rmse", "mape", "smape"], | ||
| "output_dir": "/tmp/onnx_validation" | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,186 @@ | ||
| import numpy as np | ||
| import pandas as pd | ||
| from unittest.mock import MagicMock, patch | ||
| from pathlib import Path | ||
| import sys | ||
|
|
||
|
|
||
| project_root = Path(__file__).parent | ||
| sys.path.insert(0, str(project_root)) | ||
|
|
||
| from validator.validation_runner import ValidationRunner | ||
| from validator.modules.onnx import ( | ||
| ONNXValidationModule, | ||
| ONNXConfig, | ||
| ONNXInputData, | ||
| ONNXMetrics, | ||
| ) | ||
|
|
||
|
|
||
| def load_and_preprocess_demo_data(): | ||
| """Load demo CSV data and apply feature engineering""" | ||
| from pathlib import Path | ||
|
|
||
| demo_path = ( | ||
| Path(__file__).parent | ||
| / "validator" | ||
| / "modules" | ||
| / "onnx" | ||
| / "demo_data" | ||
| / "test.csv" | ||
| ) | ||
|
|
||
| if not demo_path.exists(): | ||
| print(f"Demo data not found at: {demo_path}") | ||
| return None | ||
|
|
||
| # Read the demo CSV file | ||
| df = pd.read_csv(demo_path) | ||
| print(f"Loaded demo data with shape: {df.shape}") | ||
| print(f"Original columns: {df.columns.tolist()}") | ||
|
|
||
| df["Date"] = pd.to_datetime(df["Date"]) | ||
| df["year"] = df["Date"].dt.year | ||
| df["month"] = df["Date"].dt.month | ||
| df["day"] = df["Date"].dt.day | ||
| df["dayofweek"] = df["Date"].dt.dayofweek | ||
| df["dayofyear"] = df["Date"].dt.dayofyear | ||
|
|
||
| df = df.sort_values(["store", "product", "Date"]) | ||
|
|
||
| for lag in [1, 2, 3, 7]: | ||
| df[f"number_sold_lag_{lag}"] = df.groupby(["store", "product"])[ | ||
| "number_sold" | ||
| ].shift(lag) | ||
|
|
||
| for window in [3, 7, 14]: | ||
| df[f"number_sold_rolling_{window}"] = ( | ||
| df.groupby(["store", "product"])["number_sold"] | ||
| .rolling(window=window) | ||
| .mean() | ||
| .values | ||
| ) | ||
|
|
||
| df = df.dropna() | ||
|
|
||
| # Select only numerical feature columns | ||
| feature_columns = [ | ||
| "store", | ||
| "product", | ||
| "year", | ||
| "month", | ||
| "day", | ||
| "dayofweek", | ||
| "dayofyear", | ||
| "number_sold_lag_1", | ||
| "number_sold_lag_2", | ||
| "number_sold_lag_3", | ||
| "number_sold_lag_7", | ||
| "number_sold_rolling_3", | ||
| "number_sold_rolling_7", | ||
| "number_sold_rolling_14", | ||
| "number_sold", # Keep target column | ||
| ] | ||
|
|
||
| df_final = df[feature_columns] | ||
|
|
||
| print(f"After feature engineering: {df_final.shape}") | ||
| print(f"Final columns: {df_final.columns.tolist()}") | ||
|
|
||
| processed_csv = df_final.to_csv(index=False) | ||
| return processed_csv | ||
|
|
||
|
|
||
| @patch("validator.validation_runner.FedLedger") | ||
| @patch("requests.get") | ||
| def test_onnx_validation_works(mock_requests, mock_fedledger): | ||
| """Test that ONNX validation can complete successfully using real HuggingFace model""" | ||
|
|
||
| test_csv = load_and_preprocess_demo_data() | ||
| if test_csv is None: | ||
| print("Failed to load demo data") | ||
| return False | ||
|
|
||
| # Mock API | ||
| mock_api = MagicMock() | ||
| mock_api.list_tasks.return_value = [ | ||
| {"id": 1, "task_type": "onnx", "title": "Test", "data": {}} | ||
| ] | ||
| mock_api.mark_assignment_as_failed = MagicMock() | ||
| mock_fedledger.return_value = mock_api | ||
|
|
||
| # Mock HTTP requests for CSV data (use real HuggingFace download for model) | ||
| def mock_get_side_effect(url): | ||
| response = MagicMock() | ||
| response.raise_for_status.return_value = None | ||
| response.text = test_csv # CSV contains both features and target | ||
| return response | ||
|
|
||
| mock_requests.side_effect = mock_get_side_effect | ||
|
|
||
| runner = ValidationRunner( | ||
|
Comment on lines
+96
to
+121
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Make the test fully offline and deterministic by mocking model download/inference The test currently performs a real HuggingFace download and actual ONNX inference, which is brittle in CI/offline environments. Mock hf_hub_download to a local test artifact and mock onnxruntime.InferenceSession.run to return plausible outputs. Example (concise): from unittest.mock import MagicMock, patch
@patch("validator.validation_runner.FedLedger")
@patch("requests.get")
@patch("validator.modules.onnx.hf_hub_download", return_value="dummy.onnx")
@patch("validator.modules.onnx.ort.InferenceSession")
def test_onnx_validation_works(mock_sess, mock_hf, mock_requests, mock_fedledger):
# Configure InferenceSession mock to output a vector matching input rows
sess = MagicMock()
sess.get_inputs.return_value = [MagicMock(name="input_0")]
# Suppose test_features has shape (N, F); return (N,1) predictions
def run_side_effect(_, feed):
x = next(iter(feed.values()))
return [x[:, :1]]
sess.run.side_effect = run_side_effect
mock_sess.return_value = sess
...Also applies to: 129-145 |
||
| module="onnx", | ||
| task_ids=[1], | ||
| flock_api_key="test_key", | ||
| hf_token="test_token", | ||
| test_mode=True, | ||
| ) | ||
|
|
||
| input_data = ONNXInputData( | ||
| model_repo_id="Fan9494/test_onnx", | ||
| model_filename="model.onnx", | ||
| revision="main", | ||
| test_data_url="https://example.com/test.csv", | ||
| target_column="number_sold", | ||
| task_type="forecasting", | ||
| task_id=1, | ||
| required_metrics=[ | ||
| "mae", | ||
| "rmse", | ||
| "mape", | ||
| "smape", | ||
| "r2_score", | ||
| "directional_accuracy", | ||
| ], | ||
| ) | ||
|
|
||
| # Perform validation | ||
| print("Running ONNX validation...") | ||
| metrics = runner.perform_validation("assignment_123", 1, input_data) | ||
|
|
||
| print(f"Validation result: {metrics}") | ||
|
|
||
| if metrics is None: | ||
| print("Validation returned None - something went wrong") | ||
| print("Checking mocks:") | ||
| print(f" - HTTP requests called: {mock_requests.call_count}") | ||
| return False | ||
| else: | ||
| print("Validation completed successfully!") | ||
| print(f" - Type: {type(metrics)}") | ||
| if hasattr(metrics, "mae"): | ||
| print(f" - MAE: {metrics.mae}") | ||
| if hasattr(metrics, "rmse"): | ||
| print(f" - RMSE: {metrics.rmse}") | ||
| if hasattr(metrics, "mape"): | ||
| print(f" - MAPE: {metrics.mape}") | ||
| if hasattr(metrics, "smape"): | ||
| print(f" - SMAPE: {metrics.smape}") | ||
|
|
||
| return True | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| print("Testing ONNX Module") | ||
| print("=" * 50) | ||
|
|
||
| # Run tests | ||
| print() | ||
| test_passed = test_onnx_validation_works() | ||
|
|
||
| if test_passed: | ||
| print("\nAll ONNX tests passed!") | ||
| sys.exit(0) | ||
| else: | ||
| print("\nSome tests failed") | ||
| sys.exit(1) | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,15 +1,19 @@ | ||||||
| from abc import ABC, abstractmethod | ||||||
| from pydantic import BaseModel | ||||||
|
|
||||||
|
|
||||||
| class BaseConfig(BaseModel, frozen=True): | ||||||
| pass | ||||||
|
|
||||||
|
|
||||||
| class BaseInputData(BaseModel, frozen=True): | ||||||
| pass | ||||||
|
|
||||||
|
|
||||||
| class BaseMetrics(BaseModel, frozen=True): | ||||||
| pass | ||||||
|
|
||||||
|
|
||||||
| class BaseValidationModule(ABC): | ||||||
| config_schema: type[BaseConfig] | ||||||
| input_data_schema: type[BaseInputData] | ||||||
|
|
@@ -18,17 +22,13 @@ class BaseValidationModule(ABC): | |||||
|
|
||||||
| @abstractmethod | ||||||
| def __init__(self, config: BaseConfig, **kwargs): | ||||||
| """ | ||||||
| """. | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix incomplete docstring The docstring appears to be incomplete with just - """.
+ """📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||
| Perform any global, one-time setup needed for this module. | ||||||
| """ | ||||||
| pass | ||||||
|
|
||||||
| @abstractmethod | ||||||
| def validate( | ||||||
| self, | ||||||
| data: BaseInputData, | ||||||
| **kwargs | ||||||
| ) -> BaseMetrics: | ||||||
| def validate(self, data: BaseInputData, **kwargs) -> BaseMetrics: | ||||||
| """ | ||||||
| Download/prep the repo/revision, run validation, and return metrics parsed into a Pydantic model. | ||||||
| """ | ||||||
|
|
@@ -39,4 +39,4 @@ def cleanup(self): | |||||
| """ | ||||||
| Clean up any resources (e.g., temp files, models in memory). | ||||||
| """ | ||||||
| pass | ||||||
| pass | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Consider adding this file to .gitignore
IDE-specific configuration files like
.vscode/settings.jsonare typically excluded from version control as they contain user-specific preferences. Consider adding.vscode/to.gitignoreto avoid committing personal workspace settings.🏁 Script executed:
Length of output: 462
Action: Add .vscode/ to .gitignore or remove committed workspace settings
Verified: .vscode/settings.json is tracked in the repo and .gitignore does not contain a .vscode entry. Workspace-specific VS Code settings are typically not committed.
Files to address:
Suggested change (add to .gitignore):
If you want to remove the tracked file after updating .gitignore:
🤖 Prompt for AI Agents