-
Notifications
You must be signed in to change notification settings - Fork 0
/
toONNX.py
29 lines (23 loc) · 806 Bytes
/
toONNX.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import torchvision.models as models
import torch.nn as nn
PATH_TO_MODEL = "models/Res50_lr=0.0005_batchSize=10" #omit the .pth extension
def load(path):
print("Loading from ", path)
net = models.resnet50(weights='DEFAULT')
net.fc = nn.Linear(net.fc.in_features, 4)
net.load_state_dict(torch.load(path))
print("Loaded")
return net
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = load(PATH_TO_MODEL + ".pth")
model.eval()
dummy_input = torch.randn(1, 3, 128, 128)
output_names = [ "output" ]
torch.onnx.export(model,
dummy_input,
PATH_TO_MODEL + ".onnx",
verbose=False,
keep_initializers_as_inputs=True,
export_params=True,
)