Add an initial wasi-nn implementation for Wasmtime (#2208)
* Add an initial wasi-nn implementation for Wasmtime This change adds a crate, `wasmtime-wasi-nn`, that uses `wiggle` to expose the current state of the wasi-nn API and `openvino` to implement the exposed functions. It includes an end-to-end test demonstrating how to do classification using wasi-nn: - `crates/wasi-nn/tests/classification-example` contains Rust code that is compiled to the `wasm32-wasi` target and run with a Wasmtime embedding that exposes the wasi-nn calls - the example uses Rust bindings for wasi-nn contained in `crates/wasi-nn/tests/wasi-nn-rust-bindings`; this crate contains code generated by `witx-bindgen` and eventually should be its own standalone crate * Test wasi-nn as a CI step This change adds: - a GitHub action for installing OpenVINO - a script, `ci/run-wasi-nn-example.sh`, to run the classification example
This commit is contained in:
125
crates/wasi-nn/src/ctx.rs
Normal file
125
crates/wasi-nn/src/ctx.rs
Normal file
@@ -0,0 +1,125 @@
|
||||
//! Implements the base structure (i.e. [WasiNnCtx]) that will provide the implementation of the
|
||||
//! wasi-nn API.
|
||||
use crate::r#impl::UsageError;
|
||||
use crate::witx::types::{Graph, GraphExecutionContext};
|
||||
use openvino::InferenceError;
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::hash::Hash;
|
||||
use thiserror::Error;
|
||||
use wiggle::GuestError;
|
||||
|
||||
/// Possible errors for interacting with [WasiNnCtx].
|
||||
#[derive(Debug, Error)]
|
||||
pub enum WasiNnError {
|
||||
#[error("guest error")]
|
||||
GuestError(#[from] GuestError),
|
||||
#[error("openvino error")]
|
||||
OpenvinoError(#[from] InferenceError),
|
||||
#[error("usage error")]
|
||||
UsageError(#[from] UsageError),
|
||||
}
|
||||
|
||||
pub(crate) type WasiNnResult<T> = std::result::Result<T, WasiNnError>;
|
||||
|
||||
pub struct Table<K, V> {
|
||||
entries: HashMap<K, V>,
|
||||
next_key: u32,
|
||||
}
|
||||
|
||||
impl<K, V> Default for Table<K, V> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
entries: HashMap::new(),
|
||||
next_key: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<K, V> Table<K, V>
|
||||
where
|
||||
K: Eq + Hash + From<u32> + Copy,
|
||||
{
|
||||
pub fn insert(&mut self, value: V) -> K {
|
||||
let key = self.use_next_key();
|
||||
self.entries.insert(key, value);
|
||||
key
|
||||
}
|
||||
|
||||
pub fn remove(&mut self, key: K) -> Option<V> {
|
||||
self.entries.remove(&key)
|
||||
}
|
||||
|
||||
pub fn get(&self, key: K) -> Option<&V> {
|
||||
self.entries.get(&key)
|
||||
}
|
||||
|
||||
pub fn get_mut(&mut self, key: K) -> Option<&mut V> {
|
||||
self.entries.get_mut(&key)
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.entries.len()
|
||||
}
|
||||
|
||||
fn use_next_key(&mut self) -> K {
|
||||
let current = self.next_key;
|
||||
self.next_key += 1;
|
||||
K::from(current)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ExecutionContext {
|
||||
pub(crate) graph: Graph,
|
||||
pub(crate) request: openvino::InferRequest,
|
||||
}
|
||||
|
||||
impl ExecutionContext {
|
||||
pub(crate) fn new(graph: Graph, request: openvino::InferRequest) -> Self {
|
||||
Self { graph, request }
|
||||
}
|
||||
}
|
||||
|
||||
/// Capture the state necessary for calling into `openvino`.
|
||||
pub struct Ctx {
|
||||
pub(crate) core: openvino::Core,
|
||||
pub(crate) graphs: Table<Graph, (openvino::CNNNetwork, openvino::ExecutableNetwork)>,
|
||||
pub(crate) executions: Table<GraphExecutionContext, ExecutionContext>,
|
||||
}
|
||||
|
||||
impl Ctx {
|
||||
/// Make a new `WasiNnCtx` with the default settings.
|
||||
pub fn new() -> WasiNnResult<Self> {
|
||||
Ok(Self {
|
||||
core: openvino::Core::new(None)?,
|
||||
graphs: Table::default(),
|
||||
executions: Table::default(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// This structure provides the Rust-side context necessary for implementing the wasi-nn API. At the
|
||||
/// moment, it is specialized for a single inference implementation (i.e. OpenVINO) but conceivably
|
||||
/// this could support more than one backing implementation.
|
||||
pub struct WasiNnCtx {
|
||||
pub(crate) ctx: RefCell<Ctx>,
|
||||
}
|
||||
|
||||
impl WasiNnCtx {
|
||||
/// Make a new `WasiNnCtx` with the default settings.
|
||||
pub fn new() -> WasiNnResult<Self> {
|
||||
Ok(Self {
|
||||
ctx: RefCell::new(Ctx::new()?),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn instantiate() {
|
||||
WasiNnCtx::new().unwrap();
|
||||
}
|
||||
}
|
||||
176
crates/wasi-nn/src/impl.rs
Normal file
176
crates/wasi-nn/src/impl.rs
Normal file
@@ -0,0 +1,176 @@
|
||||
//! Implements the wasi-nn API.
|
||||
use crate::ctx::{ExecutionContext, WasiNnResult as Result};
|
||||
use crate::witx::types::{
|
||||
ExecutionTarget, Graph, GraphBuilderArray, GraphEncoding, GraphExecutionContext, Tensor,
|
||||
TensorType,
|
||||
};
|
||||
use crate::witx::wasi_ephemeral_nn::WasiEphemeralNn;
|
||||
use crate::WasiNnCtx;
|
||||
use openvino::{Layout, Precision, TensorDesc};
|
||||
use thiserror::Error;
|
||||
use wiggle::GuestPtr;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum UsageError {
|
||||
#[error("Only OpenVINO's IR is currently supported, passed encoding: {0}")]
|
||||
InvalidEncoding(GraphEncoding),
|
||||
#[error("OpenVINO expects only two buffers (i.e. [ir, weights]), passed: {0}")]
|
||||
InvalidNumberOfBuilders(u32),
|
||||
#[error("Invalid graph handle; has it been loaded?")]
|
||||
InvalidGraphHandle,
|
||||
#[error("Invalid execution context handle; has it been initialized?")]
|
||||
InvalidExecutionContextHandle,
|
||||
#[error("Not enough memory to copy tensor data of size: {0}")]
|
||||
NotEnoughMemory(u32),
|
||||
}
|
||||
|
||||
impl<'a> WasiEphemeralNn for WasiNnCtx {
|
||||
fn load<'b>(
|
||||
&self,
|
||||
builders: &GraphBuilderArray<'_>,
|
||||
encoding: GraphEncoding,
|
||||
target: ExecutionTarget,
|
||||
) -> Result<Graph> {
|
||||
if encoding != GraphEncoding::Openvino {
|
||||
return Err(UsageError::InvalidEncoding(encoding).into());
|
||||
}
|
||||
if builders.len() != 2 {
|
||||
return Err(UsageError::InvalidNumberOfBuilders(builders.len()).into());
|
||||
}
|
||||
let builders = builders.as_ptr();
|
||||
let xml = builders.read()?.as_slice()?;
|
||||
let weights = builders.add(1)?.read()?.as_slice()?;
|
||||
let graph = self
|
||||
.ctx
|
||||
.borrow_mut()
|
||||
.core
|
||||
.read_network_from_buffer(&xml, &weights)?;
|
||||
let executable_graph = self
|
||||
.ctx
|
||||
.borrow_mut()
|
||||
.core
|
||||
.load_network(&graph, map_execution_target_to_string(target))?;
|
||||
let id = self
|
||||
.ctx
|
||||
.borrow_mut()
|
||||
.graphs
|
||||
.insert((graph, executable_graph));
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
fn init_execution_context(&self, graph: Graph) -> Result<GraphExecutionContext> {
|
||||
let request =
|
||||
if let Some((_, executable_graph)) = self.ctx.borrow_mut().graphs.get_mut(graph) {
|
||||
executable_graph.create_infer_request()?
|
||||
} else {
|
||||
return Err(UsageError::InvalidGraphHandle.into());
|
||||
};
|
||||
|
||||
let execution_context = ExecutionContext::new(graph, request);
|
||||
let handle = self.ctx.borrow_mut().executions.insert(execution_context);
|
||||
Ok(handle)
|
||||
}
|
||||
|
||||
fn set_input<'b>(
|
||||
&self,
|
||||
context: GraphExecutionContext,
|
||||
index: u32,
|
||||
tensor: &Tensor<'b>,
|
||||
) -> Result<()> {
|
||||
let graph = if let Some(execution) = self.ctx.borrow_mut().executions.get_mut(context) {
|
||||
execution.graph
|
||||
} else {
|
||||
return Err(UsageError::InvalidExecutionContextHandle.into());
|
||||
};
|
||||
|
||||
let input_name = if let Some((graph, _)) = self.ctx.borrow().graphs.get(graph) {
|
||||
graph.get_input_name(index as usize)?
|
||||
} else {
|
||||
unreachable!("It should be impossible to attempt to access an execution's graph and for that graph not to exist--this is a bug.")
|
||||
};
|
||||
|
||||
// Construct the blob structure.
|
||||
let dimensions = tensor
|
||||
.dimensions
|
||||
.as_slice()?
|
||||
.iter()
|
||||
.map(|d| *d as u64)
|
||||
.collect::<Vec<_>>();
|
||||
let precision = match tensor.type_ {
|
||||
TensorType::F16 => Precision::FP16,
|
||||
TensorType::F32 => Precision::FP32,
|
||||
TensorType::U8 => Precision::U8,
|
||||
TensorType::I32 => Precision::I32,
|
||||
};
|
||||
// TODO There must be some good way to discover the layout here; this should not have to default to NHWC.
|
||||
let desc = TensorDesc::new(Layout::NHWC, &dimensions, precision);
|
||||
let data = tensor.data.as_slice()?;
|
||||
let blob = openvino::Blob::new(desc, &data)?;
|
||||
|
||||
// Actually assign the blob to the request (TODO avoid duplication with the borrow above).
|
||||
if let Some(execution) = self.ctx.borrow_mut().executions.get_mut(context) {
|
||||
execution.request.set_blob(&input_name, blob)?;
|
||||
} else {
|
||||
return Err(UsageError::InvalidExecutionContextHandle.into());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn compute(&self, context: GraphExecutionContext) -> Result<()> {
|
||||
if let Some(execution) = self.ctx.borrow_mut().executions.get_mut(context) {
|
||||
Ok(execution.request.infer()?)
|
||||
} else {
|
||||
return Err(UsageError::InvalidExecutionContextHandle.into());
|
||||
}
|
||||
}
|
||||
|
||||
fn get_output<'b>(
|
||||
&self,
|
||||
context: GraphExecutionContext,
|
||||
index: u32,
|
||||
out_buffer: &GuestPtr<'_, u8>,
|
||||
out_buffer_max_size: u32,
|
||||
) -> Result<u32> {
|
||||
let graph = if let Some(execution) = self.ctx.borrow_mut().executions.get_mut(context) {
|
||||
execution.graph
|
||||
} else {
|
||||
return Err(UsageError::InvalidExecutionContextHandle.into());
|
||||
};
|
||||
|
||||
let output_name = if let Some((graph, _)) = self.ctx.borrow().graphs.get(graph) {
|
||||
graph.get_output_name(index as usize)?
|
||||
} else {
|
||||
unreachable!("It should be impossible to attempt to access an execution's graph and for that graph not to exist--this is a bug.")
|
||||
};
|
||||
|
||||
// Retrieve the tensor data.
|
||||
let (mut blob, blob_size) =
|
||||
if let Some(execution) = self.ctx.borrow_mut().executions.get_mut(context) {
|
||||
let mut blob = execution.request.get_blob(&output_name)?; // TODO shouldn't need to be mut
|
||||
let blob_size = blob.byte_len()? as u32;
|
||||
if blob_size > out_buffer_max_size {
|
||||
return Err(UsageError::NotEnoughMemory(blob_size).into());
|
||||
}
|
||||
(blob, blob_size)
|
||||
} else {
|
||||
return Err(UsageError::InvalidExecutionContextHandle.into());
|
||||
};
|
||||
|
||||
// Copy the tensor data over to the `out_buffer`.
|
||||
let mut out_slice = out_buffer.as_array(out_buffer_max_size).as_slice()?;
|
||||
(&mut out_slice[..blob_size as usize]).copy_from_slice(blob.buffer()?);
|
||||
|
||||
Ok(blob_size)
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the execution target string expected by OpenVINO from the `ExecutionTarget` enum provided
|
||||
/// by wasi-nn.
|
||||
fn map_execution_target_to_string(target: ExecutionTarget) -> &'static str {
|
||||
match target {
|
||||
ExecutionTarget::Cpu => "CPU",
|
||||
ExecutionTarget::Gpu => "GPU",
|
||||
ExecutionTarget::Tpu => unimplemented!("OpenVINO does not support TPU execution targets"),
|
||||
}
|
||||
}
|
||||
26
crates/wasi-nn/src/lib.rs
Normal file
26
crates/wasi-nn/src/lib.rs
Normal file
@@ -0,0 +1,26 @@
|
||||
mod ctx;
|
||||
mod r#impl;
|
||||
mod witx;
|
||||
|
||||
pub use ctx::WasiNnCtx;
|
||||
|
||||
// Defines a `struct WasiNn` with member fields and appropriate APIs for dealing with all the
|
||||
// various WASI exports.
|
||||
wasmtime_wiggle::wasmtime_integration!({
|
||||
// The wiggle code to integrate with lives here:
|
||||
target: witx,
|
||||
// This must be the same witx document as used above:
|
||||
witx: ["$WASI_ROOT/phases/ephemeral/witx/wasi_ephemeral_nn.witx"],
|
||||
// This must be the same ctx type as used for the target:
|
||||
ctx: WasiNnCtx,
|
||||
// This macro will emit a struct to represent the instance, with this name and docs:
|
||||
modules: {
|
||||
wasi_ephemeral_nn => {
|
||||
name: WasiNn,
|
||||
docs: "An instantiated instance of the wasi-nn exports.",
|
||||
function_override: {}
|
||||
}
|
||||
},
|
||||
// Error to return when caller module is missing memory export:
|
||||
missing_memory: { witx::types::Errno::MissingMemory },
|
||||
});
|
||||
40
crates/wasi-nn/src/witx.rs
Normal file
40
crates/wasi-nn/src/witx.rs
Normal file
@@ -0,0 +1,40 @@
|
||||
//! Contains the macro-generated implementation of wasi-nn from the its witx definition file.
|
||||
use crate::ctx::WasiNnCtx;
|
||||
use crate::ctx::WasiNnError;
|
||||
|
||||
// Generate the traits and types of wasi-nn in several Rust modules (e.g. `types`).
|
||||
wiggle::from_witx!({
|
||||
witx: ["$WASI_ROOT/phases/ephemeral/witx/wasi_ephemeral_nn.witx"],
|
||||
ctx: WasiNnCtx,
|
||||
errors: { errno => WasiNnError }
|
||||
});
|
||||
|
||||
use types::Errno;
|
||||
|
||||
/// Wiggle generates code that performs some input validation on the arguments passed in by users of
|
||||
/// wasi-nn. Here we convert the validation error into one (or more, eventually) of the error
|
||||
/// variants defined in the witx.
|
||||
impl types::GuestErrorConversion for WasiNnCtx {
|
||||
fn into_errno(&self, e: wiggle::GuestError) -> Errno {
|
||||
eprintln!("Guest error: {:?}", e);
|
||||
Errno::InvalidArgument
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> types::UserErrorConversion for WasiNnCtx {
|
||||
fn errno_from_wasi_nn_error(&self, e: WasiNnError) -> Errno {
|
||||
eprintln!("Host error: {:?}", e);
|
||||
match e {
|
||||
WasiNnError::OpenvinoError(_) => unimplemented!(),
|
||||
WasiNnError::GuestError(_) => unimplemented!(),
|
||||
WasiNnError::UsageError(_) => unimplemented!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Additionally, we must let Wiggle know which of our error codes represents a successful operation.
|
||||
impl wiggle::GuestErrorType for Errno {
|
||||
fn success() -> Self {
|
||||
Self::Success
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user