diff --git a/inference/predictor.py b/inference/predictor.py index eff5c82..c161bc8 100644 --- a/inference/predictor.py +++ b/inference/predictor.py @@ -574,13 +574,11 @@ def _predict_reg(self, x_train:np.ndarray, y_train:np.ndarray, x_test:np.ndarray "threshold", 1), mixed_method=self.inference_config[id_pipe]["retrieval_config"].get( "mixed_method", "max"),device=self.device) - outputs.append(output) elif self.inference_with_DDP: inference = InferenceResultWithRetrieval(model=self.model, sample_selection_type="DDP") output = inference.inference(x_[:len(y_train)].squeeze(1), y_, x_[len(y_train):].squeeze(1), task_type="reg") - outputs.append(output) else: self.model.to(self.device) with(torch.autocast(device_type=self.device.type if isinstance(self.device, torch.device) else self.device, enabled=self.mix_precision), torch.inference_mode()): @@ -597,6 +595,7 @@ def _predict_reg(self, x_train:np.ndarray, y_train:np.ndarray, x_test:np.ndarray output = output['reg_output'] output = output if isinstance(output, dict) else output.squeeze(0) + # FIX: append exactly once per pipeline; previously retrieval/DDP branches appended twice and biased the mean. outputs.append(output) output = torch.stack(outputs).squeeze(2).mean(dim=0) diff --git a/utils/data_utils.py b/utils/data_utils.py index dc6c444..1b9c6bb 100644 --- a/utils/data_utils.py +++ b/utils/data_utils.py @@ -95,8 +95,9 @@ def __getitem__(self, idx: int) -> dict[str, list]: return dict( idx=int(idx), X_train=self.X_train[idx], # Training features for this step (retrieved) - X_test=self.X_test[idx], # Training labels for this step (retrieved) - y_train=self.y_train[idx], # The test sample features + # FIX: keep the dict keys semantically correct (X_test is the test feature row; y_train is the retrieved labels) + X_test=self.X_test[idx], # Test sample features for this step + y_train=self.y_train[idx], # Retrieved training labels for this step ) else: # Return only the test data; training data is assumed to be fixed and diff --git a/utils/inference_utils.py b/utils/inference_utils.py index ec8ac28..2c8f885 100644 --- a/utils/inference_utils.py +++ b/utils/inference_utils.py @@ -99,29 +99,35 @@ def generate_infenerce_config(args): use_type=None, ) + # FIX: do not reuse the same retrieval_config dict across pipelines; later updates should not mutate others. + retrieval_config_1 = retrieval_config.copy() + retrieval_config_2 = retrieval_config.copy() + retrieval_config_3 = retrieval_config.copy() + retrieval_config_4 = retrieval_config.copy() + config_list = [ dict(RebalanceFeatureDistribution=dict(worker_tags=["quantile"], discrete_flag=False, original_flag=True, svd_tag="svd"), CategoricalFeatureEncoder=dict(encoding_strategy="ordinal_strict_feature_shuffled"), FeatureShuffler=dict(mode="shuffle"), - retrieval_config=retrieval_config, + retrieval_config=retrieval_config_1, ), dict(RebalanceFeatureDistribution=dict(worker_tags=["quantile"], discrete_flag=False, original_flag=True, svd_tag="svd"), CategoricalFeatureEncoder=dict(encoding_strategy="ordinal_strict_feature_shuffled"), - FeatureShuffler=dict(mode="shuffle"), retrieval_config=retrieval_config, + FeatureShuffler=dict(mode="shuffle"), retrieval_config=retrieval_config_2, ), dict(RebalanceFeatureDistribution=dict(worker_tags=[None], discrete_flag=True, original_flag=False, svd_tag=None), CategoricalFeatureEncoder=dict(encoding_strategy="numeric"), FeatureShuffler=dict(mode="shuffle"), - retrieval_config=retrieval_config, + retrieval_config=retrieval_config_3, ), dict(RebalanceFeatureDistribution=dict(worker_tags=[None], discrete_flag=True, original_flag=False, svd_tag=None), CategoricalFeatureEncoder=dict(encoding_strategy="numeric"), FeatureShuffler=dict(mode="shuffle"), - retrieval_config=retrieval_config) + retrieval_config=retrieval_config_4) ] with open(args.inference_config_path, 'w') as f: