Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
Dataset/*
!Dataset/.gitkeep
.venv/*
ControlDataset/*
Models/*
NotInUseModels/*
Result/*
Scalers/*
TestDataset/*
inference_results.csv
FeatureWeights/*
CombinedFeatureWeights.csv
Bladder/Dataset/*
Breast/Dataset/*
Brain/Dataset/*
Liver/Dataset/*
Test_data/*
Test_data/*
198 changes: 198 additions & 0 deletions Baseline/baseline_aggregate_and_infer.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import joblib\n",
"import pandas as pd\n",
"from sklearn.preprocessing import StandardScaler\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def load_models(models_folder=\"Models\"):\n",
" \"\"\"Load all saved models from the specified folder.\"\"\"\n",
" models = {}\n",
" for file_name in os.listdir(models_folder):\n",
" if file_name.endswith(\".joblib\"):\n",
" model_name, cancer_type = file_name.split(\"_\")\n",
" model_path = os.path.join(models_folder, file_name)\n",
" models[(model_name, cancer_type)] = joblib.load(model_path)\n",
" print(f\"Loaded model: {model_name} for cancer type: {cancer_type}\")\n",
" return models"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def preprocess_data(new_data_path, scaler=None):\n",
" \"\"\"Load and preprocess new data.\"\"\"\n",
" # Load the data\n",
" data = pd.read_csv(new_data_path)\n",
" X = data.drop(['cancer_type', 'type'], axis=1, errors='ignore')\n",
" \n",
" # standardize the data\n",
" if scaler is None:\n",
" scaler = StandardScaler()\n",
" X_scaled = scaler.fit_transform(X)\n",
" else:\n",
" X_scaled = scaler.transform(X)\n",
" \n",
" return X_scaled, data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def run_inference(models, test_data_path, scalers_folder=\"Scalers\"):\n",
" # Load the test data\n",
" test_df = pd.read_csv(test_data_path)\n",
" test_features = test_df.drop(['cancer_type', 'type'], axis=1)\n",
"\n",
" # List all scaler files in the Scalers folder\n",
" scaler_files = [f for f in os.listdir(scalers_folder) if f.endswith('.joblib')]\n",
" \n",
" results = []\n",
" \n",
" for index, row in test_df.iterrows():\n",
" row_predictions = {} # Store model predictions and their confidence\n",
" \n",
" for (model_name, cancer_type), model in models.items():\n",
" # Find the scaler corresponding to the cancer type\n",
" cancer_type, _ = cancer_type.split(\".\")\n",
" scaler_filename = f\"{cancer_type}_scaler.joblib\"\n",
" if scaler_filename in scaler_files:\n",
" scaler_path = os.path.join(scalers_folder, scaler_filename)\n",
" scaler = joblib.load(scaler_path)\n",
" test_features_scaled = scaler.transform([test_features.iloc[index]]) # Transform a single row\n",
" \n",
" # Run inference with the model\n",
" probabilities = model.predict_proba(test_features_scaled)\n",
" confidence = probabilities[0][1]\n",
" \n",
" # Store the model's prediction and confidence\n",
" row_predictions[f\"{model_name}-{cancer_type}\"] = {\n",
" 'cancer_type': cancer_type,\n",
" 'predicted_class': model.predict(test_features_scaled)[0],\n",
" 'confidence': confidence,\n",
" 'probabilities': probabilities[0].tolist()\n",
" }\n",
" else:\n",
" print(f\"couldn't find {cancer_type}_scaler.joblib\")\n",
"\n",
" # Determine the final prediction based on the highest confidence\n",
" final_prediction = \"normal\"\n",
" max_confidence = -1\n",
" \n",
" for model_name, prediction_info in row_predictions.items():\n",
" if prediction_info['confidence'] > max_confidence and prediction_info['confidence'] > 0.5:\n",
" max_confidence = prediction_info['confidence']\n",
" final_prediction = prediction_info['cancer_type']\n",
"\n",
" # Append the final prediction for the current data point\n",
" results.append({\n",
" 'index': index,\n",
" 'cancer_type': final_prediction,\n",
" 'confidence': max_confidence if final_prediction is not \"normal\" else -1.0,\n",
" 'predictions': row_predictions\n",
" })\n",
"\n",
" return results\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def save_inference_results(results, output_file=\"inference_results.csv\"):\n",
" # Convert the results to a DataFrame\n",
" df_results = pd.DataFrame(results)\n",
" df_results.to_csv(output_file, index=False)\n",
" print(f\"Inference results saved to {output_file}\")\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Run the updated function\n",
"models_folder = \"Models\"\n",
"new_data_path = \"TestDataset/test_data.csv\"\n",
"\n",
"# Load models\n",
"models = load_models(models_folder=models_folder)\n",
"\n",
"\n",
"# Run inference\n",
"inference_results = run_inference(models, new_data_path)\n",
"\n",
"# Save results\n",
"save_inference_results(inference_results, output_file=\"inference_results.csv\")"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy: 88.24%\n"
]
}
],
"source": [
"test_df = pd.read_csv(\"TestDataset/test_data.csv\")\n",
"\n",
"predictions_df = pd.read_csv(\"inference_results.csv\")\n",
"\n",
"\n",
"accuracy = (predictions_df[\"cancer_type\"] == test_df[\"cancer_type\"]).mean()\n",
"print(f\"Accuracy: {accuracy:.2%}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading