From 20bc5ca7a83841b314f3c94b6c703432ba51241f Mon Sep 17 00:00:00 2001 From: Chris Fallin Date: Wed, 15 Sep 2021 00:01:51 -0700 Subject: [PATCH] Support extern constants of any primitive type. --- cranelift/isle/TODO | 2 - cranelift/isle/isle/src/ast.rs | 22 +++- cranelift/isle/isle/src/codegen.rs | 23 ++++ cranelift/isle/isle/src/ir.rs | 41 ++++--- cranelift/isle/isle/src/parser.rs | 47 +++++++- cranelift/isle/isle/src/sema.rs | 140 ++++++++++++++++++------ cranelift/isle/isle_examples/test4.isle | 8 ++ 7 files changed, 222 insertions(+), 61 deletions(-) diff --git a/cranelift/isle/TODO b/cranelift/isle/TODO index 6a94f23175..495ded7b90 100644 --- a/cranelift/isle/TODO +++ b/cranelift/isle/TODO @@ -6,8 +6,6 @@ parse instead where we know the polarity of pattern-term args and parse in-args as exprs. -- Support extern constants. - - Look into whether optimizations are possible: - More in-depth fallibility analysis (avoid failure edges where possible) diff --git a/cranelift/isle/isle/src/ast.rs b/cranelift/isle/isle/src/ast.rs index 9e6e06f6e7..97584b5627 100644 --- a/cranelift/isle/isle/src/ast.rs +++ b/cranelift/isle/isle/src/ast.rs @@ -96,6 +96,8 @@ pub enum Pattern { Var { var: Ident, pos: Pos }, /// An operator that matches a constant integer value. ConstInt { val: i64, pos: Pos }, + /// An operator that matches an external constant value. + ConstPrim { val: Ident, pos: Pos }, /// An application of a type variant or term. Term { sym: Ident, @@ -166,9 +168,10 @@ impl Pattern { } } - &Pattern::Var { .. } | &Pattern::Wildcard { .. } | &Pattern::ConstInt { .. } => { - self.clone() - } + &Pattern::Var { .. } + | &Pattern::Wildcard { .. } + | &Pattern::ConstInt { .. } + | &Pattern::ConstPrim { .. } => self.clone(), &Pattern::MacroArg { .. } => unreachable!(), } } @@ -208,9 +211,10 @@ impl Pattern { } } - &Pattern::Var { .. } | &Pattern::Wildcard { .. } | &Pattern::ConstInt { .. } => { - self.clone() - } + &Pattern::Var { .. } + | &Pattern::Wildcard { .. } + | &Pattern::ConstInt { .. } + | &Pattern::ConstPrim { .. } => self.clone(), &Pattern::MacroArg { index, .. } => macro_args[index].clone(), } } @@ -218,6 +222,7 @@ impl Pattern { pub fn pos(&self) -> Pos { match self { &Pattern::ConstInt { pos, .. } + | &Pattern::ConstPrim { pos, .. } | &Pattern::And { pos, .. } | &Pattern::Term { pos, .. } | &Pattern::BindPattern { pos, .. } @@ -280,6 +285,8 @@ pub enum Expr { Var { name: Ident, pos: Pos }, /// A constant integer. ConstInt { val: i64, pos: Pos }, + /// A constant of some other primitive type. + ConstPrim { val: Ident, pos: Pos }, /// The `(let ((var ty val)*) body)` form. Let { defs: Vec, @@ -294,6 +301,7 @@ impl Expr { &Expr::Term { pos, .. } | &Expr::Var { pos, .. } | &Expr::ConstInt { pos, .. } + | &Expr::ConstPrim { pos, .. } | &Expr::Let { pos, .. } => pos, } } @@ -344,6 +352,8 @@ pub enum Extern { /// The position of this decl. pos: Pos, }, + /// An external constant: `(const $IDENT type)` form. + Const { name: Ident, ty: Ident, pos: Pos }, } #[derive(Clone, Copy, Debug, PartialEq, Eq)] diff --git a/cranelift/isle/isle/src/codegen.rs b/cranelift/isle/isle/src/codegen.rs index 5eed4b7cb4..96d0b4c029 100644 --- a/cranelift/isle/isle/src/codegen.rs +++ b/cranelift/isle/isle/src/codegen.rs @@ -797,6 +797,23 @@ impl<'a> Codegen<'a> { self.const_int(val, ty) )?; } + &ExprInst::ConstPrim { ty, val } => { + let value = Value::Expr { + inst: id, + output: 0, + }; + self.define_val(&value, ctx, /* is_ref = */ false, ty); + let name = self.value_name(&value); + let ty_name = self.type_name(ty, /* by_ref = */ false); + writeln!( + code, + "{}let {}: {} = {};", + indent, + name, + ty_name, + self.typeenv.syms[val.index()], + )?; + } &ExprInst::CreateVariant { ref inputs, ty, @@ -955,6 +972,12 @@ impl<'a> Codegen<'a> { writeln!(code, "{}if {} == {} {{", indent, input, int_val)?; Ok(false) } + &PatternInst::MatchPrim { ref input, val, .. } => { + let input = self.value_by_val(input, ctx); + let sym = &self.typeenv.syms[val.index()]; + writeln!(code, "{}if {} == {} {{", indent, input, sym)?; + Ok(false) + } &PatternInst::MatchVariant { ref input, input_ty, diff --git a/cranelift/isle/isle/src/ir.rs b/cranelift/isle/isle/src/ir.rs index 82cc027d35..9c3171607e 100644 --- a/cranelift/isle/isle/src/ir.rs +++ b/cranelift/isle/isle/src/ir.rs @@ -32,6 +32,9 @@ pub enum PatternInst { int_val: i64, }, + /// Try matching the given value as the given constant. Produces no values. + MatchPrim { input: Value, ty: TypeId, val: Sym }, + /// Try matching the given value as the given variant, producing /// `|arg_tys|` values as output. MatchVariant { @@ -69,6 +72,9 @@ pub enum ExprInst { /// Produce a constant integer. ConstInt { ty: TypeId, val: i64 }, + /// Produce a constant extern value. + ConstPrim { ty: TypeId, val: Sym }, + /// Create a variant. CreateVariant { inputs: Vec<(Value, TypeId)>, @@ -96,6 +102,7 @@ impl ExprInst { pub fn visit_values(&self, mut f: F) { match self { &ExprInst::ConstInt { .. } => {} + &ExprInst::ConstPrim { .. } => {} &ExprInst::Construct { ref inputs, .. } | &ExprInst::CreateVariant { ref inputs, .. } => { for (input, _ty) in inputs { @@ -143,21 +150,6 @@ impl ExprSequence { None } } - - pub fn is_const_variant(&self) -> Option<(TypeId, VariantId)> { - if self.insts.len() == 2 && matches!(&self.insts[1], &ExprInst::Return { .. }) { - match &self.insts[0] { - &ExprInst::CreateVariant { - ref inputs, - ty, - variant, - } if inputs.len() == 0 => Some((ty, variant)), - _ => None, - } - } else { - None - } - } } #[derive(Clone, Copy, Debug)] @@ -196,6 +188,10 @@ impl PatternSequence { self.add_inst(PatternInst::MatchInt { input, ty, int_val }); } + fn add_match_prim(&mut self, input: Value, ty: TypeId, val: Sym) { + self.add_inst(PatternInst::MatchPrim { input, ty, val }); + } + fn add_match_variant( &mut self, input: Value, @@ -290,9 +286,15 @@ impl PatternSequence { // Assert that the value matches the constant integer. let input_val = input .to_value() - .expect("Cannot match an =var pattern against root term"); + .expect("Cannot match an integer pattern against root term"); self.add_match_int(input_val, ty, value); } + &Pattern::ConstPrim(ty, value) => { + let input_val = input + .to_value() + .expect("Cannot match a constant-primitive pattern against root term"); + self.add_match_prim(input_val, ty, value); + } &Pattern::Term(ty, term, ref args) => { match input { ValueOrArgs::ImplicitTermFromArgs(termid) => { @@ -435,6 +437,12 @@ impl ExprSequence { Value::Expr { inst, output: 0 } } + fn add_const_prim(&mut self, ty: TypeId, val: Sym) -> Value { + let inst = InstId(self.insts.len()); + self.add_inst(ExprInst::ConstPrim { ty, val }); + Value::Expr { inst, output: 0 } + } + fn add_create_variant( &mut self, inputs: &[(Value, TypeId)], @@ -490,6 +498,7 @@ impl ExprSequence { log::trace!("gen_expr: expr {:?}", expr); match expr { &Expr::ConstInt(ty, val) => self.add_const_int(ty, val), + &Expr::ConstPrim(ty, val) => self.add_const_prim(ty, val), &Expr::Let(_ty, ref bindings, ref subexpr) => { let mut vars = vars.clone(); for &(var, _var_ty, ref var_expr) in bindings { diff --git a/cranelift/isle/isle/src/parser.rs b/cranelift/isle/isle/src/parser.rs index 77ca738f7d..81cf81c2cf 100644 --- a/cranelift/isle/isle/src/parser.rs +++ b/cranelift/isle/isle/src/parser.rs @@ -72,6 +72,13 @@ impl<'a> Parser<'a> { }) } + fn is_const(&self) -> bool { + self.is(|tok| match tok { + &Token::Symbol(ref tok_s) if tok_s.starts_with("$") => true, + _ => false, + }) + } + fn lparen(&mut self) -> ParseResult<()> { self.take(|tok| *tok == Token::LParen).map(|_| ()) } @@ -129,20 +136,20 @@ impl<'a> Parser<'a> { fn str_to_ident(&self, pos: Pos, s: &str) -> ParseResult { let first = s.chars().next().unwrap(); - if !first.is_alphabetic() && first != '_' { + if !first.is_alphabetic() && first != '_' && first != '$' { return Err(self.error( pos, - format!("Identifier '{}' does not start with letter or _", s), + format!("Identifier '{}' does not start with letter or _ or $", s), )); } if s.chars() .skip(1) - .any(|c| !c.is_alphanumeric() && c != '_' && c != '.') + .any(|c| !c.is_alphanumeric() && c != '_' && c != '.' && c != '$') { return Err(self.error( pos, format!( - "Identifier '{}' contains invalid character (not a-z, A-Z, 0-9, _, .)", + "Identifier '{}' contains invalid character (not a-z, A-Z, 0-9, _, ., $)", s ), )); @@ -156,6 +163,20 @@ impl<'a> Parser<'a> { self.str_to_ident(pos.unwrap(), &s) } + fn parse_const(&mut self) -> ParseResult { + let pos = self.pos(); + let ident = self.parse_ident()?; + if ident.0.starts_with("$") { + let s = &ident.0[1..]; + Ok(Ident(s.to_string(), ident.1)) + } else { + Err(self.error( + pos.unwrap(), + "Not a constant identifier; must start with a '$'".to_string(), + )) + } + } + fn parse_type(&mut self) -> ParseResult { let pos = self.pos(); let name = self.parse_ident()?; @@ -303,6 +324,16 @@ impl<'a> Parser<'a> { arg_polarity, infallible, }) + } else if self.is_sym_str("const") { + self.symbol()?; + let pos = self.pos(); + let name = self.parse_const()?; + let ty = self.parse_ident()?; + Ok(Extern::Const { + name, + ty, + pos: pos.unwrap(), + }) } else { Err(self.error( pos.unwrap(), @@ -355,6 +386,10 @@ impl<'a> Parser<'a> { val: self.int()?, pos, }) + } else if self.is_const() { + let pos = pos.unwrap(); + let val = self.parse_const()?; + Ok(Pattern::ConstPrim { val, pos }) } else if self.is_sym_str("_") { let pos = pos.unwrap(); self.symbol()?; @@ -448,6 +483,10 @@ impl<'a> Parser<'a> { let pos = pos.unwrap(); self.symbol()?; Ok(Expr::ConstInt { val: 0, pos }) + } else if self.is_const() { + let pos = pos.unwrap(); + let val = self.parse_const()?; + Ok(Expr::ConstPrim { val, pos }) } else if self.is_sym() { let pos = pos.unwrap(); let name = self.parse_ident()?; diff --git a/cranelift/isle/isle/src/sema.rs b/cranelift/isle/isle/src/sema.rs index b7ee0dace1..61f19cd9d9 100644 --- a/cranelift/isle/isle/src/sema.rs +++ b/cranelift/isle/isle/src/sema.rs @@ -35,6 +35,7 @@ pub struct TypeEnv { pub sym_map: HashMap, pub types: Vec, pub type_map: HashMap, + pub const_types: HashMap, pub errors: Vec, } @@ -238,6 +239,7 @@ pub enum Pattern { BindPattern(TypeId, VarId, Box), Var(TypeId, VarId), ConstInt(TypeId, i64), + ConstPrim(TypeId, Sym), Term(TypeId, TermId, Vec), Wildcard(TypeId), And(TypeId, Vec), @@ -254,6 +256,7 @@ pub enum Expr { Term(TypeId, TermId, Vec), Var(TypeId, VarId), ConstInt(TypeId, i64), + ConstPrim(TypeId, Sym), Let(TypeId, Vec<(VarId, TypeId, Box)>, Box), } @@ -263,6 +266,7 @@ impl Pattern { &Self::BindPattern(t, ..) => t, &Self::Var(t, ..) => t, &Self::ConstInt(t, ..) => t, + &Self::ConstPrim(t, ..) => t, &Self::Term(t, ..) => t, &Self::Wildcard(t, ..) => t, &Self::And(t, ..) => t, @@ -284,6 +288,7 @@ impl Expr { &Self::Term(t, ..) => t, &Self::Var(t, ..) => t, &Self::ConstInt(t, ..) => t, + &Self::ConstPrim(t, ..) => t, &Self::Let(t, ..) => t, } } @@ -297,6 +302,7 @@ impl TypeEnv { sym_map: HashMap::new(), types: vec![], type_map: HashMap::new(), + const_types: HashMap::new(), errors: vec![], }; @@ -340,6 +346,29 @@ impl TypeEnv { } } + // Now collect types for extern constants. + for def in &defs.defs { + match def { + &ast::Def::Extern(ast::Extern::Const { + ref name, + ref ty, + pos, + }) => { + let ty = tyenv.intern_mut(ty); + let ty = match tyenv.type_map.get(&ty) { + Some(ty) => *ty, + None => { + tyenv.report_error(pos, "Unknown type for constant".into()); + continue; + } + }; + let name = tyenv.intern_mut(name); + tyenv.const_types.insert(name, ty); + } + _ => {} + } + } + tyenv.return_errors()?; Ok(tyenv) @@ -674,25 +703,20 @@ impl TermEnv { vars: vec![], }; - let (lhs, ty) = match self.translate_pattern( - tyenv, - &rule.pattern, - None, - &mut bindings, - ) { - Some(x) => x, - None => { - // Keep going to collect more errors. - continue; - } - }; - let rhs = - match self.translate_expr(tyenv, &rule.expr, ty, &mut bindings) { + let (lhs, ty) = + match self.translate_pattern(tyenv, &rule.pattern, None, &mut bindings) { Some(x) => x, None => { + // Keep going to collect more errors. continue; } }; + let rhs = match self.translate_expr(tyenv, &rule.expr, ty, &mut bindings) { + Some(x) => x, + None => { + continue; + } + }; let rid = RuleId(self.rules.len()); self.rules.push(Rule { @@ -814,6 +838,20 @@ impl TermEnv { }; Some((Pattern::ConstInt(ty, val), ty)) } + &ast::Pattern::ConstPrim { ref val, pos } => { + let val = tyenv.intern_mut(val); + let const_ty = match tyenv.const_types.get(&val) { + Some(ty) => *ty, + None => { + tyenv.report_error(pos, "Unknown constant".into()); + return None; + } + }; + if expected_ty.is_some() && expected_ty != Some(const_ty) { + tyenv.report_error(pos, "Type mismatch for constant".into()); + } + Some((Pattern::ConstPrim(const_ty, val), const_ty)) + } &ast::Pattern::Wildcard { pos } => { let ty = match expected_ty { Some(t) => t, @@ -1045,8 +1083,7 @@ impl TermEnv { ) -> Option<(TermArgPattern, TypeId)> { match pat { &ast::TermArgPattern::Pattern(ref pat) => { - let (subpat, ty) = - self.translate_pattern(tyenv, pat, expected_ty, bindings)?; + let (subpat, ty) = self.translate_pattern(tyenv, pat, expected_ty, bindings)?; Some((TermArgPattern::Pattern(subpat), ty)) } &ast::TermArgPattern::Expr(ref expr) => { @@ -1152,6 +1189,29 @@ impl TermEnv { Some(Expr::Var(bv.ty, bv.id)) } &ast::Expr::ConstInt { val, .. } => Some(Expr::ConstInt(ty, val)), + &ast::Expr::ConstPrim { ref val, pos } => { + let val = tyenv.intern_mut(val); + let const_ty = match tyenv.const_types.get(&val) { + Some(ty) => *ty, + None => { + tyenv.report_error(pos, "Unknown constant".into()); + return None; + } + }; + if const_ty != ty { + tyenv.report_error( + pos, + format!( + "Constant '{}' has wrong type: expected {}, but is actually {}", + tyenv.syms[val.index()], + tyenv.types[ty.index()].name(tyenv), + tyenv.types[const_ty.index()].name(tyenv) + ), + ); + return None; + } + Some(Expr::ConstPrim(ty, val)) + } &ast::Expr::Let { ref defs, ref body, @@ -1191,15 +1251,13 @@ impl TermEnv { }; // Evaluate the variable's value. - let val = Box::new( - match self.translate_expr(tyenv, &def.val, ty, bindings) { - Some(e) => e, - None => { - // Keep going for more errors. - continue; - } - }, - ); + let val = Box::new(match self.translate_expr(tyenv, &def.val, ty, bindings) { + Some(e) => e, + None => { + // Keep going for more errors. + continue; + } + }); // Bind the var with the given type. let id = VarId(bindings.next_var); @@ -1240,14 +1298,30 @@ mod test { .expect("should parse"); let tyenv = TypeEnv::from_ast(&ast).expect("should not have type-definition errors"); - let sym_a = tyenv.intern(&Ident("A".to_string())).unwrap(); - let sym_b = tyenv.intern(&Ident("B".to_string())).unwrap(); - let sym_c = tyenv.intern(&Ident("C".to_string())).unwrap(); - let sym_a_b = tyenv.intern(&Ident("A.B".to_string())).unwrap(); - let sym_a_c = tyenv.intern(&Ident("A.C".to_string())).unwrap(); - let sym_u32 = tyenv.intern(&Ident("u32".to_string())).unwrap(); - let sym_f1 = tyenv.intern(&Ident("f1".to_string())).unwrap(); - let sym_f2 = tyenv.intern(&Ident("f2".to_string())).unwrap(); + let sym_a = tyenv + .intern(&Ident("A".to_string(), Default::default())) + .unwrap(); + let sym_b = tyenv + .intern(&Ident("B".to_string(), Default::default())) + .unwrap(); + let sym_c = tyenv + .intern(&Ident("C".to_string(), Default::default())) + .unwrap(); + let sym_a_b = tyenv + .intern(&Ident("A.B".to_string(), Default::default())) + .unwrap(); + let sym_a_c = tyenv + .intern(&Ident("A.C".to_string(), Default::default())) + .unwrap(); + let sym_u32 = tyenv + .intern(&Ident("u32".to_string(), Default::default())) + .unwrap(); + let sym_f1 = tyenv + .intern(&Ident("f1".to_string(), Default::default())) + .unwrap(); + let sym_f2 = tyenv + .intern(&Ident("f2".to_string(), Default::default())) + .unwrap(); assert_eq!(tyenv.type_map.get(&sym_u32).unwrap(), &TypeId(0)); assert_eq!(tyenv.type_map.get(&sym_a).unwrap(), &TypeId(1)); diff --git a/cranelift/isle/isle_examples/test4.isle b/cranelift/isle/isle_examples/test4.isle index 0e1f45901f..2035167239 100644 --- a/cranelift/isle/isle_examples/test4.isle +++ b/cranelift/isle/isle_examples/test4.isle @@ -7,6 +7,9 @@ (extern extractor Ext1 ext1) (extern extractor Ext2 ext2) +(extern const $A u32) +(extern const $B u32) + (decl C (bool) A) (extern constructor C c) @@ -32,3 +35,8 @@ (rule (Lower2 (Opcode.C)) (MachInst.F)) + +(decl F (Opcode) u32) +(rule + (F _) + $B) \ No newline at end of file