This project is a machine learning training and prediction tool based on KAN (Kolmogorov-Arnold Network) for identifying optimal metabolites from metabolite expression data. It supports both Chinese and English language modes as well as dual-platform (Windows and Linux) execution. By integrating base learners such as XGBoost, Random Forest, SVM, and Gradient Boosting, combined with the KAN network for model fusion, it enables efficient classification tasks.
Main features include:
- Model Training: Hyperparameter tuning with Optuna, training of base learners, and fusion via KAN.
- Data Prediction: Loading pre-trained models for classification on new datasets, outputting probabilities and prediction results.
- Multi-language Support: Chinese (
KANMB_CN.py) and English (KANMB_EN.py) versions provide language-specific logging output.
A conda environment is required:
# Windows
conda env create -f KANMB_Win_environment.yaml
# Linux
conda env create -f KANMB_Linux_environment.yamlOr use manual installation:
conda create -n KANMB python=3.10
pip install tqdm matplotlib PyYAML numpy pandas joblib scikit-learn xgboost torch torchvision pykan optunapython KANMB_CN.py --mode train \
--num_folds 5 \
--n_trials 30 \
--train_file "./Data/TrainandVaild.csv" \
--model_output_dir "./test"python KANMB_CN.py --mode pre \
--pred_file "./Data/TestData.csv" \
--output_file "./test/KAN.csv" \
--model_dir "./test"| Category | Parameter | Type | Required | Default | Description |
|---|---|---|---|---|---|
| General | --mode | str | Yes | train | Running mode: train (model training) / pre (prediction) |
| --num_folds | int | No | 5 | Number of cross-validation folds | |
| --n_trials | int | No | 30 | Number of hyperparameter tuning trials | |
| Training Parameters | --train_file | str | Yes | - | Path to training dataset (CSV) |
| --model_output_dir | str | Yes | - | Directory to store trained models | |
| Prediction Parameters | --pred_file | str | Yes | - | Path to input CSV file for prediction |
| --output_file | str | Yes | - | Path to save prediction results | |
| --model_dir | str | Yes | - | Path to directory containing pre-trained models |
- Format: CSV file with header row
- Required column:
TARGET(classification label, 0/1) - Other columns: Numeric features (no preprocessing required)
- Format: CSV file with header row
- Feature columns: Must match the training dataset columns (excluding
TARGET) - Index column: Recommended to include a unique identifier column (e.g.,
ID)
- Path Rules: Use absolute or relative paths (e.g.,
./Data/), avoid Chinese characters in file paths. - Directory Setup: It is recommended to create
model_output_dirin advance for training mode. - Dependencies: The KAN library should be installed via
pip install kan(latest version) to ensure compatibility with PyTorch. - CUDA Support: For large datasets, it is recommended to configure CUDA according to the official PyTorch documentation. For most cases, CPU computation is sufficient.
- Required Parameters:
- Training mode requires
--train_fileand--model_output_dir - Prediction mode requires
--pred_file,--output_file, and--model_dir
- Training mode requires
- Log Files: Training process outputs
kan_training.log, prediction process outputspredict.log, both useful for troubleshooting.