-
Notifications
You must be signed in to change notification settings - Fork 2
/
model_tflite.go
76 lines (57 loc) · 1.49 KB
/
model_tflite.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
package precise
import (
"errors"
"github.com/mattn/go-tflite"
"gorgonia.org/tensor"
"sync"
)
// NewTFLiteModel creates a new tensorflow lite model
func NewTFLiteModel(modelPath string) (Model, error) {
model := tflite.NewModelFromFile(modelPath)
if model == nil {
return nil, errors.New("cannot load model")
}
options := tflite.NewInterpreterOptions()
interpreter := tflite.NewInterpreter(model, options)
interpreter.AllocateTensors()
return &TFLiteModel{
model: model,
interpreter: interpreter,
options: options,
lock: new(sync.Mutex),
}, nil
}
// TFLiteModel represents a tensorflow lite model
type TFLiteModel struct {
model *tflite.Model
options *tflite.InterpreterOptions
interpreter *tflite.Interpreter
lock *sync.Mutex
}
// Predict sends the input data into the input tensor, then invokes the model
func (m *TFLiteModel) Predict(inputData tensor.Tensor) (float32, error) {
if m.model == nil {
return -1, ErrModelClosed
}
input := m.interpreter.GetInputTensor(0)
copy(input.Float32s(), inputData.Data().([]float32))
m.interpreter.Invoke()
output := m.interpreter.GetOutputTensor(0)
if output.Type() != tflite.Float32 {
return -1, ErrUnexpectedType
}
return output.Float32s()[0], nil
}
// Close cleans up the model after use
func (m *TFLiteModel) Close() error {
m.lock.Lock()
defer m.lock.Unlock()
if m.model == nil {
return nil
}
m.model.Delete()
m.options.Delete()
m.interpreter.Delete()
m.model = nil
return nil
}