diff --git a/folktexts/benchmark.py b/folktexts/benchmark.py index 833dce9..86e9304 100755 --- a/folktexts/benchmark.py +++ b/folktexts/benchmark.py @@ -42,6 +42,9 @@ class BenchmarkConfig: reuse_few_shot_examples : bool, optional Whether to reuse the same samples for few-shot prompting (or sample new ones every time), by default False. + balance_few_shot_examples : bool, optional + Whether to balance the samples for few-shot prompting with respect to + their labels, by default False. batch_size : int | None, optional The batch size to use for inference. context_size : int | None, optional @@ -62,6 +65,7 @@ class BenchmarkConfig: numeric_risk_prompting: bool = False few_shot: int | None = None reuse_few_shot_examples: bool = False + balance_few_shot_examples: bool = False batch_size: int | None = None context_size: int | None = None correct_order_bias: bool = True @@ -540,6 +544,7 @@ def make_benchmark( n_shots=config.few_shot, dataset=dataset, reuse_examples=config.reuse_few_shot_examples, + class_balancing=config.balance_few_shot_examples, ) else: diff --git a/folktexts/cli/run_acs_benchmark.py b/folktexts/cli/run_acs_benchmark.py index 13fee93..e43803f 100755 --- a/folktexts/cli/run_acs_benchmark.py +++ b/folktexts/cli/run_acs_benchmark.py @@ -75,6 +75,13 @@ def list_of_strings(arg): default=False, ) + parser.add_argument( + "--balance-few-shot-examples", + help="[bool] Whether to sample evenly from all classes in few-shot prompting", + action="store_true", + default=False, + ) + # Optionally, receive a list of features to use (subset of original list) parser.add_argument( "--use-feature-subset", @@ -147,6 +154,7 @@ def main(): few_shot=args.few_shot, numeric_risk_prompting=args.numeric_risk_prompting, reuse_few_shot_examples=args.reuse_few_shot_examples, + balance_few_shot_examples=args.balance_few_shot_examples, batch_size=args.batch_size, context_size=args.context_size, correct_order_bias=not args.dont_correct_order_bias, diff --git a/folktexts/dataset.py b/folktexts/dataset.py index 081faa1..c2548e8 100755 --- a/folktexts/dataset.py +++ b/folktexts/dataset.py @@ -288,6 +288,7 @@ def sample_n_train_examples( self, n: int, reuse_examples: bool = False, + class_balancing: bool = False, ) -> tuple[pd.DataFrame, pd.Series]: """Return a set of samples from the training set. @@ -304,11 +305,37 @@ def sample_n_train_examples( X, y : tuple[pd.DataFrame, pd.Series] The features and target data for the sampled examples. """ - # TODO: make sure examples are class-balanced? - if reuse_examples: - example_indices = self._train_indices[:n] + if class_balancing: + + train_labels = self.get_target_data().iloc[self._train_indices] + unique_labels, counts = np.unique(train_labels, return_counts=True) + + # Calculate number of samples to sample per label + per_label_n = n // len(unique_labels) + remaining = n % len(unique_labels) # distribute extra samples + + if min(counts) < per_label_n: + logging.error( + f"Labels are very imbalanced: Attempting to sample {per_label_n}, " + f"but minimal group size is {min(counts)}.") + + example_indices = [] + for i, label in enumerate(unique_labels): + class_indices = self._train_indices[train_labels == label] + + if reuse_examples: + selected = class_indices[:per_label_n + int(i < remaining)] + else: + selected = self._rng.choice(class_indices, size=per_label_n + int(i < remaining), replace=False) + example_indices.extend(selected) + + # shuffle indices to ensure classes are mixed + example_indices = self._rng.permutation(example_indices) else: - example_indices = self._rng.choice(self._train_indices, size=n, replace=False) + if reuse_examples: + example_indices = self._train_indices[:n] + else: + example_indices = self._rng.choice(self._train_indices, size=n, replace=False) return ( self.data.iloc[example_indices][self.task.features], diff --git a/folktexts/evaluation.py b/folktexts/evaluation.py index 1abb088..323f01d 100644 --- a/folktexts/evaluation.py +++ b/folktexts/evaluation.py @@ -9,7 +9,7 @@ import logging import statistics -from typing import Callable, Optional +from typing import Callable import numpy as np from netcal.metrics import ECE diff --git a/folktexts/prompting.py b/folktexts/prompting.py index 824c8dc..80960f5 100644 --- a/folktexts/prompting.py +++ b/folktexts/prompting.py @@ -64,6 +64,7 @@ def encode_row_prompt_few_shot( n_shots: int, question: QAInterface = None, reuse_examples: bool = False, + class_balancing: bool = False, custom_prompt_prefix: str = None, ) -> str: """Encode a question regarding a given row using few-shot prompting. @@ -87,7 +88,11 @@ def encode_row_prompt_few_shot( The encoded few-shot prompt. """ # Take `n_shots` random samples from the train set - X_examples, y_examples = dataset.sample_n_train_examples(n_shots, reuse_examples=reuse_examples) + X_examples, y_examples = dataset.sample_n_train_examples( + n_shots, + reuse_examples=reuse_examples, + class_balancing=class_balancing, + ) # Start with task description prompt = ACS_FEW_SHOT_TASK_DESCRIPTION + "\n"