Repository: alisulmanpro / diabetes-prediction
Dataset: Pima Indians Diabetes Database from the UCI Machine Learning Repository
A compact, production-oriented project that trains an XGBoost classifier to predict diabetes from the Pima dataset and exposes a FastAPI inference endpoint. The repo includes data cleaning, EDA, baseline and XGBoost training, threshold tuning for recall, SHAP explanations, and a minimal FastAPI service for serving predictions.
- Data cleaning for biologically invalid zeros (Glucose, BloodPressure, SkinThickness, Insulin, BMI)
- Baseline model (Logistic Regression) + decision-threshold tuning to prioritize recall
- Stronger model: XGBoost with class imbalance handling (
scale_pos_weight) - Model metrics (on test split): ROC-AUC ≈ 0.817, Recall ≈ 0.80 (using tuned threshold)
- Explainability: SHAP global and per-prediction explanations
- Serving: FastAPI endpoint with Pydantic validation and Swagger docs
diabetes-prediction/
├─ data/
│ ├─ raw/diabetes.csv
│ └─ clean/diabetes_clean.csv
├─ models/
│ ├─ xgb_model.pkl
│ └─ baseline_model.pkl
├─ src/
│ ├─ load_data.py
│ ├─ clean_data.py
│ ├─ eda.py
│ ├─ train_baseline.py
│ ├─ train_xgb.py
│ ├─ explain_model.py
├─ main.py
├─ .gitignore
├─ requirements.txt
└─ README.md
- Clone the repo:
git clone https://github.com/alisulmanpro/diabetes-prediction.git
cd diabetes-prediction- Create and activate a virtual environment:
python -m venv .venv
# mac / linux
source .venv/bin/activate
# windows (PowerShell)
.venv\Scripts\Activate.ps1- Install deps:
pip install --upgrade pip
pip install -r requirements.txt-
Add dataset: place
diabetes.csv(Pima Indians Diabetes CSV) indata/raw/. -
Clean data:
python src/clean_data.py
# produces data/clean/diabetes_clean.csv- Run EDA (optional; opens plots):
python src/eda.py- Train XGBoost model:
python src/train_xgb.py
# saved to models/xgb_model.pkl- Run SHAP explainability:
python src/explain_model.py
# generates SHAP summary (global) and bar plots- Run the API locally:
uvicorn src.api:app --reload
# visit http://127.0.0.1:8000/docs for Swagger UIEndpoint: POST /predict
Input (JSON):
{
"Pregnancies": 2,
"Glucose": 150,
"BloodPressure": 80,
"SkinThickness": 30,
"Insulin": 120,
"BMI": 33.6,
"DiabetesPedigreeFunction": 0.627,
"Age": 50
}Sample curl:
curl -X POST "http://127.0.0.1:8000/predict" \
-H "Content-Type: application/json" \
-d '{"Pregnancies":2,"Glucose":150,"BloodPressure":80,"SkinThickness":30,"Insulin":120,"BMI":33.6,"DiabetesPedigreeFunction":0.627,"Age":50}'Response:
{
"diabetes_probability": 0.7801,
"prediction": 1
}Note: the API uses the XGBoost model and returns probability + binary prediction using the tuned threshold (
0.40by default). Adjust threshold insrc/api.pyif you prefer a different operating point.
- Missing / invalid zeros: Certain features in the Pima dataset use
0to indicate missing. The cleaning script replaces those zeros withNaNand imputes medians forGlucose,BloodPressure,SkinThickness,Insulin, andBMI. - Scaling: XGBoost was trained on raw features (no scaling). If you switch to a model that requires scaling (e.g., logistic regression as the deployed model), export and load the scaler alongside the model.
- Threshold tuning: In medical use-cases we prioritized recall (reduce false negatives). That is why inference uses a lower threshold than 0.5.
- Explainability:
src/explain_model.pyuses SHAPTreeExplainerto produce global and local explanations for the XGBoost model. Include these plots in your README or demo to increase trust.
- The project uses a single stratified train/test split (random seed = 42). For robust evaluation, add k-fold cross-validation and report mean ± std for metrics.
- Current test metrics (from a held-out split): ROC-AUC ≈ 0.817, Recall ≈ 0.80 (XGBoost, tuned threshold).
- Add GridSearchCV / Optuna hyperparameter tuning.
- Use cross-validation and report confidence intervals.
- Containerize with Docker and add CI/CD for model retraining.
- Build a small Streamlit demo for interactive per-patient explanations (SHAP force plots).
- Add input validation rules and monitoring (data drift alerts, model performance tracking).
- Fork the repo
- Create a feature branch (
git checkout -b feat/your-feature) - Commit and push changes (
git push origin feat/your-feature) - Open a pull request with a clear description & tests where relevant
Recommend: MIT License — change as you see fit.