-
Notifications
You must be signed in to change notification settings - Fork 100
/
sample.rs
72 lines (59 loc) · 2.44 KB
/
sample.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#![forbid(unsafe_code)]
use onnxruntime::{
environment::Environment, ndarray::Array, tensor::OrtOwnedTensor, GraphOptimizationLevel,
LoggingLevel,
};
use tracing::Level;
use tracing_subscriber::FmtSubscriber;
type Error = Box<dyn std::error::Error>;
fn main() {
if let Err(e) = run() {
eprintln!("Error: {}", e);
std::process::exit(1);
}
}
fn run() -> Result<(), Error> {
// Setup the example's log level.
// NOTE: ONNX Runtime's log level is controlled separately when building the environment.
let subscriber = FmtSubscriber::builder()
.with_max_level(Level::TRACE)
.finish();
tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");
let environment = Environment::builder()
.with_name("test")
// The ONNX Runtime's log level can be different than the one of the wrapper crate or the application.
.with_log_level(LoggingLevel::Info)
.build()?;
let mut session = environment
.new_session_builder()?
.with_optimization_level(GraphOptimizationLevel::Basic)?
.with_number_threads(1)?
// NOTE: The example uses SqueezeNet 1.0 (ONNX version: 1.3, Opset version: 8),
// _not_ SqueezeNet 1.1 as downloaded by '.with_model_downloaded(ImageClassification::SqueezeNet)'
// Obtain it with:
// curl -LO "https://github.com/onnx/models/raw/master/vision/classification/squeezenet/model/squeezenet1.0-8.onnx"
.with_model_from_file("squeezenet1.0-8.onnx")?;
let input0_shape: Vec<usize> = session.inputs[0].dimensions().map(|d| d.unwrap()).collect();
let output0_shape: Vec<usize> = session.outputs[0]
.dimensions()
.map(|d| d.unwrap())
.collect();
assert_eq!(input0_shape, [1, 3, 224, 224]);
assert_eq!(output0_shape, [1, 1000, 1, 1]);
// initialize input data with values in [0.0, 1.0]
let n: u32 = session.inputs[0]
.dimensions
.iter()
.map(|d| d.unwrap())
.product();
let array = Array::linspace(0.0_f32, 1.0, n as usize)
.into_shape(input0_shape)
.unwrap();
let input_tensor_values = vec![array];
let outputs: Vec<OrtOwnedTensor<f32, _>> = session.run(input_tensor_values)?;
assert_eq!(outputs[0].shape(), output0_shape.as_slice());
for i in 0..5 {
println!("Score for class [{}] = {}", i, outputs[0][[0, i, 0, 0]]);
}
Ok(())
}