Add an initial wasi-nn implementation for Wasmtime (#2208)

* Add an initial wasi-nn implementation for Wasmtime

This change adds a crate, `wasmtime-wasi-nn`, that uses `wiggle` to expose the current state of the wasi-nn API and `openvino` to implement the exposed functions. It includes an end-to-end test demonstrating how to do classification using wasi-nn:
 - `crates/wasi-nn/tests/classification-example` contains Rust code that is compiled to the `wasm32-wasi` target and run with a Wasmtime embedding that exposes the wasi-nn calls
 - the example uses Rust bindings for wasi-nn contained in `crates/wasi-nn/tests/wasi-nn-rust-bindings`; this crate contains code generated by `witx-bindgen` and eventually should be its own standalone crate

* Test wasi-nn as a CI step

This change adds:
 - a GitHub action for installing OpenVINO
 - a script, `ci/run-wasi-nn-example.sh`, to run the classification example
This commit is contained in:
Andrew Brown
2020-11-16 10:54:00 -08:00
committed by GitHub
parent 61a0bcbdc6
commit a61f068c64
33 changed files with 1554 additions and 1 deletions

View File

@@ -0,0 +1,76 @@
use super::NnErrno;
use core::fmt;
use core::num::NonZeroU16;
/// A raw error returned by wasi-nn APIs, internally containing a 16-bit error
/// code.
#[derive(Copy, Clone, PartialEq, Eq, Ord, PartialOrd)]
pub struct Error {
code: NonZeroU16,
}
impl Error {
/// Constructs a new error from a raw error code, returning `None` if the
/// error code is zero (which means success).
pub fn from_raw_error(error: NnErrno) -> Option<Error> {
Some(Error {
code: NonZeroU16::new(error)?,
})
}
/// Returns the raw error code that this error represents.
pub fn raw_error(&self) -> u16 {
self.code.get()
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{} (error {})", strerror(self.code.get()), self.code)?;
Ok(())
}
}
impl fmt::Debug for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Error")
.field("code", &self.code)
.field("message", &strerror(self.code.get()))
.finish()
}
}
/// This should be generated automatically by witx-bindgen but is not yet for enums other than
/// `Errno` (this API uses `NnErrno` to avoid naming conflicts). TODO: https://github.com/bytecodealliance/wasi/issues/52.
fn strerror(code: u16) -> &'static str {
match code {
super::NN_ERRNO_SUCCESS => "No error occurred.",
super::NN_ERRNO_INVALID_ARGUMENT => "Caller module passed an invalid argument.",
super::NN_ERRNO_MISSING_MEMORY => "Caller module is missing a memory export.",
super::NN_ERRNO_BUSY => "Device or resource busy.",
_ => "Unknown error.",
}
}
#[cfg(feature = "std")]
extern crate std;
#[cfg(feature = "std")]
impl std::error::Error for Error {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn error_from_success_code() {
assert_eq!(None, Error::from_raw_error(0));
}
#[test]
fn error_from_invalid_argument_code() {
assert_eq!(
"Caller module passed an invalid argument. (error 1)",
Error::from_raw_error(1).unwrap().to_string()
);
}
}

View File

