wasi-nn: refactor wasi-nn context to use multiple backends
This commit is contained in:
@@ -33,7 +33,7 @@ pub(crate) trait BackendExecutionContext {
|
|||||||
/// Errors returned by a backend; [BackendError::BackendAccess] is a catch-all
|
/// Errors returned by a backend; [BackendError::BackendAccess] is a catch-all
|
||||||
/// for failures interacting with the ML library.
|
/// for failures interacting with the ML library.
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
pub(crate) enum BackendError {
|
pub enum BackendError {
|
||||||
#[error("Failed while accessing backend")]
|
#[error("Failed while accessing backend")]
|
||||||
BackendAccess(#[from] anyhow::Error),
|
BackendAccess(#[from] anyhow::Error),
|
||||||
#[error("Failed while accessing guest module")]
|
#[error("Failed while accessing guest module")]
|
||||||
|
|||||||
@@ -1,29 +1,68 @@
|
|||||||
//! Implements the base structure (i.e. [WasiNnCtx]) that will provide the implementation of the
|
//! Implements the base structure (i.e. [WasiNnCtx]) that will provide the
|
||||||
//! wasi-nn API.
|
//! implementation of the wasi-nn API.
|
||||||
|
use crate::api::{Backend, BackendError, BackendExecutionContext, BackendGraph};
|
||||||
|
use crate::openvino::OpenvinoBackend;
|
||||||
use crate::r#impl::UsageError;
|
use crate::r#impl::UsageError;
|
||||||
use crate::witx::types::{Graph, GraphExecutionContext};
|
use crate::witx::types::{Graph, GraphEncoding, GraphExecutionContext};
|
||||||
use openvino::{InferenceError, SetupError};
|
|
||||||
use std::cell::RefCell;
|
use std::cell::RefCell;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::hash::Hash;
|
use std::hash::Hash;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use wiggle::GuestError;
|
use wiggle::GuestError;
|
||||||
|
|
||||||
|
/// Capture the state necessary for calling into the backend ML libraries.
|
||||||
|
pub struct Ctx {
|
||||||
|
pub(crate) backends: HashMap<u8, Box<dyn Backend>>,
|
||||||
|
pub(crate) graphs: Table<Graph, Box<dyn BackendGraph>>,
|
||||||
|
pub(crate) executions: Table<GraphExecutionContext, Box<dyn BackendExecutionContext>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Ctx {
|
||||||
|
/// Make a new context from the default state.
|
||||||
|
pub fn new() -> WasiNnResult<Self> {
|
||||||
|
let mut backends = HashMap::new();
|
||||||
|
backends.insert(
|
||||||
|
// This is necessary because Wiggle's variant types do not derive
|
||||||
|
// `Hash` and `Eq`.
|
||||||
|
GraphEncoding::Openvino.into(),
|
||||||
|
Box::new(OpenvinoBackend::default()) as Box<dyn Backend>,
|
||||||
|
);
|
||||||
|
Ok(Self {
|
||||||
|
backends,
|
||||||
|
graphs: Table::default(),
|
||||||
|
executions: Table::default(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This struct solely wraps [Ctx] in a `RefCell`.
|
||||||
|
pub struct WasiNnCtx {
|
||||||
|
pub(crate) ctx: RefCell<Ctx>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WasiNnCtx {
|
||||||
|
/// Make a new `WasiNnCtx` with the default settings.
|
||||||
|
pub fn new() -> WasiNnResult<Self> {
|
||||||
|
Ok(Self {
|
||||||
|
ctx: RefCell::new(Ctx::new()?),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Possible errors while interacting with [WasiNnCtx].
|
/// Possible errors while interacting with [WasiNnCtx].
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
pub enum WasiNnError {
|
pub enum WasiNnError {
|
||||||
|
#[error("backend error")]
|
||||||
|
BackendError(#[from] BackendError),
|
||||||
#[error("guest error")]
|
#[error("guest error")]
|
||||||
GuestError(#[from] GuestError),
|
GuestError(#[from] GuestError),
|
||||||
#[error("openvino inference error")]
|
|
||||||
OpenvinoInferenceError(#[from] InferenceError),
|
|
||||||
#[error("openvino setup error")]
|
|
||||||
OpenvinoSetupError(#[from] SetupError),
|
|
||||||
#[error("usage error")]
|
#[error("usage error")]
|
||||||
UsageError(#[from] UsageError),
|
UsageError(#[from] UsageError),
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) type WasiNnResult<T> = std::result::Result<T, WasiNnError>;
|
pub(crate) type WasiNnResult<T> = std::result::Result<T, WasiNnError>;
|
||||||
|
|
||||||
|
/// Record handle entries in a table.
|
||||||
pub struct Table<K, V> {
|
pub struct Table<K, V> {
|
||||||
entries: HashMap<K, V>,
|
entries: HashMap<K, V>,
|
||||||
next_key: u32,
|
next_key: u32,
|
||||||
@@ -48,10 +87,6 @@ where
|
|||||||
key
|
key
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get(&self, key: K) -> Option<&V> {
|
|
||||||
self.entries.get(&key)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_mut(&mut self, key: K) -> Option<&mut V> {
|
pub fn get_mut(&mut self, key: K) -> Option<&mut V> {
|
||||||
self.entries.get_mut(&key)
|
self.entries.get_mut(&key)
|
||||||
}
|
}
|
||||||
@@ -63,51 +98,6 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct ExecutionContext {
|
|
||||||
pub(crate) graph: Graph,
|
|
||||||
pub(crate) request: openvino::InferRequest,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ExecutionContext {
|
|
||||||
pub(crate) fn new(graph: Graph, request: openvino::InferRequest) -> Self {
|
|
||||||
Self { graph, request }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Capture the state necessary for calling into `openvino`.
|
|
||||||
pub struct Ctx {
|
|
||||||
pub(crate) core: Option<openvino::Core>,
|
|
||||||
pub(crate) graphs: Table<Graph, (openvino::CNNNetwork, openvino::ExecutableNetwork)>,
|
|
||||||
pub(crate) executions: Table<GraphExecutionContext, ExecutionContext>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Ctx {
|
|
||||||
/// Make a new `WasiNnCtx` with the default settings.
|
|
||||||
pub fn new() -> WasiNnResult<Self> {
|
|
||||||
Ok(Self {
|
|
||||||
core: Option::default(),
|
|
||||||
graphs: Table::default(),
|
|
||||||
executions: Table::default(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// This structure provides the Rust-side context necessary for implementing the wasi-nn API. At the
|
|
||||||
/// moment, it is specialized for a single inference implementation (i.e. OpenVINO) but conceivably
|
|
||||||
/// this could support more than one backing implementation.
|
|
||||||
pub struct WasiNnCtx {
|
|
||||||
pub(crate) ctx: RefCell<Ctx>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl WasiNnCtx {
|
|
||||||
/// Make a new `WasiNnCtx` with the default settings.
|
|
||||||
pub fn new() -> WasiNnResult<Self> {
|
|
||||||
Ok(Self {
|
|
||||||
ctx: RefCell::new(Ctx::new()?),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test {
|
mod test {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|||||||
@@ -1,12 +1,10 @@
|
|||||||
//! Implements the wasi-nn API.
|
//! Implements the wasi-nn API.
|
||||||
use crate::ctx::{ExecutionContext, WasiNnResult as Result};
|
use crate::ctx::WasiNnResult as Result;
|
||||||
use crate::witx::types::{
|
use crate::witx::types::{
|
||||||
ExecutionTarget, Graph, GraphBuilderArray, GraphEncoding, GraphExecutionContext, Tensor,
|
ExecutionTarget, Graph, GraphBuilderArray, GraphEncoding, GraphExecutionContext, Tensor,
|
||||||
TensorType,
|
|
||||||
};
|
};
|
||||||
use crate::witx::wasi_ephemeral_nn::WasiEphemeralNn;
|
use crate::witx::wasi_ephemeral_nn::WasiEphemeralNn;
|
||||||
use crate::WasiNnCtx;
|
use crate::WasiNnCtx;
|
||||||
use openvino::{Layout, Precision, TensorDesc};
|
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use wiggle::GuestPtr;
|
use wiggle::GuestPtr;
|
||||||
|
|
||||||
@@ -33,162 +31,60 @@ impl<'a> WasiEphemeralNn for WasiNnCtx {
|
|||||||
encoding: GraphEncoding,
|
encoding: GraphEncoding,
|
||||||
target: ExecutionTarget,
|
target: ExecutionTarget,
|
||||||
) -> Result<Graph> {
|
) -> Result<Graph> {
|
||||||
if encoding != GraphEncoding::Openvino {
|
let encoding_id: u8 = encoding.into();
|
||||||
|
let graph = if let Some(backend) = self.ctx.borrow_mut().backends.get_mut(&encoding_id) {
|
||||||
|
backend.load(builders, target)?
|
||||||
|
} else {
|
||||||
return Err(UsageError::InvalidEncoding(encoding).into());
|
return Err(UsageError::InvalidEncoding(encoding).into());
|
||||||
|
};
|
||||||
|
let graph_id = self.ctx.borrow_mut().graphs.insert(graph);
|
||||||
|
Ok(graph_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
if builders.len() != 2 {
|
fn init_execution_context(&mut self, graph_id: Graph) -> Result<GraphExecutionContext> {
|
||||||
return Err(UsageError::InvalidNumberOfBuilders(builders.len()).into());
|
let exec_context = if let Some(graph) = self.ctx.borrow_mut().graphs.get_mut(graph_id) {
|
||||||
}
|
graph.init_execution_context()?
|
||||||
|
|
||||||
// Construct the context if none is present; this is done lazily (i.e. upon actually loading
|
|
||||||
// a model) because it may fail to find and load the OpenVINO libraries. The laziness limits
|
|
||||||
// the extent of the error only to wasi-nn users, not all WASI users.
|
|
||||||
if self.ctx.borrow().core.is_none() {
|
|
||||||
self.ctx
|
|
||||||
.borrow_mut()
|
|
||||||
.core
|
|
||||||
.replace(openvino::Core::new(None)?);
|
|
||||||
}
|
|
||||||
|
|
||||||
let builders = builders.as_ptr();
|
|
||||||
let xml = builders.read()?.as_slice()?;
|
|
||||||
let weights = builders.add(1)?.read()?.as_slice()?;
|
|
||||||
let graph = self
|
|
||||||
.ctx
|
|
||||||
.borrow_mut()
|
|
||||||
.core
|
|
||||||
.as_mut()
|
|
||||||
.ok_or(UsageError::InvalidContext)?
|
|
||||||
.read_network_from_buffer(&xml, &weights)?;
|
|
||||||
let executable_graph = self
|
|
||||||
.ctx
|
|
||||||
.borrow_mut()
|
|
||||||
.core
|
|
||||||
.as_mut()
|
|
||||||
.ok_or(UsageError::InvalidContext)?
|
|
||||||
.load_network(&graph, map_execution_target_to_string(target))?;
|
|
||||||
let id = self
|
|
||||||
.ctx
|
|
||||||
.borrow_mut()
|
|
||||||
.graphs
|
|
||||||
.insert((graph, executable_graph));
|
|
||||||
Ok(id)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn init_execution_context(&mut self, graph: Graph) -> Result<GraphExecutionContext> {
|
|
||||||
let request =
|
|
||||||
if let Some((_, executable_graph)) = self.ctx.borrow_mut().graphs.get_mut(graph) {
|
|
||||||
executable_graph.create_infer_request()?
|
|
||||||
} else {
|
} else {
|
||||||
return Err(UsageError::InvalidGraphHandle.into());
|
return Err(UsageError::InvalidGraphHandle.into());
|
||||||
};
|
};
|
||||||
|
|
||||||
let execution_context = ExecutionContext::new(graph, request);
|
let exec_context_id = self.ctx.borrow_mut().executions.insert(exec_context);
|
||||||
let handle = self.ctx.borrow_mut().executions.insert(execution_context);
|
Ok(exec_context_id)
|
||||||
Ok(handle)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn set_input<'b>(
|
fn set_input<'b>(
|
||||||
&mut self,
|
&mut self,
|
||||||
context: GraphExecutionContext,
|
exec_context_id: GraphExecutionContext,
|
||||||
index: u32,
|
index: u32,
|
||||||
tensor: &Tensor<'b>,
|
tensor: &Tensor<'b>,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let graph = if let Some(execution) = self.ctx.borrow_mut().executions.get_mut(context) {
|
if let Some(exec_context) = self.ctx.borrow_mut().executions.get_mut(exec_context_id) {
|
||||||
execution.graph
|
Ok(exec_context.set_input(index, tensor)?)
|
||||||
} else {
|
} else {
|
||||||
return Err(UsageError::InvalidExecutionContextHandle.into());
|
Err(UsageError::InvalidGraphHandle.into())
|
||||||
};
|
}
|
||||||
|
|
||||||
let input_name = if let Some((graph, _)) = self.ctx.borrow().graphs.get(graph) {
|
|
||||||
graph.get_input_name(index as usize)?
|
|
||||||
} else {
|
|
||||||
unreachable!("It should be impossible to attempt to access an execution's graph and for that graph not to exist--this is a bug.")
|
|
||||||
};
|
|
||||||
|
|
||||||
// Construct the blob structure.
|
|
||||||
let dimensions = tensor
|
|
||||||
.dimensions
|
|
||||||
.as_slice()?
|
|
||||||
.iter()
|
|
||||||
.map(|d| *d as usize)
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
let precision = match tensor.type_ {
|
|
||||||
TensorType::F16 => Precision::FP16,
|
|
||||||
TensorType::F32 => Precision::FP32,
|
|
||||||
TensorType::U8 => Precision::U8,
|
|
||||||
TensorType::I32 => Precision::I32,
|
|
||||||
};
|
|
||||||
// TODO There must be some good way to discover the layout here; this should not have to default to NHWC.
|
|
||||||
let desc = TensorDesc::new(Layout::NHWC, &dimensions, precision);
|
|
||||||
let data = tensor.data.as_slice()?;
|
|
||||||
let blob = openvino::Blob::new(desc, &data)?;
|
|
||||||
|
|
||||||
// Actually assign the blob to the request (TODO avoid duplication with the borrow above).
|
|
||||||
if let Some(execution) = self.ctx.borrow_mut().executions.get_mut(context) {
|
|
||||||
execution.request.set_blob(&input_name, blob)?;
|
|
||||||
} else {
|
|
||||||
return Err(UsageError::InvalidExecutionContextHandle.into());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
fn compute(&mut self, exec_context_id: GraphExecutionContext) -> Result<()> {
|
||||||
}
|
if let Some(exec_context) = self.ctx.borrow_mut().executions.get_mut(exec_context_id) {
|
||||||
|
Ok(exec_context.compute()?)
|
||||||
fn compute(&mut self, context: GraphExecutionContext) -> Result<()> {
|
|
||||||
if let Some(execution) = self.ctx.borrow_mut().executions.get_mut(context) {
|
|
||||||
Ok(execution.request.infer()?)
|
|
||||||
} else {
|
} else {
|
||||||
return Err(UsageError::InvalidExecutionContextHandle.into());
|
Err(UsageError::InvalidExecutionContextHandle.into())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_output<'b>(
|
fn get_output<'b>(
|
||||||
&mut self,
|
&mut self,
|
||||||
context: GraphExecutionContext,
|
exec_context_id: GraphExecutionContext,
|
||||||
index: u32,
|
index: u32,
|
||||||
out_buffer: &GuestPtr<'_, u8>,
|
out_buffer: &GuestPtr<'_, u8>,
|
||||||
out_buffer_max_size: u32,
|
out_buffer_max_size: u32,
|
||||||
) -> Result<u32> {
|
) -> Result<u32> {
|
||||||
let graph = if let Some(execution) = self.ctx.borrow_mut().executions.get_mut(context) {
|
let mut destination = out_buffer.as_array(out_buffer_max_size).as_slice_mut()?;
|
||||||
execution.graph
|
if let Some(exec_context) = self.ctx.borrow_mut().executions.get_mut(exec_context_id) {
|
||||||
|
Ok(exec_context.get_output(index, &mut destination)?)
|
||||||
} else {
|
} else {
|
||||||
return Err(UsageError::InvalidExecutionContextHandle.into());
|
Err(UsageError::InvalidGraphHandle.into())
|
||||||
};
|
|
||||||
|
|
||||||
let output_name = if let Some((graph, _)) = self.ctx.borrow().graphs.get(graph) {
|
|
||||||
graph.get_output_name(index as usize)?
|
|
||||||
} else {
|
|
||||||
unreachable!("It should be impossible to attempt to access an execution's graph and for that graph not to exist--this is a bug.")
|
|
||||||
};
|
|
||||||
|
|
||||||
// Retrieve the tensor data.
|
|
||||||
let (mut blob, blob_size) =
|
|
||||||
if let Some(execution) = self.ctx.borrow_mut().executions.get_mut(context) {
|
|
||||||
let mut blob = execution.request.get_blob(&output_name)?; // TODO shouldn't need to be mut
|
|
||||||
let blob_size = blob.byte_len()? as u32;
|
|
||||||
if blob_size > out_buffer_max_size {
|
|
||||||
return Err(UsageError::NotEnoughMemory(blob_size).into());
|
|
||||||
}
|
|
||||||
(blob, blob_size)
|
|
||||||
} else {
|
|
||||||
return Err(UsageError::InvalidExecutionContextHandle.into());
|
|
||||||
};
|
|
||||||
|
|
||||||
// Copy the tensor data over to the `out_buffer`.
|
|
||||||
let mut out_slice = out_buffer.as_array(out_buffer_max_size).as_slice_mut()?;
|
|
||||||
(&mut out_slice[..blob_size as usize]).copy_from_slice(blob.buffer()?);
|
|
||||||
|
|
||||||
Ok(blob_size)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return the execution target string expected by OpenVINO from the `ExecutionTarget` enum provided
|
|
||||||
/// by wasi-nn.
|
|
||||||
fn map_execution_target_to_string(target: ExecutionTarget) -> &'static str {
|
|
||||||
match target {
|
|
||||||
ExecutionTarget::Cpu => "CPU",
|
|
||||||
ExecutionTarget::Gpu => "GPU",
|
|
||||||
ExecutionTarget::Tpu => unimplemented!("OpenVINO does not support TPU execution targets"),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,8 +14,7 @@ impl<'a> types::UserErrorConversion for WasiNnCtx {
|
|||||||
fn nn_errno_from_wasi_nn_error(&mut self, e: WasiNnError) -> Result<NnErrno, wiggle::Trap> {
|
fn nn_errno_from_wasi_nn_error(&mut self, e: WasiNnError) -> Result<NnErrno, wiggle::Trap> {
|
||||||
eprintln!("Host error: {:?}", e);
|
eprintln!("Host error: {:?}", e);
|
||||||
match e {
|
match e {
|
||||||
WasiNnError::OpenvinoSetupError(_) => unimplemented!(),
|
WasiNnError::BackendError(_) => unimplemented!(),
|
||||||
WasiNnError::OpenvinoInferenceError(_) => unimplemented!(),
|
|
||||||
WasiNnError::GuestError(_) => unimplemented!(),
|
WasiNnError::GuestError(_) => unimplemented!(),
|
||||||
WasiNnError::UsageError(_) => unimplemented!(),
|
WasiNnError::UsageError(_) => unimplemented!(),
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user