wasi-nn: use the MobileNet model instead of AlexNet
The MobileNet model is significantly smaller in size (14MB) than the AlexNet model (233MB); this change should reduce bandwidth used during CI.
This commit is contained in:
@@ -3,10 +3,10 @@ use std::fs;
|
||||
use wasi_nn;
|
||||
|
||||
pub fn main() {
|
||||
let xml = fs::read_to_string("fixture/alexnet.xml").unwrap();
|
||||
let xml = fs::read_to_string("fixture/model.xml").unwrap();
|
||||
println!("Read graph XML, first 50 characters: {}", &xml[..50]);
|
||||
|
||||
let weights = fs::read("fixture/alexnet.bin").unwrap();
|
||||
let weights = fs::read("fixture/model.bin").unwrap();
|
||||
println!("Read graph weights, size in bytes: {}", weights.len());
|
||||
|
||||
let graph = unsafe {
|
||||
@@ -24,10 +24,10 @@ pub fn main() {
|
||||
|
||||
// Load a tensor that precisely matches the graph input tensor (see
|
||||
// `fixture/frozen_inference_graph.xml`).
|
||||
let tensor_data = fs::read("fixture/tensor-1x3x227x227-f32.bgr").unwrap();
|
||||
let tensor_data = fs::read("fixture/tensor.bgr").unwrap();
|
||||
println!("Read input tensor, size in bytes: {}", tensor_data.len());
|
||||
let tensor = wasi_nn::Tensor {
|
||||
dimensions: &[1, 3, 227, 227],
|
||||
dimensions: &[1, 3, 224, 224],
|
||||
r#type: wasi_nn::TENSOR_TYPE_F32,
|
||||
data: &tensor_data,
|
||||
};
|
||||
@@ -42,7 +42,7 @@ pub fn main() {
|
||||
println!("Executed graph inference");
|
||||
|
||||
// Retrieve the output.
|
||||
let mut output_buffer = vec![0f32; 1000];
|
||||
let mut output_buffer = vec![0f32; 1001];
|
||||
unsafe {
|
||||
wasi_nn::get_output(
|
||||
context,
|
||||
@@ -60,10 +60,13 @@ pub fn main() {
|
||||
|
||||
// Sort the buffer of probabilities. The graph places the match probability for each class at the
|
||||
// index for that class (e.g. the probability of class 42 is placed at buffer[42]). Here we convert
|
||||
// to a wrapping InferenceResult and sort the results.
|
||||
// to a wrapping InferenceResult and sort the results. It is unclear why the MobileNet output
|
||||
// indices are "off by one" but the `.skip(1)` below seems necessary to get results that make sense
|
||||
// (e.g. 763 = "revolver" vs 762 = "restaurant")
|
||||
fn sort_results(buffer: &[f32]) -> Vec<InferenceResult> {
|
||||
let mut results: Vec<InferenceResult> = buffer
|
||||
.iter()
|
||||
.skip(1)
|
||||
.enumerate()
|
||||
.map(|(c, p)| InferenceResult(c, *p))
|
||||
.collect();
|
||||
|
||||
Reference in New Issue
Block a user