Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ def add(self, data, add_results, contribution_round):
add_results["A_sum"] += data["site_ols_params"]['A']
add_results["B_sum"] += data["site_ols_params"]['B']
add_results["combined_hessian"] += data["site_hessian"]
add_results["exog_names"] = data["exog_names"]

def get_result(self, add_results, contribution_round, target_accuracy, **kwargs):
next_beta = np.linalg.inv(add_results["A_sum"]).dot(add_results["B_sum"])
Expand All @@ -246,13 +247,13 @@ def get_result(self, add_results, contribution_round, target_accuracy, **kwargs)
# Stop if the result is already accurate enough
if np.all(np.greater(target_accuracy, accuracy)):
print(f"Reached accuracy threshold")
return None, {"beta": next_beta, "fed_stderror": fed_stderror, "signal": 'ABORT', "Reached accuracy threshold": True}
return None, {"beta": next_beta, "fed_stderror": fed_stderror, "variable_names_by_betas_order": data["exog_names"], "signal": 'ABORT', "Reached accuracy threshold": True}

add_results["A_sum"] = 0
add_results["B_sum"] = 0
add_results["combined_hessian"] = 0
print(f"next beta after contribution round {contribution_round} is {next_beta}")
return None, {"site_info": {"params": next_beta}, "beta": next_beta, "fed_stderror": fed_stderror, "Reached accuracy threshold": False}
return None, {"site_info": {"params": next_beta}, "beta": next_beta, "fed_stderror": fed_stderror, "variable_names_by_betas_order": data["exog_names"], "Reached accuracy threshold": False}


OPTIMIZERS = {"NR": NewtonRaphson, "IRLS": IRLS}
47 changes: 39 additions & 8 deletions examples/nvflare/regression-glm-coeff/custom/regression_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@

from coeff_optimizer import OPTIMIZERS

INVALID_SIGNS = [
'<', '>', '=',
'+', '-', '*', '/', '//', '%', '**',
'(', ')', ':', '~', '|', '^', ',',
'.', "'", '"', '@', '#', '$',
'[', ']', '{', '}', '?', '!', ' '
] # These are signs that will fail the smf.glm function if they appear in the column name

class GLMTrainer(Executor):
def __init__(
Expand Down Expand Up @@ -81,14 +88,6 @@ def __init__(
self.site_info = dict()
print(f"Initialized Federated Client. {x_values=}, {y_values=}, {formula=}, {glm_type=}, {add_intercept=}, {cast_to_string_fields=}")

if not self.formula and not (self.x_values and self.y_values):
print("Either formula or x_values and y_values must be provided.")
raise ValueError("Either formula or x_values and y_values must be provided.")

if self.offset and self.family_class != sm.families.Poisson:
print("Offset is only supported for Poisson distribution family.")
raise ValueError("Offset is only supported for Poisson distribution family.")

# Load dataset
datasets_path = '/input/datasets'
dataset_uid = next(os.walk(datasets_path))[1][0]
Expand All @@ -111,6 +110,38 @@ def __init__(
if self.data_x is not None:
self.data_x["Intercept"] = 1

def _validate_input(self):
"""
Validate the input parameters for the model, including:
- Ensure that either formula or x_values and y_values are provided.
- Ensure that the supplied variables (either from formula or x_values and y_values) are present in the dataset.
- Ensure that the supplied variables do not contain invalid signs that will cause the model to fail.
- Ensure that the offset is only used with the Poisson distribution family.
"""
# Validate either formula or explicit columns are supplied
if not self.formula and not (self.x_values and self.y_values):
print("Either formula or x_values and y_values must be provided.")
raise ValueError("Either formula or x_values and y_values must be provided.")

# Validate formula structure
if self.formula:
formula = self.formula.replace(" ", "")
dependent_var, independent_vars = formula.split("~")
formula_parts = independent_vars.split("+") + [dependent_var]
else:
formula_parts = self.x_values + self.y_values
missing_parts = [part for part in formula_parts if part not in self.data.columns]
if missing_parts:
raise ValueError(
f"The given {"formula" if formla else "y_values or x values"} contains variables that are missing from the data columns: {missing_parts}")
if any(sign in part for part in formula_parts for sign in INVALID_SIGNS):
raise ValueError(f"Column headers with the signs {INVALID_SIGNS} are invalid, please modify the dataset columns and the model's formula.")

# Validate use of offset
if self.offset and self.family_class != sm.families.Poisson:
print("Offset is only supported for Poisson distribution family.")
raise ValueError("Offset is only supported for Poisson distribution family.")

def handle_event(self, event_type: str, fl_ctx: FLContext):
pass

Expand Down