Merge pull request #3271 from bytecodealliance/pch/flexible_ser_module_versioning

More flexible versioning for module serialization
This commit is contained in:
Pat Hickey
2021-09-02 12:51:03 -07:00
committed by GitHub
5 changed files with 91 additions and 52 deletions

View File

@@ -330,6 +330,23 @@ impl Default for InstanceAllocationStrategy {
} }
} }
#[derive(Clone)]
/// Configure the strategy used for versioning in serializing and deserializing [`crate::Module`].
pub enum ModuleVersionStrategy {
/// Use the wasmtime crate's Cargo package version.
WasmtimeVersion,
/// Use a custom version string. Must be at most 255 bytes.
Custom(String),
/// Emit no version string in serialization, and accept all version strings in deserialization.
None,
}
impl Default for ModuleVersionStrategy {
fn default() -> Self {
ModuleVersionStrategy::WasmtimeVersion
}
}
/// Global configuration options used to create an [`Engine`](crate::Engine) /// Global configuration options used to create an [`Engine`](crate::Engine)
/// and customize its behavior. /// and customize its behavior.
/// ///
@@ -350,7 +367,7 @@ pub struct Config {
#[cfg(feature = "async")] #[cfg(feature = "async")]
pub(crate) async_stack_size: usize, pub(crate) async_stack_size: usize,
pub(crate) async_support: bool, pub(crate) async_support: bool,
pub(crate) deserialize_check_wasmtime_version: bool, pub(crate) module_version: ModuleVersionStrategy,
pub(crate) parallel_compilation: bool, pub(crate) parallel_compilation: bool,
pub(crate) paged_memory_initialization: bool, pub(crate) paged_memory_initialization: bool,
} }
@@ -374,7 +391,7 @@ impl Config {
#[cfg(feature = "async")] #[cfg(feature = "async")]
async_stack_size: 2 << 20, async_stack_size: 2 << 20,
async_support: false, async_support: false,
deserialize_check_wasmtime_version: true, module_version: ModuleVersionStrategy::default(),
parallel_compilation: true, parallel_compilation: true,
// Default to paged memory initialization when using uffd on linux // Default to paged memory initialization when using uffd on linux
paged_memory_initialization: cfg!(all(target_os = "linux", feature = "uffd")), paged_memory_initialization: cfg!(all(target_os = "linux", feature = "uffd")),
@@ -1254,18 +1271,23 @@ impl Config {
self self
} }
/// Configure whether deserialized modules should validate version /// Configure the version information used in serialized and deserialzied [`crate::Module`]s.
/// information. This only effects [`crate::Module::deserialize()`], which is /// This effects the behavior of [`crate::Module::serialize()`], as well as
/// used to load compiled code from trusted sources. When true, /// [`crate::Module::deserialize()`] and related functions.
/// [`crate::Module::deserialize()`] verifies that the wasmtime crate's
/// `CARGO_PKG_VERSION` matches with the version in the binary, which was
/// produced by [`crate::Module::serialize`] or
/// [`crate::Engine::precompile_module`].
/// ///
/// This value defaults to true. /// The default strategy is to use the wasmtime crate's Cargo package version.
pub fn deserialize_check_wasmtime_version(&mut self, check: bool) -> &mut Self { pub fn module_version(&mut self, strategy: ModuleVersionStrategy) -> Result<&mut Self> {
self.deserialize_check_wasmtime_version = check; match strategy {
self // This case requires special precondition for assertion in SerializedModule::to_bytes
ModuleVersionStrategy::Custom(ref v) => {
if v.as_bytes().len() > 255 {
bail!("custom module version cannot be more than 255 bytes: {}", v);
}
}
_ => {}
}
self.module_version = strategy;
Ok(self)
} }
/// Configure wether wasmtime should compile a module using multiple threads. /// Configure wether wasmtime should compile a module using multiple threads.
@@ -1351,7 +1373,7 @@ impl Clone for Config {
async_support: self.async_support, async_support: self.async_support,
#[cfg(feature = "async")] #[cfg(feature = "async")]
async_stack_size: self.async_stack_size, async_stack_size: self.async_stack_size,
deserialize_check_wasmtime_version: self.deserialize_check_wasmtime_version, module_version: self.module_version.clone(),
parallel_compilation: self.parallel_compilation, parallel_compilation: self.parallel_compilation,
paged_memory_initialization: self.paged_memory_initialization, paged_memory_initialization: self.paged_memory_initialization,
} }

View File

@@ -148,7 +148,8 @@ impl Engine {
let bytes = wat::parse_bytes(&bytes)?; let bytes = wat::parse_bytes(&bytes)?;
let (_, artifacts, types) = crate::Module::build_artifacts(self, &bytes)?; let (_, artifacts, types) = crate::Module::build_artifacts(self, &bytes)?;
let artifacts = artifacts.into_iter().map(|i| i.0).collect::<Vec<_>>(); let artifacts = artifacts.into_iter().map(|i| i.0).collect::<Vec<_>>();
crate::module::SerializedModule::from_artifacts(self, &artifacts, &types).to_bytes() crate::module::SerializedModule::from_artifacts(self, &artifacts, &types)
.to_bytes(&self.config().module_version)
} }
pub(crate) fn run_maybe_parallel< pub(crate) fn run_maybe_parallel<

