Use AlexNet for wasi-nn example (#2474)
This commit is contained in:
@@ -7,7 +7,7 @@
|
|||||||
# executed with the Wasmtime CLI.
|
# executed with the Wasmtime CLI.
|
||||||
set -e
|
set -e
|
||||||
WASMTIME_DIR=$(dirname "$0" | xargs dirname)
|
WASMTIME_DIR=$(dirname "$0" | xargs dirname)
|
||||||
FIXTURE=https://gist.github.com/abrown/c7847bf3701f9efbb2070da1878542c1/raw/07a9f163994b0ff8f0d7c5a5c9645ec3d8b24024
|
FIXTURE=https://github.com/intel/openvino-rs/raw/main/crates/openvino/tests/fixtures/alexnet
|
||||||
if [ -z "${1+x}" ]; then
|
if [ -z "${1+x}" ]; then
|
||||||
# If no temporary directory is specified, create one.
|
# If no temporary directory is specified, create one.
|
||||||
TMP_DIR=$(mktemp -d -t ci-XXXXXXXXXX)
|
TMP_DIR=$(mktemp -d -t ci-XXXXXXXXXX)
|
||||||
@@ -26,9 +26,9 @@ source /opt/intel/openvino/bin/setupvars.sh
|
|||||||
OPENVINO_INSTALL_DIR=/opt/intel/openvino cargo build -p wasmtime-cli --features wasi-nn
|
OPENVINO_INSTALL_DIR=/opt/intel/openvino cargo build -p wasmtime-cli --features wasi-nn
|
||||||
|
|
||||||
# Download all necessary test fixtures to the temporary directory.
|
# Download all necessary test fixtures to the temporary directory.
|
||||||
wget --no-clobber --directory-prefix=$TMP_DIR $FIXTURE/frozen_inference_graph.bin
|
wget --no-clobber --directory-prefix=$TMP_DIR $FIXTURE/alexnet.bin
|
||||||
wget --no-clobber --directory-prefix=$TMP_DIR $FIXTURE/frozen_inference_graph.xml
|
wget --no-clobber --directory-prefix=$TMP_DIR $FIXTURE/alexnet.xml
|
||||||
wget --no-clobber --directory-prefix=$TMP_DIR $FIXTURE/tensor-1x3x300x300-f32.bgr
|
wget --no-clobber --directory-prefix=$TMP_DIR $FIXTURE/tensor-1x3x227x227-f32.bgr
|
||||||
|
|
||||||
# Now build an example that uses the wasi-nn API.
|
# Now build an example that uses the wasi-nn API.
|
||||||
pushd $WASMTIME_DIR/crates/wasi-nn/examples/classification-example
|
pushd $WASMTIME_DIR/crates/wasi-nn/examples/classification-example
|
||||||
|
|||||||
@@ -3,11 +3,11 @@ use std::fs;
|
|||||||
use wasi_nn;
|
use wasi_nn;
|
||||||
|
|
||||||
pub fn main() {
|
pub fn main() {
|
||||||
let xml = fs::read_to_string("fixture/frozen_inference_graph.xml").unwrap();
|
let xml = fs::read_to_string("fixture/alexnet.xml").unwrap();
|
||||||
println!("First 50 characters of graph: {}", &xml[..50]);
|
println!("Read graph XML, first 50 characters: {}", &xml[..50]);
|
||||||
|
|
||||||
let weights = fs::read("fixture/frozen_inference_graph.bin").unwrap();
|
let weights = fs::read("fixture/alexnet.bin").unwrap();
|
||||||
println!("Size of weights: {}", weights.len());
|
println!("Read graph weights, size in bytes: {}", weights.len());
|
||||||
|
|
||||||
let graph = unsafe {
|
let graph = unsafe {
|
||||||
wasi_nn::load(
|
wasi_nn::load(
|
||||||
@@ -17,17 +17,17 @@ pub fn main() {
|
|||||||
)
|
)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
};
|
};
|
||||||
println!("Graph handle ID: {}", graph);
|
println!("Loaded graph into wasi-nn with ID: {}", graph);
|
||||||
|
|
||||||
let context = unsafe { wasi_nn::init_execution_context(graph).unwrap() };
|
let context = unsafe { wasi_nn::init_execution_context(graph).unwrap() };
|
||||||
println!("Execution context ID: {}", context);
|
println!("Created wasi-nn execution context with ID: {}", context);
|
||||||
|
|
||||||
// Load a tensor that precisely matches the graph input tensor (see
|
// Load a tensor that precisely matches the graph input tensor (see
|
||||||
// `fixture/frozen_inference_graph.xml`).
|
// `fixture/frozen_inference_graph.xml`).
|
||||||
let tensor_data = fs::read("fixture/tensor-1x3x300x300-f32.bgr").unwrap();
|
let tensor_data = fs::read("fixture/tensor-1x3x227x227-f32.bgr").unwrap();
|
||||||
println!("Tensor bytes: {}", tensor_data.len());
|
println!("Read input tensor, size in bytes: {}", tensor_data.len());
|
||||||
let tensor = wasi_nn::Tensor {
|
let tensor = wasi_nn::Tensor {
|
||||||
dimensions: &[1, 3, 300, 300],
|
dimensions: &[1, 3, 227, 227],
|
||||||
r#type: wasi_nn::TENSOR_TYPE_F32,
|
r#type: wasi_nn::TENSOR_TYPE_F32,
|
||||||
data: &tensor_data,
|
data: &tensor_data,
|
||||||
};
|
};
|
||||||
@@ -39,9 +39,10 @@ pub fn main() {
|
|||||||
unsafe {
|
unsafe {
|
||||||
wasi_nn::compute(context).unwrap();
|
wasi_nn::compute(context).unwrap();
|
||||||
}
|
}
|
||||||
|
println!("Executed graph inference");
|
||||||
|
|
||||||
// Retrieve the output (TODO output looks incorrect).
|
// Retrieve the output.
|
||||||
let mut output_buffer = vec![0f32; 1 << 20];
|
let mut output_buffer = vec![0f32; 1000];
|
||||||
unsafe {
|
unsafe {
|
||||||
wasi_nn::get_output(
|
wasi_nn::get_output(
|
||||||
context,
|
context,
|
||||||
@@ -50,5 +51,25 @@ pub fn main() {
|
|||||||
(output_buffer.len() * 4).try_into().unwrap(),
|
(output_buffer.len() * 4).try_into().unwrap(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
println!("output tensor: {:?}", &output_buffer[..1000])
|
println!(
|
||||||
|
"Found results, sorted top 5: {:?}",
|
||||||
|
&sort_results(&output_buffer)[..5]
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Sort the buffer of probabilities. The graph places the match probability for each class at the
|
||||||
|
// index for that class (e.g. the probability of class 42 is placed at buffer[42]). Here we convert
|
||||||
|
// to a wrapping InferenceResult and sort the results.
|
||||||
|
fn sort_results(buffer: &[f32]) -> Vec<InferenceResult> {
|
||||||
|
let mut results: Vec<InferenceResult> = buffer
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(c, p)| InferenceResult(c, *p))
|
||||||
|
.collect();
|
||||||
|
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
|
||||||
|
results
|
||||||
|
}
|
||||||
|
|
||||||
|
// A wrapper for class ID and match probabilities.
|
||||||
|
#[derive(Debug, PartialEq)]
|
||||||
|
struct InferenceResult(usize, f32);
|
||||||
|
|||||||
Reference in New Issue
Block a user