Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions ml-augmentation-toolkit_project/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# alloyxai

> **A modular machine learning pipeline for data augmentation and explainable modeling in superalloy design**
> 面向高温合金设计的数据增强与可解释性建模一体化机器学习框架

---

## 🔬 Project Overview | 项目概述

**`alloyxai`** is a research-oriented Python toolkit that integrates *data generation*, *imbalance handling*, and *model interpretability* into a unified machine learning pipeline, specifically designed for **superalloy composition optimization and microstructure-performance prediction**.

该项目融合了多种数据增强手段(MCMC、WGAN-GP、SMOGN)与可解释性分析(SHAP),适用于**高温合金成分设计、相粗化行为建模及高温性能预测等典型材料科学问题**。

---

## 🧩 Core Modules | 核心模块

| 模块名 | 描述 |
|-------------------|------|
| `MCMCSampler` | 基于贝叶斯推断的元素比例生成器(Dirichlet + TruncatedNormal) |
| `WGANGPRegressor` | 面向回归问题的小样本数据生成器,集成条件判别与梯度惩罚机制 |
| `SMOGNAugmentor` | 用于不平衡目标分布的回归型过采样(适合长尾、高偏态分布) |
| `SHAPAnalyzer` | 提供主效应、交互项、蜂群图与依赖图等多层次模型解释能力 |

---

## 🚀 Example Workflow | 示例工作流

