-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathCNNs.py
More file actions
27 lines (19 loc) · 904 Bytes
/
CNNs.py
File metadata and controls
27 lines (19 loc) · 904 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet34, ResNet34_Weights
from torchvision.models import resnet18, ResNet18_Weights
from typing import Union, Any, Tuple
def load_model(model_name: str,
dataset_name: str,
device: Union[torch.device, str] = torch.device('cuda'),
) -> nn.Module:
if 'cifar' in dataset_name:
full_name = dataset_name + "_" + model_name
model = torch.hub.load("chenyaofo/pytorch-cifar-models", full_name, pretrained=True)
elif 'imagenet' in dataset_name:
if model_name == "resnet34":
model = resnet34(weights=ResNet34_Weights.DEFAULT).to(device)
elif model_name == "resnet18":
model = resnet18(weights=ResNet18_Weights.DEFAULT).to(device)
return model.to(device)