wasi-nn: refactor wasi-nn context to use multiple backends
This commit is contained in:
@@ -1,29 +1,68 @@
|
||||
//! Implements the base structure (i.e. [WasiNnCtx]) that will provide the implementation of the
|
||||
//! wasi-nn API.
|
||||
//! Implements the base structure (i.e. [WasiNnCtx]) that will provide the
|
||||
//! implementation of the wasi-nn API.
|
||||
use crate::api::{Backend, BackendError, BackendExecutionContext, BackendGraph};
|
||||
use crate::openvino::OpenvinoBackend;
|
||||
use crate::r#impl::UsageError;
|
||||
use crate::witx::types::{Graph, GraphExecutionContext};
|
||||
use openvino::{InferenceError, SetupError};
|
||||
use crate::witx::types::{Graph, GraphEncoding, GraphExecutionContext};
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::hash::Hash;
|
||||
use thiserror::Error;
|
||||
use wiggle::GuestError;
|
||||
|
||||
/// Capture the state necessary for calling into the backend ML libraries.
|
||||
pub struct Ctx {
|
||||
pub(crate) backends: HashMap<u8, Box<dyn Backend>>,
|
||||
pub(crate) graphs: Table<Graph, Box<dyn BackendGraph>>,
|
||||
pub(crate) executions: Table<GraphExecutionContext, Box<dyn BackendExecutionContext>>,
|
||||
}
|
||||
|
||||
impl Ctx {
|
||||
/// Make a new context from the default state.
|
||||
pub fn new() -> WasiNnResult<Self> {
|
||||
let mut backends = HashMap::new();
|
||||
backends.insert(
|
||||
// This is necessary because Wiggle's variant types do not derive
|
||||
// `Hash` and `Eq`.
|
||||
GraphEncoding::Openvino.into(),
|
||||
Box::new(OpenvinoBackend::default()) as Box<dyn Backend>,
|
||||
);
|
||||
Ok(Self {
|
||||
backends,
|
||||
graphs: Table::default(),
|
||||
executions: Table::default(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// This struct solely wraps [Ctx] in a `RefCell`.
|
||||
pub struct WasiNnCtx {
|
||||
pub(crate) ctx: RefCell<Ctx>,
|
||||
}
|
||||
|
||||
impl WasiNnCtx {
|
||||
/// Make a new `WasiNnCtx` with the default settings.
|
||||
pub fn new() -> WasiNnResult<Self> {
|
||||
Ok(Self {
|
||||
ctx: RefCell::new(Ctx::new()?),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Possible errors while interacting with [WasiNnCtx].
|
||||
#[derive(Debug, Error)]
|
||||
pub enum WasiNnError {
|
||||
#[error("backend error")]
|
||||
BackendError(#[from] BackendError),
|
||||
#[error("guest error")]
|
||||
GuestError(#[from] GuestError),
|
||||
#[error("openvino inference error")]
|
||||
OpenvinoInferenceError(#[from] InferenceError),
|
||||
#[error("openvino setup error")]
|
||||
OpenvinoSetupError(#[from] SetupError),
|
||||
#[error("usage error")]
|
||||
UsageError(#[from] UsageError),
|
||||
}
|
||||
|
||||
pub(crate) type WasiNnResult<T> = std::result::Result<T, WasiNnError>;
|
||||
|
||||
/// Record handle entries in a table.
|
||||
pub struct Table<K, V> {
|
||||
entries: HashMap<K, V>,
|
||||
next_key: u32,
|
||||
@@ -48,10 +87,6 @@ where
|
||||
key
|
||||
}
|
||||
|
||||
pub fn get(&self, key: K) -> Option<&V> {
|
||||
self.entries.get(&key)
|
||||
}
|
||||
|
||||
pub fn get_mut(&mut self, key: K) -> Option<&mut V> {
|
||||
self.entries.get_mut(&key)
|
||||
}
|
||||
@@ -63,51 +98,6 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ExecutionContext {
|
||||
pub(crate) graph: Graph,
|
||||
pub(crate) request: openvino::InferRequest,
|
||||
}
|
||||
|
||||
impl ExecutionContext {
|
||||
pub(crate) fn new(graph: Graph, request: openvino::InferRequest) -> Self {
|
||||
Self { graph, request }
|
||||
}
|
||||
}
|
||||
|
||||
/// Capture the state necessary for calling into `openvino`.
|
||||
pub struct Ctx {
|
||||
pub(crate) core: Option<openvino::Core>,
|
||||
pub(crate) graphs: Table<Graph, (openvino::CNNNetwork, openvino::ExecutableNetwork)>,
|
||||
pub(crate) executions: Table<GraphExecutionContext, ExecutionContext>,
|
||||
}
|
||||
|
||||
impl Ctx {
|
||||
/// Make a new `WasiNnCtx` with the default settings.
|
||||
pub fn new() -> WasiNnResult<Self> {
|
||||
Ok(Self {
|
||||
core: Option::default(),
|
||||
graphs: Table::default(),
|
||||
executions: Table::default(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// This structure provides the Rust-side context necessary for implementing the wasi-nn API. At the
|
||||
/// moment, it is specialized for a single inference implementation (i.e. OpenVINO) but conceivably
|
||||
/// this could support more than one backing implementation.
|
||||
pub struct WasiNnCtx {
|
||||
pub(crate) ctx: RefCell<Ctx>,
|
||||
}
|
||||
|
||||
impl WasiNnCtx {
|
||||
/// Make a new `WasiNnCtx` with the default settings.
|
||||
pub fn new() -> WasiNnResult<Self> {
|
||||
Ok(Self {
|
||||
ctx: RefCell::new(Ctx::new()?),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
Reference in New Issue
Block a user