Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions inference/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,13 +574,11 @@ def _predict_reg(self, x_train:np.ndarray, y_train:np.ndarray, x_test:np.ndarray
"threshold", 1),
mixed_method=self.inference_config[id_pipe]["retrieval_config"].get(
"mixed_method", "max"),device=self.device)
outputs.append(output)
elif self.inference_with_DDP:
inference = InferenceResultWithRetrieval(model=self.model,
sample_selection_type="DDP")
output = inference.inference(x_[:len(y_train)].squeeze(1), y_, x_[len(y_train):].squeeze(1),
task_type="reg")
outputs.append(output)
else:
self.model.to(self.device)
with(torch.autocast(device_type=self.device.type if isinstance(self.device, torch.device) else self.device, enabled=self.mix_precision), torch.inference_mode()):
Expand All @@ -597,6 +595,7 @@ def _predict_reg(self, x_train:np.ndarray, y_train:np.ndarray, x_test:np.ndarray
output = output['reg_output']

output = output if isinstance(output, dict) else output.squeeze(0)
# FIX: append exactly once per pipeline; previously retrieval/DDP branches appended twice and biased the mean.
outputs.append(output)

output = torch.stack(outputs).squeeze(2).mean(dim=0)
Expand Down
5 changes: 3 additions & 2 deletions utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,9 @@ def __getitem__(self, idx: int) -> dict[str, list]:
return dict(
idx=int(idx),
X_train=self.X_train[idx], # Training features for this step (retrieved)
X_test=self.X_test[idx], # Training labels for this step (retrieved)
y_train=self.y_train[idx], # The test sample features
# FIX: keep the dict keys semantically correct (X_test is the test feature row; y_train is the retrieved labels)
X_test=self.X_test[idx], # Test sample features for this step
y_train=self.y_train[idx], # Retrieved training labels for this step
)
else:
# Return only the test data; training data is assumed to be fixed and
Expand Down
14 changes: 10 additions & 4 deletions utils/inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,29 +99,35 @@ def generate_infenerce_config(args):
use_type=None,
)

# FIX: do not reuse the same retrieval_config dict across pipelines; later updates should not mutate others.
retrieval_config_1 = retrieval_config.copy()
retrieval_config_2 = retrieval_config.copy()
retrieval_config_3 = retrieval_config.copy()
retrieval_config_4 = retrieval_config.copy()

config_list = [
dict(RebalanceFeatureDistribution=dict(worker_tags=["quantile"], discrete_flag=False, original_flag=True,
svd_tag="svd"),
CategoricalFeatureEncoder=dict(encoding_strategy="ordinal_strict_feature_shuffled"),
FeatureShuffler=dict(mode="shuffle"),
retrieval_config=retrieval_config,
retrieval_config=retrieval_config_1,
),
dict(RebalanceFeatureDistribution=dict(worker_tags=["quantile"], discrete_flag=False, original_flag=True,
svd_tag="svd"),
CategoricalFeatureEncoder=dict(encoding_strategy="ordinal_strict_feature_shuffled"),
FeatureShuffler=dict(mode="shuffle"), retrieval_config=retrieval_config,
FeatureShuffler=dict(mode="shuffle"), retrieval_config=retrieval_config_2,
),
dict(RebalanceFeatureDistribution=dict(worker_tags=[None], discrete_flag=True, original_flag=False,
svd_tag=None),
CategoricalFeatureEncoder=dict(encoding_strategy="numeric"),
FeatureShuffler=dict(mode="shuffle"),
retrieval_config=retrieval_config,
retrieval_config=retrieval_config_3,
),
dict(RebalanceFeatureDistribution=dict(worker_tags=[None], discrete_flag=True, original_flag=False,
svd_tag=None),
CategoricalFeatureEncoder=dict(encoding_strategy="numeric"),
FeatureShuffler=dict(mode="shuffle"),
retrieval_config=retrieval_config)
retrieval_config=retrieval_config_4)
]

with open(args.inference_config_path, 'w') as f:
Expand Down