Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(validate): update DL fromat model validate #151

Merged
merged 5 commits into from
Nov 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 76 additions & 69 deletions pkg/model/format.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"io/ioutil"
"os"
"path"
"strings"
)

// Format is the definition of model format.
Expand Down Expand Up @@ -77,48 +78,55 @@ func (f Format) ValidateDirectory(rootPath string) error {
return nil
}

func ValidateError(modelPath string, modelName string, modelNum int32) error {
if modelNum != 1 {
return fmt.Errorf("Expected one %v file in %v directory, but found %v .", modelName, modelPath, modelNum)
}
return nil
}

func (f Format) validateForSavedModel(modelPath string, files []os.FileInfo) error {
var pbFileFlag bool
var variablesDirFlag bool
var pbFileNum int32
var variablesDirNum int32
for _, file := range files {
if path.Ext(file.Name()) == ".pb" {
pbFileFlag = true
if file.Name() == "saved_model.pb" {
pbFileNum++
}
if file.IsDir() && file.Name() == "variables" {
variablesDirFlag = true
variablesDirNum++
}
}
if !pbFileFlag {
return fmt.Errorf("there are no *.pb file in %v directory", modelPath)
if e := ValidateError(modelPath, "saved_model.pb", pbFileNum); e != nil {
return e
}
if !variablesDirFlag {
return fmt.Errorf("there are no variables dir in %v directory", modelPath)
if e := ValidateError(modelPath, "variables", variablesDirNum); e != nil {
return e
}
return nil
}

func (f Format) validateForONNX(modelPath string, files []os.FileInfo) error {
var onnxFileFlag bool
var onnxFileNum int32
for _, file := range files {
if path.Ext(file.Name()) == ".onnx" {
onnxFileFlag = true
onnxFileNum++
}
}
if !onnxFileFlag {
return fmt.Errorf("there are no *.onnx file in %v directory", modelPath)
if e := ValidateError(modelPath, "*.onnx", onnxFileNum); e != nil {
return e
}
return nil
}

func (f Format) validateForH5(modelPath string, files []os.FileInfo) error {
var h5FileFlag bool
var h5FileNum int32
for _, file := range files {
if path.Ext(file.Name()) == ".h5" {
h5FileFlag = true
h5FileNum++
}
}
if !h5FileFlag {
return fmt.Errorf("there are no *.h5 file in %v directory", modelPath)
if e := ValidateError(modelPath, "*.h5", h5FileNum); e != nil {
return e
}
return nil
}
Expand All @@ -135,141 +143,140 @@ func (f Format) validateForPMML(modelPath string, files []os.FileInfo) error {
}

func (f Format) validateForCaffeModel(modelPath string, files []os.FileInfo) error {
var caffeModelFileFlag bool
var prototxtFileFlag bool
var caffeModelFileNum int32
var prototxtFileNum int32
for _, file := range files {
if path.Ext(file.Name()) == ".caffemodel" {
caffeModelFileFlag = true
caffeModelFileNum++
}
if path.Ext(file.Name()) == ".prototxt" {
prototxtFileFlag = true
prototxtFileNum++
}
}
if !caffeModelFileFlag {
return fmt.Errorf("there are no *.caffemodel file in %v directory", modelPath)
if e := ValidateError(modelPath, "*.caffemodel", caffeModelFileNum); e != nil {
return e
}
if !prototxtFileFlag {
return fmt.Errorf("there are no *.prototxt file in %v directory", modelPath)
if e := ValidateError(modelPath, "*.prototxt", prototxtFileNum); e != nil {
return e
}
return nil
}

func (f Format) validateForNetDef(modelPath string, files []os.FileInfo) error {
var initFileFlag bool
var predictFileFlag bool
var initFileNum int32
var predictFileNum int32
for _, file := range files {
if file.Name() == "init_net.pb" {
initFileFlag = true
initFileNum++
}
if file.Name() == "predict_net.pb" {
predictFileFlag = true
predictFileNum++
}
}
if !initFileFlag {
return fmt.Errorf("there are no init_net.pb file in %v directory", modelPath)
if e := ValidateError(modelPath, "init_net.pb", initFileNum); e != nil {
return e
}
if !predictFileFlag {
return fmt.Errorf("there are no predict_net.pb file in %v directory", modelPath)
if e := ValidateError(modelPath, "predict_net.pb", predictFileNum); e != nil {
return e
}
return nil
}

func (f Format) validateForMXNETParams(modelPath string, files []os.FileInfo) error {
var jsonFileFlag bool
var paramsFileFlag bool
var jsonFileNum int32
var paramsFileNum int32
for _, file := range files {
if path.Ext(file.Name()) == ".json" {
jsonFileFlag = true
if strings.HasSuffix(file.Name(), "symbol.json") {
jsonFileNum++
}
if path.Ext(file.Name()) == ".params" {
paramsFileFlag = true
paramsFileNum++
}
}
if !jsonFileFlag {
return fmt.Errorf("there are no *.json file in %v directory", modelPath)
if e := ValidateError(modelPath, "*symbol.json", jsonFileNum); e != nil {
return e
}
if !paramsFileFlag {
return fmt.Errorf("there are no *.params file in %v directory", modelPath)
if e := ValidateError(modelPath, "*.params", paramsFileNum); e != nil {
return e
}
return nil
}

func (f Format) validateForTorchScript(modelPath string, files []os.FileInfo) error {
var ptFileFlag bool
var ptFileNum int32
for _, file := range files {
if path.Ext(file.Name()) == ".pt" {
ptFileFlag = true
ptFileNum++
}
}
if !ptFileFlag {
return fmt.Errorf("there are no *.pt file in %v directory", modelPath)
if e := ValidateError(modelPath, "*.pt", ptFileNum); e != nil {
return e
}
return nil
}

func (f Format) validateForGraphDef(modelPath string, files []os.FileInfo) error {
var pbFileFlag bool
var graphdefFileNum int32
for _, file := range files {
if path.Ext(file.Name()) == ".pb" {
pbFileFlag = true
break
if path.Ext(file.Name()) == ".graphdef" {
graphdefFileNum++
}
}
if !pbFileFlag {
return fmt.Errorf("there are no *.pb file in %v directory", modelPath)
if e := ValidateError(modelPath, "*.graphdef", graphdefFileNum); e != nil {
return e
}
return nil
}

func (f Format) validateForTensorRT(modelPath string, files []os.FileInfo) error {
var tensorrtFileFlag bool
var tensorrtFileNum int32
for _, file := range files {
if path.Ext(file.Name()) == ".plan" {
tensorrtFileFlag = true
if path.Ext(file.Name()) == ".plan" || path.Ext(file.Name()) == ".engine" {
tensorrtFileNum++
}
}
if !tensorrtFileFlag {
return fmt.Errorf("there are no *.plan file in %v directory", modelPath)
if e := ValidateError(modelPath, "*.plan or *.engine", tensorrtFileNum); e != nil {
return e
}
return nil
}

func (f Format) validateForSKLearn(modelPath string, files []os.FileInfo) error {
var sklearnFileFlag bool
var sklearnFileNum int32
for _, file := range files {
if path.Ext(file.Name()) == ".joblib" {
sklearnFileFlag = true
sklearnFileNum++
}
}
if !sklearnFileFlag {
return fmt.Errorf("there are no *.joblib file in %v directory", modelPath)
if e := ValidateError(modelPath, "*.joblib", sklearnFileNum); e != nil {
return e
}
return nil
}

func (f Format) validateForXGBoost(modelPath string, files []os.FileInfo) error {
var xgboostFileFlag bool
var xgboostFileNum int32
for _, file := range files {
if path.Ext(file.Name()) == ".xgboost" {
xgboostFileFlag = true
xgboostFileNum++
}
}
if !xgboostFileFlag {
return fmt.Errorf("there are no *.xgboost file in %v directory", modelPath)
if e := ValidateError(modelPath, "*.xgboost", xgboostFileNum); e != nil {
return e
}
return nil
}

func (f Format) validateForMLflow(modelPath string, files []os.FileInfo) error {
var isMLflowFile bool
var MLflowFileNum int32
for _, file := range files {
if file.Name() == "MLmodel" {
// assuming that user would not fool the tool
isMLflowFile = true
MLflowFileNum++
}
}
if !isMLflowFile {
return fmt.Errorf("there are no MLmodel file in %v, directory", modelPath)
if e := ValidateError(modelPath, "MLmodel", MLflowFileNum); e != nil {
return e
}
return nil
}