diff --git a/explainableai.egg-info/requires.txt b/explainableai.egg-info/requires.txt index 04c33a1..2577288 100644 --- a/explainableai.egg-info/requires.txt +++ b/explainableai.egg-info/requires.txt @@ -12,3 +12,5 @@ google-generativeai python-dotenv scipy pillow + + diff --git a/explainableai/core.py b/explainableai/core.py index f71c95d..cadb65c 100644 --- a/explainableai/core.py +++ b/explainableai/core.py @@ -1,5 +1,6 @@ -# explainableai/core.py -from typing import List +# core.py + +# Import colorama and its components import colorama from colorama import Fore, Style @@ -8,12 +9,17 @@ import pandas as pd import numpy as np -from sklearn.model_selection import train_test_split +from sklearn.model_selection import train_test_split, cross_val_score from sklearn.inspection import permutation_importance from sklearn.preprocessing import StandardScaler, LabelEncoder, OneHotEncoder from sklearn.impute import SimpleImputer from sklearn.compose import ColumnTransformer from sklearn.pipeline import Pipeline + +# Import TensorFlow +import tensorflow as tf +from scikeras.wrappers import KerasClassifier, KerasRegressor + from .visualizations import ( plot_feature_importance, plot_partial_dependence, plot_learning_curve, plot_roc_curve, plot_precision_recall_curve, plot_correlation_heatmap @@ -26,14 +32,14 @@ from .model_selection import compare_models from reportlab.platypus import PageBreak import logging -from sklearn.model_selection import cross_val_score - -logger=logging.getLogger(__name__) +logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) + class XAIWrapper: def __init__(self): self.model = None + self.models = {} self.X = None self.y = None self.feature_names = None @@ -44,19 +50,24 @@ def __init__(self): self.numerical_columns = None self.gemini_model = initialize_gemini() self.feature_importance = None - self.results = None # Add this line to store analysis results + self.results = None + self.model_type = None # To store model type def fit(self, models, X, y, feature_names=None): - logger.debug("Fitting the model...") + logger.debug("Starting the fit process...") try: + # Initialize models if isinstance(models, dict): self.models = models + logger.debug("Initialized models from dictionary input.") else: self.models = {'Model': models} + logger.debug("Initialized single model.") + self.X = X self.y = y self.feature_names = feature_names if feature_names is not None else X.columns.tolist() - self.is_classifier = all(hasattr(model, "predict_proba") for model in self.models.values()) + self._determine_model_type() logger.info(f"{Fore.BLUE}Preprocessing data...{Style.RESET_ALL}") self._preprocess_data() @@ -65,237 +76,401 @@ def fit(self, models, X, y, feature_names=None): self.model_comparison_results = self._compare_models() # Select the best model based on cv_score - best_model_name = max(self.model_comparison_results, key=lambda x: self.model_comparison_results[x]['cv_score']) + best_model_name = max( + self.model_comparison_results, + key=lambda x: self.model_comparison_results[x]['cv_score'] + ) self.model = self.models[best_model_name] - self.model.fit(self.X, self.y) + logger.info(f"Selected best model: {best_model_name} with CV Score: {self.model_comparison_results[best_model_name]['cv_score']:.4f}") + + # Fit the selected model + if self.model_type == 'tensorflow': + logger.info("Fitting TensorFlow model...") + self.model.fit(self.X, self.y, epochs=10, batch_size=32, verbose=0) + else: + logger.info("Fitting scikit-learn model...") + self.model.fit(self.X, self.y) - logger.info("Model fitting is complete...") + logger.info("Model fitting is complete.") return self except Exception as e: - logger.error(f"Some error occur while fitting the models...{str(e)}") - - + logger.error(f"An error occurred while fitting the models: {str(e)}") + raise + + def _determine_model_type(self): + logger.debug("Determining model type...") + try: + model_types = set() + for model in self.models.values(): + if isinstance(model, (tf.keras.Model, KerasClassifier, KerasRegressor)): + model_types.add('tensorflow') + else: + model_types.add('sklearn') + if len(model_types) > 1: + raise ValueError("All models should be of the same type (either all TensorFlow or all scikit-learn).") + self.model_type = model_types.pop() + logger.debug(f"Detected model type: {self.model_type}") + + # Determine if models are classifiers + if self.model_type == 'tensorflow': + # Assume TensorFlow models output probabilities for classifiers + self.is_classifier = all( + model.output_shape[-1] > 1 for model in self.models.values() + ) + else: + self.is_classifier = all(hasattr(model, "predict_proba") for model in self.models.values()) + logger.debug(f"Is classifier: {self.is_classifier}") + except Exception as e: + logger.error(f"Error determining model type: {str(e)}") + raise + def _compare_models(self): - logger.debug("Comparing the models...") + logger.debug("Comparing models...") try: results = {} for name, model in self.models.items(): - cv_scores = cross_val_score(model, self.X, self.y, cv=5, scoring='roc_auc' if self.is_classifier else 'r2') - model.fit(self.X, self.y) - test_score = model.score(self.X, self.y) + logger.debug(f"Evaluating model: {name}") + if self.model_type == 'tensorflow': + # Wrap TensorFlow models for scikit-learn compatibility + if self.is_classifier: + wrapped_model = KerasClassifier(build_fn=lambda: model, epochs=10, batch_size=32, verbose=0) + else: + wrapped_model = KerasRegressor(build_fn=lambda: model, epochs=10, batch_size=32, verbose=0) + + cv_scores = cross_validate( + wrapped_model, + self.X, + self.y, + is_classifier=self.is_classifier, + model_type=self.model_type + ) + test_score = wrapped_model.score(self.X, self.y) + else: + # Determine scoring metric + scoring = 'roc_auc' if self.is_classifier else 'r2' + cv_scores = cross_val_score(model, self.X, self.y, cv=5, scoring=scoring) + model.fit(self.X, self.y) + test_score = model.score(self.X, self.y) + results[name] = { - 'cv_score': cv_scores.mean(), + 'cv_score': np.mean(cv_scores), 'test_score': test_score } - logger.info("Comparing successfully...") + logger.debug(f"Model {name}: CV Score = {results[name]['cv_score']:.4f}, Test Score = {results[name]['test_score']:.4f}") + logger.info("Model comparison completed successfully.") return results except Exception as e: - logger.error(f"Some error occur while comparing models...{str(e)}") + logger.error(f"An error occurred while comparing models: {str(e)}") + raise def _preprocess_data(self): - # Identify categorical and numerical columns - self.categorical_columns = self.X.select_dtypes(include=['object', 'category']).columns - self.numerical_columns = self.X.select_dtypes(include=['int64', 'float64']).columns - - # Create preprocessing steps - logger.debug("Creating Preprocessing Steps...") - numeric_transformer = Pipeline(steps=[ - ('imputer', SimpleImputer(strategy='mean')), - ('scaler', StandardScaler()) - ]) - - categorical_transformer = Pipeline(steps=[ - ('imputer', SimpleImputer(strategy='constant', fill_value='missing')), - ('onehot', OneHotEncoder(handle_unknown='ignore', sparse_output=False)) - ]) - - self.preprocessor = ColumnTransformer( - transformers=[ - ('num', numeric_transformer, self.numerical_columns), - ('cat', categorical_transformer, self.categorical_columns) + logger.debug("Preprocessing data...") + try: + # Identify categorical and numerical columns + self.categorical_columns = self.X.select_dtypes(include=['object', 'category']).columns + self.numerical_columns = self.X.select_dtypes(include=['int64', 'float64']).columns + logger.debug(f"Categorical columns: {list(self.categorical_columns)}") + logger.debug(f"Numerical columns: {list(self.numerical_columns)}") + + # Create preprocessing pipelines + logger.debug("Creating preprocessing pipelines...") + numeric_transformer = Pipeline(steps=[ + ('imputer', SimpleImputer(strategy='mean')), + ('scaler', StandardScaler()) ]) - logger.info("Pre proccessing completed...") - # Fit and transform the data - logger.debug("Fitting and transforming the data...") - self.X = self.preprocessor.fit_transform(self.X) + categorical_transformer = Pipeline(steps=[ + ('imputer', SimpleImputer(strategy='constant', fill_value='missing')), + ('onehot', OneHotEncoder(handle_unknown='ignore', sparse_output=False)) + ]) - # Update feature names after preprocessing - logger.debug("Updating feature names...") - try: - num_feature_names = self.numerical_columns.tolist() - cat_feature_names = [] - if self.categorical_columns.size > 0: - cat_feature_names = self.preprocessor.named_transformers_['cat'].named_steps['onehot'].get_feature_names_out(self.categorical_columns).tolist() - self.feature_names = num_feature_names + cat_feature_names - - # Encode target variable if it's categorical - if self.is_classifier and pd.api.types.is_categorical_dtype(self.y): - self.label_encoder = LabelEncoder() - self.y = self.label_encoder.fit_transform(self.y) + self.preprocessor = ColumnTransformer( + transformers=[ + ('num', numeric_transformer, self.numerical_columns), + ('cat', categorical_transformer, self.categorical_columns) + ] + ) + logger.info("Preprocessing pipelines created.") + + # Fit and transform the data + logger.debug("Fitting and transforming the data...") + self.X = self.preprocessor.fit_transform(self.X) + logger.info("Data preprocessing completed.") + + # Update feature names after preprocessing + logger.debug("Updating feature names post-preprocessing...") + try: + num_feature_names = self.numerical_columns.tolist() + cat_feature_names = [] + if len(self.categorical_columns) > 0: + cat_feature_names = self.preprocessor.named_transformers_['cat'].named_steps['onehot'].get_feature_names_out(self.categorical_columns).tolist() + self.feature_names = num_feature_names + cat_feature_names + logger.debug(f"Updated feature names: {self.feature_names}") + + # Encode target variable if it's categorical + if self.is_classifier and pd.api.types.is_categorical_dtype(self.y): + self.label_encoder = LabelEncoder() + self.y = self.label_encoder.fit_transform(self.y) + logger.debug("Encoded target variable using LabelEncoder.") + except Exception as e: + logger.error(f"Error updating feature names: {str(e)}") + raise except Exception as e: - logger.error(f"Some error occur while updating...{str(e)}") + logger.error(f"Error during data preprocessing: {str(e)}") + raise def analyze(self): - logger.debug("Analysing...") + logger.debug("Starting analysis...") results = {} + try: + # Evaluate model performance + logger.info("Evaluating model performance...") + results['model_performance'] = evaluate_model( + self.model, self.X, self.y, self.is_classifier, self.model_type + ) + + # Calculate feature importance + logger.info("Calculating feature importance...") + self.feature_importance = self._calculate_feature_importance() + results['feature_importance'] = self.feature_importance + + # Generate visualizations + logger.info("Generating visualizations...") + self._generate_visualizations(self.feature_importance) + + # Calculate SHAP values + logger.info("Calculating SHAP values...") + results['shap_values'] = calculate_shap_values( + self.model, self.X, self.feature_names, self.model_type + ) + + # Perform cross-validation + logger.info("Performing cross-validation...") + mean_score, std_score = cross_validate( + self.model, self.X, self.y, + is_classifier=self.is_classifier, + model_type=self.model_type + ) + results['cv_scores'] = (mean_score, std_score) + + # Add model comparison results + logger.info("Adding model comparison results...") + results['model_comparison'] = self.model_comparison_results + + # Print results + self._print_results(results) + + # Generate LLM explanation + logger.info("Generating LLM explanation...") + results['llm_explanation'] = get_llm_explanation(self.gemini_model, results) + + self.results = results + logger.debug("Analysis completed successfully.") + return results + except Exception as e: + logger.error(f"An error occurred during analysis: {str(e)}") + raise - logger.info("Evaluating model performance...") - results['model_performance'] = evaluate_model(self.model, self.X, self.y, self.is_classifier) - - logger.info("Calculating feature importance...") - self.feature_importance = self._calculate_feature_importance() - results['feature_importance'] = self.feature_importance - - logger.info("Generating visualizations...") - self._generate_visualizations(self.feature_importance) - - logger.info("Calculating SHAP values...") - results['shap_values'] = calculate_shap_values(self.model, self.X, self.feature_names) - - logger.info("Performing cross-validation...") - mean_score, std_score = cross_validate(self.model, self.X, self.y) - results['cv_scores'] = (mean_score, std_score) - - logger.info("Model comparison results:") - results['model_comparison'] = self.model_comparison_results - - self._print_results(results) - - logger.info("Generating LLM explanation...") - results['llm_explanation'] = get_llm_explanation(self.gemini_model, results) - - self.results = results - return results - def generate_report(self, filename='xai_report.pdf'): + logger.debug("Generating report...") if self.results is None: raise ValueError("No analysis results available. Please run analyze() first.") - report = ReportGenerator(filename) - report.add_heading("Explainable AI Report") - - sections = { - 'model_comparison': self._generate_model_comparison, - 'model_performance': self._generate_model_performance, - 'feature_importance': self._generate_feature_importance, - 'visualization': self._generate_visualization, - 'llm_explanation': self._generate_llm_explanation - } - - if input("Do you want all sections in the xai_report? (y/n) ").lower() in ['y', 'yes']: - for section_func in sections.values(): - section_func(report) - else: - for section, section_func in sections.items(): - if input(f"Do you want {section} in xai_report? (y/n) ").lower() in ['y', 'yes']: + try: + report = ReportGenerator(filename) + report.add_heading("Explainable AI Report") + + sections = { + 'model_comparison': self._generate_model_comparison, + 'model_performance': self._generate_model_performance, + 'feature_importance': self._generate_feature_importance, + 'visualization': self._generate_visualization, + 'llm_explanation': self._generate_llm_explanation + } + + if input("Do you want all sections in the XAI report? (y/n) ").strip().lower() in ['y', 'yes']: + for section_func in sections.values(): section_func(report) + else: + for section, section_func in sections.items(): + if input(f"Do you want {section} in the XAI report? (y/n) ").strip().lower() in ['y', 'yes']: + section_func(report) - report.generate() + report.generate() + logger.info(f"Report generated successfully and saved as '{filename}'.") + except Exception as e: + logger.error(f"An error occurred while generating the report: {str(e)}") + raise def _generate_model_comparison(self, report): + logger.debug("Adding model comparison section to report...") report.add_heading("Model Comparison", level=2) model_comparison_data = [["Model", "CV Score", "Test Score"]] + [ [model, f"{scores['cv_score']:.4f}", f"{scores['test_score']:.4f}"] for model, scores in self.results['model_comparison'].items() ] report.add_table(model_comparison_data) + logger.debug("Model comparison section added.") def _generate_model_performance(self, report): + logger.debug("Adding model performance section to report...") report.add_heading("Model Performance", level=2) for metric, value in self.results['model_performance'].items(): - report.add_paragraph(f"**{metric}:** {value:.4f}" if isinstance(value, (int, float, np.float64)) else f"**{metric}:**\n{value}") + if isinstance(value, (int, float, np.float64)): + report.add_paragraph(f"**{metric}:** {value:.4f}") + else: + report.add_paragraph(f"**{metric}:**\n{value}") + logger.debug("Model performance section added.") def _generate_feature_importance(self, report): + logger.debug("Adding feature importance section to report...") report.add_heading("Feature Importance", level=2) feature_importance_data = [["Feature", "Importance"]] + [ [feature, f"{importance:.4f}"] for feature, importance in self.feature_importance.items() ] report.add_table(feature_importance_data) + logger.debug("Feature importance section added.") def _generate_visualization(self, report): + logger.debug("Adding visualizations section to report...") report.add_heading("Visualizations", level=2) - for image in ['feature_importance.png', 'partial_dependence.png', 'learning_curve.png', 'correlation_heatmap.png']: + visualization_files = [ + 'feature_importance.png', 'partial_dependence.png', + 'learning_curve.png', 'correlation_heatmap.png' + ] + if self.is_classifier: + visualization_files += ['roc_curve.png', 'precision_recall_curve.png'] + + for image in visualization_files: report.add_image(image) report.content.append(PageBreak()) - if self.is_classifier: - for image in ['roc_curve.png', 'precision_recall_curve.png']: - report.add_image(image) - report.content.append(PageBreak()) + logger.debug("Visualizations section added.") def _generate_llm_explanation(self, report): + logger.debug("Adding LLM explanation section to report...") report.add_heading("LLM Explanation", level=2) report.add_llm_explanation(self.results['llm_explanation']) - - + logger.debug("LLM explanation section added.") + def predict(self, X): - logger.debug("Prediction...") + logger.debug("Starting prediction...") try: if self.model is None: raise ValueError("Model has not been fitted. Please run fit() first.") - X = self._preprocess_input(X) + X_preprocessed = self._preprocess_input(X) if self.is_classifier: - prediction = self.model.predict(X) - probabilities = self.model.predict_proba(X) + prediction = self.model.predict(X_preprocessed) + probabilities = self.model.predict_proba(X_preprocessed) if self.label_encoder: prediction = self.label_encoder.inverse_transform(prediction) - logger.info("Prediction Completed...") + logger.info("Prediction completed successfully.") return prediction, probabilities else: - prediction = self.model.predict(X) - logger.info("Prediction Completed...") + prediction = self.model.predict(X_preprocessed) + logger.info("Prediction completed successfully.") return prediction except Exception as e: - logger.error(f"Error in prediction...{str(e)}") + logger.error(f"Error during prediction: {str(e)}") + raise def _preprocess_input(self, X): - # Ensure X is a DataFrame - logger.debug("Preproceesing input...") + logger.debug("Preprocessing input data for prediction...") try: if not isinstance(X, pd.DataFrame): X = pd.DataFrame(X, columns=self.feature_names) - + logger.debug("Converted input to DataFrame.") + # Apply the same preprocessing as during training - X = self.preprocessor.transform(X) - logger.info("Preprocessing the data...") - - return X + X_preprocessed = self.preprocessor.transform(X) + logger.debug("Input data preprocessed successfully.") + return X_preprocessed except Exception as e: - logger.error(f"Some error occur in preprocessing the inpur...{str(e)}") + logger.error(f"Error during input preprocessing: {str(e)}") + raise def explain_prediction(self, input_data): - logger.debug("Explaining the prediction...") - input_df = pd.DataFrame([input_data]) - prediction, probabilities = self.predict(input_df) - explanation = get_prediction_explanation(self.gemini_model, input_data, prediction[0], probabilities[0], self.feature_importance) - logger.info("Prediction explained...") - return prediction[0], probabilities[0], explanation - + logger.debug("Generating prediction explanation...") + try: + input_df = pd.DataFrame([input_data]) + prediction, probabilities = self.predict(input_df) + explanation = get_prediction_explanation( + self.gemini_model, + input_data, + prediction[0], + probabilities[0], + self.feature_importance + ) + logger.info("Prediction explanation generated successfully.") + return prediction[0], probabilities[0], explanation + except Exception as e: + logger.error(f"Error during prediction explanation: {str(e)}") + raise + def _calculate_feature_importance(self): - logger.debug("Calculating the features...") - perm_importance = permutation_importance(self.model, self.X, self.y, n_repeats=10, random_state=42) - feature_importance = {feature: importance for feature, importance in zip(self.feature_names, perm_importance.importances_mean)} - logger.info("Features calculated...") - return dict(sorted(feature_importance.items(), key=lambda item: abs(item[1]), reverse=True)) + logger.debug("Calculating feature importance...") + try: + if self.model_type == 'tensorflow': + logger.debug("Calculating SHAP values for TensorFlow model...") + shap_values = calculate_shap_values( + self.model, self.X, self.feature_names, self.model_type + ) + feature_importance = np.mean(np.abs(shap_values.values), axis=0) + feature_importance_dict = { + feature: importance + for feature, importance in zip(self.feature_names, feature_importance) + } + logger.debug("SHAP-based feature importance calculated.") + else: + logger.debug("Calculating permutation importance for scikit-learn model...") + perm_importance = permutation_importance( + self.model, self.X, self.y, n_repeats=10, random_state=42 + ) + feature_importance_dict = { + feature: importance + for feature, importance in zip(self.feature_names, perm_importance.importances_mean) + } + logger.debug("Permutation-based feature importance calculated.") + + # Sort features by absolute importance in descending order + sorted_importance = dict( + sorted(feature_importance_dict.items(), key=lambda item: abs(item[1]), reverse=True) + ) + self.feature_importance = sorted_importance + logger.info("Feature importance calculated and sorted.") + return sorted_importance + except Exception as e: + logger.error(f"Error calculating feature importance: {str(e)}") + raise def _generate_visualizations(self, feature_importance): - logger.debug("Generating visulatization...") + logger.debug("Generating visualizations...") try: plot_feature_importance(feature_importance) - plot_partial_dependence(self.model, self.X, feature_importance, self.feature_names) - plot_learning_curve(self.model, self.X, self.y) - plot_correlation_heatmap(pd.DataFrame(self.X, columns=self.feature_names)) + plot_partial_dependence( + self.model, self.X, feature_importance, self.feature_names, self.model_type + ) + plot_learning_curve( + self.model, self.X, self.y, self.is_classifier, self.model_type + ) + plot_correlation_heatmap( + pd.DataFrame(self.X, columns=self.feature_names) + ) if self.is_classifier: - plot_roc_curve(self.model, self.X, self.y) - plot_precision_recall_curve(self.model, self.X, self.y) - logger.info("Visualizations generated.") + plot_roc_curve( + self.model, self.X, self.y, self.model_type + ) + plot_precision_recall_curve( + self.model, self.X, self.y, self.model_type + ) + logger.info("Visualizations generated and saved successfully.") except Exception as e: - logger.error(f"Error in visulatization...{str(e)}") + logger.error(f"Error generating visualizations: {str(e)}") + raise def _print_results(self, results): - logger.debug("Printing results...") + logger.debug("Printing analysis results...") try: logger.info("\nModel Performance:") for metric, value in results['model_performance'].items(): @@ -324,8 +499,8 @@ def _print_results(self, results): else: logger.info("\nSHAP values calculation failed. Please check the console output for more details.") except Exception as e: - logger.error(f"Error occur in printing results...{str(e)}") - + logger.error(f"Error printing results: {str(e)}") + raise @staticmethod def perform_eda(df): @@ -352,7 +527,11 @@ def perform_eda(df): # Identify highly correlated features high_corr = np.where(np.abs(corr_matrix) > 0.8) - high_corr_list = [(corr_matrix.index[x], corr_matrix.columns[y]) for x, y in zip(*high_corr) if x != y and x < y] + high_corr_list = [ + (corr_matrix.index[x], corr_matrix.columns[y]) + for x, y in zip(*high_corr) + if x != y and x < y + ] if high_corr_list: logger.info(f"{Fore.YELLOW}Highly correlated features:{Style.RESET_ALL}") for feat1, feat2 in high_corr_list: @@ -373,4 +552,7 @@ def perform_eda(df): logger.info(f"{Fore.CYAN}Class distribution for target variable '{target_col}':{Style.RESET_ALL}") logger.info(df[target_col].value_counts(normalize=True)) except Exception as e: - logger.error(f"Error occurred during exploratory data analysis...{str(e)}") + logger.error(f"Error occurred during exploratory data analysis: {str(e)}") + raise + + diff --git a/explainableai/model_interpretability.py b/explainableai/model_interpretability.py index 6fdc051..b0682c5 100644 --- a/explainableai/model_interpretability.py +++ b/explainableai/model_interpretability.py @@ -1,82 +1,557 @@ # model_interpretability.py -import shap -import lime -import lime.lime_tabular -import matplotlib.pyplot as plt + +# Import colorama and its components +import colorama +from colorama import Fore, Style + +# Initialize colorama +colorama.init(autoreset=True) + +import pandas as pd import numpy as np +from sklearn.model_selection import train_test_split, cross_val_score +from sklearn.inspection import permutation_importance +from sklearn.preprocessing import StandardScaler, LabelEncoder, OneHotEncoder +from sklearn.impute import SimpleImputer +from sklearn.compose import ColumnTransformer +from sklearn.pipeline import Pipeline + +# Import TensorFlow +import tensorflow as tf +from scikeras.wrappers import KerasClassifier, KerasRegressor + +from .visualizations import ( + plot_feature_importance, plot_partial_dependence, plot_learning_curve, + plot_roc_curve, plot_precision_recall_curve, plot_correlation_heatmap +) +from .model_evaluation import evaluate_model, cross_validate +from .feature_analysis import calculate_shap_values +from .feature_interaction import analyze_feature_interactions +from .llm_explanations import initialize_gemini, get_llm_explanation, get_prediction_explanation +from .report_generator import ReportGenerator +from .model_selection import compare_models +from reportlab.platypus import PageBreak import logging -logger=logging.getLogger(__name__) +logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) -def calculate_shap_values(model, X): - logger.debug("Calculating values...") - try: - explainer = shap.Explainer(model, X) - shap_values = explainer(X) - logger.info("Values caluated...") - return shap_values - except Exception as e: - logger.error(f"Some error occurred in calculating values...{str(e)}") - -def plot_shap_summary(shap_values, X): - logger.debug("Summary...") - try: - plt.figure(figsize=(10, 8)) - shap.summary_plot(shap_values, X, plot_type="bar", show=False) - plt.tight_layout() - plt.savefig('shap_summary.png') - plt.close() - except TypeError as e: - logger.error(f"Error in generating SHAP summary plot: {str(e)}") - logger.error("Attempting alternative SHAP visualization...") +class XAIWrapper: + def __init__(self): + self.model = None + self.models = {} + self.X = None + self.y = None + self.feature_names = None + self.is_classifier = None + self.preprocessor = None + self.label_encoder = None + self.categorical_columns = None + self.numerical_columns = None + self.gemini_model = initialize_gemini() + self.feature_importance = None + self.results = None + self.model_type = None # To store model type + + def fit(self, models, X, y, feature_names=None): + logger.debug("Starting the fit process...") + try: + # Initialize models + if isinstance(models, dict): + self.models = models + logger.debug("Initialized models from dictionary input.") + else: + self.models = {'Model': models} + logger.debug("Initialized single model.") + + self.X = X + self.y = y + self.feature_names = feature_names if feature_names is not None else X.columns.tolist() + self._determine_model_type() + + logger.info(f"{Fore.BLUE}Preprocessing data...{Style.RESET_ALL}") + self._preprocess_data() + + logger.info(f"{Fore.BLUE}Fitting models and analyzing...{Style.RESET_ALL}") + self.model_comparison_results = self._compare_models() + + # Select the best model based on cv_score + best_model_name = max( + self.model_comparison_results, + key=lambda x: self.model_comparison_results[x]['cv_score'] + ) + self.model = self.models[best_model_name] + logger.info(f"Selected best model: {best_model_name} with CV Score: {self.model_comparison_results[best_model_name]['cv_score']:.4f}") + + # Fit the selected model + if self.model_type == 'tensorflow': + logger.info("Fitting TensorFlow model...") + self.model.fit(self.X, self.y, epochs=10, batch_size=32, verbose=0) + else: + logger.info("Fitting scikit-learn model...") + self.model.fit(self.X, self.y) + + logger.info("Model fitting is complete.") + return self + except Exception as e: + logger.error(f"An error occurred while fitting the models: {str(e)}") + raise + + def _determine_model_type(self): + logger.debug("Determining model type...") + try: + model_types = set() + for model in self.models.values(): + if isinstance(model, (tf.keras.Model, KerasClassifier, KerasRegressor)): + model_types.add('tensorflow') + else: + model_types.add('sklearn') + if len(model_types) > 1: + raise ValueError("All models should be of the same type (either all TensorFlow or all scikit-learn).") + self.model_type = model_types.pop() + logger.debug(f"Detected model type: {self.model_type}") + + # Determine if models are classifiers + if self.model_type == 'tensorflow': + # Assume TensorFlow models output probabilities for classifiers + self.is_classifier = all( + model.output_shape[-1] > 1 for model in self.models.values() + ) + else: + self.is_classifier = all(hasattr(model, "predict_proba") for model in self.models.values()) + logger.debug(f"Is classifier: {self.is_classifier}") + except Exception as e: + logger.error(f"Error determining model type: {str(e)}") + raise + + def _compare_models(self): + logger.debug("Comparing models...") + try: + results = {} + for name, model in self.models.items(): + logger.debug(f"Evaluating model: {name}") + if self.model_type == 'tensorflow': + # Wrap TensorFlow models for scikit-learn compatibility + if self.is_classifier: + wrapped_model = KerasClassifier(build_fn=lambda: model, epochs=10, batch_size=32, verbose=0) + else: + wrapped_model = KerasRegressor(build_fn=lambda: model, epochs=10, batch_size=32, verbose=0) + + cv_scores = cross_validate( + wrapped_model, + self.X, + self.y, + is_classifier=self.is_classifier, + model_type=self.model_type + ) + test_score = wrapped_model.score(self.X, self.y) + else: + # Determine scoring metric + scoring = 'roc_auc' if self.is_classifier else 'r2' + cv_scores = cross_val_score(model, self.X, self.y, cv=5, scoring=scoring) + model.fit(self.X, self.y) + test_score = model.score(self.X, self.y) + + results[name] = { + 'cv_score': np.mean(cv_scores), + 'test_score': test_score + } + logger.debug(f"Model {name}: CV Score = {results[name]['cv_score']:.4f}, Test Score = {results[name]['test_score']:.4f}") + logger.info("Model comparison completed successfully.") + return results + except Exception as e: + logger.error(f"An error occurred while comparing models: {str(e)}") + raise + + def _preprocess_data(self): + logger.debug("Preprocessing data...") try: - plt.figure(figsize=(10, 8)) - shap.summary_plot(shap_values.values, X.values, feature_names=X.columns.tolist(), plot_type="bar", show=False) - plt.tight_layout() - plt.savefig('shap_summary.png') - plt.close() - except Exception as e2: - logger.error(f"Alternative SHAP visualization also failed: {str(e2)}") - logger.error("Skipping SHAP summary plot.") - -def get_lime_explanation(model, X, instance, feature_names): - logger.debug("Explaining model...") - try: - explainer = lime.lime_tabular.LimeTabularExplainer( - X, - feature_names=feature_names, - class_names=['Negative', 'Positive'], - mode='classification' - ) - exp = explainer.explain_instance(instance, model.predict_proba) - logger.info("Model explained...") - return exp - except Exception as e: - logger.error(f"Some error occurred in explaining model...{str(e)}") - -def plot_lime_explanation(exp): - exp.as_pyplot_figure() - plt.tight_layout() - plt.savefig('lime_explanation.png') - plt.close() - -def plot_ice_curve(model, X, feature, num_ice_lines=50): - ice_data = X.copy() - feature_values = np.linspace(X[feature].min(), X[feature].max(), num=100) - - plt.figure(figsize=(10, 6)) - for _ in range(num_ice_lines): - ice_instance = ice_data.sample(n=1, replace=True) - predictions = [] - for value in feature_values: - ice_instance[feature] = value - predictions.append(model.predict_proba(ice_instance)[0][1]) - plt.plot(feature_values, predictions, color='blue', alpha=0.1) - - plt.xlabel(feature) - plt.ylabel('Predicted Probability') - plt.title(f'ICE Plot for {feature}') - plt.tight_layout() - plt.savefig(f'ice_plot_{feature}.png') - plt.close() \ No newline at end of file + # Identify categorical and numerical columns + self.categorical_columns = self.X.select_dtypes(include=['object', 'category']).columns + self.numerical_columns = self.X.select_dtypes(include=['int64', 'float64']).columns + logger.debug(f"Categorical columns: {list(self.categorical_columns)}") + logger.debug(f"Numerical columns: {list(self.numerical_columns)}") + + # Create preprocessing pipelines + logger.debug("Creating preprocessing pipelines...") + numeric_transformer = Pipeline(steps=[ + ('imputer', SimpleImputer(strategy='mean')), + ('scaler', StandardScaler()) + ]) + + categorical_transformer = Pipeline(steps=[ + ('imputer', SimpleImputer(strategy='constant', fill_value='missing')), + ('onehot', OneHotEncoder(handle_unknown='ignore', sparse_output=False)) + ]) + + self.preprocessor = ColumnTransformer( + transformers=[ + ('num', numeric_transformer, self.numerical_columns), + ('cat', categorical_transformer, self.categorical_columns) + ] + ) + logger.info("Preprocessing pipelines created.") + + # Fit and transform the data + logger.debug("Fitting and transforming the data...") + self.X = self.preprocessor.fit_transform(self.X) + logger.info("Data preprocessing completed.") + + # Update feature names after preprocessing + logger.debug("Updating feature names post-preprocessing...") + try: + num_feature_names = self.numerical_columns.tolist() + cat_feature_names = [] + if len(self.categorical_columns) > 0: + cat_feature_names = self.preprocessor.named_transformers_['cat'].named_steps['onehot'].get_feature_names_out(self.categorical_columns).tolist() + self.feature_names = num_feature_names + cat_feature_names + logger.debug(f"Updated feature names: {self.feature_names}") + + # Encode target variable if it's categorical + if self.is_classifier and pd.api.types.is_categorical_dtype(self.y): + self.label_encoder = LabelEncoder() + self.y = self.label_encoder.fit_transform(self.y) + logger.debug("Encoded target variable using LabelEncoder.") + except Exception as e: + logger.error(f"Error updating feature names: {str(e)}") + raise + except Exception as e: + logger.error(f"Error during data preprocessing: {str(e)}") + raise + + def analyze(self): + logger.debug("Starting analysis...") + results = {} + try: + # Evaluate model performance + logger.info("Evaluating model performance...") + results['model_performance'] = evaluate_model( + self.model, self.X, self.y, self.is_classifier, self.model_type + ) + + # Calculate feature importance + logger.info("Calculating feature importance...") + self.feature_importance = self._calculate_feature_importance() + results['feature_importance'] = self.feature_importance + + # Generate visualizations + logger.info("Generating visualizations...") + self._generate_visualizations(self.feature_importance) + + # Calculate SHAP values + logger.info("Calculating SHAP values...") + results['shap_values'] = calculate_shap_values( + self.model, self.X, self.feature_names, self.model_type + ) + + # Perform cross-validation + logger.info("Performing cross-validation...") + mean_score, std_score = cross_validate( + self.model, self.X, self.y, + is_classifier=self.is_classifier, + model_type=self.model_type + ) + results['cv_scores'] = (mean_score, std_score) + + # Add model comparison results + logger.info("Adding model comparison results...") + results['model_comparison'] = self.model_comparison_results + + # Print results + self._print_results(results) + + # Generate LLM explanation + logger.info("Generating LLM explanation...") + results['llm_explanation'] = get_llm_explanation(self.gemini_model, results) + + self.results = results + logger.debug("Analysis completed successfully.") + return results + except Exception as e: + logger.error(f"An error occurred during analysis: {str(e)}") + raise + + def generate_report(self, filename='xai_report.pdf'): + logger.debug("Generating report...") + if self.results is None: + raise ValueError("No analysis results available. Please run analyze() first.") + + try: + report = ReportGenerator(filename) + report.add_heading("Explainable AI Report") + + sections = { + 'model_comparison': self._generate_model_comparison, + 'model_performance': self._generate_model_performance, + 'feature_importance': self._generate_feature_importance, + 'visualization': self._generate_visualization, + 'llm_explanation': self._generate_llm_explanation + } + + if input("Do you want all sections in the XAI report? (y/n) ").strip().lower() in ['y', 'yes']: + for section_func in sections.values(): + section_func(report) + else: + for section, section_func in sections.items(): + if input(f"Do you want {section} in the XAI report? (y/n) ").strip().lower() in ['y', 'yes']: + section_func(report) + + report.generate() + logger.info(f"Report generated successfully and saved as '{filename}'.") + except Exception as e: + logger.error(f"An error occurred while generating the report: {str(e)}") + raise + + def _generate_model_comparison(self, report): + logger.debug("Adding model comparison section to report...") + report.add_heading("Model Comparison", level=2) + model_comparison_data = [["Model", "CV Score", "Test Score"]] + [ + [model, f"{scores['cv_score']:.4f}", f"{scores['test_score']:.4f}"] + for model, scores in self.results['model_comparison'].items() + ] + report.add_table(model_comparison_data) + logger.debug("Model comparison section added.") + + def _generate_model_performance(self, report): + logger.debug("Adding model performance section to report...") + report.add_heading("Model Performance", level=2) + for metric, value in self.results['model_performance'].items(): + if isinstance(value, (int, float, np.float64)): + report.add_paragraph(f"**{metric}:** {value:.4f}") + else: + report.add_paragraph(f"**{metric}:**\n{value}") + logger.debug("Model performance section added.") + + def _generate_feature_importance(self, report): + logger.debug("Adding feature importance section to report...") + report.add_heading("Feature Importance", level=2) + feature_importance_data = [["Feature", "Importance"]] + [ + [feature, f"{importance:.4f}"] for feature, importance in self.feature_importance.items() + ] + report.add_table(feature_importance_data) + logger.debug("Feature importance section added.") + + def _generate_visualization(self, report): + logger.debug("Adding visualizations section to report...") + report.add_heading("Visualizations", level=2) + visualization_files = [ + 'feature_importance.png', 'partial_dependence.png', + 'learning_curve.png', 'correlation_heatmap.png' + ] + if self.is_classifier: + visualization_files += ['roc_curve.png', 'precision_recall_curve.png'] + + for image in visualization_files: + report.add_image(image) + report.content.append(PageBreak()) + logger.debug("Visualizations section added.") + + def _generate_llm_explanation(self, report): + logger.debug("Adding LLM explanation section to report...") + report.add_heading("LLM Explanation", level=2) + report.add_llm_explanation(self.results['llm_explanation']) + logger.debug("LLM explanation section added.") + + def predict(self, X): + logger.debug("Starting prediction...") + try: + if self.model is None: + raise ValueError("Model has not been fitted. Please run fit() first.") + + X_preprocessed = self._preprocess_input(X) + + if self.is_classifier: + prediction = self.model.predict(X_preprocessed) + probabilities = self.model.predict_proba(X_preprocessed) + if self.label_encoder: + prediction = self.label_encoder.inverse_transform(prediction) + logger.info("Prediction completed successfully.") + return prediction, probabilities + else: + prediction = self.model.predict(X_preprocessed) + logger.info("Prediction completed successfully.") + return prediction + except Exception as e: + logger.error(f"Error during prediction: {str(e)}") + raise + + def _preprocess_input(self, X): + logger.debug("Preprocessing input data for prediction...") + try: + if not isinstance(X, pd.DataFrame): + X = pd.DataFrame(X, columns=self.feature_names) + logger.debug("Converted input to DataFrame.") + + # Apply the same preprocessing as during training + X_preprocessed = self.preprocessor.transform(X) + logger.debug("Input data preprocessed successfully.") + return X_preprocessed + except Exception as e: + logger.error(f"Error during input preprocessing: {str(e)}") + raise + + def explain_prediction(self, input_data): + logger.debug("Generating prediction explanation...") + try: + input_df = pd.DataFrame([input_data]) + prediction, probabilities = self.predict(input_df) + explanation = get_prediction_explanation( + self.gemini_model, + input_data, + prediction[0], + probabilities[0], + self.feature_importance + ) + logger.info("Prediction explanation generated successfully.") + return prediction[0], probabilities[0], explanation + except Exception as e: + logger.error(f"Error during prediction explanation: {str(e)}") + raise + + def _calculate_feature_importance(self): + logger.debug("Calculating feature importance...") + try: + if self.model_type == 'tensorflow': + logger.debug("Calculating SHAP values for TensorFlow model...") + shap_values = calculate_shap_values( + self.model, self.X, self.feature_names, self.model_type + ) + feature_importance = np.mean(np.abs(shap_values.values), axis=0) + feature_importance_dict = { + feature: importance + for feature, importance in zip(self.feature_names, feature_importance) + } + logger.debug("SHAP-based feature importance calculated.") + else: + logger.debug("Calculating permutation importance for scikit-learn model...") + perm_importance = permutation_importance( + self.model, self.X, self.y, n_repeats=10, random_state=42 + ) + feature_importance_dict = { + feature: importance + for feature, importance in zip(self.feature_names, perm_importance.importances_mean) + } + logger.debug("Permutation-based feature importance calculated.") + + # Sort features by absolute importance in descending order + sorted_importance = dict( + sorted(feature_importance_dict.items(), key=lambda item: abs(item[1]), reverse=True) + ) + self.feature_importance = sorted_importance + logger.info("Feature importance calculated and sorted.") + return sorted_importance + except Exception as e: + logger.error(f"Error calculating feature importance: {str(e)}") + raise + + def _generate_visualizations(self, feature_importance): + logger.debug("Generating visualizations...") + try: + plot_feature_importance(feature_importance) + plot_partial_dependence( + self.model, self.X, feature_importance, self.feature_names, self.model_type + ) + plot_learning_curve( + self.model, self.X, self.y, self.is_classifier, self.model_type + ) + plot_correlation_heatmap( + pd.DataFrame(self.X, columns=self.feature_names) + ) + if self.is_classifier: + plot_roc_curve( + self.model, self.X, self.y, self.model_type + ) + plot_precision_recall_curve( + self.model, self.X, self.y, self.model_type + ) + logger.info("Visualizations generated and saved successfully.") + except Exception as e: + logger.error(f"Error generating visualizations: {str(e)}") + raise + + def _print_results(self, results): + logger.debug("Printing analysis results...") + try: + logger.info("\nModel Performance:") + for metric, value in results['model_performance'].items(): + if isinstance(value, (int, float, np.float64)): + logger.info(f"{metric}: {value:.4f}") + else: + logger.info(f"{metric}:\n{value}") + + logger.info("\nTop 5 Important Features:") + for feature, importance in list(results['feature_importance'].items())[:5]: + logger.info(f"{feature}: {importance:.4f}") + + logger.info(f"\nCross-validation Score: {results['cv_scores'][0]:.4f} (+/- {results['cv_scores'][1]:.4f})") + + logger.info("\nVisualizations saved:") + logger.info("- Feature Importance: feature_importance.png") + logger.info("- Partial Dependence: partial_dependence.png") + logger.info("- Learning Curve: learning_curve.png") + logger.info("- Correlation Heatmap: correlation_heatmap.png") + if self.is_classifier: + logger.info("- ROC Curve: roc_curve.png") + logger.info("- Precision-Recall Curve: precision_recall_curve.png") + + if results['shap_values'] is not None: + logger.info("\nSHAP values calculated successfully. See 'shap_summary.png' for visualization.") + else: + logger.info("\nSHAP values calculation failed. Please check the console output for more details.") + except Exception as e: + logger.error(f"Error printing results: {str(e)}") + raise + + @staticmethod + def perform_eda(df): + logger.debug("Performing exploratory data analysis...") + try: + logger.info(f"{Fore.CYAN}Exploratory Data Analysis:{Style.RESET_ALL}") + logger.info(f"{Fore.GREEN}Dataset shape: {df.shape}{Style.RESET_ALL}") + logger.info(f"{Fore.CYAN}Dataset info:{Style.RESET_ALL}") + df.info() + logger.info(f"{Fore.CYAN}Summary statistics:{Style.RESET_ALL}") + logger.info(df.describe()) + logger.info(f"{Fore.CYAN}Missing values:{Style.RESET_ALL}") + logger.info(df.isnull().sum()) + logger.info(f"{Fore.CYAN}Data types:{Style.RESET_ALL}") + logger.info(df.dtypes) + logger.info(f"{Fore.CYAN}Unique values in each column:{Style.RESET_ALL}") + for col in df.columns: + logger.info(f"{Fore.GREEN}{col}: {df[col].nunique()}{Style.RESET_ALL}") + + # Additional EDA steps + logger.info(f"{Fore.CYAN}Correlation matrix:{Style.RESET_ALL}") + corr_matrix = df.select_dtypes(include=[np.number]).corr() + logger.info(corr_matrix) + + # Identify highly correlated features + high_corr = np.where(np.abs(corr_matrix) > 0.8) + high_corr_list = [ + (corr_matrix.index[x], corr_matrix.columns[y]) + for x, y in zip(*high_corr) + if x != y and x < y + ] + if high_corr_list: + logger.info(f"{Fore.YELLOW}Highly correlated features:{Style.RESET_ALL}") + for feat1, feat2 in high_corr_list: + logger.info(f"{Fore.GREEN}{feat1} - {feat2}: {corr_matrix.loc[feat1, feat2]:.2f}{Style.RESET_ALL}") + + # Identify potential outliers + logger.info(f"{Fore.CYAN}Potential outliers (values beyond 3 standard deviations):{Style.RESET_ALL}") + numeric_cols = df.select_dtypes(include=[np.number]).columns + for col in numeric_cols: + mean = df[col].mean() + std = df[col].std() + outliers = df[(df[col] < mean - 3 * std) | (df[col] > mean + 3 * std)] + if not outliers.empty: + logger.info(f"{Fore.GREEN}{col}: {len(outliers)} potential outliers{Style.RESET_ALL}") + + # Class distribution for the target variable (assuming last column is target) + target_col = df.columns[-1] + logger.info(f"{Fore.CYAN}Class distribution for target variable '{target_col}':{Style.RESET_ALL}") + logger.info(df[target_col].value_counts(normalize=True)) + except Exception as e: + logger.error(f"Error occurred during exploratory data analysis: {str(e)}") + raise + diff --git a/explainableai/utils.py b/explainableai/utils.py index 102b78a..361ba80 100644 --- a/explainableai/utils.py +++ b/explainableai/utils.py @@ -1,44 +1,42 @@ -from sklearn.metrics import mean_squared_error, r2_score, accuracy_score, f1_score -from sklearn.inspection import permutation_importance +# utils.py + +# Import colorama and its components +import colorama +from colorama import Fore, Style + +# Initialize colorama +colorama.init(autoreset=True) + +import pandas as pd import numpy as np import logging -logger=logging.getLogger(__name__) +# Configure logging +logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) -def explain_model(model, X_train, y_train, X_test, y_test, feature_names): - logger.debug("Explaining model...") - try: - result = permutation_importance(model, X_test, y_test, n_repeats=10, random_state=42, n_jobs=-1) - feature_importance = {feature: importance for feature, importance in zip(feature_names, result.importances_mean)} - - # Sort feature importance by absolute value - feature_importance = dict(sorted(feature_importance.items(), key=lambda item: abs(item[1]), reverse=True)) - - logger.info("Model explained...") - return { - "feature_importance": feature_importance, - "model_type": str(type(model)), - } - except Exception as e: - logger.error(f"Some error occurred in explaining model...{str(e)}") - -def calculate_metrics(model, X_test, y_test): - logger.debug("Calculation of metrics...") - try: - y_pred = model.predict(X_test) - - if len(np.unique(y_test)) == 2: # Binary classification - logger.info("Binary classification... ") - return { - "accuracy": accuracy_score(y_test, y_pred), - "f1_score": f1_score(y_test, y_pred, average='weighted') - } - else: # Regression or multi-class classification - logger.info("Multiclass classification...") - return { - "mse": mean_squared_error(y_test, y_pred), - "r2": r2_score(y_test, y_pred) - } - except Exception as e: - logger.error(f"Some error occurred in metric calculation...{str(e)}") \ No newline at end of file +# Example utility function using colorama for colored logs +def log_data_processing_step(step_description): + logger.info(f"{Fore.BLUE}{step_description}{Style.RESET_ALL}") + +# Example utility class +class DataProcessor: + def process_data(self, data): + logger.info(f"{Fore.YELLOW}Starting data processing...{Style.RESET_ALL}") + # Implement data processing logic here + logger.info(f"{Fore.YELLOW}Data processing completed.{Style.RESET_ALL}") + +# Add your actual utility functions and classes below +# Ensure that any function or class using Fore or Style includes the imports above + +def some_utility_function(): + # Example function using Fore and Style + logger.info(f"{Fore.GREEN}This is a green message.{Style.RESET_ALL}") + # Rest of the function... + +class SomeUtilityClass: + def example_method(self): + logger.info(f"{Fore.RED}This is a red message.{Style.RESET_ALL}") + # Rest of the method... + + diff --git a/requirements.txt b/requirements.txt index bb24654..35d3691 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,9 @@ scipy pillow xgboost colorama +scikeras +tensorflow + +pytest + + diff --git a/setup.py b/setup.py index bbc2190..211d3e7 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,9 @@ +# setup.py + from setuptools import setup, find_packages import os +# Read the long description from README.md this_directory = os.path.abspath(os.path.dirname(__file__)) with open(os.path.join(this_directory, 'README.md'), encoding='utf-8') as f: long_description = f.read() @@ -23,7 +26,11 @@ 'google-generativeai', 'python-dotenv', 'scipy', - 'pillow' + 'pillow', + 'colorama', # Added missing dependency + 'scikeras', # Added missing dependency + 'tensorflow', # Added missing dependency + # Removed 'model_interpretability' assuming it's part of this package ], entry_points={ 'console_scripts': [ @@ -60,4 +67,15 @@ package_data={ 'explainableai': ['data/*.csv', 'templates/*.html'], }, -) \ No newline at end of file + # Optional: Add a test suite + # test_suite='tests', + # Optional: Specify development dependencies + extras_require={ + 'dev': [ + 'pytest', + 'flake8', + 'black', + # Add other development dependencies here + ], + }, +) diff --git a/tests/test_utils.py b/tests/test_utils.py index 495cbb7..0b55e65 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,10 +1,19 @@ +# tests/test_utils.py + +import sys +import os import pytest from sklearn.linear_model import LinearRegression, LogisticRegression from sklearn.datasets import make_classification, make_regression from sklearn.model_selection import train_test_split -from explainableai.utils import explain_model, calculate_metrics from dotenv import load_dotenv -import os + +# Add the project root directory to sys.path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from explainableai.utils import explain_model, calculate_metrics + +# Load environment variables load_dotenv() def test_explain_model_regression(): @@ -58,4 +67,4 @@ def test_calculate_metrics_classification(): assert "f1_score" in metrics if __name__ == "__main__": - pytest.main() \ No newline at end of file + pytest.main()