diff --git a/crates/wasi-nn/src/api.rs b/crates/wasi-nn/src/api.rs new file mode 100644 index 0000000000..630a352a7b --- /dev/null +++ b/crates/wasi-nn/src/api.rs @@ -0,0 +1,45 @@ +//! Define the Rust interface a backend must implement in order to be used by +//! this crate. the `Box` types returned by these interfaces allow +//! implementations to maintain backend-specific state between calls. + +use crate::witx::types::{ExecutionTarget, GraphBuilderArray, Tensor}; +use thiserror::Error; +use wiggle::GuestError; + +/// A [Backend] contains the necessary state to load [BackendGraph]s. +pub(crate) trait Backend { + fn name(&self) -> &str; + fn load( + &mut self, + builders: &GraphBuilderArray<'_>, + target: ExecutionTarget, + ) -> Result, BackendError>; +} + +/// A [BackendGraph] can create [BackendExecutionContext]s; this is the backing +/// implementation for a [crate::witx::types::Graph]. +pub(crate) trait BackendGraph { + fn init_execution_context(&mut self) -> Result, BackendError>; +} + +/// A [BackendExecutionContext] performs the actual inference; this is the +/// backing implementation for a [crate::witx::types::GraphExecutionContext]. +pub(crate) trait BackendExecutionContext { + fn set_input(&mut self, index: u32, tensor: &Tensor<'_>) -> Result<(), BackendError>; + fn compute(&mut self) -> Result<(), BackendError>; + fn get_output(&mut self, index: u32, destination: &mut [u8]) -> Result; +} + +/// Errors returned by a backend; [BackendError::BackendAccess] is a catch-all +/// for failures interacting with the ML library. +#[derive(Debug, Error)] +pub(crate) enum BackendError { + #[error("Failed while accessing backend")] + BackendAccess(#[from] anyhow::Error), + #[error("Failed while accessing guest module")] + GuestAccess(#[from] GuestError), + #[error("The backend expects {0} buffers, passed {1}")] + InvalidNumberOfBuilders(u32, u32), + #[error("Not enough memory to copy tensor data of size: {0}")] + NotEnoughMemory(usize), +} diff --git a/crates/wasi-nn/src/lib.rs b/crates/wasi-nn/src/lib.rs index d66b05f5e6..7efae80f61 100644 --- a/crates/wasi-nn/src/lib.rs +++ b/crates/wasi-nn/src/lib.rs @@ -1,5 +1,7 @@ +mod api; mod ctx; mod r#impl; +mod openvino; mod witx; pub use ctx::WasiNnCtx; diff --git a/crates/wasi-nn/src/openvino.rs b/crates/wasi-nn/src/openvino.rs new file mode 100644 index 0000000000..817a0a3d6b --- /dev/null +++ b/crates/wasi-nn/src/openvino.rs @@ -0,0 +1,139 @@ +//! Implements the wasi-nn API. +use crate::api::{Backend, BackendError, BackendExecutionContext, BackendGraph}; +use crate::witx::types::{ExecutionTarget, GraphBuilderArray, Tensor, TensorType}; +use openvino::{InferenceError, Layout, Precision, SetupError, TensorDesc}; +use std::sync::Arc; + +#[derive(Default)] +pub(crate) struct OpenvinoBackend(Option); + +impl Backend for OpenvinoBackend { + fn name(&self) -> &str { + "openvino" + } + + fn load( + &mut self, + builders: &GraphBuilderArray<'_>, + target: ExecutionTarget, + ) -> Result, BackendError> { + if builders.len() != 2 { + return Err(BackendError::InvalidNumberOfBuilders(2, builders.len()).into()); + } + + // Construct the context if none is present; this is done lazily (i.e. + // upon actually loading a model) because it may fail to find and load + // the OpenVINO libraries. The laziness limits the extent of the error + // only to wasi-nn users, not all WASI users. + if self.0.is_none() { + self.0.replace(openvino::Core::new(None)?); + } + + // Read the guest array. + let builders = builders.as_ptr(); + let xml = builders.read()?.as_slice()?; + let weights = builders.add(1)?.read()?.as_slice()?; + + // Construct OpenVINO graph structures: `cnn_network` contains the graph + // structure, `exec_network` can perform inference. + let core = self + .0 + .as_mut() + .expect("openvino::Core was previously constructed"); + let cnn_network = core.read_network_from_buffer(&xml, &weights)?; + let exec_network = + core.load_network(&cnn_network, map_execution_target_to_string(target))?; + + Ok(Box::new(OpenvinoGraph(Arc::new(cnn_network), exec_network))) + } +} + +struct OpenvinoGraph(Arc, openvino::ExecutableNetwork); + +impl BackendGraph for OpenvinoGraph { + fn init_execution_context(&mut self) -> Result, BackendError> { + let infer_request = self.1.create_infer_request()?; + Ok(Box::new(OpenvinoExecutionContext( + self.0.clone(), + infer_request, + ))) + } +} + +struct OpenvinoExecutionContext(Arc, openvino::InferRequest); + +impl BackendExecutionContext for OpenvinoExecutionContext { + fn set_input(&mut self, index: u32, tensor: &Tensor<'_>) -> Result<(), BackendError> { + let input_name = self.0.get_input_name(index as usize)?; + + // Construct the blob structure. + let dimensions = tensor + .dimensions + .as_slice()? + .iter() + .map(|d| *d as usize) + .collect::>(); + let precision = map_tensor_type_to_precision(tensor.type_); + + // TODO There must be some good way to discover the layout here; this + // should not have to default to NHWC. + let desc = TensorDesc::new(Layout::NHWC, &dimensions, precision); + let data = tensor.data.as_slice()?; + let blob = openvino::Blob::new(desc, &data)?; + + // Actually assign the blob to the request. + self.1.set_blob(&input_name, blob)?; + Ok(()) + } + + fn compute(&mut self) -> Result<(), BackendError> { + self.1.infer()?; + Ok(()) + } + + fn get_output(&mut self, index: u32, destination: &mut [u8]) -> Result { + let output_name = self.0.get_output_name(index as usize)?; + let mut blob = self.1.get_blob(&output_name)?; + let blob_size = blob.byte_len()?; + if blob_size > destination.len() { + return Err(BackendError::NotEnoughMemory(blob_size)); + } + + // Copy the tensor data into the destination buffer. + destination[..blob_size].copy_from_slice(blob.buffer()?); + Ok(blob_size as u32) + } +} + +impl From for BackendError { + fn from(e: InferenceError) -> Self { + BackendError::BackendAccess(anyhow::Error::new(e)) + } +} + +impl From for BackendError { + fn from(e: SetupError) -> Self { + BackendError::BackendAccess(anyhow::Error::new(e)) + } +} + +/// Return the execution target string expected by OpenVINO from the +/// `ExecutionTarget` enum provided by wasi-nn. +fn map_execution_target_to_string(target: ExecutionTarget) -> &'static str { + match target { + ExecutionTarget::Cpu => "CPU", + ExecutionTarget::Gpu => "GPU", + ExecutionTarget::Tpu => unimplemented!("OpenVINO does not support TPU execution targets"), + } +} + +/// Return OpenVINO's precision type for the `TensorType` enum provided by +/// wasi-nn. +fn map_tensor_type_to_precision(tensor_type: TensorType) -> openvino::Precision { + match tensor_type { + TensorType::F16 => Precision::FP16, + TensorType::F32 => Precision::FP32, + TensorType::U8 => Precision::U8, + TensorType::I32 => Precision::I32, + } +}