The MobileNet model is significantly smaller in size (14MB) than the AlexNet model (233MB); this change should reduce bandwidth used during CI.
80 lines
2.6 KiB
Rust
80 lines
2.6 KiB
Rust
use std::convert::TryInto;
|
|
use std::fs;
|
|
use wasi_nn;
|
|
|
|
pub fn main() {
|
|
let xml = fs::read_to_string("fixture/model.xml").unwrap();
|
|
println!("Read graph XML, first 50 characters: {}", &xml[..50]);
|
|
|
|
let weights = fs::read("fixture/model.bin").unwrap();
|
|
println!("Read graph weights, size in bytes: {}", weights.len());
|
|
|
|
let graph = unsafe {
|
|
wasi_nn::load(
|
|
&[&xml.into_bytes(), &weights],
|
|
wasi_nn::GRAPH_ENCODING_OPENVINO,
|
|
wasi_nn::EXECUTION_TARGET_CPU,
|
|
)
|
|
.unwrap()
|
|
};
|
|
println!("Loaded graph into wasi-nn with ID: {}", graph);
|
|
|
|
let context = unsafe { wasi_nn::init_execution_context(graph).unwrap() };
|
|
println!("Created wasi-nn execution context with ID: {}", context);
|
|
|
|
// Load a tensor that precisely matches the graph input tensor (see
|
|
// `fixture/frozen_inference_graph.xml`).
|
|
let tensor_data = fs::read("fixture/tensor.bgr").unwrap();
|
|
println!("Read input tensor, size in bytes: {}", tensor_data.len());
|
|
let tensor = wasi_nn::Tensor {
|
|
dimensions: &[1, 3, 224, 224],
|
|
r#type: wasi_nn::TENSOR_TYPE_F32,
|
|
data: &tensor_data,
|
|
};
|
|
unsafe {
|
|
wasi_nn::set_input(context, 0, tensor).unwrap();
|
|
}
|
|
|
|
// Execute the inference.
|
|
unsafe {
|
|
wasi_nn::compute(context).unwrap();
|
|
}
|
|
println!("Executed graph inference");
|
|
|
|
// Retrieve the output.
|
|
let mut output_buffer = vec![0f32; 1001];
|
|
unsafe {
|
|
wasi_nn::get_output(
|
|
context,
|
|
0,
|
|
&mut output_buffer[..] as *mut [f32] as *mut u8,
|
|
(output_buffer.len() * 4).try_into().unwrap(),
|
|
)
|
|
.unwrap();
|
|
}
|
|
println!(
|
|
"Found results, sorted top 5: {:?}",
|
|
&sort_results(&output_buffer)[..5]
|
|
)
|
|
}
|
|
|
|
// Sort the buffer of probabilities. The graph places the match probability for each class at the
|
|
// index for that class (e.g. the probability of class 42 is placed at buffer[42]). Here we convert
|
|
// to a wrapping InferenceResult and sort the results. It is unclear why the MobileNet output
|
|
// indices are "off by one" but the `.skip(1)` below seems necessary to get results that make sense
|
|
// (e.g. 763 = "revolver" vs 762 = "restaurant")
|
|
fn sort_results(buffer: &[f32]) -> Vec<InferenceResult> {
|
|
let mut results: Vec<InferenceResult> = buffer
|
|
.iter()
|
|
.skip(1)
|
|
.enumerate()
|
|
.map(|(c, p)| InferenceResult(c, *p))
|
|
.collect();
|
|
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
|
|
results
|
|
}
|
|
|
|
// A wrapper for class ID and match probabilities.
|
|
#[derive(Debug, PartialEq)]
|
|
struct InferenceResult(usize, f32);
|