From 84b7612b98f53be1c0e9f42d99f07a21994107c3 Mon Sep 17 00:00:00 2001 From: Chris Fallin Date: Tue, 29 Jun 2021 17:00:43 -0700 Subject: [PATCH] Initial public commit of ISLE prototype DSL compiler. --- cranelift/isle/.gitignore | 3 + cranelift/isle/Cargo.lock | 199 ++++++ cranelift/isle/Cargo.toml | 11 + cranelift/isle/examples/test.isle | 12 + cranelift/isle/src/ast.rs | 135 ++++ cranelift/isle/src/compile.rs | 111 +++ cranelift/isle/src/error.rs | 48 ++ cranelift/isle/src/ir.rs | 1089 +++++++++++++++++++++++++++++ cranelift/isle/src/lexer.rs | 241 +++++++ cranelift/isle/src/lower.rs | 33 + cranelift/isle/src/main.rs | 28 + cranelift/isle/src/parser.rs | 429 ++++++++++++ cranelift/isle/src/sema.rs | 862 +++++++++++++++++++++++ 13 files changed, 3201 insertions(+) create mode 100644 cranelift/isle/.gitignore create mode 100644 cranelift/isle/Cargo.lock create mode 100644 cranelift/isle/Cargo.toml create mode 100644 cranelift/isle/examples/test.isle create mode 100644 cranelift/isle/src/ast.rs create mode 100644 cranelift/isle/src/compile.rs create mode 100644 cranelift/isle/src/error.rs create mode 100644 cranelift/isle/src/ir.rs create mode 100644 cranelift/isle/src/lexer.rs create mode 100644 cranelift/isle/src/lower.rs create mode 100644 cranelift/isle/src/main.rs create mode 100644 cranelift/isle/src/parser.rs create mode 100644 cranelift/isle/src/sema.rs diff --git a/cranelift/isle/.gitignore b/cranelift/isle/.gitignore new file mode 100644 index 0000000000..3110c83344 --- /dev/null +++ b/cranelift/isle/.gitignore @@ -0,0 +1,3 @@ +/target +*~ +.*.swp diff --git a/cranelift/isle/Cargo.lock b/cranelift/isle/Cargo.lock new file mode 100644 index 0000000000..b890d9e546 --- /dev/null +++ b/cranelift/isle/Cargo.lock @@ -0,0 +1,199 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "aho-corasick" +version = "0.7.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e37cfd5e7657ada45f742d6e99ca5788580b5c529dc78faf11ece6dc702656f" +dependencies = [ + "memchr", +] + +[[package]] +name = "atty" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" +dependencies = [ + "hermit-abi", + "libc", + "winapi", +] + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "env_logger" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a19187fea3ac7e84da7dacf48de0c45d63c6a76f9490dae389aead16c243fce3" +dependencies = [ + "atty", + "humantime", + "log", + "regex", + "termcolor", +] + +[[package]] +name = "hermit-abi" +version = "0.1.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33" +dependencies = [ + "libc", +] + +[[package]] +name = "humantime" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" + +[[package]] +name = "isle" +version = "0.1.0" +dependencies = [ + "env_logger", + "log", + "thiserror", +] + +[[package]] +name = "libc" +version = "0.2.97" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "12b8adadd720df158f4d70dfe7ccc6adb0472d7c55ca83445f6a5ab3e36f8fb6" + +[[package]] +name = "log" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51b9bbe6c47d51fc3e1a9b945965946b4c44142ab8792c50835a980d362c2710" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "memchr" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b16bd47d9e329435e309c58469fe0791c2d0d1ba96ec0954152a5ae2b04387dc" + +[[package]] +name = "proc-macro2" +version = "1.0.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c7ed8b8c7b886ea3ed7dde405212185f423ab44682667c8c6dd14aa1d9f6612" +dependencies = [ + "unicode-xid", +] + +[[package]] +name = "quote" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3d0b9745dc2debf507c8422de05d7226cc1f0644216dfdfead988f9b1ab32a7" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "regex" +version = "1.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d07a8629359eb56f1e2fb1652bb04212c072a87ba68546a04065d525673ac461" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.6.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b" + +[[package]] +name = "syn" +version = "1.0.75" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7f58f7e8eaa0009c5fec437aabf511bd9933e4b2d7407bd05273c01a8906ea7" +dependencies = [ + "proc-macro2", + "quote", + "unicode-xid", +] + +[[package]] +name = "termcolor" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2dfed899f0eb03f32ee8c6a0aabdb8a7949659e3466561fc0adf54e26d88c5f4" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "thiserror" +version = "1.0.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93119e4feac1cbe6c798c34d3a53ea0026b0b1de6a120deef895137c0529bfe2" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "060d69a0afe7796bf42e9e2ff91f5ee691fb15c53d38b4b62a9a53eb23164745" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "unicode-xid" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ccb82d61f80a663efe1f787a51b16b5a51e3314d6ac365b08639f52387b33f3" + +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-util" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178" +dependencies = [ + "winapi", +] + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" diff --git a/cranelift/isle/Cargo.toml b/cranelift/isle/Cargo.toml new file mode 100644 index 0000000000..8774800064 --- /dev/null +++ b/cranelift/isle/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "isle" +version = "0.1.0" +authors = ["Chris Fallin "] +edition = "2018" +license = "Apache-2.0 WITH LLVM-exception" + +[dependencies] +log = "0.4" +env_logger = "0.8" +thiserror = "1.0" diff --git a/cranelift/isle/examples/test.isle b/cranelift/isle/examples/test.isle new file mode 100644 index 0000000000..1ea1c3ce98 --- /dev/null +++ b/cranelift/isle/examples/test.isle @@ -0,0 +1,12 @@ +(type u32 (primitive u32)) +(type A (enum (A1 (x u32)) (A2 (x u32)))) +(type B (enum (B1 (x u32)) (B2 (x u32)))) + +(decl Input (A) u32) +(extractor Input get_input) ;; fn get_input(ctx: &mut C, ret: u32) -> Option<(A,)> + +(decl Lower (A) B) + +(rule + (Lower (A.A1 sub @ (Input (A.A2 42)))) + (B.B2 sub)) diff --git a/cranelift/isle/src/ast.rs b/cranelift/isle/src/ast.rs new file mode 100644 index 0000000000..97b7facffa --- /dev/null +++ b/cranelift/isle/src/ast.rs @@ -0,0 +1,135 @@ +use crate::lexer::Pos; + +/// The parsed form of an ISLE file. +#[derive(Clone, PartialEq, Eq, Debug)] +pub struct Defs { + pub defs: Vec, + pub filename: String, +} + +/// One toplevel form in an ISLE file. +#[derive(Clone, PartialEq, Eq, Debug)] +pub enum Def { + Type(Type), + Rule(Rule), + Decl(Decl), + Extern(Extern), +} + +/// An identifier -- a variable, term symbol, or type. +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct Ident(pub String); + +/// A declaration of a type. +#[derive(Clone, PartialEq, Eq, Debug)] +pub struct Type { + pub name: Ident, + pub is_extern: bool, + pub ty: TypeValue, + pub pos: Pos, +} + +/// The actual type-value: a primitive or an enum with variants. +/// +/// TODO: add structs as well? +#[derive(Clone, PartialEq, Eq, Debug)] +pub enum TypeValue { + Primitive(Ident), + Enum(Vec), +} + +/// One variant of an enum type. +#[derive(Clone, PartialEq, Eq, Debug)] +pub struct Variant { + pub name: Ident, + pub fields: Vec, +} + +/// One field of an enum variant. +#[derive(Clone, PartialEq, Eq, Debug)] +pub struct Field { + pub name: Ident, + pub ty: Ident, +} + +/// A declaration of a term with its argument and return types. +#[derive(Clone, PartialEq, Eq, Debug)] +pub struct Decl { + pub term: Ident, + pub arg_tys: Vec, + pub ret_ty: Ident, + pub pos: Pos, +} + +#[derive(Clone, PartialEq, Eq, Debug)] +pub struct Rule { + pub pattern: Pattern, + pub expr: Expr, + pub pos: Pos, + pub prio: Option, +} + +/// A pattern: the left-hand side of a rule. +#[derive(Clone, PartialEq, Eq, Debug)] +pub enum Pattern { + /// An operator that binds a variable to a subterm and match the + /// subpattern. + BindPattern { var: Ident, subpat: Box }, + /// A variable that has already been bound (`=x` syntax). + Var { var: Ident }, + /// An operator that matches a constant integer value. + ConstInt { val: i64 }, + /// An application of a type variant or term. + Term { sym: Ident, args: Vec }, + /// An operator that matches anything. + Wildcard, +} + +/// An expression: the right-hand side of a rule. +/// +/// Note that this *almost* looks like a core Lisp or lambda calculus, +/// except that there is no abstraction (lambda). This first-order +/// limit is what makes it analyzable. +#[derive(Clone, PartialEq, Eq, Debug)] +pub enum Expr { + /// A term: `(sym args...)`. + Term { sym: Ident, args: Vec }, + /// A variable use. + Var { name: Ident }, + /// A constant integer. + ConstInt { val: i64 }, + /// The `(let ((var ty val)*) body)` form. + Let { defs: Vec, body: Box }, +} + +/// One variable locally bound in a `(let ...)` expression. +#[derive(Clone, PartialEq, Eq, Debug)] +pub struct LetDef { + pub var: Ident, + pub ty: Ident, + pub val: Box, +} + +/// An external binding: an extractor or constructor function attached +/// to a term. +#[derive(Clone, PartialEq, Eq, Debug)] +pub enum Extern { + /// An external extractor: `(extractor Term rustfunc)` form. + Extractor { + /// The term to which this external extractor is attached. + term: Ident, + /// The Rust function name. + func: Ident, + /// The position of this decl. + pos: Pos, + }, + /// An external constructor: `(constructor Term rustfunc)` form. + Constructor { + /// The term to which this external constructor is attached. + term: Ident, + /// The Rust function name. + func: Ident, + /// The position of this decl. + pos: Pos, + }, +} diff --git a/cranelift/isle/src/compile.rs b/cranelift/isle/src/compile.rs new file mode 100644 index 0000000000..0aeea048b9 --- /dev/null +++ b/cranelift/isle/src/compile.rs @@ -0,0 +1,111 @@ +//! Compilation process, from AST to Sema to Sequences of Insts. + +use crate::error::*; +use crate::{ast, ir, sema}; +use std::collections::HashMap; + +/// A Compiler manages the compilation pipeline from AST to Sequences. +pub struct Compiler<'a> { + ast: &'a ast::Defs, + type_env: sema::TypeEnv, + term_env: sema::TermEnv, + seqs: Vec, + // TODO: if this becomes a perf issue, then build a better data + // structure. For now we index on root term/variant. + // + // TODO: index at callsites (extractors/constructors) too. We'll + // need tree-summaries of arg and expected return value at each + // callsite. + term_db: HashMap, +} + +#[derive(Clone, Debug, Default)] +struct TermData { + producers: Vec<(ir::TreeSummary, sema::RuleId)>, + consumers: Vec<(ir::TreeSummary, sema::RuleId)>, + has_constructor: bool, + has_extractor: bool, +} + +pub type CompileResult = Result; + +impl<'a> Compiler<'a> { + pub fn new(ast: &'a ast::Defs) -> CompileResult> { + let mut type_env = sema::TypeEnv::from_ast(ast)?; + let term_env = sema::TermEnv::from_ast(&mut type_env, ast)?; + Ok(Compiler { + ast, + type_env, + term_env, + seqs: vec![], + term_db: HashMap::new(), + }) + } + + pub fn build_sequences(&mut self) -> CompileResult<()> { + for rid in 0..self.term_env.rules.len() { + let rid = sema::RuleId(rid); + let seq = ir::Sequence::from_rule(&self.type_env, &self.term_env, rid); + self.seqs.push(seq); + } + Ok(()) + } + + pub fn collect_tree_summaries(&mut self) -> CompileResult<()> { + // For each rule, compute summaries of its LHS and RHS, then + // index it in the appropriate TermData. + for (i, seq) in self.seqs.iter().enumerate() { + let rule_id = sema::RuleId(i); + let consumer_summary = seq.input_tree_summary(); + let producer_summary = seq.output_tree_summary(); + if let Some(consumer_root_term) = consumer_summary.root() { + let consumer_termdb = self + .term_db + .entry(consumer_root_term.clone()) + .or_insert_with(|| Default::default()); + consumer_termdb.consumers.push((consumer_summary, rule_id)); + } + if let Some(producer_root_term) = producer_summary.root() { + let producer_termdb = self + .term_db + .entry(producer_root_term.clone()) + .or_insert_with(|| Default::default()); + producer_termdb.consumers.push((producer_summary, rule_id)); + } + } + + // For each term, if a constructor and/or extractor is + // present, note that. + for term in &self.term_env.terms { + if let sema::TermKind::Regular { + extractor, + constructor, + } = term.kind + { + if !extractor.is_some() && !constructor.is_some() { + continue; + } + let entry = self + .term_db + .entry(ir::TermOrVariant::Term(term.id)) + .or_insert_with(|| Default::default()); + if extractor.is_some() { + entry.has_extractor = true; + } + if constructor.is_some() { + entry.has_constructor = true; + } + } + } + + Ok(()) + } + + pub fn inline_internal_terms(&mut self) -> CompileResult<()> { + unimplemented!() + } + + pub fn to_sequences(self) -> Vec { + self.seqs + } +} diff --git a/cranelift/isle/src/error.rs b/cranelift/isle/src/error.rs new file mode 100644 index 0000000000..2399fb1231 --- /dev/null +++ b/cranelift/isle/src/error.rs @@ -0,0 +1,48 @@ +//! Error types. + +use crate::lexer::Pos; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum Error { + #[error("Parse error")] + ParseError(#[from] ParseError), + #[error("Semantic error")] + SemaError(#[from] SemaError), + #[error("IO error")] + IoError(#[from] std::io::Error), +} + +#[derive(Clone, Debug, Error)] +pub struct ParseError { + pub msg: String, + pub filename: String, + pub pos: Pos, +} + +#[derive(Clone, Debug, Error)] +pub struct SemaError { + pub msg: String, + pub filename: String, + pub pos: Pos, +} + +impl std::fmt::Display for ParseError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "{}:{}:{}: {}", + self.filename, self.pos.line, self.pos.col, self.msg + ) + } +} + +impl std::fmt::Display for SemaError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "{}:{}:{}: {}", + self.filename, self.pos.line, self.pos.col, self.msg + ) + } +} diff --git a/cranelift/isle/src/ir.rs b/cranelift/isle/src/ir.rs new file mode 100644 index 0000000000..3bf6c59a4c --- /dev/null +++ b/cranelift/isle/src/ir.rs @@ -0,0 +1,1089 @@ +//! Lowered matching IR. + +use crate::declare_id; +use crate::sema::*; +use std::collections::hash_map::Entry as HashEntry; +use std::collections::HashMap; + +declare_id!(InstId); + +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct Value(InstId, usize); + +/// A single node in the sea-of-nodes. Each node produces one value. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub enum Inst { + /// Get the input root-term value. + Arg { ty: TypeId }, + + /// Set the return value. Produces no values. + Return { ty: TypeId, value: Value }, + + /// Match a value as equal to another value. Produces no values. + MatchEqual { a: Value, b: Value, ty: TypeId }, + + /// Try matching the given value as the given integer. Produces no values. + MatchInt { + input: Value, + ty: TypeId, + int_val: i64, + }, + + /// Try matching the given value as the given variant, producing + /// `|arg_tys|` values as output. + MatchVariant { + input: Value, + input_ty: TypeId, + arg_tys: Vec, + variant: VariantId, + }, + + /// Invoke an extractor, taking the given value as input and + /// producing `|arg_tys|` values as output. + Extract { + input: Value, + input_ty: TypeId, + arg_tys: Vec, + term: TermId, + }, + + /// Produce a constant integer. + ConstInt { ty: TypeId, val: i64 }, + + /// Create a variant. + CreateVariant { + inputs: Vec<(Value, TypeId)>, + ty: TypeId, + variant: VariantId, + }, + + /// Invoke a constructor. + Construct { + inputs: Vec<(Value, TypeId)>, + ty: TypeId, + term: TermId, + }, + + /// Copy a value. Used mainly when rewriting/inlining. + Copy { ty: TypeId, val: Value }, + + /// A non-operation (nop). Used to "nop out" unused instructions + /// without renumbering all values. + Nop, +} + +impl Inst { + fn map_values Value>(&self, f: F) -> Self { + match self { + &Inst::Arg { ty } => Inst::Arg { ty }, + &Inst::Return { ty, value } => Inst::Return { + ty, + value: f(value), + }, + &Inst::MatchEqual { a, b, ty } => Inst::MatchEqual { + a: f(a), + b: f(b), + ty, + }, + &Inst::MatchInt { input, ty, int_val } => Inst::MatchInt { + input: f(input), + ty, + int_val, + }, + &Inst::MatchVariant { + input, + input_ty, + ref arg_tys, + variant, + } => Inst::MatchVariant { + input: f(input), + input_ty, + arg_tys: arg_tys.clone(), + variant, + }, + &Inst::Extract { + input, + input_ty, + ref arg_tys, + term, + } => Inst::Extract { + input: f(input), + input_ty, + arg_tys: arg_tys.clone(), + term, + }, + &Inst::ConstInt { ty, val } => Inst::ConstInt { ty, val }, + &Inst::CreateVariant { + ref inputs, + ty, + variant, + } => Inst::CreateVariant { + inputs: inputs + .iter() + .map(|(i, ty)| (f(*i), *ty)) + .collect::>(), + ty, + variant, + }, + &Inst::Construct { + ref inputs, + ty, + term, + } => Inst::Construct { + inputs: inputs + .iter() + .map(|(i, ty)| (f(*i), *ty)) + .collect::>(), + ty, + term, + }, + &Inst::Copy { ty, val } => Inst::Copy { ty, val: f(val) }, + &Inst::Nop => Inst::Nop, + } + } + + fn map_insts InstId>(&self, f: F) -> Self { + self.map_values(|val| Value(f(val.0), val.1)) + } + + fn num_results(&self) -> usize { + match self { + &Inst::Arg { .. } + | &Inst::ConstInt { .. } + | &Inst::Construct { .. } + | &Inst::CreateVariant { .. } + | &Inst::Copy { .. } => 1, + &Inst::Return { .. } | &Inst::MatchEqual { .. } | &Inst::MatchInt { .. } => 0, + &Inst::Extract { ref arg_tys, .. } | &Inst::MatchVariant { ref arg_tys, .. } => { + arg_tys.len() + } + &Inst::Nop => 0, + } + } +} + +impl Value { + fn map_inst InstId>(&self, f: F) -> Self { + Value(f(self.0), self.1) + } +} + +/// A linear sequence of instructions that either convert an input +/// value to an output value, or fail. +#[derive(Clone, Debug, PartialEq, Eq, Hash, Default)] +pub struct Sequence { + /// Instruction sequence. InstId indexes into this sequence. + insts: Vec, +} + +impl Sequence { + fn add_inst(&mut self, inst: Inst) -> InstId { + let id = InstId(self.insts.len()); + self.insts.push(inst); + id + } + + fn add_arg(&mut self, ty: TypeId) -> Value { + let inst = InstId(self.insts.len()); + self.add_inst(Inst::Arg { ty }); + Value(inst, 0) + } + + fn add_return(&mut self, ty: TypeId, value: Value) { + self.add_inst(Inst::Return { ty, value }); + } + + fn add_match_equal(&mut self, a: Value, b: Value, ty: TypeId) { + self.add_inst(Inst::MatchEqual { a, b, ty }); + } + + fn add_match_int(&mut self, input: Value, ty: TypeId, int_val: i64) { + self.add_inst(Inst::MatchInt { input, ty, int_val }); + } + + fn add_match_variant( + &mut self, + input: Value, + input_ty: TypeId, + arg_tys: &[TypeId], + variant: VariantId, + ) -> Vec { + let inst = InstId(self.insts.len()); + let mut outs = vec![]; + for (i, _arg_ty) in arg_tys.iter().enumerate() { + let val = Value(inst, i); + outs.push(val); + } + let arg_tys = arg_tys.iter().cloned().collect(); + self.add_inst(Inst::MatchVariant { + input, + input_ty, + arg_tys, + variant, + }); + outs + } + + fn add_extract( + &mut self, + input: Value, + input_ty: TypeId, + arg_tys: &[TypeId], + term: TermId, + ) -> Vec { + let inst = InstId(self.insts.len()); + let mut outs = vec![]; + for (i, _arg_ty) in arg_tys.iter().enumerate() { + let val = Value(inst, i); + outs.push(val); + } + let arg_tys = arg_tys.iter().cloned().collect(); + self.add_inst(Inst::Extract { + input, + input_ty, + arg_tys, + term, + }); + outs + } + + fn add_const_int(&mut self, ty: TypeId, val: i64) -> Value { + let inst = InstId(self.insts.len()); + self.add_inst(Inst::ConstInt { ty, val }); + Value(inst, 0) + } + + fn add_create_variant( + &mut self, + inputs: &[(Value, TypeId)], + ty: TypeId, + variant: VariantId, + ) -> Value { + let inst = InstId(self.insts.len()); + let inputs = inputs.iter().cloned().collect(); + self.add_inst(Inst::CreateVariant { + inputs, + ty, + variant, + }); + Value(inst, 0) + } + + fn add_construct(&mut self, inputs: &[(Value, TypeId)], ty: TypeId, term: TermId) -> Value { + let inst = InstId(self.insts.len()); + let inputs = inputs.iter().cloned().collect(); + self.add_inst(Inst::Construct { inputs, ty, term }); + Value(inst, 0) + } + + fn gen_pattern( + &mut self, + input: Value, + typeenv: &TypeEnv, + termenv: &TermEnv, + pat: &Pattern, + vars: &mut HashMap, + ) { + match pat { + &Pattern::BindPattern(_ty, var, ref subpat) => { + // Bind the appropriate variable and recurse. + assert!(!vars.contains_key(&var)); + vars.insert(var, input); + self.gen_pattern(input, typeenv, termenv, &*subpat, vars); + } + &Pattern::Var(ty, var) => { + // Assert that the value matches the existing bound var. + let var_val = vars + .get(&var) + .cloned() + .expect("Variable should already be bound"); + self.add_match_equal(input, var_val, ty); + } + &Pattern::ConstInt(ty, value) => { + // Assert that the value matches the constant integer. + self.add_match_int(input, ty, value); + } + &Pattern::Term(ty, term, ref args) => { + // Determine whether the term has an external extractor or not. + let termdata = &termenv.terms[term.index()]; + let arg_tys = &termdata.arg_tys[..]; + match &termdata.kind { + &TermKind::EnumVariant { variant } => { + let arg_values = self.add_match_variant(input, ty, arg_tys, variant); + for (subpat, value) in args.iter().zip(arg_values.into_iter()) { + self.gen_pattern(value, typeenv, termenv, subpat, vars); + } + } + &TermKind::Regular { .. } => { + let arg_values = self.add_extract(input, ty, arg_tys, term); + for (subpat, value) in args.iter().zip(arg_values.into_iter()) { + self.gen_pattern(value, typeenv, termenv, subpat, vars); + } + } + } + } + &Pattern::Wildcard(_ty) => { + // Nothing! + } + } + } + + fn gen_expr( + &mut self, + typeenv: &TypeEnv, + termenv: &TermEnv, + expr: &Expr, + vars: &HashMap, + ) -> Value { + match expr { + &Expr::ConstInt(ty, val) => self.add_const_int(ty, val), + &Expr::Let(_ty, ref bindings, ref subexpr) => { + let mut vars = vars.clone(); + for &(var, _var_ty, ref var_expr) in bindings { + let var_value = self.gen_expr(typeenv, termenv, &*var_expr, &vars); + vars.insert(var, var_value); + } + self.gen_expr(typeenv, termenv, &*subexpr, &vars) + } + &Expr::Var(_ty, var_id) => vars.get(&var_id).cloned().unwrap(), + &Expr::Term(ty, term, ref arg_exprs) => { + let termdata = &termenv.terms[term.index()]; + let mut arg_values_tys = vec![]; + for (arg_ty, arg_expr) in termdata.arg_tys.iter().cloned().zip(arg_exprs.iter()) { + arg_values_tys + .push((self.gen_expr(typeenv, termenv, &*arg_expr, &vars), arg_ty)); + } + match &termdata.kind { + &TermKind::EnumVariant { variant } => { + self.add_create_variant(&arg_values_tys[..], ty, variant) + } + &TermKind::Regular { .. } => self.add_construct(&arg_values_tys[..], ty, term), + } + } + } + } +} + +impl Sequence { + /// Build a sequence from a rule. + pub fn from_rule(tyenv: &TypeEnv, termenv: &TermEnv, rule: RuleId) -> Sequence { + let mut seq: Sequence = Default::default(); + + // Lower the pattern, starting from the root input value. + let ruledata = &termenv.rules[rule.index()]; + let input_ty = ruledata.lhs.ty(); + let input = seq.add_arg(input_ty); + let mut vars = HashMap::new(); + seq.gen_pattern(input, tyenv, termenv, &ruledata.lhs, &mut vars); + + // Lower the expression, making use of the bound variables + // from the pattern. + let rhs_root = seq.gen_expr(tyenv, termenv, &ruledata.rhs, &vars); + // Return the root RHS value. + let output_ty = ruledata.rhs.ty(); + seq.add_return(output_ty, rhs_root); + + seq + } + + /// Inline sequence(s) in place of given instructions. + pub fn inline(&self, inlines: Vec<(InstId, &'_ Sequence)>) -> Sequence { + let mut seq: Sequence = Default::default(); + // Map from inst ID in this seq to inst ID in final seq. + let mut inst_map: HashMap = HashMap::new(); + + let mut next_inline = 0; + for (id, inst) in self.insts.iter().enumerate() { + let orig_inst_id = InstId(id); + + // If this is an inlining point, do the inlining. The + // inlining point must be at a Construct or Extract call. + // + // - For a Construct inlining, we emit the Construct + // *first*, taking its output value as the arg for the + // invoked sequence. The returned value will in turn be + // substituted for that value at the end of inlining. + // + // - For an Extract inlining, we emit the sequence first, + // taking the input of the Extract as the arg for the + // invoked sequence. The returned value will then be the + // new input to the Extract. + + if next_inline < inlines.len() && inlines[next_inline].0 == orig_inst_id { + let inlined_seq = &inlines[next_inline].1; + next_inline += 1; + + let (arg, arg_ty) = match inst { + &Inst::Construct { ty, .. } => { + // Emit the Construct, mapping its input + // values across the mapping, and saving its + // output as the arg for the inlined sequence. + let inst = inst.map_insts(|id| { + inst_map + .get(&id) + .cloned() + .expect("Should have inst mapping") + }); + let new_inst_id = seq.add_inst(inst); + (Value(new_inst_id, 0), ty) + } + &Inst::Extract { + input, input_ty, .. + } => { + // Map the input and save it as the arg, but + // don't emit the Extract yet. + ( + input.map_inst(|id| { + inst_map + .get(&id) + .cloned() + .expect("Should have inst mapping") + }), + input_ty, + ) + } + _ => panic!("Unexpected instruction {:?} at inlining point", inst), + }; + + // Copy the inlined insts into the output sequence. We + // map `Arg` to the input, and save the `Ret`, which + // must come last. + let mut inlined_inst_map: HashMap = HashMap::new(); + let mut ret: Option<(InstId, TypeId)> = None; + for (i, inst) in inlined_seq.insts.iter().enumerate() { + let inlined_orig_inst_id = InstId(i); + let new_inst_id = InstId(seq.insts.len()); + let inst = match inst { + &Inst::Return { ty, value } => { + let value = + value.map_inst(|id| inlined_inst_map.get(&id).cloned().unwrap()); + ret = Some((new_inst_id, ty)); + Inst::Copy { ty, val: value } + } + &Inst::Arg { ty } => { + assert_eq!(ty, arg_ty); + Inst::Copy { ty, val: arg } + } + _ => inst.map_insts(|id| inlined_inst_map.get(&id).cloned().unwrap()), + }; + let new_id = seq.add_inst(inst); + inlined_inst_map.insert(inlined_orig_inst_id, new_id); + } + + // Now, emit the Extract if appropriate (it comes + // after the inlined sequence, while Construct goes + // before), and map the old inst ID to the resulting + // output of either the Extract or the return above. + let final_inst_id = match inst { + &Inst::Extract { + input_ty, + ref arg_tys, + term, + .. + } => { + let input = Value(ret.unwrap().0, 0); + seq.add_inst(Inst::Extract { + input, + input_ty, + arg_tys: arg_tys.clone(), + term, + }) + } + &Inst::Construct { .. } => ret.unwrap().0, + _ => unreachable!(), + }; + + inst_map.insert(orig_inst_id, final_inst_id); + } else { + // Non-inlining-point instruction. Just copy over, + // mapping values as appropriate. + let inst = inst.map_insts(|id| { + inst_map + .get(&id) + .cloned() + .expect("inst ID should be present") + }); + let new_id = seq.add_inst(inst); + inst_map.insert(orig_inst_id, new_id); + } + } + + seq + } + + /// Perform constant-propagation / simplification across + /// construct/extract pairs, variants and integer values, and + /// copies. + pub fn simplify(&self) -> Option { + #[derive(Clone, Debug)] + enum SymbolicValue { + Value(Value), + ConstInt(Value, i64), + Variant(Value, VariantId, Vec), + Term(Value, TermId, Vec), + } + impl SymbolicValue { + fn to_value(&self) -> Value { + match self { + &SymbolicValue::Value(v) => v, + &SymbolicValue::ConstInt(v, ..) => v, + &SymbolicValue::Variant(v, ..) => v, + &SymbolicValue::Term(v, ..) => v, + } + } + } + let mut value_map: HashMap = HashMap::new(); + let mut seq: Sequence = Default::default(); + + for (i, inst) in self.insts.iter().enumerate() { + let orig_inst_id = InstId(i); + match inst { + &Inst::Arg { .. } => { + let new_inst = seq.add_inst(inst.clone()); + value_map.insert( + Value(orig_inst_id, 0), + SymbolicValue::Value(Value(new_inst, 0)), + ); + } + &Inst::Return { ty, value } => { + let inst = Inst::Return { + ty, + value: value_map.get(&value).unwrap().to_value(), + }; + seq.add_inst(inst); + } + &Inst::MatchEqual { a, b, ty } => { + let sym_a = value_map.get(&a).unwrap(); + let sym_b = value_map.get(&b).unwrap(); + match (sym_a, sym_b) { + ( + &SymbolicValue::ConstInt(_, int_a), + &SymbolicValue::ConstInt(_, int_b), + ) => { + if int_a == int_b { + // No-op -- we can skip it. + continue; + } else { + // We can't possibly match! + return None; + } + } + ( + &SymbolicValue::Term(_, term_a, _), + &SymbolicValue::Term(_, term_b, _), + ) => { + if term_a != term_b { + return None; + } + } + ( + &SymbolicValue::Variant(_, var_a, _), + &SymbolicValue::Variant(_, var_b, _), + ) => { + if var_a != var_b { + return None; + } + } + _ => {} + } + let val_a = sym_a.to_value(); + let val_b = sym_b.to_value(); + seq.add_inst(Inst::MatchEqual { + a: val_a, + b: val_b, + ty, + }); + } + &Inst::MatchInt { input, int_val, ty } => { + let sym_input = value_map.get(&input).unwrap(); + match sym_input { + &SymbolicValue::ConstInt(_, const_val) => { + if int_val == const_val { + // No runtime check needed -- we can continue. + continue; + } else { + // Static mismatch, so we can remove this + // whole Sequence. + return None; + } + } + _ => {} + } + let val_input = sym_input.to_value(); + seq.add_inst(Inst::MatchInt { + input: val_input, + int_val, + ty, + }); + } + &Inst::MatchVariant { + input, + input_ty, + variant, + ref arg_tys, + } => { + let sym_input = value_map.get(&input).unwrap(); + match sym_input { + &SymbolicValue::Variant(_, val_variant, ref args) => { + if val_variant != variant { + return None; + } + // Variant matches: unpack args' symbolic values into results. + let args = args.clone(); + for (i, arg) in args.iter().enumerate() { + let val = Value(orig_inst_id, i); + value_map.insert(val, arg.clone()); + } + } + _ => { + let val_input = sym_input.to_value(); + let new_inst = seq.add_inst(Inst::MatchVariant { + input: val_input, + input_ty, + variant, + arg_tys: arg_tys.clone(), + }); + for i in 0..arg_tys.len() { + let val = Value(orig_inst_id, i); + let sym = SymbolicValue::Value(Value(new_inst, i)); + value_map.insert(val, sym); + } + } + } + } + &Inst::Extract { + input, + input_ty, + term, + ref arg_tys, + } => { + let sym_input = value_map.get(&input).unwrap(); + match sym_input { + &SymbolicValue::Term(_, val_term, ref args) => { + if val_term != term { + return None; + } + // Term matches: unpack args' symbolic values into results. + let args = args.clone(); + for (i, arg) in args.iter().enumerate() { + let val = Value(orig_inst_id, i); + value_map.insert(val, arg.clone()); + } + } + _ => { + let val_input = sym_input.to_value(); + let new_inst = seq.add_inst(Inst::Extract { + input: val_input, + input_ty, + term, + arg_tys: arg_tys.clone(), + }); + for i in 0..arg_tys.len() { + let val = Value(orig_inst_id, i); + let sym = SymbolicValue::Value(Value(new_inst, i)); + value_map.insert(val, sym); + } + } + } + } + &Inst::ConstInt { ty, val } => { + let new_inst = seq.add_inst(Inst::ConstInt { ty, val }); + value_map.insert( + Value(orig_inst_id, 0), + SymbolicValue::ConstInt(Value(new_inst, 0), val), + ); + } + &Inst::CreateVariant { + ref inputs, + variant, + ty, + } => { + let sym_inputs = inputs + .iter() + .map(|input| value_map.get(&input.0).cloned().unwrap()) + .collect::>(); + let inputs = sym_inputs + .iter() + .zip(inputs.iter()) + .map(|(si, (_, ty))| (si.to_value(), *ty)) + .collect::>(); + let new_inst = seq.add_inst(Inst::CreateVariant { + inputs, + variant, + ty, + }); + value_map.insert( + Value(orig_inst_id, 0), + SymbolicValue::Variant(Value(new_inst, 0), variant, sym_inputs), + ); + } + &Inst::Construct { + ref inputs, + term, + ty, + } => { + let sym_inputs = inputs + .iter() + .map(|input| value_map.get(&input.0).cloned().unwrap()) + .collect::>(); + let inputs = sym_inputs + .iter() + .zip(inputs.iter()) + .map(|(si, (_, ty))| (si.to_value(), *ty)) + .collect::>(); + let new_inst = seq.add_inst(Inst::Construct { inputs, term, ty }); + value_map.insert( + Value(orig_inst_id, 0), + SymbolicValue::Term(Value(new_inst, 0), term, sym_inputs), + ); + } + &Inst::Copy { val, .. } => { + let sym_value = value_map.get(&val).cloned().unwrap(); + value_map.insert(Value(orig_inst_id, 0), sym_value); + } + &Inst::Nop => {} + }; + } + + // Now do a pass backward to track which instructions are used. + let mut used = vec![false; seq.insts.len()]; + for (id, inst) in seq.insts.iter().enumerate().rev() { + // Mark roots as used unconditionally: Return, MatchEqual, + // MatchInt, MatchVariant. + match inst { + &Inst::Return { .. } + | &Inst::MatchEqual { .. } + | &Inst::MatchInt { .. } + | &Inst::MatchVariant { .. } => used[id] = true, + _ => {} + } + // If this instruction is not used, continue. + if !used[id] { + continue; + } + // Otherwise, mark all inputs as used as well. + match inst { + &Inst::Return { value, .. } => used[value.0.index()] = true, + &Inst::MatchEqual { a, b, .. } => { + used[a.0.index()] = true; + used[b.0.index()] = true; + } + &Inst::MatchInt { input, .. } + | &Inst::MatchVariant { input, .. } + | &Inst::Extract { input, .. } => { + used[input.0.index()] = true; + } + &Inst::CreateVariant { ref inputs, .. } | Inst::Construct { ref inputs, .. } => { + for input in inputs { + used[input.0 .0.index()] = true; + } + } + &Inst::Copy { val, .. } => { + used[val.0.index()] = true; + } + &Inst::Arg { .. } | &Inst::ConstInt { .. } => {} + &Inst::Nop => {} + } + } + + // Now, remove any non-used instructions. + for id in 0..seq.insts.len() { + if !used[id] { + seq.insts[id] = Inst::Nop; + } + } + + Some(seq) + } + + /// Build a tree summary of the output produced by a sequence. + pub fn output_tree_summary(&self) -> TreeSummary { + // Scan forward, building a TreeSummary for what is known + // about each value (a "lower bound" on its shape). + let mut value_summaries: HashMap = HashMap::new(); + for (id, inst) in self.insts.iter().enumerate() { + let inst_id = InstId(id); + match inst { + &Inst::Arg { .. } => { + value_summaries.insert(Value(inst_id, 0), TreeSummary::Other); + } + &Inst::Return { value, .. } => { + return value_summaries + .get(&value) + .cloned() + .unwrap_or(TreeSummary::Other); + } + &Inst::MatchEqual { .. } + | &Inst::MatchInt { .. } + | &Inst::MatchVariant { .. } + | &Inst::Extract { .. } => {} + &Inst::ConstInt { val, .. } => { + value_summaries.insert(Value(inst_id, 0), TreeSummary::ConstInt(val)); + } + &Inst::CreateVariant { + ref inputs, + variant, + .. + } => { + let args = inputs + .iter() + .map(|(val, _)| { + value_summaries + .get(&val) + .cloned() + .unwrap_or(TreeSummary::Other) + }) + .collect::>(); + value_summaries.insert(Value(inst_id, 0), TreeSummary::Variant(variant, args)); + } + &Inst::Construct { + ref inputs, term, .. + } => { + let args = inputs + .iter() + .map(|(val, _)| { + value_summaries + .get(&val) + .cloned() + .unwrap_or(TreeSummary::Other) + }) + .collect::>(); + value_summaries.insert(Value(inst_id, 0), TreeSummary::Term(term, args)); + } + &Inst::Copy { val, .. } => { + // Copy summary from input to output. + let input_value = value_summaries + .get(&val) + .cloned() + .unwrap_or(TreeSummary::Other); + value_summaries.insert(Value(inst_id, 0), input_value); + } + &Inst::Nop => {} + } + } + + panic!("Sequence did not end in Return") + } + + /// Build a tree summary of the input expected by a sequence. + pub fn input_tree_summary(&self) -> TreeSummary { + // Scan backward, building a TreeSummary for each value (a + // "lower bound" on what it must be to satisfy the sequence's + // conditions). + let mut value_summaries: HashMap = HashMap::new(); + for (id, inst) in self.insts.iter().enumerate().rev() { + let inst_id = InstId(id); + match inst { + &Inst::Arg { .. } => { + // Must *start* with Arg; otherwise we might have missed some condition. + assert_eq!(id, 0); + return value_summaries + .get(&Value(inst_id, 0)) + .cloned() + .unwrap_or(TreeSummary::Other); + } + &Inst::Return { .. } => {} + + &Inst::MatchEqual { a, b, .. } => { + if value_summaries.contains_key(&a) && !value_summaries.contains_key(&b) { + let val = value_summaries.get(&a).cloned().unwrap(); + value_summaries.insert(b, val); + } else if value_summaries.contains_key(&b) && !value_summaries.contains_key(&a) + { + let val = value_summaries.get(&b).cloned().unwrap(); + value_summaries.insert(a, val); + } else if value_summaries.contains_key(&a) && value_summaries.contains_key(&b) { + let val_a = value_summaries.get(&a).cloned().unwrap(); + let val_b = value_summaries.get(&b).cloned().unwrap(); + let combined = TreeSummary::Conjunction(vec![val_a, val_b]); + value_summaries.insert(a, combined.clone()); + value_summaries.insert(b, combined); + } + } + &Inst::MatchInt { input, int_val, .. } => { + value_summaries.insert(input, TreeSummary::ConstInt(int_val)); + } + &Inst::MatchVariant { + input, + variant, + ref arg_tys, + .. + } => { + let args = (0..arg_tys.len()) + .map(|i| Value(inst_id, i)) + .map(|val| { + value_summaries + .get(&val) + .cloned() + .unwrap_or(TreeSummary::Other) + }) + .collect::>(); + let summary = TreeSummary::Variant(variant, args); + match value_summaries.entry(input) { + HashEntry::Vacant(v) => { + v.insert(summary); + } + HashEntry::Occupied(mut o) => { + let combined = TreeSummary::Conjunction(vec![ + summary, + std::mem::replace(o.get_mut(), TreeSummary::Other), + ]); + *o.get_mut() = combined; + } + } + } + + &Inst::Extract { + input, + term, + ref arg_tys, + .. + } => { + let args = (0..arg_tys.len()) + .map(|i| Value(inst_id, i)) + .map(|val| { + value_summaries + .get(&val) + .cloned() + .unwrap_or(TreeSummary::Other) + }) + .collect::>(); + let summary = TreeSummary::Term(term, args); + match value_summaries.entry(input) { + HashEntry::Vacant(v) => { + v.insert(summary); + } + HashEntry::Occupied(mut o) => { + let combined = TreeSummary::Conjunction(vec![ + summary, + std::mem::replace(o.get_mut(), TreeSummary::Other), + ]); + *o.get_mut() = combined; + } + } + } + + &Inst::ConstInt { .. } | &Inst::CreateVariant { .. } | &Inst::Construct { .. } => {} + + &Inst::Copy { val, .. } => { + // Copy summary from output to input. + let output_value = value_summaries + .get(&Value(inst_id, 0)) + .cloned() + .unwrap_or(TreeSummary::Other); + value_summaries.insert(val, output_value); + } + + &Inst::Nop => {} + } + } + + panic!("Sequence did not start with Arg") + } +} + +/// A "summary" of a tree shape -- a template that describes a tree of +/// terms and constant integer values. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum TreeSummary { + /// A known term, with given subtrees. + Term(TermId, Vec), + /// A known enum variant, with given subtrees. + Variant(VariantId, Vec), + /// A known constant integer value. + ConstInt(i64), + /// All of a list of summaries: represents a combined list of + /// requirements. The "provides" relation is satisfied if the + /// provider provides *all* of the providee's summaries in the + /// conjunction. A conjunction on the provider side (i.e., as an + /// "output summary") is illegal. + Conjunction(Vec), + /// Something else. + Other, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub enum TreeSummaryOverlap { + Never, + Sometimes, + Always, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum TermOrVariant { + Term(TermId), + Variant(VariantId), +} + +impl TreeSummary { + /// Does a term tree matching this summary "provide" the shape + /// described/expected by another summary? Answer can be "always", + /// "possibly", or "no". + pub fn provides(&self, other: &TreeSummary) -> TreeSummaryOverlap { + match (self, other) { + (_, &TreeSummary::Other) => TreeSummaryOverlap::Always, + (&TreeSummary::Other, _) => TreeSummaryOverlap::Sometimes, + + (&TreeSummary::Conjunction(..), _) => { + panic!("Conjunction on LHS of `provides` relation") + } + (this, &TreeSummary::Conjunction(ref args)) => args + .iter() + .map(|arg| this.provides(arg)) + .min() + .unwrap_or(TreeSummaryOverlap::Always), + + ( + &TreeSummary::Term(self_term, ref self_args), + &TreeSummary::Term(other_term, ref other_args), + ) => { + if self_term != other_term { + TreeSummaryOverlap::Never + } else { + assert_eq!(self_args.len(), other_args.len()); + self_args + .iter() + .zip(other_args.iter()) + .map(|(self_arg, other_arg)| self_arg.provides(other_arg)) + .min() + .unwrap_or(TreeSummaryOverlap::Always) + } + } + + ( + &TreeSummary::Variant(self_var, ref self_args), + &TreeSummary::Variant(other_var, ref other_args), + ) => { + if self_var != other_var { + TreeSummaryOverlap::Never + } else { + assert_eq!(self_args.len(), other_args.len()); + self_args + .iter() + .zip(other_args.iter()) + .map(|(self_arg, other_arg)| self_arg.provides(other_arg)) + .min() + .unwrap_or(TreeSummaryOverlap::Always) + } + } + + (&TreeSummary::ConstInt(i1), &TreeSummary::ConstInt(i2)) => { + if i1 != i2 { + TreeSummaryOverlap::Never + } else { + TreeSummaryOverlap::Always + } + } + + _ => TreeSummaryOverlap::Never, + } + } + + pub fn root(&self) -> Option { + match self { + &TreeSummary::Term(term, ..) => Some(TermOrVariant::Term(term)), + &TreeSummary::Variant(variant, ..) => Some(TermOrVariant::Variant(variant)), + _ => None, + } + } +} diff --git a/cranelift/isle/src/lexer.rs b/cranelift/isle/src/lexer.rs new file mode 100644 index 0000000000..eafb72a462 --- /dev/null +++ b/cranelift/isle/src/lexer.rs @@ -0,0 +1,241 @@ +//! Lexer for the ISLE language. + +#[derive(Clone, Debug)] +pub struct Lexer<'a> { + buf: &'a [u8], + pos: Pos, + lookahead: Option<(Pos, Token<'a>)>, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Pos { + pub offset: usize, + pub line: usize, + pub col: usize, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Token<'a> { + LParen, + RParen, + Symbol(&'a str), + Int(i64), +} + +impl<'a> Lexer<'a> { + pub fn new(s: &'a str) -> Lexer<'a> { + let mut l = Lexer { + buf: s.as_bytes(), + pos: Pos { + offset: 0, + line: 1, + col: 0, + }, + lookahead: None, + }; + l.reload(); + l + } + + pub fn offset(&self) -> usize { + self.pos.offset + } + + pub fn pos(&self) -> Pos { + self.pos + } + + fn next_token(&mut self) -> Option<(Pos, Token<'a>)> { + fn is_sym_first_char(c: u8) -> bool { + match c { + b'-' | b'0'..=b'9' | b'(' | b')' | b';' => false, + c if c.is_ascii_whitespace() => false, + _ => true, + } + } + fn is_sym_other_char(c: u8) -> bool { + match c { + b'(' | b')' | b';' => false, + c if c.is_ascii_whitespace() => false, + _ => true, + } + } + + // Skip any whitespace and any comments. + while self.pos.offset < self.buf.len() { + if self.buf[self.pos.offset].is_ascii_whitespace() { + self.pos.col += 1; + if self.buf[self.pos.offset] == b'\n' { + self.pos.line += 1; + self.pos.col = 0; + } + self.pos.offset += 1; + continue; + } + if self.buf[self.pos.offset] == b';' { + while self.pos.offset < self.buf.len() && self.buf[self.pos.offset] != b'\n' { + self.pos.offset += 1; + } + self.pos.line += 1; + self.pos.col = 0; + continue; + } + break; + } + + if self.pos.offset == self.buf.len() { + return None; + } + + let char_pos = self.pos; + match self.buf[self.pos.offset] { + b'(' => { + self.pos.offset += 1; + self.pos.col += 1; + Some((char_pos, Token::LParen)) + } + b')' => { + self.pos.offset += 1; + self.pos.col += 1; + Some((char_pos, Token::RParen)) + } + c if is_sym_first_char(c) => { + let start = self.pos.offset; + let start_pos = self.pos; + while self.pos.offset < self.buf.len() + && is_sym_other_char(self.buf[self.pos.offset]) + { + self.pos.col += 1; + self.pos.offset += 1; + } + let end = self.pos.offset; + let s = std::str::from_utf8(&self.buf[start..end]) + .expect("Only ASCII characters, should be UTF-8"); + Some((start_pos, Token::Symbol(s))) + } + c if (c >= b'0' && c <= b'9') || c == b'-' => { + let start_pos = self.pos; + let neg = if c == b'-' { + self.pos.offset += 1; + self.pos.col += 1; + true + } else { + false + }; + let mut num = 0; + while self.pos.offset < self.buf.len() + && (self.buf[self.pos.offset] >= b'0' && self.buf[self.pos.offset] <= b'9') + { + num = (num * 10) + (self.buf[self.pos.offset] - b'0') as i64; + self.pos.offset += 1; + self.pos.col += 1; + } + + let tok = if neg { + Token::Int(-num) + } else { + Token::Int(num) + }; + Some((start_pos, tok)) + } + c => panic!("Unexpected character '{}' at offset {}", c, self.pos.offset), + } + } + + fn reload(&mut self) { + if self.lookahead.is_none() && self.pos.offset < self.buf.len() { + self.lookahead = self.next_token(); + } + } + + pub fn peek(&self) -> Option<(Pos, Token<'a>)> { + self.lookahead + } + + pub fn eof(&self) -> bool { + self.lookahead.is_none() + } +} + +impl<'a> std::iter::Iterator for Lexer<'a> { + type Item = (Pos, Token<'a>); + + fn next(&mut self) -> Option<(Pos, Token<'a>)> { + let tok = self.lookahead.take(); + self.reload(); + tok + } +} + +impl<'a> Token<'a> { + pub fn is_int(&self) -> bool { + match self { + Token::Int(_) => true, + _ => false, + } + } + + pub fn is_sym(&self) -> bool { + match self { + Token::Symbol(_) => true, + _ => false, + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn lexer_basic() { + assert_eq!( + Lexer::new(";; comment\n; another\r\n \t(one two three 23 -568 )\n") + .map(|(_, tok)| tok) + .collect::>(), + vec![ + Token::LParen, + Token::Symbol("one"), + Token::Symbol("two"), + Token::Symbol("three"), + Token::Int(23), + Token::Int(-568), + Token::RParen + ] + ); + } + + #[test] + fn ends_with_sym() { + assert_eq!( + Lexer::new("asdf").map(|(_, tok)| tok).collect::>(), + vec![Token::Symbol("asdf"),] + ); + } + + #[test] + fn ends_with_num() { + assert_eq!( + Lexer::new("23").map(|(_, tok)| tok).collect::>(), + vec![Token::Int(23)], + ); + } + + #[test] + fn weird_syms() { + assert_eq!( + Lexer::new("(+ [] => !! _test!;comment\n)") + .map(|(_, tok)| tok) + .collect::>(), + vec![ + Token::LParen, + Token::Symbol("+"), + Token::Symbol("[]"), + Token::Symbol("=>"), + Token::Symbol("!!"), + Token::Symbol("_test!"), + Token::RParen, + ] + ); + } +} diff --git a/cranelift/isle/src/lower.rs b/cranelift/isle/src/lower.rs new file mode 100644 index 0000000000..52736d4f34 --- /dev/null +++ b/cranelift/isle/src/lower.rs @@ -0,0 +1,33 @@ +use crate::ir::*; +use crate::sema; + +struct LowerState<'a> { + tyenv: &'a sema::TypeEnv, + func: &'a sema::Func, + builder: FuncBuilder, + control_flow: ControlInput, +} + +pub fn lower(tyenv: &sema::TypeEnv, func: &sema::Func) -> Func { + let mut builder = FuncBuilder::default(); + let entry = builder.intern(Node::Entry); + + let mut state = LowerState { + tyenv, + func, + builder, + control_flow: ControlInput(entry, 0), + }; + + if !func.is_extern && !func.is_inline { + for case in &func.cases { + state.lower_case(case); + } + } + + state.builder.build() +} + +impl<'a> LowerState<'a> { + fn lower_case(&mut self) {} +} diff --git a/cranelift/isle/src/main.rs b/cranelift/isle/src/main.rs new file mode 100644 index 0000000000..b9364551e8 --- /dev/null +++ b/cranelift/isle/src/main.rs @@ -0,0 +1,28 @@ +#![allow(dead_code)] + +use std::io::stdin; +use std::io::Read; + +mod ast; +mod compile; +mod error; +mod ir; +mod lexer; +mod parser; +mod sema; + +fn main() -> Result<(), error::Error> { + let _ = env_logger::try_init(); + let mut input = String::new(); + stdin().read_to_string(&mut input)?; + let mut parser = parser::Parser::new("", &input[..]); + let defs = parser.parse_defs()?; + let mut compiler = compile::Compiler::new(&defs)?; + compiler.build_sequences()?; + compiler.collect_tree_summaries()?; + + for seq in compiler.to_sequences() { + println!("---\nsequence\n---\n{:?}\n", seq); + } + Ok(()) +} diff --git a/cranelift/isle/src/parser.rs b/cranelift/isle/src/parser.rs new file mode 100644 index 0000000000..a7759cc306 --- /dev/null +++ b/cranelift/isle/src/parser.rs @@ -0,0 +1,429 @@ +//! Parser for ISLE language. + +use crate::ast::*; +use crate::error::*; +use crate::lexer::{Lexer, Pos, Token}; + +#[derive(Clone, Debug)] +pub struct Parser<'a> { + filename: &'a str, + lexer: Lexer<'a>, +} + +pub type ParseResult = std::result::Result; + +impl<'a> Parser<'a> { + pub fn new(filename: &'a str, s: &'a str) -> Parser<'a> { + Parser { + filename, + lexer: Lexer::new(s), + } + } + + pub fn error(&self, pos: Pos, msg: String) -> ParseError { + ParseError { + filename: self.filename.to_string(), + pos, + msg, + } + } + + fn take bool>(&mut self, f: F) -> ParseResult> { + if let Some((pos, peek)) = self.lexer.peek() { + if !f(peek) { + return Err(self.error(pos, format!("Unexpected token {:?}", peek))); + } + self.lexer.next(); + Ok(peek) + } else { + Err(self.error(self.lexer.pos(), "Unexpected EOF".to_string())) + } + } + + fn is bool>(&self, f: F) -> bool { + if let Some((_, peek)) = self.lexer.peek() { + f(peek) + } else { + false + } + } + + fn pos(&self) -> Option { + self.lexer.peek().map(|(pos, _)| pos) + } + + fn is_lparen(&self) -> bool { + self.is(|tok| tok == Token::LParen) + } + fn is_rparen(&self) -> bool { + self.is(|tok| tok == Token::RParen) + } + fn is_sym(&self) -> bool { + self.is(|tok| tok.is_sym()) + } + fn is_int(&self) -> bool { + self.is(|tok| tok.is_int()) + } + fn is_sym_str(&self, s: &str) -> bool { + self.is(|tok| tok == Token::Symbol(s)) + } + + fn lparen(&mut self) -> ParseResult<()> { + self.take(|tok| tok == Token::LParen).map(|_| ()) + } + fn rparen(&mut self) -> ParseResult<()> { + self.take(|tok| tok == Token::RParen).map(|_| ()) + } + + fn symbol(&mut self) -> ParseResult<&'a str> { + match self.take(|tok| tok.is_sym())? { + Token::Symbol(s) => Ok(s), + _ => unreachable!(), + } + } + + fn int(&mut self) -> ParseResult { + match self.take(|tok| tok.is_int())? { + Token::Int(i) => Ok(i), + _ => unreachable!(), + } + } + + pub fn parse_defs(&mut self) -> ParseResult { + let mut defs = vec![]; + while !self.lexer.eof() { + defs.push(self.parse_def()?); + } + Ok(Defs { + defs, + filename: self.filename.to_string(), + }) + } + + fn parse_def(&mut self) -> ParseResult { + self.lparen()?; + let pos = self.pos(); + let def = match self.symbol()? { + "type" => Def::Type(self.parse_type()?), + "rule" => Def::Rule(self.parse_rule()?), + "decl" => Def::Decl(self.parse_decl()?), + "constructor" => Def::Extern(self.parse_ctor()?), + "extractor" => Def::Extern(self.parse_etor()?), + s => { + return Err(self.error(pos.unwrap(), format!("Unexpected identifier: {}", s))); + } + }; + self.rparen()?; + Ok(def) + } + + fn str_to_ident(&self, pos: Pos, s: &str) -> ParseResult { + let first = s.chars().next().unwrap(); + if !first.is_alphabetic() && first != '_' { + return Err(self.error( + pos, + format!("Identifier '{}' does not start with letter or _", s), + )); + } + if s.chars() + .skip(1) + .any(|c| !c.is_alphanumeric() && c != '_' && c != '.') + { + return Err(self.error( + pos, + format!( + "Identifier '{}' contains invalid character (not a-z, A-Z, 0-9, _, .)", + s + ), + )); + } + Ok(Ident(s.to_string())) + } + + fn parse_ident(&mut self) -> ParseResult { + let pos = self.pos(); + let s = self.symbol()?; + self.str_to_ident(pos.unwrap(), s) + } + + fn parse_type(&mut self) -> ParseResult { + let pos = self.pos(); + let name = self.parse_ident()?; + let mut is_extern = false; + if self.is_sym_str("extern") { + self.symbol()?; + is_extern = true; + } + let ty = self.parse_typevalue()?; + Ok(Type { + name, + is_extern, + ty, + pos: pos.unwrap(), + }) + } + + fn parse_typevalue(&mut self) -> ParseResult { + let pos = self.pos(); + self.lparen()?; + if self.is_sym_str("primitive") { + self.symbol()?; + let primitive_ident = self.parse_ident()?; + self.rparen()?; + Ok(TypeValue::Primitive(primitive_ident)) + } else if self.is_sym_str("enum") { + self.symbol()?; + let mut variants = vec![]; + while !self.is_rparen() { + let variant = self.parse_type_variant()?; + variants.push(variant); + } + self.rparen()?; + Ok(TypeValue::Enum(variants)) + } else { + Err(self.error(pos.unwrap(), "Unknown type definition".to_string())) + } + } + + fn parse_type_variant(&mut self) -> ParseResult { + self.lparen()?; + let name = self.parse_ident()?; + let mut fields = vec![]; + while !self.is_rparen() { + fields.push(self.parse_type_field()?); + } + self.rparen()?; + Ok(Variant { name, fields }) + } + + fn parse_type_field(&mut self) -> ParseResult { + self.lparen()?; + let name = self.parse_ident()?; + let ty = self.parse_ident()?; + self.rparen()?; + Ok(Field { name, ty }) + } + + fn parse_decl(&mut self) -> ParseResult { + let pos = self.pos(); + let term = self.parse_ident()?; + + self.lparen()?; + let mut arg_tys = vec![]; + while !self.is_rparen() { + arg_tys.push(self.parse_ident()?); + } + self.rparen()?; + + let ret_ty = self.parse_ident()?; + + Ok(Decl { + term, + arg_tys, + ret_ty, + pos: pos.unwrap(), + }) + } + + fn parse_ctor(&mut self) -> ParseResult { + let pos = self.pos(); + let term = self.parse_ident()?; + let func = self.parse_ident()?; + Ok(Extern::Constructor { + term, + func, + pos: pos.unwrap(), + }) + } + + fn parse_etor(&mut self) -> ParseResult { + let pos = self.pos(); + let term = self.parse_ident()?; + let func = self.parse_ident()?; + Ok(Extern::Extractor { + term, + func, + pos: pos.unwrap(), + }) + } + + fn parse_rule(&mut self) -> ParseResult { + let pos = self.pos(); + let prio = if self.is_int() { + Some(self.int()?) + } else { + None + }; + let pattern = self.parse_pattern()?; + let expr = self.parse_expr()?; + Ok(Rule { + pattern, + expr, + pos: pos.unwrap(), + prio, + }) + } + + fn parse_pattern(&mut self) -> ParseResult { + let pos = self.pos(); + if self.is_int() { + Ok(Pattern::ConstInt { val: self.int()? }) + } else if self.is_sym_str("_") { + self.symbol()?; + Ok(Pattern::Wildcard) + } else if self.is_sym() { + let s = self.symbol()?; + if s.starts_with("=") { + let s = &s[1..]; + let var = self.str_to_ident(pos.unwrap(), s)?; + Ok(Pattern::Var { var }) + } else { + let var = self.str_to_ident(pos.unwrap(), s)?; + if self.is_sym_str("@") { + self.symbol()?; + let subpat = Box::new(self.parse_pattern()?); + Ok(Pattern::BindPattern { var, subpat }) + } else { + Ok(Pattern::BindPattern { + var, + subpat: Box::new(Pattern::Wildcard), + }) + } + } + } else if self.is_lparen() { + self.lparen()?; + let sym = self.parse_ident()?; + let mut args = vec![]; + while !self.is_rparen() { + args.push(self.parse_pattern()?); + } + self.rparen()?; + Ok(Pattern::Term { sym, args }) + } else { + Err(self.error(pos.unwrap(), "Unexpected pattern".into())) + } + } + + fn parse_expr(&mut self) -> ParseResult { + let pos = self.pos(); + if self.is_lparen() { + self.lparen()?; + if self.is_sym_str("let") { + self.symbol()?; + self.lparen()?; + let mut defs = vec![]; + while !self.is_rparen() { + let def = self.parse_letdef()?; + defs.push(def); + } + self.rparen()?; + let body = Box::new(self.parse_expr()?); + self.rparen()?; + Ok(Expr::Let { defs, body }) + } else { + let sym = self.parse_ident()?; + let mut args = vec![]; + while !self.is_rparen() { + args.push(self.parse_expr()?); + } + self.rparen()?; + Ok(Expr::Term { sym, args }) + } + } else if self.is_sym() { + let name = self.parse_ident()?; + Ok(Expr::Var { name }) + } else if self.is_int() { + let val = self.int()?; + Ok(Expr::ConstInt { val }) + } else { + Err(self.error(pos.unwrap(), "Invalid expression".into())) + } + } + + fn parse_letdef(&mut self) -> ParseResult { + self.lparen()?; + let var = self.parse_ident()?; + let ty = self.parse_ident()?; + let val = Box::new(self.parse_expr()?); + self.rparen()?; + Ok(LetDef { var, ty, val }) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn parse_type() { + let text = r" + ;; comment + (type Inst extern (enum + (Alu (a Reg) (b Reg) (dest Reg)) + (Load (a Reg) (dest Reg)))) + (type u32 (primitive u32)) + "; + let defs = Parser::new("(none)", text) + .parse_defs() + .expect("should parse"); + assert_eq!( + defs, + Defs { + filename: "(none)".to_string(), + defs: vec![ + Def::Type(Type { + name: Ident("Inst".to_string()), + is_extern: true, + ty: TypeValue::Enum(vec![ + Variant { + name: Ident("Alu".to_string()), + fields: vec![ + Field { + name: Ident("a".to_string()), + ty: Ident("Reg".to_string()), + }, + Field { + name: Ident("b".to_string()), + ty: Ident("Reg".to_string()), + }, + Field { + name: Ident("dest".to_string()), + ty: Ident("Reg".to_string()), + }, + ], + }, + Variant { + name: Ident("Load".to_string()), + fields: vec![ + Field { + name: Ident("a".to_string()), + ty: Ident("Reg".to_string()), + }, + Field { + name: Ident("dest".to_string()), + ty: Ident("Reg".to_string()), + }, + ], + } + ]), + pos: Pos { + offset: 42, + line: 4, + col: 18, + }, + }), + Def::Type(Type { + name: Ident("u32".to_string()), + is_extern: false, + ty: TypeValue::Primitive(Ident("u32".to_string())), + pos: Pos { + offset: 167, + line: 7, + col: 18, + }, + }), + ] + } + ); + } +} diff --git a/cranelift/isle/src/sema.rs b/cranelift/isle/src/sema.rs new file mode 100644 index 0000000000..a11faccc49 --- /dev/null +++ b/cranelift/isle/src/sema.rs @@ -0,0 +1,862 @@ +//! Semantic analysis. + +use crate::ast; +use crate::error::*; +use crate::lexer::Pos; +use std::collections::HashMap; + +pub type SemaResult = std::result::Result; + +#[macro_export] +macro_rules! declare_id { + ($name:ident) => { + #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] + pub struct $name(pub usize); + impl $name { + pub fn index(self) -> usize { + self.0 + } + } + }; +} + +declare_id!(Sym); +declare_id!(TypeId); +declare_id!(VariantId); +declare_id!(FieldId); +declare_id!(TermId); +declare_id!(RuleId); +declare_id!(VarId); + +#[derive(Clone, Debug)] +pub struct TypeEnv { + pub filename: String, + pub syms: Vec, + pub sym_map: HashMap, + pub types: Vec, + pub type_map: HashMap, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Type { + Primitive(TypeId, Sym), + Enum { + name: Sym, + id: TypeId, + is_extern: bool, + variants: Vec, + pos: Pos, + }, +} + +impl Type { + fn name<'a>(&self, tyenv: &'a TypeEnv) -> &'a str { + match self { + Self::Primitive(_, name) | Self::Enum { name, .. } => &tyenv.syms[name.index()], + } + } +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Variant { + pub name: Sym, + pub id: VariantId, + pub fields: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Field { + pub name: Sym, + pub id: FieldId, + pub ty: TypeId, +} + +#[derive(Clone, Debug)] +pub struct TermEnv { + pub terms: Vec, + pub term_map: HashMap, + pub rules: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Term { + pub id: TermId, + pub name: Sym, + pub arg_tys: Vec, + pub ret_ty: TypeId, + pub kind: TermKind, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum TermKind { + EnumVariant { + variant: VariantId, + }, + Regular { + // Producer and consumer rules are catalogued separately after + // building Sequences. Here we just record whether an + // extractor and/or constructor is known. + extractor: Option, + constructor: Option, + }, +} + +#[derive(Clone, Debug)] +pub struct Rule { + pub id: RuleId, + pub lhs: Pattern, + pub rhs: Expr, + pub prio: Option, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Pattern { + BindPattern(TypeId, VarId, Box), + Var(TypeId, VarId), + ConstInt(TypeId, i64), + Term(TypeId, TermId, Vec), + Wildcard(TypeId), +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Expr { + Term(TypeId, TermId, Vec), + Var(TypeId, VarId), + ConstInt(TypeId, i64), + Let(TypeId, Vec<(VarId, TypeId, Box)>, Box), +} + +impl Pattern { + pub fn ty(&self) -> TypeId { + match self { + &Self::BindPattern(t, ..) => t, + &Self::Var(t, ..) => t, + &Self::ConstInt(t, ..) => t, + &Self::Term(t, ..) => t, + &Self::Wildcard(t, ..) => t, + } + } +} + +impl Expr { + pub fn ty(&self) -> TypeId { + match self { + &Self::Term(t, ..) => t, + &Self::Var(t, ..) => t, + &Self::ConstInt(t, ..) => t, + &Self::Let(t, ..) => t, + } + } +} + +impl TypeEnv { + pub fn from_ast(defs: &ast::Defs) -> SemaResult { + let mut tyenv = TypeEnv { + filename: defs.filename.clone(), + syms: vec![], + sym_map: HashMap::new(), + types: vec![], + type_map: HashMap::new(), + }; + + // Traverse defs, assigning type IDs to type names. We'll fill + // in types on a second pass. + for def in &defs.defs { + match def { + &ast::Def::Type(ref td) => { + let tid = TypeId(tyenv.type_map.len()); + let name = tyenv.intern_mut(&td.name); + if tyenv.type_map.contains_key(&name) { + return Err(tyenv.error( + td.pos, + format!("Type name defined more than once: '{}'", td.name.0), + )); + } + tyenv.type_map.insert(name, tid); + } + _ => {} + } + } + + // Now lower AST nodes to type definitions, raising errors + // where typenames of fields are undefined or field names are + // duplicated. + let mut tid = 0; + for def in &defs.defs { + match def { + &ast::Def::Type(ref td) => { + let ty = tyenv.type_from_ast(TypeId(tid), td)?; + tyenv.types.push(ty); + tid += 1; + } + _ => {} + } + } + + Ok(tyenv) + } + + fn type_from_ast(&mut self, tid: TypeId, ty: &ast::Type) -> SemaResult { + let name = self.intern(&ty.name).unwrap(); + match &ty.ty { + &ast::TypeValue::Primitive(ref id) => Ok(Type::Primitive(tid, self.intern_mut(id))), + &ast::TypeValue::Enum(ref ty_variants) => { + let mut variants = vec![]; + for variant in ty_variants { + let combined_ident = ast::Ident(format!("{}.{}", ty.name.0, variant.name.0)); + let var_name = self.intern_mut(&combined_ident); + let id = VariantId(variants.len()); + if variants.iter().any(|v: &Variant| v.name == var_name) { + return Err(self.error( + ty.pos, + format!("Duplicate variant name in type: '{}'", variant.name.0), + )); + } + let mut fields = vec![]; + for field in &variant.fields { + let field_name = self.intern_mut(&field.name); + if fields.iter().any(|f: &Field| f.name == field_name) { + return Err(self.error( + ty.pos, + format!( + "Duplicate field name '{}' in variant '{}' of type", + field.name.0, variant.name.0 + ), + )); + } + let field_ty = self.intern_mut(&field.ty); + let field_tid = match self.type_map.get(&field_ty) { + Some(tid) => *tid, + None => { + return Err(self.error( + ty.pos, + format!( + "Unknown type '{}' for field '{}' in variant '{}'", + field.ty.0, field.name.0, variant.name.0 + ), + )); + } + }; + fields.push(Field { + name: field_name, + id: FieldId(fields.len()), + ty: field_tid, + }); + } + variants.push(Variant { + name: var_name, + id, + fields, + }); + } + Ok(Type::Enum { + name, + id: tid, + is_extern: ty.is_extern, + variants, + pos: ty.pos, + }) + } + } + } + + fn error(&self, pos: Pos, msg: String) -> SemaError { + SemaError { + filename: self.filename.clone(), + pos, + msg, + } + } + + pub fn intern_mut(&mut self, ident: &ast::Ident) -> Sym { + if let Some(s) = self.sym_map.get(&ident.0).cloned() { + s + } else { + let s = Sym(self.syms.len()); + self.syms.push(ident.0.clone()); + self.sym_map.insert(ident.0.clone(), s); + s + } + } + + pub fn intern(&self, ident: &ast::Ident) -> Option { + self.sym_map.get(&ident.0).cloned() + } +} + +struct Bindings { + next_var: usize, + vars: Vec, +} + +struct BoundVar { + name: Sym, + id: VarId, + ty: TypeId, +} + +impl TermEnv { + pub fn from_ast(tyenv: &mut TypeEnv, defs: &ast::Defs) -> SemaResult { + let mut env = TermEnv { + terms: vec![], + term_map: HashMap::new(), + rules: vec![], + }; + + env.collect_term_sigs(tyenv, defs)?; + env.collect_enum_variant_terms(tyenv)?; + env.collect_rules(tyenv, defs)?; + + Ok(env) + } + + fn collect_term_sigs(&mut self, tyenv: &mut TypeEnv, defs: &ast::Defs) -> SemaResult<()> { + for def in &defs.defs { + match def { + &ast::Def::Decl(ref decl) => { + let tid = TermId(self.terms.len()); + let name = tyenv.intern_mut(&decl.term); + if self.term_map.contains_key(&name) { + return Err( + tyenv.error(decl.pos, format!("Duplicate decl for '{}'", decl.term.0)) + ); + } + self.term_map.insert(name, tid); + + let arg_tys = decl + .arg_tys + .iter() + .map(|id| { + let sym = tyenv.intern_mut(id); + tyenv.type_map.get(&sym).cloned().ok_or_else(|| { + tyenv.error(decl.pos, format!("Unknown arg type: '{}'", id.0)) + }) + }) + .collect::>>()?; + let ret_ty = { + let sym = tyenv.intern_mut(&decl.ret_ty); + tyenv.type_map.get(&sym).cloned().ok_or_else(|| { + tyenv.error( + decl.pos, + format!("Unknown return type: '{}'", decl.ret_ty.0), + ) + })? + }; + + self.terms.push(Term { + id: tid, + name, + arg_tys, + ret_ty, + kind: TermKind::Regular { + extractor: None, + constructor: None, + }, + }); + } + _ => {} + } + } + + Ok(()) + } + + fn collect_enum_variant_terms(&mut self, tyenv: &mut TypeEnv) -> SemaResult<()> { + for ty in &tyenv.types { + match ty { + &Type::Enum { + pos, + id, + ref variants, + .. + } => { + for variant in variants { + if self.term_map.contains_key(&variant.name) { + return Err(tyenv.error( + pos, + format!( + "Duplicate enum variant constructor: '{}'", + tyenv.syms[variant.name.index()] + ), + )); + } + let tid = TermId(self.terms.len()); + let arg_tys = variant.fields.iter().map(|fld| fld.ty).collect::>(); + let ret_ty = id; + self.terms.push(Term { + id: tid, + name: variant.name, + arg_tys, + ret_ty, + kind: TermKind::EnumVariant { + variant: variant.id, + }, + }); + self.term_map.insert(variant.name, tid); + } + } + _ => {} + } + } + + Ok(()) + } + + fn collect_rules(&mut self, tyenv: &mut TypeEnv, defs: &ast::Defs) -> SemaResult<()> { + for def in &defs.defs { + match def { + &ast::Def::Rule(ref rule) => { + let mut bindings = Bindings { + next_var: 0, + vars: vec![], + }; + + let (lhs, ty) = self.translate_pattern( + tyenv, + rule.pos, + &rule.pattern, + None, + &mut bindings, + )?; + let rhs = + self.translate_expr(tyenv, rule.pos, &rule.expr, ty, &mut bindings)?; + let rid = RuleId(self.rules.len()); + self.rules.push(Rule { + id: rid, + lhs, + rhs, + prio: rule.prio, + }); + } + &ast::Def::Extern(ast::Extern::Constructor { + ref term, + ref func, + pos, + }) => { + let term_sym = tyenv.intern_mut(term); + let func_sym = tyenv.intern_mut(func); + let term_id = match self.term_map.get(&term_sym) { + Some(term) => term, + None => { + return Err(tyenv.error( + pos, + format!("Constructor declared on undefined term '{}'", term.0), + )) + } + }; + match &mut self.terms[term_id.index()].kind { + &mut TermKind::EnumVariant { .. } => { + return Err(tyenv.error( + pos, + format!("Constructor defined on enum type '{}'", term.0), + )); + } + &mut TermKind::Regular { + ref mut constructor, + .. + } => { + if constructor.is_some() { + return Err(tyenv.error( + pos, + format!( + "Constructor defined more than once on term '{}'", + term.0 + ), + )); + } + *constructor = Some(func_sym); + } + } + } + &ast::Def::Extern(ast::Extern::Extractor { + ref term, + ref func, + pos, + }) => { + let term_sym = tyenv.intern_mut(term); + let func_sym = tyenv.intern_mut(func); + let term_id = match self.term_map.get(&term_sym) { + Some(term) => term, + None => { + return Err(tyenv.error( + pos, + format!("Extractor declared on undefined term '{}'", term.0), + )) + } + }; + match &mut self.terms[term_id.index()].kind { + &mut TermKind::EnumVariant { .. } => { + return Err(tyenv.error( + pos, + format!("Extractor defined on enum type '{}'", term.0), + )); + } + &mut TermKind::Regular { + ref mut extractor, .. + } => { + if extractor.is_some() { + return Err(tyenv.error( + pos, + format!( + "Extractor defined more than once on term '{}'", + term.0 + ), + )); + } + *extractor = Some(func_sym); + } + } + } + _ => {} + } + } + + Ok(()) + } + + fn translate_pattern( + &self, + tyenv: &mut TypeEnv, + pos: Pos, + pat: &ast::Pattern, + expected_ty: Option, + bindings: &mut Bindings, + ) -> SemaResult<(Pattern, TypeId)> { + match pat { + // TODO: flag on primitive type decl indicating it's an integer type? + &ast::Pattern::ConstInt { val } => { + let ty = expected_ty.ok_or_else(|| { + tyenv.error(pos, "Need an implied type for an integer constant".into()) + })?; + Ok((Pattern::ConstInt(ty, val), ty)) + } + &ast::Pattern::Wildcard => { + let ty = expected_ty.ok_or_else(|| { + tyenv.error(pos, "Need an implied type for a wildcard".into()) + })?; + Ok((Pattern::Wildcard(ty), ty)) + } + &ast::Pattern::BindPattern { + ref var, + ref subpat, + } => { + // Do the subpattern first so we can resolve the type for sure. + let (subpat, ty) = + self.translate_pattern(tyenv, pos, &*subpat, expected_ty, bindings)?; + + let name = tyenv.intern_mut(var); + if bindings.vars.iter().any(|bv| bv.name == name) { + return Err(tyenv.error( + pos, + format!("Rebound variable name in LHS pattern: '{}'", var.0), + )); + } + let id = VarId(bindings.next_var); + bindings.next_var += 1; + bindings.vars.push(BoundVar { name, id, ty }); + + Ok((Pattern::BindPattern(ty, id, Box::new(subpat)), ty)) + } + &ast::Pattern::Var { ref var } => { + // Look up the variable; it must already have been bound. + let name = tyenv.intern_mut(var); + let bv = match bindings.vars.iter().rev().find(|bv| bv.name == name) { + None => { + return Err(tyenv.error( + pos, + format!( + "Unknown variable '{}' in bound-var pattern '={}'", + var.0, var.0 + ), + )) + } + Some(bv) => bv, + }; + let ty = match expected_ty { + None => bv.ty, + Some(expected_ty) if expected_ty == bv.ty => bv.ty, + Some(expected_ty) => { + return Err(tyenv.error(pos, format!("Mismatched types: pattern expects type '{}' but already-bound var '{}' has type '{}'", tyenv.types[expected_ty.index()].name(tyenv), var.0, tyenv.types[bv.ty.index()].name(tyenv)))); + } + }; + Ok((Pattern::Var(ty, bv.id), ty)) + } + &ast::Pattern::Term { ref sym, ref args } => { + let name = tyenv.intern_mut(&sym); + // Look up the term. + let tid = self.term_map.get(&name).ok_or_else(|| { + tyenv.error(pos, format!("Unknown term in pattern: '{}'", sym.0)) + })?; + + // Get the return type and arg types. Verify the + // expected type of this pattern, if any, against the + // return type of the term. + let ret_ty = self.terms[tid.index()].ret_ty; + let ty = match expected_ty { + None => ret_ty, + Some(expected_ty) if expected_ty == ret_ty => ret_ty, + Some(expected_ty) => { + return Err(tyenv.error(pos, format!("Mismatched types: pattern expects type '{}' but term has return type '{}'", tyenv.types[expected_ty.index()].name(tyenv), tyenv.types[ret_ty.index()].name(tyenv)))); + } + }; + + // Check that we have the correct argument count. + if self.terms[tid.index()].arg_tys.len() != args.len() { + return Err(tyenv.error( + pos, + format!( + "Incorrect argument count for term '{}': got {}, expect {}", + sym.0, + args.len(), + self.terms[tid.index()].arg_tys.len() + ), + )); + } + + // Resolve subpatterns. + let mut subpats = vec![]; + for (i, arg) in args.iter().enumerate() { + let arg_ty = self.terms[tid.index()].arg_tys[i]; + let (subpat, _) = + self.translate_pattern(tyenv, pos, arg, Some(arg_ty), bindings)?; + subpats.push(subpat); + } + + Ok((Pattern::Term(ty, *tid, subpats), ty)) + } + } + } + + fn translate_expr( + &self, + tyenv: &mut TypeEnv, + pos: Pos, + expr: &ast::Expr, + ty: TypeId, + bindings: &mut Bindings, + ) -> SemaResult { + match expr { + &ast::Expr::Term { ref sym, ref args } => { + // Look up the term. + let name = tyenv.intern_mut(&sym); + // Look up the term. + let tid = self.term_map.get(&name).ok_or_else(|| { + tyenv.error(pos, format!("Unknown term in pattern: '{}'", sym.0)) + })?; + + // Get the return type and arg types. Verify the + // expected type of this pattern, if any, against the + // return type of the term. + let ret_ty = self.terms[tid.index()].ret_ty; + if ret_ty != ty { + return Err(tyenv.error(pos, format!("Mismatched types: expression expects type '{}' but term has return type '{}'", tyenv.types[ty.index()].name(tyenv), tyenv.types[ret_ty.index()].name(tyenv)))); + } + + // Check that we have the correct argument count. + if self.terms[tid.index()].arg_tys.len() != args.len() { + return Err(tyenv.error( + pos, + format!( + "Incorrect argument count for term '{}': got {}, expect {}", + sym.0, + args.len(), + self.terms[tid.index()].arg_tys.len() + ), + )); + } + + // Resolve subexpressions. + let mut subexprs = vec![]; + for (i, arg) in args.iter().enumerate() { + let arg_ty = self.terms[tid.index()].arg_tys[i]; + let subexpr = self.translate_expr(tyenv, pos, arg, arg_ty, bindings)?; + subexprs.push(subexpr); + } + + Ok(Expr::Term(ty, *tid, subexprs)) + } + &ast::Expr::Var { ref name } => { + let sym = tyenv.intern_mut(name); + // Look through bindings, innermost (most recent) first. + let bv = match bindings.vars.iter().rev().find(|b| b.name == sym) { + None => { + return Err(tyenv.error(pos, format!("Unknown variable '{}'", name.0))); + } + Some(bv) => bv, + }; + + // Verify type. + if bv.ty != ty { + return Err(tyenv.error( + pos, + format!( + "Variable '{}' has type {} but we need {} in context", + name.0, + tyenv.types[bv.ty.index()].name(tyenv), + tyenv.types[ty.index()].name(tyenv) + ), + )); + } + + Ok(Expr::Var(bv.ty, bv.id)) + } + &ast::Expr::ConstInt { val } => Ok(Expr::ConstInt(ty, val)), + &ast::Expr::Let { ref defs, ref body } => { + let orig_binding_len = bindings.vars.len(); + + // For each new binding... + let mut let_defs = vec![]; + for def in defs { + // Check that the given variable name does not already exist. + let name = tyenv.intern_mut(&def.var); + if bindings.vars.iter().any(|bv| bv.name == name) { + return Err( + tyenv.error(pos, format!("Variable '{}' already bound", def.var.0)) + ); + } + + // Look up the type. + let tysym = match tyenv.intern(&def.ty) { + Some(ty) => ty, + None => { + return Err(tyenv.error( + pos, + format!("Unknown type {} for variable '{}'", def.ty.0, def.var.0), + )) + } + }; + let tid = match tyenv.type_map.get(&tysym) { + Some(tid) => *tid, + None => { + return Err(tyenv.error( + pos, + format!("Unknown type {} for variable '{}'", def.ty.0, def.var.0), + )) + } + }; + + // Evaluate the variable's value. + let val = Box::new(self.translate_expr(tyenv, pos, &def.val, ty, bindings)?); + + // Bind the var with the given type. + let id = VarId(bindings.next_var); + bindings.next_var += 1; + bindings.vars.push(BoundVar { name, id, ty: tid }); + + let_defs.push((id, ty, val)); + } + + // Evaluate the body, expecting the type of the overall let-expr. + let body = Box::new(self.translate_expr(tyenv, pos, body, ty, bindings)?); + let body_ty = body.ty(); + + // Pop the bindings. + bindings.vars.truncate(orig_binding_len); + + Ok(Expr::Let(body_ty, let_defs, body)) + } + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::ast::Ident; + use crate::parser::Parser; + + #[test] + fn build_type_env() { + let text = r" + (type u32 (primitive u32)) + (type A extern (enum (B (f1 u32) (f2 u32)) (C (f1 u32)))) + "; + let ast = Parser::new("file.isle", text) + .parse_defs() + .expect("should parse"); + let tyenv = TypeEnv::from_ast(&ast).expect("should not have type-definition errors"); + + let sym_a = tyenv.intern(&Ident("A".to_string())).unwrap(); + let sym_b = tyenv.intern(&Ident("A.B".to_string())).unwrap(); + let sym_c = tyenv.intern(&Ident("A.C".to_string())).unwrap(); + let sym_u32 = tyenv.intern(&Ident("u32".to_string())).unwrap(); + let sym_f1 = tyenv.intern(&Ident("f1".to_string())).unwrap(); + let sym_f2 = tyenv.intern(&Ident("f2".to_string())).unwrap(); + + assert_eq!(tyenv.type_map.get(&sym_u32).unwrap(), &TypeId(0)); + assert_eq!(tyenv.type_map.get(&sym_a).unwrap(), &TypeId(1)); + + assert_eq!( + tyenv.types, + vec![ + Type::Primitive(TypeId(0), sym_u32), + Type::Enum { + name: sym_a, + id: TypeId(1), + is_extern: true, + variants: vec![ + Variant { + name: sym_b, + id: VariantId(0), + fields: vec![ + Field { + name: sym_f1, + id: FieldId(0), + ty: TypeId(0), + }, + Field { + name: sym_f2, + id: FieldId(1), + ty: TypeId(0), + }, + ], + }, + Variant { + name: sym_c, + id: VariantId(1), + fields: vec![Field { + name: sym_f1, + id: FieldId(0), + ty: TypeId(0), + },], + }, + ], + pos: Pos { + offset: 58, + line: 3, + col: 18, + }, + }, + ] + ); + } + + #[test] + fn build_rules() { + let text = r" + (type u32 (primitive u32)) + (type A extern (enum (B (f1 u32) (f2 u32)) (C (f1 u32)))) + + (decl T1 (A) u32) + (decl T2 (A A) A) + (decl T3 (u32) A) + + (constructor T1 t1_ctor) + (extractor T2 t2_etor) + + (rule + (T1 _) 1) + (rule + (T2 x =x) (T3 42)) + (rule + (T3 1) (A.C 2)) + (rule -1 + (T3 _) (A.C 3)) + "; + let ast = Parser::new("file.isle", text) + .parse_defs() + .expect("should parse"); + let mut tyenv = TypeEnv::from_ast(&ast).expect("should not have type-definition errors"); + let _ = TermEnv::from_ast(&mut tyenv, &ast).expect("could not typecheck rules"); + } +}