Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions folktexts/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ class BenchmarkConfig:
reuse_few_shot_examples : bool, optional
Whether to reuse the same samples for few-shot prompting (or sample new
ones every time), by default False.
balance_few_shot_examples : bool, optional
Whether to balance the samples for few-shot prompting with respect to
their labels, by default False.
batch_size : int | None, optional
The batch size to use for inference.
context_size : int | None, optional
Expand All @@ -62,6 +65,7 @@ class BenchmarkConfig:
numeric_risk_prompting: bool = False
few_shot: int | None = None
reuse_few_shot_examples: bool = False
balance_few_shot_examples: bool = False
batch_size: int | None = None
context_size: int | None = None
correct_order_bias: bool = True
Expand Down Expand Up @@ -540,6 +544,7 @@ def make_benchmark(
n_shots=config.few_shot,
dataset=dataset,
reuse_examples=config.reuse_few_shot_examples,
class_balancing=config.balance_few_shot_examples,
)

else:
Expand Down
8 changes: 8 additions & 0 deletions folktexts/cli/run_acs_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ def list_of_strings(arg):
default=False,
)

parser.add_argument(
"--balance-few-shot-examples",
help="[bool] Whether to sample evenly from all classes in few-shot prompting",
action="store_true",
default=False,
)

# Optionally, receive a list of features to use (subset of original list)
parser.add_argument(
"--use-feature-subset",
Expand Down Expand Up @@ -147,6 +154,7 @@ def main():
few_shot=args.few_shot,
numeric_risk_prompting=args.numeric_risk_prompting,
reuse_few_shot_examples=args.reuse_few_shot_examples,
balance_few_shot_examples=args.balance_few_shot_examples,
batch_size=args.batch_size,
context_size=args.context_size,
correct_order_bias=not args.dont_correct_order_bias,
Expand Down
35 changes: 31 additions & 4 deletions folktexts/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def sample_n_train_examples(
self,
n: int,
reuse_examples: bool = False,
class_balancing: bool = False,
) -> tuple[pd.DataFrame, pd.Series]:
"""Return a set of samples from the training set.

Expand All @@ -304,11 +305,37 @@ def sample_n_train_examples(
X, y : tuple[pd.DataFrame, pd.Series]
The features and target data for the sampled examples.
"""
# TODO: make sure examples are class-balanced?
if reuse_examples:
example_indices = self._train_indices[:n]
if class_balancing:

train_labels = self.get_target_data().iloc[self._train_indices]
unique_labels, counts = np.unique(train_labels, return_counts=True)

# Calculate number of samples to sample per label
per_label_n = n // len(unique_labels)
remaining = n % len(unique_labels) # distribute extra samples

if min(counts) < per_label_n:
logging.error(
f"Labels are very imbalanced: Attempting to sample {per_label_n}, "
f"but minimal group size is {min(counts)}.")

example_indices = []
for i, label in enumerate(unique_labels):
class_indices = self._train_indices[train_labels == label]

if reuse_examples:
selected = class_indices[:per_label_n + int(i < remaining)]
else:
selected = self._rng.choice(class_indices, size=per_label_n + int(i < remaining), replace=False)
example_indices.extend(selected)

# shuffle indices to ensure classes are mixed
example_indices = self._rng.permutation(example_indices)
else:
example_indices = self._rng.choice(self._train_indices, size=n, replace=False)
if reuse_examples:
example_indices = self._train_indices[:n]
else:
example_indices = self._rng.choice(self._train_indices, size=n, replace=False)

return (
self.data.iloc[example_indices][self.task.features],
Expand Down
2 changes: 1 addition & 1 deletion folktexts/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import logging
import statistics
from typing import Callable, Optional
from typing import Callable

import numpy as np
from netcal.metrics import ECE
Expand Down
7 changes: 6 additions & 1 deletion folktexts/prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def encode_row_prompt_few_shot(
n_shots: int,
question: QAInterface = None,
reuse_examples: bool = False,
class_balancing: bool = False,
custom_prompt_prefix: str = None,
) -> str:
"""Encode a question regarding a given row using few-shot prompting.
Expand All @@ -87,7 +88,11 @@ def encode_row_prompt_few_shot(
The encoded few-shot prompt.
"""
# Take `n_shots` random samples from the train set
X_examples, y_examples = dataset.sample_n_train_examples(n_shots, reuse_examples=reuse_examples)
X_examples, y_examples = dataset.sample_n_train_examples(
n_shots,
reuse_examples=reuse_examples,
class_balancing=class_balancing,
)

# Start with task description
prompt = ACS_FEW_SHOT_TASK_DESCRIPTION + "\n"
Expand Down