-
Notifications
You must be signed in to change notification settings - Fork 11
Open
Description
This is my code :
class Accurate_Modle((torch.nn.Module)):
def init(self,data_kind,padding):
super(Accurate_Modle, self).init()
self.con6 = torch.nn.Conv2d(2*112,384, kernel_size=1, stride=1, padding=0)
self.con7 = torch.nn.Conv2d(384, 384, kernel_size=1, stride=1, padding=0)
self.con8 = torch.nn.Conv2d(384, 384, kernel_size=1, stride=1, padding=0)
self.con9 = torch.nn.Conv2d(384, 1, kernel_size=1, stride=1, padding=0)
if data_kind=='mb':
# torch.nn.init.constant_(self.con5.bias, 0)
self.conv1 = torch.nn.Sequential(OctConv2d('first', in_channels=1, out_channels=112, kernel_size=3),
OctReLU(),
OctConv2d('regular', in_channels=112, out_channels=112, kernel_size=3),
OctReLU(),
OctConv2d('regular', in_channels=112, out_channels=112, kernel_size=3),
OctReLU(),
OctConv2d('regular', in_channels=112, out_channels=112, kernel_size=3),
OctReLU(),
OctConv2d('last', in_channels=112, out_channels=112, kernel_size=3),
nn.ReLU(),
)
self.full = torch.nn.Sequential(
self.con6,
torch.nn.ReLU(),
self.con7,
torch.nn.ReLU(),
self.con8,
torch.nn.ReLU(),
self.con9,
torch.nn.Sigmoid()
)
# elif data_kind=='kitt':
# #...
def forward(self,x0,x1,flag):
if flag=='train':
y0 = self.conv1(x0) # left_patch
y1 = self.conv1(x1)
y3 = torch.cat((y0, y1), 1)
# print(y3.shape)
y = self.full(y3)
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels