diff --git a/crates/wasmtime/src/config.rs b/crates/wasmtime/src/config.rs index ebf6c8aa7f..73e8cd379a 100644 --- a/crates/wasmtime/src/config.rs +++ b/crates/wasmtime/src/config.rs @@ -330,6 +330,23 @@ impl Default for InstanceAllocationStrategy { } } +#[derive(Clone)] +/// Configure the strategy used for versioning in serializing and deserializing [`crate::Module`]. +pub enum ModuleVersionStrategy { + /// Use the wasmtime crate's Cargo package version. + WasmtimeVersion, + /// Use a custom version string. Must be at most 255 bytes. + Custom(String), + /// Emit no version string in serialization, and accept all version strings in deserialization. + None, +} + +impl Default for ModuleVersionStrategy { + fn default() -> Self { + ModuleVersionStrategy::WasmtimeVersion + } +} + /// Global configuration options used to create an [`Engine`](crate::Engine) /// and customize its behavior. /// @@ -350,7 +367,7 @@ pub struct Config { #[cfg(feature = "async")] pub(crate) async_stack_size: usize, pub(crate) async_support: bool, - pub(crate) deserialize_check_wasmtime_version: bool, + pub(crate) module_version: ModuleVersionStrategy, pub(crate) parallel_compilation: bool, pub(crate) paged_memory_initialization: bool, } @@ -374,7 +391,7 @@ impl Config { #[cfg(feature = "async")] async_stack_size: 2 << 20, async_support: false, - deserialize_check_wasmtime_version: true, + module_version: ModuleVersionStrategy::default(), parallel_compilation: true, // Default to paged memory initialization when using uffd on linux paged_memory_initialization: cfg!(all(target_os = "linux", feature = "uffd")), @@ -1254,18 +1271,23 @@ impl Config { self } - /// Configure whether deserialized modules should validate version - /// information. This only effects [`crate::Module::deserialize()`], which is - /// used to load compiled code from trusted sources. When true, - /// [`crate::Module::deserialize()`] verifies that the wasmtime crate's - /// `CARGO_PKG_VERSION` matches with the version in the binary, which was - /// produced by [`crate::Module::serialize`] or - /// [`crate::Engine::precompile_module`]. + /// Configure the version information used in serialized and deserialzied [`crate::Module`]s. + /// This effects the behavior of [`crate::Module::serialize()`], as well as + /// [`crate::Module::deserialize()`] and related functions. /// - /// This value defaults to true. - pub fn deserialize_check_wasmtime_version(&mut self, check: bool) -> &mut Self { - self.deserialize_check_wasmtime_version = check; - self + /// The default strategy is to use the wasmtime crate's Cargo package version. + pub fn module_version(&mut self, strategy: ModuleVersionStrategy) -> Result<&mut Self> { + match strategy { + // This case requires special precondition for assertion in SerializedModule::to_bytes + ModuleVersionStrategy::Custom(ref v) => { + if v.as_bytes().len() > 255 { + bail!("custom module version cannot be more than 255 bytes: {}", v); + } + } + _ => {} + } + self.module_version = strategy; + Ok(self) } /// Configure wether wasmtime should compile a module using multiple threads. @@ -1351,7 +1373,7 @@ impl Clone for Config { async_support: self.async_support, #[cfg(feature = "async")] async_stack_size: self.async_stack_size, - deserialize_check_wasmtime_version: self.deserialize_check_wasmtime_version, + module_version: self.module_version.clone(), parallel_compilation: self.parallel_compilation, paged_memory_initialization: self.paged_memory_initialization, } diff --git a/crates/wasmtime/src/engine.rs b/crates/wasmtime/src/engine.rs index 8eab0e28bf..9e0f70187b 100644 --- a/crates/wasmtime/src/engine.rs +++ b/crates/wasmtime/src/engine.rs @@ -148,7 +148,8 @@ impl Engine { let bytes = wat::parse_bytes(&bytes)?; let (_, artifacts, types) = crate::Module::build_artifacts(self, &bytes)?; let artifacts = artifacts.into_iter().map(|i| i.0).collect::>(); - crate::module::SerializedModule::from_artifacts(self, &artifacts, &types).to_bytes() + crate::module::SerializedModule::from_artifacts(self, &artifacts, &types) + .to_bytes(&self.config().module_version) } pub(crate) fn run_maybe_parallel< diff --git a/crates/wasmtime/src/module.rs b/crates/wasmtime/src/module.rs index 6c97223ff0..e8a0174697 100644 --- a/crates/wasmtime/src/module.rs +++ b/crates/wasmtime/src/module.rs @@ -316,12 +316,12 @@ impl Module { engine.0, artifacts.iter().map(|p| &p.0), types, - ).to_bytes().ok() + ).to_bytes(&engine.0.config().module_version).ok() }, // Cache hit, deserialize the provided artifacts |(engine, _wasm), serialized_bytes| { - let (i, m, t, upvars) = SerializedModule::from_bytes(&serialized_bytes, true) + let (i, m, t, upvars) = SerializedModule::from_bytes(&serialized_bytes, &engine.0.config().module_version) .ok()? .into_parts(engine.0) .ok()?; @@ -467,10 +467,7 @@ impl Module { /// blobs across versions of wasmtime you can be safely guaranteed that /// future versions of wasmtime will reject old cache entries). pub unsafe fn deserialize(engine: &Engine, bytes: impl AsRef<[u8]>) -> Result { - let module = SerializedModule::from_bytes( - bytes.as_ref(), - engine.config().deserialize_check_wasmtime_version, - )?; + let module = SerializedModule::from_bytes(bytes.as_ref(), &engine.config().module_version)?; module.into_module(engine) } @@ -486,10 +483,7 @@ impl Module { /// /// [`deserialize`]: Module::deserialize pub unsafe fn deserialize_file(engine: &Engine, path: impl AsRef) -> Result { - let module = SerializedModule::from_file( - path.as_ref(), - engine.config().deserialize_check_wasmtime_version, - )?; + let module = SerializedModule::from_file(path.as_ref(), &engine.config().module_version)?; module.into_module(engine) } @@ -625,7 +619,7 @@ impl Module { #[cfg(compiler)] #[cfg_attr(nightlydoc, doc(cfg(feature = "cranelift")))] // see build.rs pub fn serialize(&self) -> Result> { - SerializedModule::new(self).to_bytes() + SerializedModule::new(self).to_bytes(&self.inner.engine.config().module_version) } /// Creates a submodule `Module` value from the specified parameters. diff --git a/crates/wasmtime/src/module/serialization.rs b/crates/wasmtime/src/module/serialization.rs index 69087b4baf..71e746cb27 100644 --- a/crates/wasmtime/src/module/serialization.rs +++ b/crates/wasmtime/src/module/serialization.rs @@ -48,7 +48,7 @@ //! //! This format is implemented by the `to_bytes` and `from_mmap` function. -use crate::{Engine, Module}; +use crate::{Engine, Module, ModuleVersionStrategy}; use anyhow::{anyhow, bail, Context, Result}; use object::read::elf::FileHeader; use object::{Bytes, File, Object, ObjectSection}; @@ -325,7 +325,7 @@ impl<'a> SerializedModule<'a> { )) } - pub fn to_bytes(&self) -> Result> { + pub fn to_bytes(&self, version_strat: &ModuleVersionStrategy) -> Result> { // First up, create a linked-ish list of ELF files. For more // information on this format, see the doc comment on this module. // The only semi-tricky bit here is that we leave an @@ -352,7 +352,12 @@ impl<'a> SerializedModule<'a> { // The last part of our artifact is the bincode-encoded `Metadata` // section with a few other guards to help give better error messages. ret.extend_from_slice(HEADER); - let version = env!("CARGO_PKG_VERSION"); + let version = match version_strat { + ModuleVersionStrategy::WasmtimeVersion => env!("CARGO_PKG_VERSION"), + ModuleVersionStrategy::Custom(c) => &c, + ModuleVersionStrategy::None => "", + }; + // This precondition is checked in Config::module_version: assert!( version.len() < 256, "package version must be less than 256 bytes" @@ -364,20 +369,20 @@ impl<'a> SerializedModule<'a> { Ok(ret) } - pub fn from_bytes(bytes: &[u8], check_version: bool) -> Result { - Self::from_mmap(MmapVec::from_slice(bytes)?, check_version) + pub fn from_bytes(bytes: &[u8], version_strat: &ModuleVersionStrategy) -> Result { + Self::from_mmap(MmapVec::from_slice(bytes)?, version_strat) } - pub fn from_file(path: &Path, check_version: bool) -> Result { + pub fn from_file(path: &Path, version_strat: &ModuleVersionStrategy) -> Result { Self::from_mmap( MmapVec::from_file(path).with_context(|| { format!("failed to create file mapping for: {}", path.display()) })?, - check_version, + version_strat, ) } - pub fn from_mmap(mut mmap: MmapVec, check_version: bool) -> Result { + pub fn from_mmap(mut mmap: MmapVec, version_strat: &ModuleVersionStrategy) -> Result { // Artifacts always start with an ELF file, so read that first. // Afterwards we continually read ELF files until we see the `u64::MAX` // marker, meaning we've reached the end. @@ -419,14 +424,26 @@ impl<'a> SerializedModule<'a> { bail!("serialized data is malformed"); } - if check_version { - let version = std::str::from_utf8(&metadata[1..1 + version_len])?; - if version != env!("CARGO_PKG_VERSION") { - bail!( - "Module was compiled with incompatible Wasmtime version '{}'", - version - ); + match version_strat { + ModuleVersionStrategy::WasmtimeVersion => { + let version = std::str::from_utf8(&metadata[1..1 + version_len])?; + if version != env!("CARGO_PKG_VERSION") { + bail!( + "Module was compiled with incompatible Wasmtime version '{}'", + version + ); + } } + ModuleVersionStrategy::Custom(v) => { + let version = std::str::from_utf8(&metadata[1..1 + version_len])?; + if version != v { + bail!( + "Module was compiled with incompatible version '{}'", + version + ); + } + } + ModuleVersionStrategy::None => { /* ignore the version info, accept all */ } } let metadata = bincode::deserialize::(&metadata[1 + version_len..]) diff --git a/tests/all/module_serialize.rs b/tests/all/module_serialize.rs index 1e9aa9eff0..42aadc6638 100644 --- a/tests/all/module_serialize.rs +++ b/tests/all/module_serialize.rs @@ -15,24 +15,29 @@ unsafe fn deserialize_and_instantiate(store: &mut Store<()>, buffer: &[u8]) -> R #[test] fn test_version_mismatch() -> Result<()> { let engine = Engine::default(); - let mut buffer = serialize(&engine, "(module)")?; - const HEADER: &[u8] = b"\0wasmtime-aot"; - let pos = memchr::memmem::rfind_iter(&buffer, HEADER).next().unwrap(); - buffer[pos + HEADER.len() + 1 /* version length */] = 'x' as u8; + let buffer = serialize(&engine, "(module)")?; - match unsafe { Module::deserialize(&engine, &buffer) } { + let mut config = Config::new(); + config + .module_version(ModuleVersionStrategy::Custom("custom!".to_owned())) + .unwrap(); + let custom_version_engine = Engine::new(&config).unwrap(); + match unsafe { Module::deserialize(&custom_version_engine, &buffer) } { Ok(_) => bail!("expected deserialization to fail"), Err(e) => assert!(e .to_string() - .starts_with("Module was compiled with incompatible Wasmtime version")), + .starts_with("Module was compiled with incompatible version")), } - // Test deserialize_check_wasmtime_version, which disables the logic which rejects the above. let mut config = Config::new(); - config.deserialize_check_wasmtime_version(false); - let engine = Engine::new(&config).unwrap(); - unsafe { Module::deserialize(&engine, &buffer) } - .expect("module with corrupt version should deserialize when check is disabled"); + config.module_version(ModuleVersionStrategy::None).unwrap(); + let none_version_engine = Engine::new(&config).unwrap(); + unsafe { Module::deserialize(&none_version_engine, &buffer) } + .expect("accepts the wasmtime versioned module"); + + let buffer = serialize(&custom_version_engine, "(module)")?; + unsafe { Module::deserialize(&none_version_engine, &buffer) } + .expect("accepts the custom versioned module"); Ok(()) }