|
| 1 | +# Copyright (c) 2023 Microsoft |
| 2 | +# Licensed under The MIT License [see LICENSE for details] |
| 3 | + |
| 4 | +import argparse |
| 5 | +from collections import defaultdict |
| 6 | +from typing import Dict, List, Tuple, DefaultDict |
| 7 | +import numpy as np |
| 8 | +import torch |
| 9 | + |
| 10 | +def parse_arguments() -> argparse.Namespace: |
| 11 | + """Parse command line arguments""" |
| 12 | + parser = argparse.ArgumentParser(description="Filter compressed prompts based on metrics.") |
| 13 | + parser.add_argument( |
| 14 | + "--load_path", |
| 15 | + help="path to load data", |
| 16 | + default="../../../results/meetingbank/gpt-4-32k_comp/annotation_cs512_meetingbank_train_formated.pt", |
| 17 | + ) |
| 18 | + parser.add_argument( |
| 19 | + "--save_path", |
| 20 | + help="path to save filtered data", |
| 21 | + default="../../../results/meetingbank/gpt-4-32k_comp/annotation_kept_cs512_meetingbank_train_formated.pt", |
| 22 | + ) |
| 23 | + parser.add_argument( |
| 24 | + "--percentile", |
| 25 | + help="percentile threshold for filtering", |
| 26 | + default=90, |
| 27 | + type=int |
| 28 | + ) |
| 29 | + return parser.parse_args() |
| 30 | + |
| 31 | +def filter_by_metric( |
| 32 | + data: DefaultDict[str, List], |
| 33 | + metric_name: str, |
| 34 | + percentile: float |
| 35 | +) -> Tuple[DefaultDict[str, List], DefaultDict[str, List]]: |
| 36 | + """ |
| 37 | + Filter data based on a specific metric and percentile threshold |
| 38 | + |
| 39 | + Args: |
| 40 | + data: Dictionary containing all data points and their metrics |
| 41 | + metric_name: Name of the metric to filter by |
| 42 | + percentile: Percentile threshold for filtering |
| 43 | + |
| 44 | + Returns: |
| 45 | + Tuple of (kept_data, filtered_data) |
| 46 | + """ |
| 47 | + metric_list = data[metric_name] |
| 48 | + threshold = np.percentile(metric_list, percentile) |
| 49 | + |
| 50 | + kept = defaultdict(list) |
| 51 | + filtered = defaultdict(list) |
| 52 | + |
| 53 | + # List of all metrics to transfer |
| 54 | + metrics = [ |
| 55 | + "labels", "origin", "comp", "retrieval", "comp_rate", |
| 56 | + "variation_rate", "hitting_rate", "matching_rate", "alignment_gap" |
| 57 | + ] |
| 58 | + |
| 59 | + for values in zip(*(data[metric] for metric in metrics)): |
| 60 | + # Create a dictionary of current values |
| 61 | + current = dict(zip(metrics, values)) |
| 62 | + |
| 63 | + # Determine which container to use based on the metric threshold |
| 64 | + target = filtered if current[metric_name] >= threshold else kept |
| 65 | + |
| 66 | + # Add values to appropriate container |
| 67 | + for metric, value in current.items(): |
| 68 | + target[metric].append(value) |
| 69 | + |
| 70 | + return kept, filtered |
| 71 | + |
| 72 | +def main(): |
| 73 | + """Main function to run the filtering process""" |
| 74 | + args = parse_arguments() |
| 75 | + |
| 76 | + # Load data |
| 77 | + res_pt = torch.load(args.load_path, weights_only=False) |
| 78 | + print(f"Initial sample count: {len(res_pt['variation_rate'])}") |
| 79 | + |
| 80 | + # First filtering stage: variation rate |
| 81 | + kept, filtered = filter_by_metric( |
| 82 | + data=res_pt, |
| 83 | + metric_name="variation_rate", |
| 84 | + percentile=args.percentile |
| 85 | + ) |
| 86 | + |
| 87 | + # Second filtering stage: alignment gap |
| 88 | + final_kept, additional_filtered = filter_by_metric( |
| 89 | + data=kept, |
| 90 | + metric_name="alignment_gap", |
| 91 | + percentile=args.percentile |
| 92 | + ) |
| 93 | + |
| 94 | + # Save filtered results |
| 95 | + torch.save(final_kept, args.save_path) |
| 96 | + |
| 97 | + # Print statistics |
| 98 | + print(f"Samples after first filter: {len(kept['variation_rate'])}") |
| 99 | + print(f"Final kept samples: {len(final_kept['variation_rate'])}") |
| 100 | + |
| 101 | +if __name__ == "__main__": |
| 102 | + main() |
0 commit comments