Make WASI-NN classes send and/or sync (#5077)

* Make send and remove wrapper around WasiNnCtx·

This removes the wrapper around WasiNnCtx and no longer requires borrow_mut(). Once send/sync
changes in OpenVINO crate are merged in it will allow·use by frameworks that requires this trait.

* Bump openvino to compatible version.

* BackendExecutionContext should be Send and Sync

* Fix rust format issues.

* Update Cargo.lock for openvino

* Audit changes to openvino crates.
This commit is contained in:
Matthew Tamayo-Rios
2022-10-28 00:52:23 +02:00
committed by GitHub
parent 2702619427
commit f082756643
7 changed files with 43 additions and 38 deletions

View File

@@ -17,7 +17,7 @@ anyhow = { workspace = true }
wiggle = { workspace = true }
# These dependencies are necessary for the wasi-nn implementation:
openvino = { version = "0.4.1", features = ["runtime-linking"] }
openvino = { version = "0.4.2", features = ["runtime-linking"] }
thiserror = "1.0"
[build-dependencies]

View File

@@ -7,7 +7,7 @@ use thiserror::Error;
use wiggle::GuestError;
/// A [Backend] contains the necessary state to load [BackendGraph]s.
pub(crate) trait Backend {
pub(crate) trait Backend: Send {
fn name(&self) -> &str;
fn load(
&mut self,
@@ -18,13 +18,13 @@ pub(crate) trait Backend {
/// A [BackendGraph] can create [BackendExecutionContext]s; this is the backing
/// implementation for a [crate::witx::types::Graph].
pub(crate) trait BackendGraph {
pub(crate) trait BackendGraph: Send {
fn init_execution_context(&mut self) -> Result<Box<dyn BackendExecutionContext>, BackendError>;
}
/// A [BackendExecutionContext] performs the actual inference; this is the
/// backing implementation for a [crate::witx::types::GraphExecutionContext].
pub(crate) trait BackendExecutionContext {
pub(crate) trait BackendExecutionContext: Send + Sync {
fn set_input(&mut self, index: u32, tensor: &Tensor<'_>) -> Result<(), BackendError>;
fn compute(&mut self) -> Result<(), BackendError>;
fn get_output(&mut self, index: u32, destination: &mut [u8]) -> Result<u32, BackendError>;

View File

@@ -4,20 +4,19 @@ use crate::api::{Backend, BackendError, BackendExecutionContext, BackendGraph};
use crate::openvino::OpenvinoBackend;
use crate::r#impl::UsageError;
use crate::witx::types::{Graph, GraphEncoding, GraphExecutionContext};
use std::cell::RefCell;
use std::collections::HashMap;
use std::hash::Hash;
use thiserror::Error;
use wiggle::GuestError;
/// Capture the state necessary for calling into the backend ML libraries.
pub struct Ctx {
pub struct WasiNnCtx {
pub(crate) backends: HashMap<u8, Box<dyn Backend>>,
pub(crate) graphs: Table<Graph, Box<dyn BackendGraph>>,
pub(crate) executions: Table<GraphExecutionContext, Box<dyn BackendExecutionContext>>,
}
impl Ctx {
impl WasiNnCtx {
/// Make a new context from the default state.
pub fn new() -> WasiNnResult<Self> {
let mut backends = HashMap::new();
@@ -35,20 +34,6 @@ impl Ctx {
}
}
/// This struct solely wraps [Ctx] in a `RefCell`.
pub struct WasiNnCtx {
pub(crate) ctx: RefCell<Ctx>,
}
impl WasiNnCtx {
/// Make a new `WasiNnCtx` with the default settings.
pub fn new() -> WasiNnResult<Self> {
Ok(Self {
ctx: RefCell::new(Ctx::new()?),
})
}
}
/// Possible errors while interacting with [WasiNnCtx].
#[derive(Debug, Error)]
pub enum WasiNnError {

View File

@@ -32,23 +32,23 @@ impl<'a> WasiEphemeralNn for WasiNnCtx {
target: ExecutionTarget,
) -> Result<Graph> {
let encoding_id: u8 = encoding.into();
let graph = if let Some(backend) = self.ctx.borrow_mut().backends.get_mut(&encoding_id) {
let graph = if let Some(backend) = self.backends.get_mut(&encoding_id) {
backend.load(builders, target)?
} else {
return Err(UsageError::InvalidEncoding(encoding).into());
};
let graph_id = self.ctx.borrow_mut().graphs.insert(graph);
let graph_id = self.graphs.insert(graph);
Ok(graph_id)
}
fn init_execution_context(&mut self, graph_id: Graph) -> Result<GraphExecutionContext> {
let exec_context = if let Some(graph) = self.ctx.borrow_mut().graphs.get_mut(graph_id) {
let exec_context = if let Some(graph) = self.graphs.get_mut(graph_id) {
graph.init_execution_context()?
} else {
return Err(UsageError::InvalidGraphHandle.into());
};
let exec_context_id = self.ctx.borrow_mut().executions.insert(exec_context);
let exec_context_id = self.executions.insert(exec_context);
Ok(exec_context_id)
}
@@ -58,7 +58,7 @@ impl<'a> WasiEphemeralNn for WasiNnCtx {
index: u32,
tensor: &Tensor<'b>,
) -> Result<()> {
if let Some(exec_context) = self.ctx.borrow_mut().executions.get_mut(exec_context_id) {
if let Some(exec_context) = self.executions.get_mut(exec_context_id) {
Ok(exec_context.set_input(index, tensor)?)
} else {
Err(UsageError::InvalidGraphHandle.into())
@@ -66,7 +66,7 @@ impl<'a> WasiEphemeralNn for WasiNnCtx {
}
fn compute(&mut self, exec_context_id: GraphExecutionContext) -> Result<()> {
if let Some(exec_context) = self.ctx.borrow_mut().executions.get_mut(exec_context_id) {
if let Some(exec_context) = self.executions.get_mut(exec_context_id) {
Ok(exec_context.compute()?)
} else {
Err(UsageError::InvalidExecutionContextHandle.into())
@@ -81,7 +81,7 @@ impl<'a> WasiEphemeralNn for WasiNnCtx {
out_buffer_max_size: u32,
) -> Result<u32> {
let mut destination = out_buffer.as_array(out_buffer_max_size).as_slice_mut()?;
if let Some(exec_context) = self.ctx.borrow_mut().executions.get_mut(exec_context_id) {
if let Some(exec_context) = self.executions.get_mut(exec_context_id) {
Ok(exec_context.get_output(index, &mut destination)?)
} else {
Err(UsageError::InvalidGraphHandle.into())