Add an initial wasi-nn implementation for Wasmtime (#2208)
* Add an initial wasi-nn implementation for Wasmtime This change adds a crate, `wasmtime-wasi-nn`, that uses `wiggle` to expose the current state of the wasi-nn API and `openvino` to implement the exposed functions. It includes an end-to-end test demonstrating how to do classification using wasi-nn: - `crates/wasi-nn/tests/classification-example` contains Rust code that is compiled to the `wasm32-wasi` target and run with a Wasmtime embedding that exposes the wasi-nn calls - the example uses Rust bindings for wasi-nn contained in `crates/wasi-nn/tests/wasi-nn-rust-bindings`; this crate contains code generated by `witx-bindgen` and eventually should be its own standalone crate * Test wasi-nn as a CI step This change adds: - a GitHub action for installing OpenVINO - a script, `ci/run-wasi-nn-example.sh`, to run the classification example
This commit is contained in:
76
crates/wasi-nn/examples/wasi-nn-rust-bindings/src/error.rs
Normal file
76
crates/wasi-nn/examples/wasi-nn-rust-bindings/src/error.rs
Normal file
@@ -0,0 +1,76 @@
|
||||
use super::NnErrno;
|
||||
use core::fmt;
|
||||
use core::num::NonZeroU16;
|
||||
|
||||
/// A raw error returned by wasi-nn APIs, internally containing a 16-bit error
|
||||
/// code.
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Ord, PartialOrd)]
|
||||
pub struct Error {
|
||||
code: NonZeroU16,
|
||||
}
|
||||
|
||||
impl Error {
|
||||
/// Constructs a new error from a raw error code, returning `None` if the
|
||||
/// error code is zero (which means success).
|
||||
pub fn from_raw_error(error: NnErrno) -> Option<Error> {
|
||||
Some(Error {
|
||||
code: NonZeroU16::new(error)?,
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the raw error code that this error represents.
|
||||
pub fn raw_error(&self) -> u16 {
|
||||
self.code.get()
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{} (error {})", strerror(self.code.get()), self.code)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for Error {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("Error")
|
||||
.field("code", &self.code)
|
||||
.field("message", &strerror(self.code.get()))
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
/// This should be generated automatically by witx-bindgen but is not yet for enums other than
|
||||
/// `Errno` (this API uses `NnErrno` to avoid naming conflicts). TODO: https://github.com/bytecodealliance/wasi/issues/52.
|
||||
fn strerror(code: u16) -> &'static str {
|
||||
match code {
|
||||
super::NN_ERRNO_SUCCESS => "No error occurred.",
|
||||
super::NN_ERRNO_INVALID_ARGUMENT => "Caller module passed an invalid argument.",
|
||||
super::NN_ERRNO_MISSING_MEMORY => "Caller module is missing a memory export.",
|
||||
super::NN_ERRNO_BUSY => "Device or resource busy.",
|
||||
_ => "Unknown error.",
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
extern crate std;
|
||||
#[cfg(feature = "std")]
|
||||
impl std::error::Error for Error {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn error_from_success_code() {
|
||||
assert_eq!(None, Error::from_raw_error(0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn error_from_invalid_argument_code() {
|
||||
assert_eq!(
|
||||
"Caller module passed an invalid argument. (error 1)",
|
||||
Error::from_raw_error(1).unwrap().to_string()
|
||||
);
|
||||
}
|
||||
}
|
||||
199
crates/wasi-nn/examples/wasi-nn-rust-bindings/src/generated.rs
Normal file
199
crates/wasi-nn/examples/wasi-nn-rust-bindings/src/generated.rs
Normal file
@@ -0,0 +1,199 @@
|
||||
// This file is automatically generated, DO NOT EDIT
|
||||
//
|
||||
// To regenerate this file run the `crates/witx-bindgen` command
|
||||
|
||||
use core::mem::MaybeUninit;
|
||||
|
||||
pub use crate::error::Error;
|
||||
pub type Result<T, E = Error> = core::result::Result<T, E>;
|
||||
pub type BufferSize = u32;
|
||||
pub type NnErrno = u16;
|
||||
/// No error occurred.
|
||||
pub const NN_ERRNO_SUCCESS: NnErrno = 0;
|
||||
/// Caller module passed an invalid argument.
|
||||
pub const NN_ERRNO_INVALID_ARGUMENT: NnErrno = 1;
|
||||
/// Caller module is missing a memory export.
|
||||
pub const NN_ERRNO_MISSING_MEMORY: NnErrno = 2;
|
||||
/// Device or resource busy.
|
||||
pub const NN_ERRNO_BUSY: NnErrno = 3;
|
||||
pub type TensorDimensions<'a> = &'a [u32];
|
||||
pub type TensorType = u8;
|
||||
pub const TENSOR_TYPE_F16: TensorType = 0;
|
||||
pub const TENSOR_TYPE_F32: TensorType = 1;
|
||||
pub const TENSOR_TYPE_U8: TensorType = 2;
|
||||
pub const TENSOR_TYPE_I32: TensorType = 3;
|
||||
pub type TensorData<'a> = &'a [u8];
|
||||
#[repr(C)]
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct Tensor<'a> {
|
||||
/// Describe the size of the tensor (e.g. 2x2x2x2 -> [2, 2, 2, 2]). To represent a tensor containing a single value,
|
||||
/// use `[1]` for the tensor dimensions.
|
||||
pub dimensions: TensorDimensions<'a>,
|
||||
pub r#type: TensorType,
|
||||
/// Contains the tensor data.
|
||||
pub data: TensorData<'a>,
|
||||
}
|
||||
pub type GraphBuilder<'a> = &'a [u8];
|
||||
pub type GraphBuilderArray<'a> = &'a [GraphBuilder<'a>];
|
||||
pub type Graph = u32;
|
||||
pub type GraphEncoding = u8;
|
||||
/// TODO document buffer order
|
||||
pub const GRAPH_ENCODING_OPENVINO: GraphEncoding = 0;
|
||||
pub type ExecutionTarget = u8;
|
||||
pub const EXECUTION_TARGET_CPU: ExecutionTarget = 0;
|
||||
pub const EXECUTION_TARGET_GPU: ExecutionTarget = 1;
|
||||
pub const EXECUTION_TARGET_TPU: ExecutionTarget = 2;
|
||||
pub type GraphExecutionContext = u32;
|
||||
/// Load an opaque sequence of bytes to use for inference.
|
||||
///
|
||||
/// This allows runtime implementations to support multiple graph encoding formats. For unsupported graph encodings,
|
||||
/// return `errno::inval`.
|
||||
///
|
||||
/// ## Parameters
|
||||
///
|
||||
/// * `builder` - The bytes necessary to build the graph.
|
||||
/// * `encoding` - The encoding of the graph.
|
||||
/// * `target` - Where to execute the graph.
|
||||
pub unsafe fn load(
|
||||
builder: GraphBuilderArray,
|
||||
encoding: GraphEncoding,
|
||||
target: ExecutionTarget,
|
||||
) -> Result<Graph> {
|
||||
let mut graph = MaybeUninit::uninit();
|
||||
let rc = wasi_ephemeral_nn::load(
|
||||
builder.as_ptr(),
|
||||
builder.len(),
|
||||
encoding,
|
||||
target,
|
||||
graph.as_mut_ptr(),
|
||||
);
|
||||
if let Some(err) = Error::from_raw_error(rc) {
|
||||
Err(err)
|
||||
} else {
|
||||
Ok(graph.assume_init())
|
||||
}
|
||||
}
|
||||
|
||||
/// TODO Functions like `describe_graph_inputs` and `describe_graph_outputs` (returning
|
||||
/// an array of `$tensor_description`s) might be useful for introspecting the graph but are not yet included here.
|
||||
/// Create an execution instance of a loaded graph.
|
||||
/// TODO this may need to accept flags that might affect the compilation or execution of the graph.
|
||||
pub unsafe fn init_execution_context(graph: Graph) -> Result<GraphExecutionContext> {
|
||||
let mut context = MaybeUninit::uninit();
|
||||
let rc = wasi_ephemeral_nn::init_execution_context(graph, context.as_mut_ptr());
|
||||
if let Some(err) = Error::from_raw_error(rc) {
|
||||
Err(err)
|
||||
} else {
|
||||
Ok(context.assume_init())
|
||||
}
|
||||
}
|
||||
|
||||
/// Define the inputs to use for inference.
|
||||
///
|
||||
/// This should return an $nn_errno (TODO define) if the input tensor does not match the expected dimensions and type.
|
||||
///
|
||||
/// ## Parameters
|
||||
///
|
||||
/// * `index` - The index of the input to change.
|
||||
/// * `tensor` - The tensor to set as the input.
|
||||
pub unsafe fn set_input(context: GraphExecutionContext, index: u32, tensor: Tensor) -> Result<()> {
|
||||
let rc = wasi_ephemeral_nn::set_input(context, index, &tensor as *const _ as *mut _);
|
||||
if let Some(err) = Error::from_raw_error(rc) {
|
||||
Err(err)
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract the outputs after inference.
|
||||
///
|
||||
/// This should return an $nn_errno (TODO define) if the inference has not yet run.
|
||||
///
|
||||
/// ## Parameters
|
||||
///
|
||||
/// * `index` - The index of the output to retrieve.
|
||||
/// * `out_buffer` - An out parameter to which to copy the tensor data. The caller is responsible for allocating enough memory for
|
||||
/// the tensor data or an error will be returned. Currently there is no dynamic way to extract the additional
|
||||
/// tensor metadata (i.e. dimension, element type) but this should be added at some point.
|
||||
///
|
||||
/// ## Return
|
||||
///
|
||||
/// * `bytes_written` - The number of bytes of tensor data written to the `$out_buffer`.
|
||||
pub unsafe fn get_output(
|
||||
context: GraphExecutionContext,
|
||||
index: u32,
|
||||
out_buffer: *mut u8,
|
||||
out_buffer_max_size: BufferSize,
|
||||
) -> Result<BufferSize> {
|
||||
let mut bytes_written = MaybeUninit::uninit();
|
||||
let rc = wasi_ephemeral_nn::get_output(
|
||||
context,
|
||||
index,
|
||||
out_buffer,
|
||||
out_buffer_max_size,
|
||||
bytes_written.as_mut_ptr(),
|
||||
);
|
||||
if let Some(err) = Error::from_raw_error(rc) {
|
||||
Err(err)
|
||||
} else {
|
||||
Ok(bytes_written.assume_init())
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the inference on the given inputs (see `set_input`).
|
||||
///
|
||||
/// This should return an $nn_errno (TODO define) if the inputs are not all defined.
|
||||
pub unsafe fn compute(context: GraphExecutionContext) -> Result<()> {
|
||||
let rc = wasi_ephemeral_nn::compute(context);
|
||||
if let Some(err) = Error::from_raw_error(rc) {
|
||||
Err(err)
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub mod wasi_ephemeral_nn {
|
||||
use super::*;
|
||||
#[link(wasm_import_module = "wasi_ephemeral_nn")]
|
||||
extern "C" {
|
||||
/// Load an opaque sequence of bytes to use for inference.
|
||||
///
|
||||
/// This allows runtime implementations to support multiple graph encoding formats. For unsupported graph encodings,
|
||||
/// return `errno::inval`.
|
||||
pub fn load(
|
||||
builder_ptr: *const GraphBuilder,
|
||||
builder_len: usize,
|
||||
encoding: GraphEncoding,
|
||||
target: ExecutionTarget,
|
||||
graph: *mut Graph,
|
||||
) -> NnErrno;
|
||||
/// TODO Functions like `describe_graph_inputs` and `describe_graph_outputs` (returning
|
||||
/// an array of `$tensor_description`s) might be useful for introspecting the graph but are not yet included here.
|
||||
/// Create an execution instance of a loaded graph.
|
||||
/// TODO this may need to accept flags that might affect the compilation or execution of the graph.
|
||||
pub fn init_execution_context(graph: Graph, context: *mut GraphExecutionContext)
|
||||
-> NnErrno;
|
||||
/// Define the inputs to use for inference.
|
||||
///
|
||||
/// This should return an $nn_errno (TODO define) if the input tensor does not match the expected dimensions and type.
|
||||
pub fn set_input(
|
||||
context: GraphExecutionContext,
|
||||
index: u32,
|
||||
tensor: *mut Tensor,
|
||||
) -> NnErrno;
|
||||
/// Extract the outputs after inference.
|
||||
///
|
||||
/// This should return an $nn_errno (TODO define) if the inference has not yet run.
|
||||
pub fn get_output(
|
||||
context: GraphExecutionContext,
|
||||
index: u32,
|
||||
out_buffer: *mut u8,
|
||||
out_buffer_max_size: BufferSize,
|
||||
bytes_written: *mut BufferSize,
|
||||
) -> NnErrno;
|
||||
/// Compute the inference on the given inputs (see `set_input`).
|
||||
///
|
||||
/// This should return an $nn_errno (TODO define) if the inputs are not all defined.
|
||||
pub fn compute(context: GraphExecutionContext) -> NnErrno;
|
||||
}
|
||||
}
|
||||
3
crates/wasi-nn/examples/wasi-nn-rust-bindings/src/lib.rs
Normal file
3
crates/wasi-nn/examples/wasi-nn-rust-bindings/src/lib.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
mod error;
|
||||
mod generated;
|
||||
pub use generated::*;
|
||||
Reference in New Issue
Block a user