Files
wasmtime/crates/wasi-nn/src/impl.rs
Andrew Brown 7717d8fa55 wiggle: adapt Wiggle guest slices for unsafe shared use (#5229)
* wiggle: adapt Wiggle guest slices for `unsafe` shared use

When multiple threads can concurrently modify a WebAssembly shared
memory, the underlying data for a Wiggle `GuestSlice` and
`GuestSliceMut` could change due to access from other threads. This
breaks Rust guarantees when `&[T]` and `&mut [T]` slices are handed out.
This change modifies `GuestPtr` to make `as_slice` and `as_slice_mut`
return an `Option` which is `None` when the underlying WebAssembly
memory is shared.

But WASI implementations still need access to the underlying WebAssembly
memory, both to read to it and write from it. This change adds new APIs:
- `GuestPtr::to_vec` copies the  bytes from WebAssembly memory (from
  which we can safely take a `&[T]`)
- `GuestPtr::as_unsafe_slice_mut` returns a wrapper `struct` from which
  we can  `unsafe`-ly return a mutable slice (users must accept the
  unsafety of concurrently modifying a `&mut [T]`)

This approach allows us to maintain Wiggle's borrow-checking
infrastructure, which enforces the guarantee that Wiggle will not modify
overlapping regions, e.g. This is important because the underlying
system calls may expect this. Though other threads may modify the same
underlying region, this is impossible to prevent; at least Wiggle will
not be able to do so.

Finally, the changes to Wiggle's API are propagated to all WASI
implementations in Wasmtime. For now, code locations that attempt to get
a guest slice will panic if the underlying memory is shared. Note that
Wiggle is not enabled for shared memory (that will come later in
something like #5054), but when it is, these panics will be clear
indicators of locations that must be re-implemented in a thread-safe
way.

* review: remove double cast

* review: refactor to include more logic in 'UnsafeGuestSlice'

* review: add reference to #4203

* review: link all thread-safe WASI fixups to #5235

* fix: consume 'UnsafeGuestSlice' during conversion to safe versions

* review: remove 'as_slice' and 'as_slice_mut'

* review: use 'as_unsafe_slice_mut' in 'to_vec'

* review: add `UnsafeBorrowResult`
2022-11-10 21:54:52 +00:00

94 lines
3.2 KiB
Rust

//! Implements the wasi-nn API.
use crate::ctx::WasiNnResult as Result;
use crate::witx::types::{
ExecutionTarget, Graph, GraphBuilderArray, GraphEncoding, GraphExecutionContext, Tensor,
};
use crate::witx::wasi_ephemeral_nn::WasiEphemeralNn;
use crate::WasiNnCtx;
use thiserror::Error;
use wiggle::GuestPtr;
#[derive(Debug, Error)]
pub enum UsageError {
#[error("Invalid context; has the load function been called?")]
InvalidContext,
#[error("Only OpenVINO's IR is currently supported, passed encoding: {0:?}")]
InvalidEncoding(GraphEncoding),
#[error("OpenVINO expects only two buffers (i.e. [ir, weights]), passed: {0}")]
InvalidNumberOfBuilders(u32),
#[error("Invalid graph handle; has it been loaded?")]
InvalidGraphHandle,
#[error("Invalid execution context handle; has it been initialized?")]
InvalidExecutionContextHandle,
#[error("Not enough memory to copy tensor data of size: {0}")]
NotEnoughMemory(u32),
}
impl<'a> WasiEphemeralNn for WasiNnCtx {
fn load<'b>(
&mut self,
builders: &GraphBuilderArray<'_>,
encoding: GraphEncoding,
target: ExecutionTarget,
) -> Result<Graph> {
let encoding_id: u8 = encoding.into();
let graph = if let Some(backend) = self.backends.get_mut(&encoding_id) {
backend.load(builders, target)?
} else {
return Err(UsageError::InvalidEncoding(encoding).into());
};
let graph_id = self.graphs.insert(graph);
Ok(graph_id)
}
fn init_execution_context(&mut self, graph_id: Graph) -> Result<GraphExecutionContext> {
let exec_context = if let Some(graph) = self.graphs.get_mut(graph_id) {
graph.init_execution_context()?
} else {
return Err(UsageError::InvalidGraphHandle.into());
};
let exec_context_id = self.executions.insert(exec_context);
Ok(exec_context_id)
}
fn set_input<'b>(
&mut self,
exec_context_id: GraphExecutionContext,
index: u32,
tensor: &Tensor<'b>,
) -> Result<()> {
if let Some(exec_context) = self.executions.get_mut(exec_context_id) {
Ok(exec_context.set_input(index, tensor)?)
} else {
Err(UsageError::InvalidGraphHandle.into())
}
}
fn compute(&mut self, exec_context_id: GraphExecutionContext) -> Result<()> {
if let Some(exec_context) = self.executions.get_mut(exec_context_id) {
Ok(exec_context.compute()?)
} else {
Err(UsageError::InvalidExecutionContextHandle.into())
}
}
fn get_output<'b>(
&mut self,
exec_context_id: GraphExecutionContext,
index: u32,
out_buffer: &GuestPtr<'_, u8>,
out_buffer_max_size: u32,
) -> Result<u32> {
if let Some(exec_context) = self.executions.get_mut(exec_context_id) {
let mut destination = out_buffer
.as_array(out_buffer_max_size)
.as_slice_mut()?
.expect("cannot use with shared memories; see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)");
Ok(exec_context.get_output(index, &mut destination)?)
} else {
Err(UsageError::InvalidGraphHandle.into())
}
}
}