Skip to content

youngx123/vit-transformer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 

Repository files navigation

vit transformer

使用 mnist 数据对模型进行训练、测试

训练 100epoch, 测试集中使用分类精度为 0.993

import torch
from model import VIT

net = VIT(image_size=224,channels=3, patch_size=16, num_classes=10, 
          dim=256, transLayer=6, multiheads=8)

img = torch.randn(2, 3, 224, 224)

out = net(img) # (1, num_classes)
print(out.shape)
torch.Size([2, 10])

# 保存为 onnx 模型
net.eval()
torch.onnx.export(net, img, "vit.onnx", verbose=1, training=torch.onnx.TrainingMode.EVAL,
                    input_names=["inputNode"], output_names=["outNode1"], opset_version=11, )

VIT2 为测试将模型转为 onnx 时, einops 中一些维度变换的函数能否被转换

参考:

https://github.com/lucidrains/vit-pytorch

About

pytorch implement vit transformer

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages