-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathutils.py
More file actions
62 lines (47 loc) · 1.89 KB
/
utils.py
File metadata and controls
62 lines (47 loc) · 1.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause-Clear
from bert_score import BERTScorer
from rouge_score import rouge_scorer
from torch.utils.data import Dataset
from data import QualcommInteractiveCookingDataset
_rouge_scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)
_bert_scorer = BERTScorer(
model_type="roberta-large", lang="en", device="cuda", rescale_with_baseline=True
)
mean = lambda lst: sum(lst) / (len(lst) + 1e-12)
def get_bert_score(pred_text: str, gt_text: str) -> float:
"""
Calculate the BERT score between predicted text and ground truth text.
Args:
pred_text: The predicted text to evaluate
gt_text: The ground truth text to compare against
Returns:
The BERT F1 score as a float
"""
_, _, _bert_score = _bert_scorer.score([pred_text], [gt_text], verbose=False)
_bert_score = float(_bert_score[0])
return _bert_score
def get_rouge_score(pred_text: str, gt_text: str) -> float:
"""
Calculate the ROUGE-L score between predicted text and ground truth text.
Args:
pred_text: The predicted text to evaluate
gt_text: The ground truth text to compare against
Returns:
The ROUGE-L F-measure score as a float
"""
_rouge_score = _rouge_scorer.score(pred_text, gt_text)
_rouge_score = _rouge_score["rougeL"].fmeasure
return _rouge_score
def load_dataset(plan_set: str, split: str) -> Dataset:
"""
Load the Qualcomm Interactive Cooking Dataset.
Args:
dataset_root: Root directory containing the dataset
plan_set: The plan set to use
split: Dataset split to load (e.g., 'train', 'validation', 'test')
Returns:
The loaded dataset as a torch.utils.data.Dataset object
"""
dataset = QualcommInteractiveCookingDataset(plan_set=plan_set, split=split)
return dataset