diff --git a/crates/wiggle/generate/src/codegen_settings.rs b/crates/wiggle/generate/src/codegen_settings.rs index ced7b0cbe4..616a12dd7e 100644 --- a/crates/wiggle/generate/src/codegen_settings.rs +++ b/crates/wiggle/generate/src/codegen_settings.rs @@ -1,4 +1,4 @@ -use crate::config::{AsyncConf, ErrorConf}; +use crate::config::{AsyncConf, ErrorConf, TracingConf}; use anyhow::{anyhow, Error}; use proc_macro2::TokenStream; use quote::quote; @@ -15,7 +15,7 @@ pub struct CodegenSettings { // Disabling this feature makes it possible to remove all of the tracing // code emitted in the Wiggle-generated code; this can be helpful while // inspecting the code (e.g., with `cargo expand`). - pub tracing: bool, + pub tracing: TracingConf, } impl CodegenSettings { pub fn new( @@ -23,14 +23,14 @@ impl CodegenSettings { async_: &AsyncConf, doc: &Document, wasmtime: bool, - tracing: bool, + tracing: &TracingConf, ) -> Result { let errors = ErrorTransform::new(error_conf, doc)?; Ok(Self { errors, async_: async_.clone(), wasmtime, - tracing, + tracing: tracing.clone(), }) } pub fn get_async(&self, module: &Module, func: &InterfaceFunc) -> Asyncness { diff --git a/crates/wiggle/generate/src/config.rs b/crates/wiggle/generate/src/config.rs index 4c71ea7f3e..1eedc1840e 100644 --- a/crates/wiggle/generate/src/config.rs +++ b/crates/wiggle/generate/src/config.rs @@ -15,7 +15,7 @@ pub struct Config { pub errors: ErrorConf, pub async_: AsyncConf, pub wasmtime: bool, - pub tracing: bool, + pub tracing: TracingConf, } mod kw { @@ -26,6 +26,7 @@ mod kw { syn::custom_keyword!(target); syn::custom_keyword!(wasmtime); syn::custom_keyword!(tracing); + syn::custom_keyword!(disable_for); } #[derive(Debug, Clone)] @@ -34,7 +35,7 @@ pub enum ConfigField { Error(ErrorConf), Async(AsyncConf), Wasmtime(bool), - Tracing(bool), + Tracing(TracingConf), } impl Parse for ConfigField { @@ -73,7 +74,7 @@ impl Parse for ConfigField { } else if lookahead.peek(kw::tracing) { input.parse::()?; input.parse::()?; - Ok(ConfigField::Tracing(input.parse::()?.value)) + Ok(ConfigField::Tracing(input.parse()?)) } else { Err(lookahead.error()) } @@ -128,7 +129,7 @@ impl Config { errors: errors.take().unwrap_or_default(), async_: async_.take().unwrap_or_default(), wasmtime: wasmtime.unwrap_or(true), - tracing: tracing.unwrap_or(true), + tracing: tracing.unwrap_or_default(), }) } @@ -409,7 +410,7 @@ impl Parse for AsyncFunctions { let lookahead = input.lookahead1(); if lookahead.peek(syn::token::Brace) { let _ = braced!(content in input); - let items: Punctuated = + let items: Punctuated = content.parse_terminated(Parse::parse)?; let mut functions: HashMap> = HashMap::new(); use std::collections::hash_map::Entry; @@ -437,13 +438,13 @@ impl Parse for AsyncFunctions { } #[derive(Clone)] -pub struct AsyncConfField { +pub struct FunctionField { pub module_name: Ident, pub function_names: Vec, pub err_loc: Span, } -impl Parse for AsyncConfField { +impl Parse for FunctionField { fn parse(input: ParseStream) -> Result { let err_loc = input.span(); let module_name = input.parse::()?; @@ -454,14 +455,14 @@ impl Parse for AsyncConfField { let _ = braced!(content in input); let function_names: Punctuated = content.parse_terminated(Parse::parse)?; - Ok(AsyncConfField { + Ok(FunctionField { module_name, function_names: function_names.iter().cloned().collect(), err_loc, }) } else if lookahead.peek(Ident) { let name = input.parse()?; - Ok(AsyncConfField { + Ok(FunctionField { module_name, function_names: vec![name], err_loc, @@ -565,3 +566,70 @@ impl Parse for WasmtimeConfigField { } } } + +#[derive(Clone, Debug)] +pub struct TracingConf { + enabled: bool, + excluded_functions: HashMap>, +} + +impl TracingConf { + pub fn enabled_for(&self, module: &str, function: &str) -> bool { + if !self.enabled { + return false; + } + self.excluded_functions + .get(module) + .and_then(|fs| fs.iter().find(|f| *f == function)) + .is_none() + } +} + +impl Default for TracingConf { + fn default() -> Self { + Self { + enabled: true, + excluded_functions: HashMap::new(), + } + } +} + +impl Parse for TracingConf { + fn parse(input: ParseStream) -> Result { + let enabled = input.parse::()?.value; + + let lookahead = input.lookahead1(); + if lookahead.peek(kw::disable_for) { + input.parse::()?; + let content; + let _ = braced!(content in input); + let items: Punctuated = + content.parse_terminated(Parse::parse)?; + let mut functions: HashMap> = HashMap::new(); + use std::collections::hash_map::Entry; + for i in items { + let function_names = i + .function_names + .iter() + .map(|i| i.to_string()) + .collect::>(); + match functions.entry(i.module_name.to_string()) { + Entry::Occupied(o) => o.into_mut().extend(function_names), + Entry::Vacant(v) => { + v.insert(function_names); + } + } + } + + Ok(TracingConf { + enabled, + excluded_functions: functions, + }) + } else { + Ok(TracingConf { + enabled, + excluded_functions: HashMap::new(), + }) + } + } +} diff --git a/crates/wiggle/generate/src/funcs.rs b/crates/wiggle/generate/src/funcs.rs index f124d066a4..bd1f1eebdf 100644 --- a/crates/wiggle/generate/src/funcs.rs +++ b/crates/wiggle/generate/src/funcs.rs @@ -84,7 +84,7 @@ fn _define_func( ); ); if settings.get_async(&module, &func).is_sync() { - let traced_body = if settings.tracing { + let traced_body = if settings.tracing.enabled_for(&mod_name, &func_name) { quote!( #mk_span _span.in_scope(|| { @@ -109,7 +109,7 @@ fn _define_func( bounds, ) } else { - let traced_body = if settings.tracing { + let traced_body = if settings.tracing.enabled_for(&mod_name, &func_name) { quote!( use #rt::tracing::Instrument as _; #mk_span @@ -261,7 +261,12 @@ impl witx::Bindgen for Rust<'_> { args.push(quote!(#name)); } } - if self.settings.tracing && func.params.len() > 0 { + if self + .settings + .tracing + .enabled_for(self.module.name.as_str(), self.funcname) + && func.params.len() > 0 + { let args = func .params .iter() @@ -290,7 +295,11 @@ impl witx::Bindgen for Rust<'_> { let ret = #trait_name::#ident(ctx, #(#args),*).await; }) }; - if self.settings.tracing { + if self + .settings + .tracing + .enabled_for(self.module.name.as_str(), self.funcname) + { self.src.extend(quote! { #rt::tracing::event!( #rt::tracing::Level::TRACE, diff --git a/crates/wiggle/macro/src/lib.rs b/crates/wiggle/macro/src/lib.rs index 72370593a2..11bbbfc01d 100644 --- a/crates/wiggle/macro/src/lib.rs +++ b/crates/wiggle/macro/src/lib.rs @@ -153,7 +153,7 @@ pub fn from_witx(args: TokenStream) -> TokenStream { &config.async_, &doc, config.wasmtime, - config.tracing, + &config.tracing, ) .expect("validating codegen settings"); @@ -195,7 +195,7 @@ pub fn wasmtime_integration(args: TokenStream) -> TokenStream { &config.c.async_, &doc, true, - config.c.tracing, + &config.c.tracing, ) .expect("validating codegen settings"); diff --git a/crates/wiggle/test-helpers/examples/tracing.rs b/crates/wiggle/test-helpers/examples/tracing.rs index 07c361d985..8de56c8dba 100644 --- a/crates/wiggle/test-helpers/examples/tracing.rs +++ b/crates/wiggle/test-helpers/examples/tracing.rs @@ -15,6 +15,9 @@ pub enum RichError { // Define an errno with variants corresponding to RichError. Use it in a // trivial function. wiggle::from_witx!({ + tracing: true disable_for { + one_error_conversion::foo, + }, witx_literal: " (typename $errno (enum (@witx tag u8) $ok $invalid_arg $picket_line)) (typename $s (record (field $f1 (@witx usize)) (field $f2 (@witx pointer u8))))