-
Notifications
You must be signed in to change notification settings - Fork 53
Expand file tree
/
Copy pathdataset.py
More file actions
54 lines (46 loc) · 2.2 KB
/
dataset.py
File metadata and controls
54 lines (46 loc) · 2.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import json
import random
from torch.utils.data import Dataset, DataLoader
import pandas as pd
sep_token = " "
class CICERO2Dataset(Dataset):
def __init__(self, f, shuffle):
content, labels, answer = [], [], []
x = open(f).readlines()
if shuffle:
random.shuffle(x)
for line in x:
instance = json.loads(line)
choices = instance["Choices"]
context = sep_token.join([instance["Question"], "target: " + instance["Target"],
"context: " + " <utt> ".join(instance["Dialogue"])])
l = instance["Correct Answers"]
for k, c in enumerate(choices):
content.append("{} \\n choice: {}".format(context, c))
content.append("{} \\n let's think step by step,and answer the two questions:, is the answer {} right and why?".format(context, c))
if k == l:
labels.append(1)
else:
labels.append(0)
self.content, self.labels, self.choices = content, labels, choices
def __len__(self):
return len(self.content)
def __getitem__(self, index):
s1, s2 = self.content[index], self.labels[index]
choice_index = index % len(self.choices)
s3 = self.choices[choice_index]
return s1, s2, s3
def collate_fn(self, data):
dat = pd.DataFrame(data)
return [dat[i].tolist() for i in dat]
def configure_dataloaders(train_batch_size=16, eval_batch_size=16, shuffle=False):
"Prepare dataloaders"
train_dataset = CICERO2Dataset("data/cicero/train.json", True)
train_loader = DataLoader(train_dataset, batch_size=train_batch_size,
collate_fn=train_dataset.collate_fn)
val_dataset = CICERO2Dataset("data/cicero/val.json", False)
val_loader = DataLoader(val_dataset, batch_size=eval_batch_size, collate_fn=val_dataset.collate_fn)
test_dataset = CICERO2Dataset("data/cicero/test.json", False)
test_loader = DataLoader(test_dataset, batch_size=eval_batch_size, collate_fn=val_dataset.collate_fn)
return train_loader, val_loader, test_loader