Make WASI-NN classes send and/or sync (#5077)
* Make send and remove wrapper around WasiNnCtx· This removes the wrapper around WasiNnCtx and no longer requires borrow_mut(). Once send/sync changes in OpenVINO crate are merged in it will allow·use by frameworks that requires this trait. * Bump openvino to compatible version. * BackendExecutionContext should be Send and Sync * Fix rust format issues. * Update Cargo.lock for openvino * Audit changes to openvino crates.
This commit is contained in:
committed by
GitHub
parent
2702619427
commit
f082756643
12
Cargo.lock
generated
12
Cargo.lock
generated
@@ -1949,9 +1949,9 @@ checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "openvino"
|
name = "openvino"
|
||||||
version = "0.4.1"
|
version = "0.4.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "d9627908ea4af5766040aa191c8607479af7f70b45fdf6e999b450069fea851a"
|
checksum = "c7336c11cad0eb45f65436cdbf073c697397a1bfe53836cef997129d69443c77"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"openvino-sys",
|
"openvino-sys",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
@@ -1959,9 +1959,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "openvino-finder"
|
name = "openvino-finder"
|
||||||
version = "0.4.1"
|
version = "0.4.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "213893e484dcf3db4af79d498a955f7c4c209d06e7020779cda68fca779c2578"
|
checksum = "c650edf39ea54dfbe18f0ad513858ff0bed3f6a308b677e0d5f71b330f476ccf"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
"log",
|
"log",
|
||||||
@@ -1969,9 +1969,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "openvino-sys"
|
name = "openvino-sys"
|
||||||
version = "0.4.1"
|
version = "0.4.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e2ba37c26ad2591acc48abee5350d65daa263bf0ab7a79d2ab6999d4b20130ec"
|
checksum = "6d003d61f18f7bf6dd965b4e913cbd3e7cda6a3c179115c8ee59e5c29b390f45"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"libloading",
|
"libloading",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ anyhow = { workspace = true }
|
|||||||
wiggle = { workspace = true }
|
wiggle = { workspace = true }
|
||||||
|
|
||||||
# These dependencies are necessary for the wasi-nn implementation:
|
# These dependencies are necessary for the wasi-nn implementation:
|
||||||
openvino = { version = "0.4.1", features = ["runtime-linking"] }
|
openvino = { version = "0.4.2", features = ["runtime-linking"] }
|
||||||
thiserror = "1.0"
|
thiserror = "1.0"
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ use thiserror::Error;
|
|||||||
use wiggle::GuestError;
|
use wiggle::GuestError;
|
||||||
|
|
||||||
/// A [Backend] contains the necessary state to load [BackendGraph]s.
|
/// A [Backend] contains the necessary state to load [BackendGraph]s.
|
||||||
pub(crate) trait Backend {
|
pub(crate) trait Backend: Send {
|
||||||
fn name(&self) -> &str;
|
fn name(&self) -> &str;
|
||||||
fn load(
|
fn load(
|
||||||
&mut self,
|
&mut self,
|
||||||
@@ -18,13 +18,13 @@ pub(crate) trait Backend {
|
|||||||
|
|
||||||
/// A [BackendGraph] can create [BackendExecutionContext]s; this is the backing
|
/// A [BackendGraph] can create [BackendExecutionContext]s; this is the backing
|
||||||
/// implementation for a [crate::witx::types::Graph].
|
/// implementation for a [crate::witx::types::Graph].
|
||||||
pub(crate) trait BackendGraph {
|
pub(crate) trait BackendGraph: Send {
|
||||||
fn init_execution_context(&mut self) -> Result<Box<dyn BackendExecutionContext>, BackendError>;
|
fn init_execution_context(&mut self) -> Result<Box<dyn BackendExecutionContext>, BackendError>;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A [BackendExecutionContext] performs the actual inference; this is the
|
/// A [BackendExecutionContext] performs the actual inference; this is the
|
||||||
/// backing implementation for a [crate::witx::types::GraphExecutionContext].
|
/// backing implementation for a [crate::witx::types::GraphExecutionContext].
|
||||||
pub(crate) trait BackendExecutionContext {
|
pub(crate) trait BackendExecutionContext: Send + Sync {
|
||||||
fn set_input(&mut self, index: u32, tensor: &Tensor<'_>) -> Result<(), BackendError>;
|
fn set_input(&mut self, index: u32, tensor: &Tensor<'_>) -> Result<(), BackendError>;
|
||||||
fn compute(&mut self) -> Result<(), BackendError>;
|
fn compute(&mut self) -> Result<(), BackendError>;
|
||||||
fn get_output(&mut self, index: u32, destination: &mut [u8]) -> Result<u32, BackendError>;
|
fn get_output(&mut self, index: u32, destination: &mut [u8]) -> Result<u32, BackendError>;
|
||||||
|
|||||||
@@ -4,20 +4,19 @@ use crate::api::{Backend, BackendError, BackendExecutionContext, BackendGraph};
|
|||||||
use crate::openvino::OpenvinoBackend;
|
use crate::openvino::OpenvinoBackend;
|
||||||
use crate::r#impl::UsageError;
|
use crate::r#impl::UsageError;
|
||||||
use crate::witx::types::{Graph, GraphEncoding, GraphExecutionContext};
|
use crate::witx::types::{Graph, GraphEncoding, GraphExecutionContext};
|
||||||
use std::cell::RefCell;
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::hash::Hash;
|
use std::hash::Hash;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use wiggle::GuestError;
|
use wiggle::GuestError;
|
||||||
|
|
||||||
/// Capture the state necessary for calling into the backend ML libraries.
|
/// Capture the state necessary for calling into the backend ML libraries.
|
||||||
pub struct Ctx {
|
pub struct WasiNnCtx {
|
||||||
pub(crate) backends: HashMap<u8, Box<dyn Backend>>,
|
pub(crate) backends: HashMap<u8, Box<dyn Backend>>,
|
||||||
pub(crate) graphs: Table<Graph, Box<dyn BackendGraph>>,
|
pub(crate) graphs: Table<Graph, Box<dyn BackendGraph>>,
|
||||||
pub(crate) executions: Table<GraphExecutionContext, Box<dyn BackendExecutionContext>>,
|
pub(crate) executions: Table<GraphExecutionContext, Box<dyn BackendExecutionContext>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Ctx {
|
impl WasiNnCtx {
|
||||||
/// Make a new context from the default state.
|
/// Make a new context from the default state.
|
||||||
pub fn new() -> WasiNnResult<Self> {
|
pub fn new() -> WasiNnResult<Self> {
|
||||||
let mut backends = HashMap::new();
|
let mut backends = HashMap::new();
|
||||||
@@ -35,20 +34,6 @@ impl Ctx {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// This struct solely wraps [Ctx] in a `RefCell`.
|
|
||||||
pub struct WasiNnCtx {
|
|
||||||
pub(crate) ctx: RefCell<Ctx>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl WasiNnCtx {
|
|
||||||
/// Make a new `WasiNnCtx` with the default settings.
|
|
||||||
pub fn new() -> WasiNnResult<Self> {
|
|
||||||
Ok(Self {
|
|
||||||
ctx: RefCell::new(Ctx::new()?),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Possible errors while interacting with [WasiNnCtx].
|
/// Possible errors while interacting with [WasiNnCtx].
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
pub enum WasiNnError {
|
pub enum WasiNnError {
|
||||||
|
|||||||
@@ -32,23 +32,23 @@ impl<'a> WasiEphemeralNn for WasiNnCtx {
|
|||||||
target: ExecutionTarget,
|
target: ExecutionTarget,
|
||||||
) -> Result<Graph> {
|
) -> Result<Graph> {
|
||||||
let encoding_id: u8 = encoding.into();
|
let encoding_id: u8 = encoding.into();
|
||||||
let graph = if let Some(backend) = self.ctx.borrow_mut().backends.get_mut(&encoding_id) {
|
let graph = if let Some(backend) = self.backends.get_mut(&encoding_id) {
|
||||||
backend.load(builders, target)?
|
backend.load(builders, target)?
|
||||||
} else {
|
} else {
|
||||||
return Err(UsageError::InvalidEncoding(encoding).into());
|
return Err(UsageError::InvalidEncoding(encoding).into());
|
||||||
};
|
};
|
||||||
let graph_id = self.ctx.borrow_mut().graphs.insert(graph);
|
let graph_id = self.graphs.insert(graph);
|
||||||
Ok(graph_id)
|
Ok(graph_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn init_execution_context(&mut self, graph_id: Graph) -> Result<GraphExecutionContext> {
|
fn init_execution_context(&mut self, graph_id: Graph) -> Result<GraphExecutionContext> {
|
||||||
let exec_context = if let Some(graph) = self.ctx.borrow_mut().graphs.get_mut(graph_id) {
|
let exec_context = if let Some(graph) = self.graphs.get_mut(graph_id) {
|
||||||
graph.init_execution_context()?
|
graph.init_execution_context()?
|
||||||
} else {
|
} else {
|
||||||
return Err(UsageError::InvalidGraphHandle.into());
|
return Err(UsageError::InvalidGraphHandle.into());
|
||||||
};
|
};
|
||||||
|
|
||||||
let exec_context_id = self.ctx.borrow_mut().executions.insert(exec_context);
|
let exec_context_id = self.executions.insert(exec_context);
|
||||||
Ok(exec_context_id)
|
Ok(exec_context_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -58,7 +58,7 @@ impl<'a> WasiEphemeralNn for WasiNnCtx {
|
|||||||
index: u32,
|
index: u32,
|
||||||
tensor: &Tensor<'b>,
|
tensor: &Tensor<'b>,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
if let Some(exec_context) = self.ctx.borrow_mut().executions.get_mut(exec_context_id) {
|
if let Some(exec_context) = self.executions.get_mut(exec_context_id) {
|
||||||
Ok(exec_context.set_input(index, tensor)?)
|
Ok(exec_context.set_input(index, tensor)?)
|
||||||
} else {
|
} else {
|
||||||
Err(UsageError::InvalidGraphHandle.into())
|
Err(UsageError::InvalidGraphHandle.into())
|
||||||
@@ -66,7 +66,7 @@ impl<'a> WasiEphemeralNn for WasiNnCtx {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn compute(&mut self, exec_context_id: GraphExecutionContext) -> Result<()> {
|
fn compute(&mut self, exec_context_id: GraphExecutionContext) -> Result<()> {
|
||||||
if let Some(exec_context) = self.ctx.borrow_mut().executions.get_mut(exec_context_id) {
|
if let Some(exec_context) = self.executions.get_mut(exec_context_id) {
|
||||||
Ok(exec_context.compute()?)
|
Ok(exec_context.compute()?)
|
||||||
} else {
|
} else {
|
||||||
Err(UsageError::InvalidExecutionContextHandle.into())
|
Err(UsageError::InvalidExecutionContextHandle.into())
|
||||||
@@ -81,7 +81,7 @@ impl<'a> WasiEphemeralNn for WasiNnCtx {
|
|||||||
out_buffer_max_size: u32,
|
out_buffer_max_size: u32,
|
||||||
) -> Result<u32> {
|
) -> Result<u32> {
|
||||||
let mut destination = out_buffer.as_array(out_buffer_max_size).as_slice_mut()?;
|
let mut destination = out_buffer.as_array(out_buffer_max_size).as_slice_mut()?;
|
||||||
if let Some(exec_context) = self.ctx.borrow_mut().executions.get_mut(exec_context_id) {
|
if let Some(exec_context) = self.executions.get_mut(exec_context_id) {
|
||||||
Ok(exec_context.get_output(index, &mut destination)?)
|
Ok(exec_context.get_output(index, &mut destination)?)
|
||||||
} else {
|
} else {
|
||||||
Err(UsageError::InvalidGraphHandle.into())
|
Err(UsageError::InvalidGraphHandle.into())
|
||||||
|
|||||||
@@ -241,6 +241,30 @@ Contains unsafe blocks but are encapsulated and required for the operation at
|
|||||||
hand.
|
hand.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
[[audits.openvino]]
|
||||||
|
who = "Matthew Tamayo-Rios <matthew@geekbeast.com>"
|
||||||
|
criteria = "safe-to-deploy"
|
||||||
|
version = "0.4.2"
|
||||||
|
notes = """
|
||||||
|
I am the author of most of these changes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
[[audits.openvino-finder]]
|
||||||
|
who = "Matthew Tamayo-Rios <matthew@geekbeast.com>"
|
||||||
|
criteria = "safe-to-deploy"
|
||||||
|
delta = "0.4.1 -> 0.4.2"
|
||||||
|
notes = """
|
||||||
|
Only updates to Cargo file for versioning.
|
||||||
|
"""
|
||||||
|
|
||||||
|
[[audits.openvino-sys]]
|
||||||
|
who = "Matthew Tamayo-Rios <matthew@geekbeast.com>"
|
||||||
|
criteria = "safe-to-deploy"
|
||||||
|
delta = "0.4.1 -> 0.4.2"
|
||||||
|
notes = """
|
||||||
|
Only updates to tests to use new rust functions for mut pointers.
|
||||||
|
"""
|
||||||
|
|
||||||
[[audits.memory_units]]
|
[[audits.memory_units]]
|
||||||
who = "Alex Crichton <alex@alexcrichton.com>"
|
who = "Alex Crichton <alex@alexcrichton.com>"
|
||||||
criteria = "safe-to-run"
|
criteria = "safe-to-run"
|
||||||
|
|||||||
@@ -550,10 +550,6 @@ criteria = "safe-to-run"
|
|||||||
version = "0.3.0"
|
version = "0.3.0"
|
||||||
criteria = "safe-to-deploy"
|
criteria = "safe-to-deploy"
|
||||||
|
|
||||||
[[exemptions.openvino]]
|
|
||||||
version = "0.4.1"
|
|
||||||
criteria = "safe-to-deploy"
|
|
||||||
|
|
||||||
[[exemptions.openvino-finder]]
|
[[exemptions.openvino-finder]]
|
||||||
version = "0.4.1"
|
version = "0.4.1"
|
||||||
criteria = "safe-to-deploy"
|
criteria = "safe-to-deploy"
|
||||||
|
|||||||
Reference in New Issue
Block a user