-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_pipeline.py
More file actions
123 lines (88 loc) · 2.79 KB
/
test_pipeline.py
File metadata and controls
123 lines (88 loc) · 2.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import torch
import torch.nn as nn
from typing import Tuple
import torch
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from torchdistpackage import tpc, setup_distributed, test_comm
from torchdistpackage.parallel import partition_uniform, forward_backward
class DummyClsDataset(Dataset):
def __init__(self, shape, num_samples=1000):
self.num_samples = num_samples
self.shape = shape
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
del idx
img = torch.randn(self.shape)
label = torch.randint(0, 10, (1,)).squeeze()
return img, label
def get_dataloaders(
batch_size, batch_shape, train_samples, test_samples
) -> Tuple[DataLoader, DataLoader]:
""" get dataloaders """
train_data = DummyClsDataset(batch_shape, train_samples)
sampler = DistributedSampler(train_data, shuffle=True)
train_dataloader = DataLoader(
dataset=train_data,
batch_size=batch_size,
num_workers=0,
pin_memory=True,
sampler=sampler,
)
test_dataloader = DataLoader(
DummyClsDataset(batch_shape, test_samples), batch_size=batch_size, shuffle=True
)
return train_dataloader, test_dataloader
def build_seq_model():
model = nn.Sequential(
nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 10), nn.ReLU(), nn.ReLU()
)
return model
# Define some config
BATCH_SIZE = 512
NUM_EPOCHS = 2
NUM_CHUNKS = 2
NUM_MICRO_BATCHES = 4
setup_distributed()
world_size = int(os.environ["SLURM_NTASKS"])
pp_size = 2
dist_config = [("pipe", pp_size), ("data", world_size / (pp_size))]
tpc.setup_process_groups(dist_config)
model = build_seq_model().cuda()
model = nn.Sequential(*partition_uniform(model)).cuda()
# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# build dataloader
root = os.environ.get("DATA", "./data")
train_dataloader, test_dataloader = get_dataloaders(BATCH_SIZE, [10], 4000, 100)
test_dataloader = None
def pp_fwd_fn(ins):
input_img = ins
img_feat = model(input_img)
if tpc.is_last_in_pipeline_group():
loss = img_feat.mean()
return loss
return img_feat
def local_forward_backward(inp, model):
inputs = []
if tpc.is_first_in_pipeline_group():
inputs.append(inp)
forward_backward(
optimizer,
pp_fwd_fn,
None,
inputs,
num_microbatches=NUM_MICRO_BATCHES,
forward_only=False,
dtype=torch.float32,
scatter_gather_tensors=False,
)
for epoch in range(NUM_EPOCHS):
for img, label in train_dataloader:
img = img.cuda()
label = label.cuda()
local_forward_backward(img, model)
optimizer.step()
print("-------------")