diff --git a/ci/run-wasi-nn-example.sh b/ci/run-wasi-nn-example.sh index 5a77c06fa5..e24ffa75ac 100755 --- a/ci/run-wasi-nn-example.sh +++ b/ci/run-wasi-nn-example.sh @@ -7,7 +7,7 @@ # executed with the Wasmtime CLI. set -e WASMTIME_DIR=$(dirname "$0" | xargs dirname) -FIXTURE=https://gist.github.com/abrown/c7847bf3701f9efbb2070da1878542c1/raw/07a9f163994b0ff8f0d7c5a5c9645ec3d8b24024 +FIXTURE=https://github.com/intel/openvino-rs/raw/main/crates/openvino/tests/fixtures/alexnet if [ -z "${1+x}" ]; then # If no temporary directory is specified, create one. TMP_DIR=$(mktemp -d -t ci-XXXXXXXXXX) @@ -26,9 +26,9 @@ source /opt/intel/openvino/bin/setupvars.sh OPENVINO_INSTALL_DIR=/opt/intel/openvino cargo build -p wasmtime-cli --features wasi-nn # Download all necessary test fixtures to the temporary directory. -wget --no-clobber --directory-prefix=$TMP_DIR $FIXTURE/frozen_inference_graph.bin -wget --no-clobber --directory-prefix=$TMP_DIR $FIXTURE/frozen_inference_graph.xml -wget --no-clobber --directory-prefix=$TMP_DIR $FIXTURE/tensor-1x3x300x300-f32.bgr +wget --no-clobber --directory-prefix=$TMP_DIR $FIXTURE/alexnet.bin +wget --no-clobber --directory-prefix=$TMP_DIR $FIXTURE/alexnet.xml +wget --no-clobber --directory-prefix=$TMP_DIR $FIXTURE/tensor-1x3x227x227-f32.bgr # Now build an example that uses the wasi-nn API. pushd $WASMTIME_DIR/crates/wasi-nn/examples/classification-example diff --git a/crates/wasi-nn/examples/classification-example/src/main.rs b/crates/wasi-nn/examples/classification-example/src/main.rs index 898a4bfff3..3465de5cae 100644 --- a/crates/wasi-nn/examples/classification-example/src/main.rs +++ b/crates/wasi-nn/examples/classification-example/src/main.rs @@ -3,11 +3,11 @@ use std::fs; use wasi_nn; pub fn main() { - let xml = fs::read_to_string("fixture/frozen_inference_graph.xml").unwrap(); - println!("First 50 characters of graph: {}", &xml[..50]); + let xml = fs::read_to_string("fixture/alexnet.xml").unwrap(); + println!("Read graph XML, first 50 characters: {}", &xml[..50]); - let weights = fs::read("fixture/frozen_inference_graph.bin").unwrap(); - println!("Size of weights: {}", weights.len()); + let weights = fs::read("fixture/alexnet.bin").unwrap(); + println!("Read graph weights, size in bytes: {}", weights.len()); let graph = unsafe { wasi_nn::load( @@ -17,17 +17,17 @@ pub fn main() { ) .unwrap() }; - println!("Graph handle ID: {}", graph); + println!("Loaded graph into wasi-nn with ID: {}", graph); let context = unsafe { wasi_nn::init_execution_context(graph).unwrap() }; - println!("Execution context ID: {}", context); + println!("Created wasi-nn execution context with ID: {}", context); // Load a tensor that precisely matches the graph input tensor (see // `fixture/frozen_inference_graph.xml`). - let tensor_data = fs::read("fixture/tensor-1x3x300x300-f32.bgr").unwrap(); - println!("Tensor bytes: {}", tensor_data.len()); + let tensor_data = fs::read("fixture/tensor-1x3x227x227-f32.bgr").unwrap(); + println!("Read input tensor, size in bytes: {}", tensor_data.len()); let tensor = wasi_nn::Tensor { - dimensions: &[1, 3, 300, 300], + dimensions: &[1, 3, 227, 227], r#type: wasi_nn::TENSOR_TYPE_F32, data: &tensor_data, }; @@ -39,9 +39,10 @@ pub fn main() { unsafe { wasi_nn::compute(context).unwrap(); } + println!("Executed graph inference"); - // Retrieve the output (TODO output looks incorrect). - let mut output_buffer = vec![0f32; 1 << 20]; + // Retrieve the output. + let mut output_buffer = vec![0f32; 1000]; unsafe { wasi_nn::get_output( context, @@ -50,5 +51,25 @@ pub fn main() { (output_buffer.len() * 4).try_into().unwrap(), ); } - println!("output tensor: {:?}", &output_buffer[..1000]) + println!( + "Found results, sorted top 5: {:?}", + &sort_results(&output_buffer)[..5] + ) } + +// Sort the buffer of probabilities. The graph places the match probability for each class at the +// index for that class (e.g. the probability of class 42 is placed at buffer[42]). Here we convert +// to a wrapping InferenceResult and sort the results. +fn sort_results(buffer: &[f32]) -> Vec { + let mut results: Vec = buffer + .iter() + .enumerate() + .map(|(c, p)| InferenceResult(c, *p)) + .collect(); + results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); + results +} + +// A wrapper for class ID and match probabilities. +#[derive(Debug, PartialEq)] +struct InferenceResult(usize, f32);