View File

@@ -316,12 +316,12 @@ impl Module {
engine.0, engine.0,
artifacts.iter().map(|p| &p.0), artifacts.iter().map(|p| &p.0),
types, types,
).to_bytes().ok() ).to_bytes(&engine.0.config().module_version).ok()
}, },
// Cache hit, deserialize the provided artifacts // Cache hit, deserialize the provided artifacts
|(engine, _wasm), serialized_bytes| { |(engine, _wasm), serialized_bytes| {
let (i, m, t, upvars) = SerializedModule::from_bytes(&serialized_bytes, true) let (i, m, t, upvars) = SerializedModule::from_bytes(&serialized_bytes, &engine.0.config().module_version)
.ok()? .ok()?
.into_parts(engine.0) .into_parts(engine.0)
.ok()?; .ok()?;
@@ -467,10 +467,7 @@ impl Module {
/// blobs across versions of wasmtime you can be safely guaranteed that /// blobs across versions of wasmtime you can be safely guaranteed that
/// future versions of wasmtime will reject old cache entries). /// future versions of wasmtime will reject old cache entries).
pub unsafe fn deserialize(engine: &Engine, bytes: impl AsRef<[u8]>) -> Result<Module> { pub unsafe fn deserialize(engine: &Engine, bytes: impl AsRef<[u8]>) -> Result<Module> {
let module = SerializedModule::from_bytes( let module = SerializedModule::from_bytes(bytes.as_ref(), &engine.config().module_version)?;
bytes.as_ref(),
engine.config().deserialize_check_wasmtime_version,
)?;
module.into_module(engine) module.into_module(engine)
} }
@@ -486,10 +483,7 @@ impl Module {
/// ///
/// [`deserialize`]: Module::deserialize /// [`deserialize`]: Module::deserialize
pub unsafe fn deserialize_file(engine: &Engine, path: impl AsRef<Path>) -> Result<Module> { pub unsafe fn deserialize_file(engine: &Engine, path: impl AsRef<Path>) -> Result<Module> {
let module = SerializedModule::from_file( let module = SerializedModule::from_file(path.as_ref(), &engine.config().module_version)?;
path.as_ref(),
engine.config().deserialize_check_wasmtime_version,
)?;
module.into_module(engine) module.into_module(engine)
} }
@@ -625,7 +619,7 @@ impl Module {
#[cfg(compiler)] #[cfg(compiler)]
#[cfg_attr(nightlydoc, doc(cfg(feature = "cranelift")))] // see build.rs #[cfg_attr(nightlydoc, doc(cfg(feature = "cranelift")))] // see build.rs
pub fn serialize(&self) -> Result<Vec<u8>> { pub fn serialize(&self) -> Result<Vec<u8>> {
SerializedModule::new(self).to_bytes() SerializedModule::new(self).to_bytes(&self.inner.engine.config().module_version)
} }
/// Creates a submodule `Module` value from the specified parameters. /// Creates a submodule `Module` value from the specified parameters.

View File

@@ -48,7 +48,7 @@
//! //!
//! This format is implemented by the `to_bytes` and `from_mmap` function. //! This format is implemented by the `to_bytes` and `from_mmap` function.
use crate::{Engine, Module}; use crate::{Engine, Module, ModuleVersionStrategy};
use anyhow::{anyhow, bail, Context, Result}; use anyhow::{anyhow, bail, Context, Result};
use object::read::elf::FileHeader; use object::read::elf::FileHeader;
use object::{Bytes, File, Object, ObjectSection}; use object::{Bytes, File, Object, ObjectSection};
@@ -325,7 +325,7 @@ impl<'a> SerializedModule<'a> {
)) ))
} }
pub fn to_bytes(&self) -> Result<Vec<u8>> { pub fn to_bytes(&self, version_strat: &ModuleVersionStrategy) -> Result<Vec<u8>> {
// First up, create a linked-ish list of ELF files. For more // First up, create a linked-ish list of ELF files. For more
// information on this format, see the doc comment on this module. // information on this format, see the doc comment on this module.
// The only semi-tricky bit here is that we leave an // The only semi-tricky bit here is that we leave an
@@ -352,7 +352,12 @@ impl<'a> SerializedModule<'a> {
// The last part of our artifact is the bincode-encoded `Metadata` // The last part of our artifact is the bincode-encoded `Metadata`
// section with a few other guards to help give better error messages. // section with a few other guards to help give better error messages.
ret.extend_from_slice(HEADER); ret.extend_from_slice(HEADER);
let version = env!("CARGO_PKG_VERSION"); let version = match version_strat {
ModuleVersionStrategy::WasmtimeVersion => env!("CARGO_PKG_VERSION"),
ModuleVersionStrategy::Custom(c) => &c,
ModuleVersionStrategy::None => "",
};
// This precondition is checked in Config::module_version:
assert!( assert!(
version.len() < 256, version.len() < 256,
"package version must be less than 256 bytes" "package version must be less than 256 bytes"
@@ -364,20 +369,20 @@ impl<'a> SerializedModule<'a> {
Ok(ret) Ok(ret)
} }
pub fn from_bytes(bytes: &[u8], check_version: bool) -> Result<Self> { pub fn from_bytes(bytes: &[u8], version_strat: &ModuleVersionStrategy) -> Result<Self> {
Self::from_mmap(MmapVec::from_slice(bytes)?, check_version) Self::from_mmap(MmapVec::from_slice(bytes)?, version_strat)
} }
pub fn from_file(path: &Path, check_version: bool) -> Result<Self> { pub fn from_file(path: &Path, version_strat: &ModuleVersionStrategy) -> Result<Self> {
Self::from_mmap( Self::from_mmap(
MmapVec::from_file(path).with_context(|| { MmapVec::from_file(path).with_context(|| {
format!("failed to create file mapping for: {}", path.display()) format!("failed to create file mapping for: {}", path.display())
})?, })?,
check_version, version_strat,
) )
} }
pub fn from_mmap(mut mmap: MmapVec, check_version: bool) -> Result<Self> { pub fn from_mmap(mut mmap: MmapVec, version_strat: &ModuleVersionStrategy) -> Result<Self> {
// Artifacts always start with an ELF file, so read that first. // Artifacts always start with an ELF file, so read that first.
// Afterwards we continually read ELF files until we see the `u64::MAX` // Afterwards we continually read ELF files until we see the `u64::MAX`
// marker, meaning we've reached the end. // marker, meaning we've reached the end.
@@ -419,7 +424,8 @@ impl<'a> SerializedModule<'a> {
bail!("serialized data is malformed"); bail!("serialized data is malformed");
} }
if check_version { match version_strat {
ModuleVersionStrategy::WasmtimeVersion => {
let version = std::str::from_utf8(&metadata[1..1 + version_len])?; let version = std::str::from_utf8(&metadata[1..1 + version_len])?;
if version != env!("CARGO_PKG_VERSION") { if version != env!("CARGO_PKG_VERSION") {
bail!( bail!(
@@ -428,6 +434,17 @@ impl<'a> SerializedModule<'a> {
); );
} }
} }
ModuleVersionStrategy::Custom(v) => {
let version = std::str::from_utf8(&metadata[1..1 + version_len])?;
if version != v {
bail!(
"Module was compiled with incompatible version '{}'",
version
);
}
}
ModuleVersionStrategy::None => { /* ignore the version info, accept all */ }
}
let metadata = bincode::deserialize::<Metadata>(&metadata[1 + version_len..]) let metadata = bincode::deserialize::<Metadata>(&metadata[1 + version_len..])
.context("deserialize compilation artifacts")?; .context("deserialize compilation artifacts")?;

View File

@@ -15,24 +15,29 @@ unsafe fn deserialize_and_instantiate(store: &mut Store<()>, buffer: &[u8]) -> R
#[test] #[test]
fn test_version_mismatch() -> Result<()> { fn test_version_mismatch() -> Result<()> {
let engine = Engine::default(); let engine = Engine::default();
let mut buffer = serialize(&engine, "(module)")?; let buffer = serialize(&engine, "(module)")?;
const HEADER: &[u8] = b"\0wasmtime-aot";
let pos = memchr::memmem::rfind_iter(&buffer, HEADER).next().unwrap();
buffer[pos + HEADER.len() + 1 /* version length */] = 'x' as u8;
match unsafe { Module::deserialize(&engine, &buffer) } { let mut config = Config::new();
config
.module_version(ModuleVersionStrategy::Custom("custom!".to_owned()))
.unwrap();
let custom_version_engine = Engine::new(&config).unwrap();
match unsafe { Module::deserialize(&custom_version_engine, &buffer) } {
Ok(_) => bail!("expected deserialization to fail"), Ok(_) => bail!("expected deserialization to fail"),
Err(e) => assert!(e Err(e) => assert!(e
.to_string() .to_string()
.starts_with("Module was compiled with incompatible Wasmtime version")), .starts_with("Module was compiled with incompatible version")),
} }
// Test deserialize_check_wasmtime_version, which disables the logic which rejects the above.
let mut config = Config::new(); let mut config = Config::new();
config.deserialize_check_wasmtime_version(false); config.module_version(ModuleVersionStrategy::None).unwrap();
let engine = Engine::new(&config).unwrap(); let none_version_engine = Engine::new(&config).unwrap();
unsafe { Module::deserialize(&engine, &buffer) } unsafe { Module::deserialize(&none_version_engine, &buffer) }
.expect("module with corrupt version should deserialize when check is disabled"); .expect("accepts the wasmtime versioned module");
let buffer = serialize(&custom_version_engine, "(module)")?;
unsafe { Module::deserialize(&none_version_engine, &buffer) }
.expect("accepts the custom versioned module");
Ok(()) Ok(())
} }