//! Implements the base structure (i.e. [WasiNnCtx]) that will provide the //! implementation of the wasi-nn API. use crate::api::{Backend, BackendError, BackendExecutionContext, BackendGraph}; use crate::openvino::OpenvinoBackend; use crate::r#impl::UsageError; use crate::witx::types::{Graph, GraphEncoding, GraphExecutionContext}; use std::cell::RefCell; use std::collections::HashMap; use std::hash::Hash; use thiserror::Error; use wiggle::GuestError; /// Capture the state necessary for calling into the backend ML libraries. pub struct Ctx { pub(crate) backends: HashMap>, pub(crate) graphs: Table>, pub(crate) executions: Table>, } impl Ctx { /// Make a new context from the default state. pub fn new() -> WasiNnResult { let mut backends = HashMap::new(); backends.insert( // This is necessary because Wiggle's variant types do not derive // `Hash` and `Eq`. GraphEncoding::Openvino.into(), Box::new(OpenvinoBackend::default()) as Box, ); Ok(Self { backends, graphs: Table::default(), executions: Table::default(), }) } } /// This struct solely wraps [Ctx] in a `RefCell`. pub struct WasiNnCtx { pub(crate) ctx: RefCell, } impl WasiNnCtx { /// Make a new `WasiNnCtx` with the default settings. pub fn new() -> WasiNnResult { Ok(Self { ctx: RefCell::new(Ctx::new()?), }) } } /// Possible errors while interacting with [WasiNnCtx]. #[derive(Debug, Error)] pub enum WasiNnError { #[error("backend error")] BackendError(#[from] BackendError), #[error("guest error")] GuestError(#[from] GuestError), #[error("usage error")] UsageError(#[from] UsageError), } pub(crate) type WasiNnResult = std::result::Result; /// Record handle entries in a table. pub struct Table { entries: HashMap, next_key: u32, } impl Default for Table { fn default() -> Self { Self { entries: HashMap::new(), next_key: 0, } } } impl Table where K: Eq + Hash + From + Copy, { pub fn insert(&mut self, value: V) -> K { let key = self.use_next_key(); self.entries.insert(key, value); key } pub fn get_mut(&mut self, key: K) -> Option<&mut V> { self.entries.get_mut(&key) } fn use_next_key(&mut self) -> K { let current = self.next_key; self.next_key += 1; K::from(current) } } #[cfg(test)] mod test { use super::*; #[test] fn instantiate() { WasiNnCtx::new().unwrap(); } }