diff --git a/crates/wasi-nn/src/api.rs b/crates/wasi-nn/src/api.rs index 630a352a7b..94f94526ae 100644 --- a/crates/wasi-nn/src/api.rs +++ b/crates/wasi-nn/src/api.rs @@ -33,7 +33,7 @@ pub(crate) trait BackendExecutionContext { /// Errors returned by a backend; [BackendError::BackendAccess] is a catch-all /// for failures interacting with the ML library. #[derive(Debug, Error)] -pub(crate) enum BackendError { +pub enum BackendError { #[error("Failed while accessing backend")] BackendAccess(#[from] anyhow::Error), #[error("Failed while accessing guest module")] diff --git a/crates/wasi-nn/src/ctx.rs b/crates/wasi-nn/src/ctx.rs index dbe93c7224..c63e05ed28 100644 --- a/crates/wasi-nn/src/ctx.rs +++ b/crates/wasi-nn/src/ctx.rs @@ -1,29 +1,68 @@ -//! Implements the base structure (i.e. [WasiNnCtx]) that will provide the implementation of the -//! wasi-nn API. +//! Implements the base structure (i.e. [WasiNnCtx]) that will provide the +//! implementation of the wasi-nn API. +use crate::api::{Backend, BackendError, BackendExecutionContext, BackendGraph}; +use crate::openvino::OpenvinoBackend; use crate::r#impl::UsageError; -use crate::witx::types::{Graph, GraphExecutionContext}; -use openvino::{InferenceError, SetupError}; +use crate::witx::types::{Graph, GraphEncoding, GraphExecutionContext}; use std::cell::RefCell; use std::collections::HashMap; use std::hash::Hash; use thiserror::Error; use wiggle::GuestError; +/// Capture the state necessary for calling into the backend ML libraries. +pub struct Ctx { + pub(crate) backends: HashMap>, + pub(crate) graphs: Table>, + pub(crate) executions: Table>, +} + +impl Ctx { + /// Make a new context from the default state. + pub fn new() -> WasiNnResult { + let mut backends = HashMap::new(); + backends.insert( + // This is necessary because Wiggle's variant types do not derive + // `Hash` and `Eq`. + GraphEncoding::Openvino.into(), + Box::new(OpenvinoBackend::default()) as Box, + ); + Ok(Self { + backends, + graphs: Table::default(), + executions: Table::default(), + }) + } +} + +/// This struct solely wraps [Ctx] in a `RefCell`. +pub struct WasiNnCtx { + pub(crate) ctx: RefCell, +} + +impl WasiNnCtx { + /// Make a new `WasiNnCtx` with the default settings. + pub fn new() -> WasiNnResult { + Ok(Self { + ctx: RefCell::new(Ctx::new()?), + }) + } +} + /// Possible errors while interacting with [WasiNnCtx]. #[derive(Debug, Error)] pub enum WasiNnError { + #[error("backend error")] + BackendError(#[from] BackendError), #[error("guest error")] GuestError(#[from] GuestError), - #[error("openvino inference error")] - OpenvinoInferenceError(#[from] InferenceError), - #[error("openvino setup error")] - OpenvinoSetupError(#[from] SetupError), #[error("usage error")] UsageError(#[from] UsageError), } pub(crate) type WasiNnResult = std::result::Result; +/// Record handle entries in a table. pub struct Table { entries: HashMap, next_key: u32, @@ -48,10 +87,6 @@ where key } - pub fn get(&self, key: K) -> Option<&V> { - self.entries.get(&key) - } - pub fn get_mut(&mut self, key: K) -> Option<&mut V> { self.entries.get_mut(&key) } @@ -63,51 +98,6 @@ where } } -pub struct ExecutionContext { - pub(crate) graph: Graph, - pub(crate) request: openvino::InferRequest, -} - -impl ExecutionContext { - pub(crate) fn new(graph: Graph, request: openvino::InferRequest) -> Self { - Self { graph, request } - } -} - -/// Capture the state necessary for calling into `openvino`. -pub struct Ctx { - pub(crate) core: Option, - pub(crate) graphs: Table, - pub(crate) executions: Table, -} - -impl Ctx { - /// Make a new `WasiNnCtx` with the default settings. - pub fn new() -> WasiNnResult { - Ok(Self { - core: Option::default(), - graphs: Table::default(), - executions: Table::default(), - }) - } -} - -/// This structure provides the Rust-side context necessary for implementing the wasi-nn API. At the -/// moment, it is specialized for a single inference implementation (i.e. OpenVINO) but conceivably -/// this could support more than one backing implementation. -pub struct WasiNnCtx { - pub(crate) ctx: RefCell, -} - -impl WasiNnCtx { - /// Make a new `WasiNnCtx` with the default settings. - pub fn new() -> WasiNnResult { - Ok(Self { - ctx: RefCell::new(Ctx::new()?), - }) - } -} - #[cfg(test)] mod test { use super::*; diff --git a/crates/wasi-nn/src/impl.rs b/crates/wasi-nn/src/impl.rs index 8152de77b9..1c1c558970 100644 --- a/crates/wasi-nn/src/impl.rs +++ b/crates/wasi-nn/src/impl.rs @@ -1,12 +1,10 @@ //! Implements the wasi-nn API. -use crate::ctx::{ExecutionContext, WasiNnResult as Result}; +use crate::ctx::WasiNnResult as Result; use crate::witx::types::{ ExecutionTarget, Graph, GraphBuilderArray, GraphEncoding, GraphExecutionContext, Tensor, - TensorType, }; use crate::witx::wasi_ephemeral_nn::WasiEphemeralNn; use crate::WasiNnCtx; -use openvino::{Layout, Precision, TensorDesc}; use thiserror::Error; use wiggle::GuestPtr; @@ -33,162 +31,60 @@ impl<'a> WasiEphemeralNn for WasiNnCtx { encoding: GraphEncoding, target: ExecutionTarget, ) -> Result { - if encoding != GraphEncoding::Openvino { + let encoding_id: u8 = encoding.into(); + let graph = if let Some(backend) = self.ctx.borrow_mut().backends.get_mut(&encoding_id) { + backend.load(builders, target)? + } else { return Err(UsageError::InvalidEncoding(encoding).into()); - } - - if builders.len() != 2 { - return Err(UsageError::InvalidNumberOfBuilders(builders.len()).into()); - } - - // Construct the context if none is present; this is done lazily (i.e. upon actually loading - // a model) because it may fail to find and load the OpenVINO libraries. The laziness limits - // the extent of the error only to wasi-nn users, not all WASI users. - if self.ctx.borrow().core.is_none() { - self.ctx - .borrow_mut() - .core - .replace(openvino::Core::new(None)?); - } - - let builders = builders.as_ptr(); - let xml = builders.read()?.as_slice()?; - let weights = builders.add(1)?.read()?.as_slice()?; - let graph = self - .ctx - .borrow_mut() - .core - .as_mut() - .ok_or(UsageError::InvalidContext)? - .read_network_from_buffer(&xml, &weights)?; - let executable_graph = self - .ctx - .borrow_mut() - .core - .as_mut() - .ok_or(UsageError::InvalidContext)? - .load_network(&graph, map_execution_target_to_string(target))?; - let id = self - .ctx - .borrow_mut() - .graphs - .insert((graph, executable_graph)); - Ok(id) + }; + let graph_id = self.ctx.borrow_mut().graphs.insert(graph); + Ok(graph_id) } - fn init_execution_context(&mut self, graph: Graph) -> Result { - let request = - if let Some((_, executable_graph)) = self.ctx.borrow_mut().graphs.get_mut(graph) { - executable_graph.create_infer_request()? - } else { - return Err(UsageError::InvalidGraphHandle.into()); - }; + fn init_execution_context(&mut self, graph_id: Graph) -> Result { + let exec_context = if let Some(graph) = self.ctx.borrow_mut().graphs.get_mut(graph_id) { + graph.init_execution_context()? + } else { + return Err(UsageError::InvalidGraphHandle.into()); + }; - let execution_context = ExecutionContext::new(graph, request); - let handle = self.ctx.borrow_mut().executions.insert(execution_context); - Ok(handle) + let exec_context_id = self.ctx.borrow_mut().executions.insert(exec_context); + Ok(exec_context_id) } fn set_input<'b>( &mut self, - context: GraphExecutionContext, + exec_context_id: GraphExecutionContext, index: u32, tensor: &Tensor<'b>, ) -> Result<()> { - let graph = if let Some(execution) = self.ctx.borrow_mut().executions.get_mut(context) { - execution.graph + if let Some(exec_context) = self.ctx.borrow_mut().executions.get_mut(exec_context_id) { + Ok(exec_context.set_input(index, tensor)?) } else { - return Err(UsageError::InvalidExecutionContextHandle.into()); - }; - - let input_name = if let Some((graph, _)) = self.ctx.borrow().graphs.get(graph) { - graph.get_input_name(index as usize)? - } else { - unreachable!("It should be impossible to attempt to access an execution's graph and for that graph not to exist--this is a bug.") - }; - - // Construct the blob structure. - let dimensions = tensor - .dimensions - .as_slice()? - .iter() - .map(|d| *d as usize) - .collect::>(); - let precision = match tensor.type_ { - TensorType::F16 => Precision::FP16, - TensorType::F32 => Precision::FP32, - TensorType::U8 => Precision::U8, - TensorType::I32 => Precision::I32, - }; - // TODO There must be some good way to discover the layout here; this should not have to default to NHWC. - let desc = TensorDesc::new(Layout::NHWC, &dimensions, precision); - let data = tensor.data.as_slice()?; - let blob = openvino::Blob::new(desc, &data)?; - - // Actually assign the blob to the request (TODO avoid duplication with the borrow above). - if let Some(execution) = self.ctx.borrow_mut().executions.get_mut(context) { - execution.request.set_blob(&input_name, blob)?; - } else { - return Err(UsageError::InvalidExecutionContextHandle.into()); + Err(UsageError::InvalidGraphHandle.into()) } - - Ok(()) } - fn compute(&mut self, context: GraphExecutionContext) -> Result<()> { - if let Some(execution) = self.ctx.borrow_mut().executions.get_mut(context) { - Ok(execution.request.infer()?) + fn compute(&mut self, exec_context_id: GraphExecutionContext) -> Result<()> { + if let Some(exec_context) = self.ctx.borrow_mut().executions.get_mut(exec_context_id) { + Ok(exec_context.compute()?) } else { - return Err(UsageError::InvalidExecutionContextHandle.into()); + Err(UsageError::InvalidExecutionContextHandle.into()) } } fn get_output<'b>( &mut self, - context: GraphExecutionContext, + exec_context_id: GraphExecutionContext, index: u32, out_buffer: &GuestPtr<'_, u8>, out_buffer_max_size: u32, ) -> Result { - let graph = if let Some(execution) = self.ctx.borrow_mut().executions.get_mut(context) { - execution.graph + let mut destination = out_buffer.as_array(out_buffer_max_size).as_slice_mut()?; + if let Some(exec_context) = self.ctx.borrow_mut().executions.get_mut(exec_context_id) { + Ok(exec_context.get_output(index, &mut destination)?) } else { - return Err(UsageError::InvalidExecutionContextHandle.into()); - }; - - let output_name = if let Some((graph, _)) = self.ctx.borrow().graphs.get(graph) { - graph.get_output_name(index as usize)? - } else { - unreachable!("It should be impossible to attempt to access an execution's graph and for that graph not to exist--this is a bug.") - }; - - // Retrieve the tensor data. - let (mut blob, blob_size) = - if let Some(execution) = self.ctx.borrow_mut().executions.get_mut(context) { - let mut blob = execution.request.get_blob(&output_name)?; // TODO shouldn't need to be mut - let blob_size = blob.byte_len()? as u32; - if blob_size > out_buffer_max_size { - return Err(UsageError::NotEnoughMemory(blob_size).into()); - } - (blob, blob_size) - } else { - return Err(UsageError::InvalidExecutionContextHandle.into()); - }; - - // Copy the tensor data over to the `out_buffer`. - let mut out_slice = out_buffer.as_array(out_buffer_max_size).as_slice_mut()?; - (&mut out_slice[..blob_size as usize]).copy_from_slice(blob.buffer()?); - - Ok(blob_size) - } -} - -/// Return the execution target string expected by OpenVINO from the `ExecutionTarget` enum provided -/// by wasi-nn. -fn map_execution_target_to_string(target: ExecutionTarget) -> &'static str { - match target { - ExecutionTarget::Cpu => "CPU", - ExecutionTarget::Gpu => "GPU", - ExecutionTarget::Tpu => unimplemented!("OpenVINO does not support TPU execution targets"), + Err(UsageError::InvalidGraphHandle.into()) + } } } diff --git a/crates/wasi-nn/src/witx.rs b/crates/wasi-nn/src/witx.rs index d17da5fbbb..81a02c139b 100644 --- a/crates/wasi-nn/src/witx.rs +++ b/crates/wasi-nn/src/witx.rs @@ -14,8 +14,7 @@ impl<'a> types::UserErrorConversion for WasiNnCtx { fn nn_errno_from_wasi_nn_error(&mut self, e: WasiNnError) -> Result { eprintln!("Host error: {:?}", e); match e { - WasiNnError::OpenvinoSetupError(_) => unimplemented!(), - WasiNnError::OpenvinoInferenceError(_) => unimplemented!(), + WasiNnError::BackendError(_) => unimplemented!(), WasiNnError::GuestError(_) => unimplemented!(), WasiNnError::UsageError(_) => unimplemented!(), }