```bash
# 安装依赖
pip install -r requirements.txt

# 运行主流程(默认启用 MCMC + WGAN + SHAP)
python pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@


Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import os
import pandas as pd
import numpy as np
import pymc as pm
import arviz as az
import matplotlib.pyplot as plt
import seaborn as sns


class MCMCSampler:
"""
使用 PyMC 对高温合金元素组成与温度进行 MCMC 采样。

- 元素组成建模为 Dirichlet 分布(强约束:总和为100%)
- 温度建模为 Truncated Normal 分布
"""

def __init__(self,
data_path,
trace_save_path,
sample_save_path,
elements_cols=None,
t_col='T',
draws=4000,
tune=1000,
chains=4,
cores=4,
seed=42,
concentration=100):
"""
初始化采样器

Parameters:
data_path (str): 原始CSV数据路径
trace_save_path (str): 轨迹保存路径
sample_save_path (str): 生成样本保存路径
elements_cols (list): 元素列名(默认10种常见元素)
t_col (str): 温度列名
draws (int): 每条链的采样步数
tune (int): 调优步数
chains (int): 链数
cores (int): 并行核数
seed (int): 随机种子
concentration (float): Dirichlet浓度参数
"""
self.data_path = data_path
self.trace_save_path = trace_save_path
self.sample_save_path = sample_save_path
self.elements_cols = elements_cols or ['Co', 'Al', 'W', 'Ta', 'Ti', 'Nb', 'Ni', 'Cr', 'V', 'Mo']
self.t_col = t_col
self.draws = draws
self.tune = tune
self.chains = chains
self.cores = cores
self.seed = seed
self.concentration = concentration
self.EPSILON = 1e-6

def load_data(self):
"""读取数据并检查列合法性"""
if not os.path.exists(self.data_path):
raise FileNotFoundError(f"找不到数据文件: {self.data_path}")
self.data = pd.read_csv(self.data_path)

for col in self.elements_cols + [self.t_col]:
if col not in self.data.columns:
raise ValueError(f"缺失列: {col},请检查数据文件格式。")

self.elements_data = self.data[self.elements_cols].replace(0, 1e-5)
self.t_data = self.data[self.t_col]

def _compute_dirichlet_alpha(self):
"""根据元素均值计算 Dirichlet 参数 α"""
mean_props = self.elements_data.mean(axis=0) / 100.0
alpha = np.maximum(mean_props * self.concentration, self.EPSILON)
return alpha

def build_model(self):
"""构建 PyMC 模型并进行采样"""
alpha = self._compute_dirichlet_alpha()
t_mu, t_sigma = self.t_data.mean(), self.t_data.std()
t_min, t_max = self.t_data.min(), self.t_data.max()

with pm.Model() as self.model:
proportions = pm.Dirichlet("proportions", a=alpha, shape=(len(self.elements_cols),))
elements_generated = pm.Deterministic("elements_generated", proportions * 100)
t_prior = pm.TruncatedNormal("T_prior", mu=t_mu, sigma=t_sigma,
lower=t_min - 10, upper=t_max + 10)

self.trace = pm.sample(
draws=self.draws,
tune=self.tune,
chains=self.chains,
cores=self.cores,
target_accept=0.95,
random_seed=self.seed,
return_inferencedata=True
)

def check_convergence(self):
"""使用ArviZ进行收敛性诊断"""
summary = az.summary(self.trace, var_names=["proportions", "T_prior"])
if summary["r_hat"].max() > 1.05:
print("⚠️ 警告:存在未收敛参数,建议增加采样步数或调整模型!")
return summary

def save_trace(self):
"""保存 MCMC 轨迹数据为 CSV"""
proportions_trace = self.trace.posterior['proportions'].stack(sample=("chain", "draw")).values.transpose(1, 0)
t_trace = self.trace.posterior['T_prior'].stack(sample=("chain", "draw")).values.flatten()
trace_df = pd.DataFrame(proportions_trace, columns=[f"proportions_{el}" for el in self.elements_cols])
trace_df["T_prior"] = t_trace

os.makedirs(os.path.dirname(self.trace_save_path), exist_ok=True)
trace_df.to_csv(self.trace_save_path, index=False)

def extract_samples(self):
"""提取生成的后验样本"""
posterior = self.trace.posterior
self.samples_df = pd.DataFrame({
col: posterior['elements_generated'][..., i].values.flatten()
for i, col in enumerate(self.elements_cols)
})
self.samples_df['T'] = posterior['T_prior'].values.flatten()

def save_samples(self):
"""保存后验样本"""
os.makedirs(os.path.dirname(self.sample_save_path), exist_ok=True)
self.samples_df.to_csv(self.sample_save_path, index=False)

def plot_distributions(self, save_dir=None):
"""原始与生成数据分布对比图(可选保存)"""
for col in self.elements_cols + ['T']:
plt.figure(figsize=(8, 4))
sns.kdeplot(self.data[col], label="原始数据", fill=True)
sns.kdeplot(self.samples_df[col], label="生成数据", fill=True)
plt.title(f"{col} 分布对比")
plt.xlabel("值")
plt.ylabel("密度")
plt.legend()
plt.tight_layout()
if save_dir:
os.makedirs(save_dir, exist_ok=True)
plt.savefig(os.path.join(save_dir, f"{col}_kde.png"))
plt.show()

def run(self, plot=True, save_plot_dir=None):
"""执行完整 MCMC 流程"""
print("🔄 开始 MCMC 流程...")
self.load_data()
self.build_model()
self.check_convergence()
self.save_trace()
self.extract_samples()
self.save_samples()
if plot:
self.plot_distributions(save_dir=save_plot_dir)
print("✅ MCMC流程完成!")
return self.samples_df, self.trace

Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import os
import pandas as pd
import numpy as np
import shap
import matplotlib.pyplot as plt
from xgboost import XGBRegressor
from sklearn.model_selection import cross_val_predict, KFold
from sklearn.metrics import r2_score


class SHAPAnalyzer:
"""
使用XGBoost + SHAP进行特征重要性分析和交互作用分析。
"""

def __init__(self, target_col, feature_name_mapping=None, random_state=42):
self.target_col = target_col
self.feature_name_mapping = feature_name_mapping or {}
self.random_state = random_state

def fit(self, train_data, test_data, model_params=None):
self.train_data = train_data
self.test_data = test_data

self.X_train = self.train_data.drop(columns=[self.target_col])
self.y_train = self.train_data[self.target_col]
self.X_test = self.test_data.drop(columns=[self.target_col], errors='ignore')

self.features = self.X_train.columns.tolist()
self.feature_display_names = [self.feature_name_mapping.get(col, col) for col in self.features]

self.model_params = model_params or {
'colsample_bytree': 1.0,
'gamma': 2.0,
'learning_rate': 0.1,
'max_depth': 10,
'n_estimators': 50,
'subsample': 0.7,
'eval_metric': 'rmse',
'n_jobs': -1,
'random_state': self.random_state
}

xgb_model = XGBRegressor(**self.model_params)
kf = KFold(n_splits=10, shuffle=True, random_state=self.random_state)
y_pred = cross_val_predict(xgb_model, self.X_train, self.y_train, cv=kf)

self.r2_score_cv = r2_score(self.y_train, y_pred)
self.y_cv_pred = y_pred # 保存交叉验证预测
print(f"Cross-validated R²: {self.r2_score_cv:.4f}")

self.final_model = xgb_model.fit(self.X_train, self.y_train)

self.explainer = shap.TreeExplainer(self.final_model, feature_perturbation='tree_path_dependent')
self.shap_values = self.explainer(self.X_test).values
self.shap_interaction_values = self.explainer.shap_interaction_values(self.X_test)

def save_feature_importance(self, path):
xgb_importance = self.final_model.feature_importances_
shap_importance = np.abs(self.shap_values).mean(axis=0)

importance_df = pd.DataFrame({
'Feature': self.features,
'DisplayName': self.feature_display_names,
'XGBoost_Importance': xgb_importance,
'SHAP_Importance': shap_importance
}).sort_values('SHAP_Importance', ascending=False)

os.makedirs(os.path.dirname(path), exist_ok=True)
importance_df.to_csv(path, index=False, float_format="%.6f")
print(f"✅ 特征重要性保存到: {path}")

def save_shap_values(self, path):
shap_df = pd.DataFrame(self.shap_values, columns=self.features)
os.makedirs(os.path.dirname(path), exist_ok=True)
shap_df.to_csv(path, index=False, float_format="%.6f")
print(f"✅ SHAP值保存到: {path}")

def save_shap_summary_plot(self, path):
plt.figure(figsize=(10, 8))
shap.summary_plot(self.shap_values, self.X_test, feature_names=self.feature_display_names, show=False)
plt.title("SHAP Summary Plot")
plt.tight_layout()
plt.savefig(path, dpi=300, bbox_inches='tight')
plt.close()
print(f"✅ SHAP蜂群图保存到: {path}")

def save_interaction_heatmap(self, path):
plt.figure(figsize=(10, 8))
shap.summary_plot(self.shap_interaction_values, self.X_test, plot_type="compact_dot", show=False)
plt.title("SHAP Interaction Heatmap")
plt.tight_layout()
plt.savefig(path, dpi=300, bbox_inches='tight')
plt.close()
print(f"✅ 交互热力图保存到: {path}")

def save_interaction_strengths(self, path):
strength = np.mean(np.abs(self.shap_interaction_values), axis=0)

interaction_records = []
for i in range(len(self.features)):
for j in range(i+1, len(self.features)):
interaction_records.append({
'Feature_A': self.features[i],
'Feature_B': self.features[j],
'Interaction_Strength': strength[i, j]
})

interaction_df = pd.DataFrame(interaction_records).sort_values('Interaction_Strength', ascending=False)
os.makedirs(os.path.dirname(path), exist_ok=True)
interaction_df.to_csv(path, index=False, float_format="%.6f")
print(f"✅ 全局交互强度保存到: {path}")

def plot_dependence(self, feature, interaction_feature=None, path=None):
shap.dependence_plot(
feature,
self.shap_values,
self.X_test,
interaction_index=interaction_feature,
show=False
)
plt.title(f"{feature} Interaction with {interaction_feature}")
if path:
os.makedirs(os.path.dirname(path), exist_ok=True)
plt.savefig(path, dpi=300, bbox_inches='tight')
plt.close()
print(f"✅ 依赖图保存到: {path}")
else:
plt.show()
Loading