-
Notifications
You must be signed in to change notification settings - Fork 44
Expand file tree
/
Copy pathgpv_transforms.py
More file actions
31 lines (24 loc) · 824 Bytes
/
gpv_transforms.py
File metadata and controls
31 lines (24 loc) · 824 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import numpy as np
import torch
import torch.nn as nn
class Compose(object):
def __init__(self,co_transforms):
self.co_transforms = co_transforms
def __call__(self, inputs):
for transforms in self.co_transforms:
inputs = transforms(inputs) #
return inputs
class ArrayToTensor(object):
def __call__(self,array):
assert(isinstance(array,np.ndarray))
#array = np.transpose(array, (2,0,1))
# handle numpy array
tensor = torch.from_numpy(array.copy())
tensor = torch.unsqueeze(tensor,dim=0)
return tensor.float()
class oneD2twoD(object):
def __init__(self,img_size=32):
self.img_size = img_size
def __call__(self,inputs):
inputs = np.reshape(inputs,(self.img_size,self.img_size))
return inputs