@@ -0,0 +1,199 @@
// This file is automatically generated, DO NOT EDIT
//
// To regenerate this file run the `crates/witx-bindgen` command
use core::mem::MaybeUninit;
pub use crate::error::Error;
pub type Result<T, E = Error> = core::result::Result<T, E>;
pub type BufferSize = u32;
pub type NnErrno = u16;
/// No error occurred.
pub const NN_ERRNO_SUCCESS: NnErrno = 0;
/// Caller module passed an invalid argument.
pub const NN_ERRNO_INVALID_ARGUMENT: NnErrno = 1;
/// Caller module is missing a memory export.
pub const NN_ERRNO_MISSING_MEMORY: NnErrno = 2;
/// Device or resource busy.
pub const NN_ERRNO_BUSY: NnErrno = 3;
pub type TensorDimensions<'a> = &'a [u32];
pub type TensorType = u8;
pub const TENSOR_TYPE_F16: TensorType = 0;
pub const TENSOR_TYPE_F32: TensorType = 1;
pub const TENSOR_TYPE_U8: TensorType = 2;
pub const TENSOR_TYPE_I32: TensorType = 3;
pub type TensorData<'a> = &'a [u8];
#[repr(C)]
#[derive(Copy, Clone, Debug)]
pub struct Tensor<'a> {
/// Describe the size of the tensor (e.g. 2x2x2x2 -> [2, 2, 2, 2]). To represent a tensor containing a single value,
/// use `[1]` for the tensor dimensions.
pub dimensions: TensorDimensions<'a>,
pub r#type: TensorType,
/// Contains the tensor data.
pub data: TensorData<'a>,
}
pub type GraphBuilder<'a> = &'a [u8];
pub type GraphBuilderArray<'a> = &'a [GraphBuilder<'a>];
pub type Graph = u32;
pub type GraphEncoding = u8;
/// TODO document buffer order
pub const GRAPH_ENCODING_OPENVINO: GraphEncoding = 0;
pub type ExecutionTarget = u8;
pub const EXECUTION_TARGET_CPU: ExecutionTarget = 0;
pub const EXECUTION_TARGET_GPU: ExecutionTarget = 1;
pub const EXECUTION_TARGET_TPU: ExecutionTarget = 2;
pub type GraphExecutionContext = u32;
/// Load an opaque sequence of bytes to use for inference.
///
/// This allows runtime implementations to support multiple graph encoding formats. For unsupported graph encodings,
/// return `errno::inval`.
///
/// ## Parameters
///
/// * `builder` - The bytes necessary to build the graph.
/// * `encoding` - The encoding of the graph.
/// * `target` - Where to execute the graph.
pub unsafe fn load(
builder: GraphBuilderArray,
encoding: GraphEncoding,
target: ExecutionTarget,
) -> Result<Graph> {
let mut graph = MaybeUninit::uninit();
let rc = wasi_ephemeral_nn::load(
builder.as_ptr(),
builder.len(),
encoding,
target,
graph.as_mut_ptr(),
);
if let Some(err) = Error::from_raw_error(rc) {
Err(err)
} else {
Ok(graph.assume_init())
}
}
/// TODO Functions like `describe_graph_inputs` and `describe_graph_outputs` (returning
/// an array of `$tensor_description`s) might be useful for introspecting the graph but are not yet included here.
/// Create an execution instance of a loaded graph.
/// TODO this may need to accept flags that might affect the compilation or execution of the graph.
pub unsafe fn init_execution_context(graph: Graph) -> Result<GraphExecutionContext> {
let mut context = MaybeUninit::uninit();
let rc = wasi_ephemeral_nn::init_execution_context(graph, context.as_mut_ptr());
if let Some(err) = Error::from_raw_error(rc) {
Err(err)
} else {
Ok(context.assume_init())
}
}
/// Define the inputs to use for inference.
///
/// This should return an $nn_errno (TODO define) if the input tensor does not match the expected dimensions and type.
///
/// ## Parameters
///
/// * `index` - The index of the input to change.
/// * `tensor` - The tensor to set as the input.
pub unsafe fn set_input(context: GraphExecutionContext, index: u32, tensor: Tensor) -> Result<()> {
let rc = wasi_ephemeral_nn::set_input(context, index, &tensor as *const _ as *mut _);
if let Some(err) = Error::from_raw_error(rc) {
Err(err)
} else {
Ok(())
}
}
/// Extract the outputs after inference.
///
/// This should return an $nn_errno (TODO define) if the inference has not yet run.
///
/// ## Parameters
///
/// * `index` - The index of the output to retrieve.
/// * `out_buffer` - An out parameter to which to copy the tensor data. The caller is responsible for allocating enough memory for
/// the tensor data or an error will be returned. Currently there is no dynamic way to extract the additional
/// tensor metadata (i.e. dimension, element type) but this should be added at some point.
///
/// ## Return
///
/// * `bytes_written` - The number of bytes of tensor data written to the `$out_buffer`.
pub unsafe fn get_output(
context: GraphExecutionContext,
index: u32,
out_buffer: *mut u8,
out_buffer_max_size: BufferSize,
) -> Result<BufferSize> {
let mut bytes_written = MaybeUninit::uninit();
let rc = wasi_ephemeral_nn::get_output(
context,
index,
out_buffer,
out_buffer_max_size,
bytes_written.as_mut_ptr(),
);
if let Some(err) = Error::from_raw_error(rc) {
Err(err)
} else {
Ok(bytes_written.assume_init())
}
}
/// Compute the inference on the given inputs (see `set_input`).
///
/// This should return an $nn_errno (TODO define) if the inputs are not all defined.
pub unsafe fn compute(context: GraphExecutionContext) -> Result<()> {
let rc = wasi_ephemeral_nn::compute(context);
if let Some(err) = Error::from_raw_error(rc) {
Err(err)
} else {
Ok(())
}
}
pub mod wasi_ephemeral_nn {
use super::*;
#[link(wasm_import_module = "wasi_ephemeral_nn")]
extern "C" {
/// Load an opaque sequence of bytes to use for inference.
///
/// This allows runtime implementations to support multiple graph encoding formats. For unsupported graph encodings,
/// return `errno::inval`.
pub fn load(
builder_ptr: *const GraphBuilder,
builder_len: usize,
encoding: GraphEncoding,
target: ExecutionTarget,
graph: *mut Graph,
) -> NnErrno;
/// TODO Functions like `describe_graph_inputs` and `describe_graph_outputs` (returning
/// an array of `$tensor_description`s) might be useful for introspecting the graph but are not yet included here.
/// Create an execution instance of a loaded graph.
/// TODO this may need to accept flags that might affect the compilation or execution of the graph.
pub fn init_execution_context(graph: Graph, context: *mut GraphExecutionContext)
-> NnErrno;
/// Define the inputs to use for inference.
///
/// This should return an $nn_errno (TODO define) if the input tensor does not match the expected dimensions and type.
pub fn set_input(
context: GraphExecutionContext,
index: u32,
tensor: *mut Tensor,
) -> NnErrno;
/// Extract the outputs after inference.
///
/// This should return an $nn_errno (TODO define) if the inference has not yet run.
pub fn get_output(
context: GraphExecutionContext,
index: u32,
out_buffer: *mut u8,
out_buffer_max_size: BufferSize,
bytes_written: *mut BufferSize,
) -> NnErrno;
/// Compute the inference on the given inputs (see `set_input`).
///
/// This should return an $nn_errno (TODO define) if the inputs are not all defined.
pub fn compute(context: GraphExecutionContext) -> NnErrno;
}
}

View File

@@ -0,0 +1,3 @@
mod error;
mod generated;
pub use generated::*;