diff --git a/cranelift/peepmatic/Cargo.toml b/cranelift/peepmatic/Cargo.toml new file mode 100644 index 0000000000..843b06a1a2 --- /dev/null +++ b/cranelift/peepmatic/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "peepmatic" +version = "0.1.0" +authors = ["Nick Fitzgerald "] +edition = "2018" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +anyhow = "1.0.27" +peepmatic-automata = { version = "0.1.0", path = "crates/automata", features = ["dot"] } +peepmatic-macro = { version = "0.1.0", path = "crates/macro" } +peepmatic-runtime = { version = "0.1.0", path = "crates/runtime", features = ["construct"] } +wast = "13.0.0" +z3 = { version = "0.5.0", features = ["static-link-z3"] } diff --git a/cranelift/peepmatic/README.md b/cranelift/peepmatic/README.md new file mode 100644 index 0000000000..550f6d310d --- /dev/null +++ b/cranelift/peepmatic/README.md @@ -0,0 +1,423 @@ +
+

peepmatic

+ +

+ + peepmatic is a DSL and compiler for peephole optimizers for + Cranelift. + +

+ + + +
+ + + + + +- [About](#about) +- [Example](#example) +- [A DSL for Optimizations](#a-dsl-for-optimizations) + - [Variables](#variables) + - [Constants](#constants) + - [Nested Patterns](#nested-patterns) + - [Preconditions and Unquoting](#preconditions-and-unquoting) + - [Bit Widths](#bit-widths) +- [Implementation](#implementation) + - [Parsing](#parsing) + - [Type Checking](#type-checking) + - [Linearization](#linearization) + - [Automatization](#automatization) +- [References](#references) +- [Acknowledgments](#acknowledgments) + + + +## About + +Peepmatic is a DSL for peephole optimizations and compiler for generating +peephole optimizers from them. The user writes a set of optimizations in the +DSL, and then `peepmatic` compiles the set of optimizations into an efficient +peephole optimizer: + +``` +DSL ----peepmatic----> Peephole Optimizer +``` + +The generated peephole optimizer has all of its optimizations' left-hand sides +collapsed into a compact automata that makes matching candidate instruction +sequences fast. + +The DSL's optimizations may be written by hand or discovered mechanically with a +superoptimizer like [Souper][]. Eventually, `peepmatic` should have a verifier +that ensures that the DSL's optimizations are sound, similar to what [Alive][] +does for LLVM optimizations. + +Currently, `peepmatic` is targeting peephole optimizers that operate on +Cranelift's clif intermediate representation. The intended next target is +Cranelift's new backend's "vcode" intermediate representation. Supporting +non-Cranelift targets is not a goal. + +[Cranelift]: https://github.com/bytecodealliance/wasmtime/tree/master/cranelift#readme +[Souper]: https://github.com/google/souper +[Alive]: https://github.com/AliveToolkit/alive2 + +## Example + +This snippet of our DSL describes optimizations for removing redundant +bitwise-or instructions that are no-ops: + +```lisp +(=> (bor $x (bor $x $y)) + (bor $x $y)) + +(=> (bor $y (bor $x $y)) + (bor $x $y)) + +(=> (bor (bor $x $y) $x) + (bor $x $y)) + +(=> (bor (bor $x $y) $y) + (bor $x $y)) +``` + +When compiled into a peephole optimizer automaton, they look like this: + +![](examples/redundant-bor.png) + +## A DSL for Optimizations + +A single peephole optimization has two parts: + +1. A **left-hand side** that describes candidate instruction sequences that the + optimization applies to. +2. A **right-hand side** that contains the new instruction sequence that + replaces old instruction sequences that the left-hand side matched. + +A left-hand side may bind sub-expressions to variables and the right-hand side +may contain those bound variables to reuse the sub-expressions. The operations +inside the left-hand and right-hand sides are a subset of clif operations. + +Let's take a look at an example: + +```lisp +(=> (imul $x 2) + (ishl $x 1)) +``` + +As you can see, the DSL uses S-expressions. (S-expressions are easy to parse and +we also have a bunch of nice parsing infrastructure for S-expressions already +for our [`wat`][wat] and [`wast`][wast] crates.) + +[wat]: https://crates.io/crates/wat +[wast]: https://crates.io/crates/wast + +The left-hand side of this optimization is `(imul $x 2)`. It matches integer +multiplication operations where a value is multiplied by the constant two. The +value multiplied by two is bound to the variable `$x`. + +The right-hand side of this optimization is `(ishl $x 1)`. It reuses the `$x` +variable that was bound in the left-hand side. + +This optimization replaces expressions of the form `x * 2` with `x << 1`. This +is sound because multiplication by two is the same as shifting left by one for +binary integers, and it is desirable because a shift-left instruction executes +in fewer cycles than a multiplication. + +The general form of an optimization is: + +```lisp +(=> ) +``` + +### Variables + +Variables begin with a dollar sign and are followed by lowercase letters, +numbers, hyphens, and underscores: `$x`, `$y`, `$my-var`, `$operand2`. + +Left-hand side patterns may contain variables that match any kind of +sub-expression and give it a name so that it may be reused in the right-hand +side. + +```lisp +;; Replace `x + 0` with simply `x`. +(=> (iadd $x 0) + $x) +``` + +Within a pattern, every occurrence of a variable with the same name must match +the same value. That is `(iadd $x $x)` matches `(iadd 1 1)` but does not match +`(iadd 1 2)`. This lets us write optimizations such as this: + +```lisp +;; Xor'ing a value with itself is always zero. +(=> (bxor $x $x) + (iconst 0)) +``` + +### Constants + +We've already seen specific integer literals and wildcard variables in patterns, +but we can also match any constant. These are written similar to variables, but +use uppercase letters rather than lowercase: `$C`, `$MY-CONST`, and `$OPERAND2`. + +For example, we can use constant patterns to combine an `iconst` and `iadd` into +a single `iadd_imm` instruction: + +```lisp +(=> (iadd (iconst $C) $x) + (iadd_imm $C $x)) +``` + +### Nested Patterns + +Patterns can also match nested operations with their own nesting: + +```lisp +(=> (bor $x (bor $x $y)) + (bor $x $y)) +``` + +### Preconditions and Unquoting + +Let's reconsider our first example optimization: + +```lisp +(=> (imul $x 2) + (ishl $x 1)) +``` + +This optimization is a little too specific. Here is another version of this +optimization that we'd like to support: + +```lisp +(=> (imul $x 4) + (ishl $x 2)) +``` + +We don't want to have to write out all instances of this general class of +optimizations! That would be a lot of repetition and could also bloat the size +of our generated peephole optimizer's matching automata. + +Instead, we can generalize this optimization by matching any multiplication by a +power of two constant `C` and replacing it with a shift left of `log2(C)`. + +First, rather than match `2` directly, we want to match any constant variable `C`: + +```lisp +(imul $x $C) +``` + +Note that variables begin with lowercase letters, while constants begin with +uppercase letters. Both the constant pattern `$C` and variable pattern `$x` will +match `5`, but only the variable pattern `$x` will match a whole sub-expression +like `(iadd 1 2)`. The constant pattern `$C` only matches constant values. + +Next, we augment our left-hand side's pattern with a **precondition** that the +constant `$C` must be a power of two. Preconditions are introduced by wrapping +a pattern in the `when` form: + +```lisp +;; Our new left-hand side, augmenting a pattern with a precondition! +(when + ;; The pattern matching multiplication by a constant value. + (imul $x $C) + + ;; The precondition that $C must be a power of two. + (is-power-of-two $C)) + ``` + +In the right-hand side, we use **unquoting** to perform compile-time evaluation +of `log2($C)`. Unquoting is done with the `$(...)` form: + +```lisp +;; Our new right-hand side, using unqouting to do compile-time evaluation of +;; constants that were matched and bound in the left-hand side! +(ishl $x $(log2 $C)) +``` + +Finally, here is the general optimization putting our new left-hand and +right-hand sides together: + +```lisp +(=> (when (imul $x $C) + (is-power-of-two $C)) + (ishl $x $(log2 $C))) +``` + +### Bit Widths + +Similar to how Cranelift's instructions are bit-width polymorphic, `peepmatic` +optimizations are also bit-width polymorphic. Unless otherwise specified, a +pattern will match expressions manipulating `i32`s just the same as expressions +manipulating `i64`s, etc... An optimization that doesn't constrain its pattern's +bit widths must be valid for all bit widths: + +* 1 +* 8 +* 16 +* 32 +* 64 +* 128 + +To constrain an optimization to only match `i32`s, for example, you can use the +`bit-width` precondition: + +```lisp +(=> (when (iadd $x $y) + (bit-width $x 32) + (bit-width $y 32)) + ...) +``` + +Alternatively, you can ascribe a type to an operation by putting the type inside +curly brackets after the operator, like this: + +```lisp +(=> (when (sextend{i64} (ireduce{i32} $x)) + (bit-width $x 64)) + (sshr (ishl $x 32) 32)) +``` + +## Implementation + +Peepmatic has roughly four phases: + +1. Parsing +2. Type Checking +3. Linearization +4. Automatization + +(I say "roughly" because there are a couple micro-passes that happen after +linearization and before automatization. But those are the four main phases.) + +### Parsing + +Parsing transforms the DSL source text into an abstract syntax tree (AST). + +We use [the `wast` crate][wast]. It gives us nicely formatted errors with source +context, as well as some other generally nice-to-have parsing infrastructure. + +Relevant source files: + +* `src/parser.rs` +* `src/ast.rs` + +[wast]: https://crates.io/crates/wast + +### Type Checking + +Type checking operates on the AST. It checks that types and bit widths in the +optimizations are all valid. For example, it ensures that the type and bit width +of an optimization's left-hand side is the same as its right-hand side, because +it doesn't make sense to replace an integer expression with a boolean +expression. + +After type checking is complete, certain AST nodes are assigned a type and bit +width, that are later used in linearization and when matching and applying +optimizations. + +We walk the AST and gather type constraints. Every constraint is associated with +a span in the source file. We hand these constraints off to Z3. In the case that +there are type errors (i.e. Z3 returns `unsat`), we get the constraints that are +in conflict with each othe via `z3::Solver::get_unsat_core` and report the type +errors to the user, with the source context, thanks to the constraints' +associated spans. + +Using Z3 not only makes implementing type checking easier than it otherwise +would be, but makes it that much easier to extend type checking with searching +for counterexample inputs in the future. That is, inputs for which the RHS is +not equivalent to the LHS, implying that the optimization is unsound. + +Relevant source files: + +* `src/verify.rs` + +### Linearization + +Linearization takes the AST of optimizations and converts each optimization into +a linear form. The goal is to make automaton construction easier in the +automatization step, as well as simplifying the language to make matching and +applying optimizations easier. + +Each optimization's left-hand side is converted into a sequence of + +* match operation, +* path to the instruction/value/immediate to which the operation is applied, and +* expected result of the operation. + +All match operations must have the expected result for the optimization to be +applicable to an instruction sequence. + +Each optimization's right-hand side is converted into a sequence of build +actions. These are commands that describe how to construct the right-hand side, +given that the left-hand side has been matched. + +Relevant source files: + +* `src/linearize.rs` +* `src/linear_passes.rs` +* `crates/runtime/src/linear.rs` + +### Automatization + +Automatization takes a set of linear optimizations and combines them into a +transducer automaton. This automaton is the final, compiled peephole +optimizations. The goal is to de-duplicate as much as we can from all the linear +optimizations, producing as compact and cache-friendly a representation as we +can. + +Plain automata can tell you whether it matches an input string. It can be +thought of as a compact representation of a set of strings. A transducer is a +type of automaton that doesn't just match input strings, but can map them to +output values. It can be thought of as a compact representation of a dictionary +or map. By using transducers, we de-duplicate not only the prefix and suffix of +the match operations, but also the right-hand side build actions. + +Each state in the emitted transducers is associated with a match operation and +path. The transitions out of that state are over the result of the match +operation. Each transition optionally accumulates some RHS build actions. By the +time we reach a final state, the RHS build actions are complete and can be +interpreted to apply the matched optimization. + +The relevant source files for constructing the transducer automaton are: + +* `src/automatize.rs` +* `crates/automata/src/lib.rs` + +The relevant source files for the runtime that interprets the transducers and +applies optimizations are: + +* `crates/runtime/src/optimizations.rs` +* `crates/runtime/src/optimizer.rs` + +## References + +I found these resources helpful when designing `peepmatic`: + +* [Extending tree pattern matching for application to peephole + optimizations](https://pure.tue.nl/ws/portalfiles/portal/125543109/Thesis_JanekvOirschot.pdf) + by van Oirschot + +* [Interpreted Pattern Match Execution for + MLIR](https://drive.google.com/drive/folders/1hb_sXbdMbIz95X-aaa6Vf5wSYRwsJuve) + by Jeff Niu + +* [Direct Construction of Minimal Acyclic Subsequential + Transducers](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.24.3698&rep=rep1&type=pdf) + by Mihov and Maurel + +* [Index 1,600,000,000 Keys with Automata and + Rust](https://blog.burntsushi.net/transducers/) and [the `fst` + crate](https://crates.io/crates/fst) by Andrew Gallant + +## Acknowledgments + +Thanks to [Jubi Taneja], [Dan Gohman], [John Regehr], and [Nuno Lopes] for their +input in design discussions and for sharing helpful resources! + +[Jubi Taneja]: https://www.cs.utah.edu/~jubi/ +[Dan Gohman]: https://github.com/sunfishcode +[John Regehr]: https://www.cs.utah.edu/~regehr/ +[Nuno Lopes]: http://web.ist.utl.pt/nuno.lopes/ diff --git a/cranelift/peepmatic/examples/mul-by-pow2.peepmatic b/cranelift/peepmatic/examples/mul-by-pow2.peepmatic new file mode 100644 index 0000000000..cca78a372f --- /dev/null +++ b/cranelift/peepmatic/examples/mul-by-pow2.peepmatic @@ -0,0 +1,3 @@ +(=> (when (imul $x $C) + (is-power-of-two $C)) + (ishl $x $(log2 $C))) diff --git a/cranelift/peepmatic/examples/preopt.peepmatic b/cranelift/peepmatic/examples/preopt.peepmatic new file mode 100644 index 0000000000..22523f0a01 --- /dev/null +++ b/cranelift/peepmatic/examples/preopt.peepmatic @@ -0,0 +1,193 @@ +;; Apply basic simplifications. +;; +;; This folds constants with arithmetic to form `_imm` instructions, and other +;; minor simplifications. +;; +;; Doesn't apply some simplifications if the native word width (in bytes) is +;; smaller than the controlling type's width of the instruction. This would +;; result in an illegal instruction that would likely be expanded back into an +;; instruction on smaller types with the same initial opcode, creating +;; unnecessary churn. + +;; Binary instructions whose second argument is constant. +(=> (when (iadd $x $C) + (fits-in-native-word $C)) + (iadd_imm $C $x)) +(=> (when (imul $x $C) + (fits-in-native-word $C)) + (imul_imm $C $x)) +(=> (when (sdiv $x $C) + (fits-in-native-word $C)) + (sdiv_imm $C $x)) +(=> (when (udiv $x $C) + (fits-in-native-word $C)) + (udiv_imm $C $x)) +(=> (when (srem $x $C) + (fits-in-native-word $C)) + (srem_imm $C $x)) +(=> (when (urem $x $C) + (fits-in-native-word $C)) + (urem_imm $C $x)) +(=> (when (band $x $C) + (fits-in-native-word $C)) + (band_imm $C $x)) +(=> (when (bor $x $C) + (fits-in-native-word $C)) + (bor_imm $C $x)) +(=> (when (bxor $x $C) + (fits-in-native-word $C)) + (bxor_imm $C $x)) +(=> (when (rotl $x $C) + (fits-in-native-word $C)) + (rotl_imm $C $x)) +(=> (when (rotr $x $C) + (fits-in-native-word $C)) + (rotr_imm $C $x)) +(=> (when (ishl $x $C) + (fits-in-native-word $C)) + (ishl_imm $C $x)) +(=> (when (ushr $x $C) + (fits-in-native-word $C)) + (ushr_imm $C $x)) +(=> (when (sshr $x $C) + (fits-in-native-word $C)) + (sshr_imm $C $x)) +(=> (when (isub $x $C) + (fits-in-native-word $C)) + (iadd_imm $(neg $C) $x)) +(=> (when (ifcmp $x $C) + (fits-in-native-word $C)) + (ifcmp_imm $C $x)) +(=> (when (icmp $cond $x $C) + (fits-in-native-word $C)) + (icmp_imm $cond $C $x)) + +;; Binary instructions whose first operand is constant. +(=> (when (iadd $C $x) + (fits-in-native-word $C)) + (iadd_imm $C $x)) +(=> (when (imul $C $x) + (fits-in-native-word $C)) + (imul_imm $C $x)) +(=> (when (band $C $x) + (fits-in-native-word $C)) + (band_imm $C $x)) +(=> (when (bor $C $x) + (fits-in-native-word $C)) + (bor_imm $C $x)) +(=> (when (bxor $C $x) + (fits-in-native-word $C)) + (bxor_imm $C $x)) +(=> (when (isub $C $x) + (fits-in-native-word $C)) + (irsub_imm $C $x)) + +;; Unary instructions whose operand is constant. +(=> (adjust_sp_down $C) (adjust_sp_down_imm $C)) + +;; Fold `(binop_imm $C1 (binop_imm $C2 $x))` into `(binop_imm $(binop $C2 $C1) $x)`. +(=> (iadd_imm $C1 (iadd_imm $C2 $x)) (iadd_imm $(iadd $C1 $C2) $x)) +(=> (imul_imm $C1 (imul_imm $C2 $x)) (imul_imm $(imul $C1 $C2) $x)) +(=> (bor_imm $C1 (bor_imm $C2 $x)) (bor_imm $(bor $C1 $C2) $x)) +(=> (band_imm $C1 (band_imm $C2 $x)) (band_imm $(band $C1 $C2) $x)) +(=> (bxor_imm $C1 (bxor_imm $C2 $x)) (bxor_imm $(bxor $C1 $C2) $x)) + +;; Remove operations that are no-ops. +(=> (iadd_imm 0 $x) $x) +(=> (imul_imm 1 $x) $x) +(=> (sdiv_imm 1 $x) $x) +(=> (udiv_imm 1 $x) $x) +(=> (bor_imm 0 $x) $x) +(=> (band_imm -1 $x) $x) +(=> (bxor_imm 0 $x) $x) +(=> (rotl_imm 0 $x) $x) +(=> (rotr_imm 0 $x) $x) +(=> (ishl_imm 0 $x) $x) +(=> (ushr_imm 0 $x) $x) +(=> (sshr_imm 0 $x) $x) + +;; Replace with zero. +(=> (imul_imm 0 $x) 0) +(=> (band_imm 0 $x) 0) + +;; Replace with negative 1. +(=> (bor_imm -1 $x) -1) + +;; Transform `[(x << N) >> N]` into a (un)signed-extending move. +;; +;; i16 -> i8 -> i16 +(=> (when (ushr_imm 8 (ishl_imm 8 $x)) + (bit-width $x 16)) + (uextend{i16} (ireduce{i8} $x))) +(=> (when (sshr_imm 8 (ishl_imm 8 $x)) + (bit-width $x 16)) + (sextend{i16} (ireduce{i8} $x))) +;; i32 -> i8 -> i32 +(=> (when (ushr_imm 24 (ishl_imm 24 $x)) + (bit-width $x 32)) + (uextend{i32} (ireduce{i8} $x))) +(=> (when (sshr_imm 24 (ishl_imm 24 $x)) + (bit-width $x 32)) + (sextend{i32} (ireduce{i8} $x))) +;; i32 -> i16 -> i32 +(=> (when (ushr_imm 16 (ishl_imm 16 $x)) + (bit-width $x 32)) + (uextend{i32} (ireduce{i16} $x))) +(=> (when (sshr_imm 16 (ishl_imm 16 $x)) + (bit-width $x 32)) + (sextend{i32} (ireduce{i16} $x))) +;; i64 -> i8 -> i64 +(=> (when (ushr_imm 56 (ishl_imm 56 $x)) + (bit-width $x 64)) + (uextend{i64} (ireduce{i8} $x))) +(=> (when (sshr_imm 56 (ishl_imm 56 $x)) + (bit-width $x 64)) + (sextend{i64} (ireduce{i8} $x))) +;; i64 -> i16 -> i64 +(=> (when (ushr_imm 48 (ishl_imm 48 $x)) + (bit-width $x 64)) + (uextend{i64} (ireduce{i16} $x))) +(=> (when (sshr_imm 48 (ishl_imm 48 $x)) + (bit-width $x 64)) + (sextend{i64} (ireduce{i16} $x))) +;; i64 -> i32 -> i64 +(=> (when (ushr_imm 32 (ishl_imm 32 $x)) + (bit-width $x 64)) + (uextend{i64} (ireduce{i32} $x))) +(=> (when (sshr_imm 32 (ishl_imm 32 $x)) + (bit-width $x 64)) + (sextend{i64} (ireduce{i32} $x))) + +;; Fold away redundant `bint` instructions that accept both integer and boolean +;; arguments. +(=> (select (bint $x) $y $z) (select $x $y $z)) +(=> (brz (bint $x)) (brz $x)) +(=> (brnz (bint $x)) (brnz $x)) +(=> (trapz (bint $x)) (trapz $x)) +(=> (trapnz (bint $x)) (trapnz $x)) + +;; Fold comparisons into branch operations when possible. +;; +;; This matches against operations which compare against zero, then use the +;; result in a `brz` or `brnz` branch. It folds those two operations into a +;; single `brz` or `brnz`. +(=> (brnz (icmp_imm ne 0 $x)) (brnz $x)) +(=> (brz (icmp_imm ne 0 $x)) (brz $x)) +(=> (brnz (icmp_imm eq 0 $x)) (brz $x)) +(=> (brz (icmp_imm eq 0 $x)) (brnz $x)) + +;; Division and remainder by constants. +;; +;; Note that this section is incomplete, and a bunch of related optimizations +;; are still hand-coded in `simple_preopt.rs`. + +;; (Division by one is handled above.) + +;; Remainder by one is zero. +(=> (urem_imm 1 $x) 0) +(=> (srem_imm 1 $x) 0) + +;; Division by a power of two -> shift right. +(=> (when (udiv_imm $C $x) + (is-power-of-two $C)) + (ushr_imm $(log2 $C) $x)) diff --git a/cranelift/peepmatic/examples/redundant-bor.peepmatic b/cranelift/peepmatic/examples/redundant-bor.peepmatic new file mode 100644 index 0000000000..d8d6f4a144 --- /dev/null +++ b/cranelift/peepmatic/examples/redundant-bor.peepmatic @@ -0,0 +1,13 @@ +;; Remove redundant bitwise OR instructions that are no-ops. + +(=> (bor $x (bor $x $y)) + (bor $x $y)) + +(=> (bor $y (bor $x $y)) + (bor $x $y)) + +(=> (bor (bor $x $y) $x) + (bor $x $y)) + +(=> (bor (bor $x $y) $y) + (bor $x $y)) diff --git a/cranelift/peepmatic/examples/redundant-bor.png b/cranelift/peepmatic/examples/redundant-bor.png new file mode 100644 index 0000000000..dab873e6cf Binary files /dev/null and b/cranelift/peepmatic/examples/redundant-bor.png differ diff --git a/cranelift/peepmatic/examples/simple.peepmatic b/cranelift/peepmatic/examples/simple.peepmatic new file mode 100644 index 0000000000..a90e0c1a69 --- /dev/null +++ b/cranelift/peepmatic/examples/simple.peepmatic @@ -0,0 +1,11 @@ +(=> (bor $x (bor $x $y)) + (bor $x $y)) + +(=> (bor $y (bor $x $y)) + (bor $x $y)) + +(=> (bor (bor $x $y) $x) + (bor $x $y)) + +(=> (bor (bor $x $y) $y) + (bor $x $y)) diff --git a/cranelift/peepmatic/src/ast.rs b/cranelift/peepmatic/src/ast.rs new file mode 100644 index 0000000000..daac871553 --- /dev/null +++ b/cranelift/peepmatic/src/ast.rs @@ -0,0 +1,508 @@ +//! Abstract syntax tree type definitions. +//! +//! This file makes fairly heavy use of macros, which are defined in the +//! `peepmatic_macro` crate that lives at `crates/macro`. Notably, the following +//! traits are all derived via `derive(Ast)`: +//! +//! * `Span` -- access the `wast::Span` where an AST node was parsed from. For +//! `struct`s, there must be a `span: wast::Span` field, because the macro +//! always generates an implementation that returns `self.span` for +//! `struct`s. For `enum`s, every variant must have a single, unnamed field +//! which implements the `Span` trait. The macro will generate code to return +//! the span of whatever variant it is. +//! +//! * `ChildNodes` -- get each of the child AST nodes that a given node +//! references. Some fields in an AST type aren't actually considered an AST +//! node (like spans) and these are ignored via the `#[peepmatic(skip_child)]` +//! attribute. Some fields contain multiple AST nodes (like vectors of +//! operands) and these are flattened with `#[peepmatic(flatten)]`. +//! +//! * `From<&'a Self> for DynAstRef<'a>` -- convert a particular AST type into +//! `DynAstRef`, which is an `enum` of all the different kinds of AST nodes. + +use peepmatic_macro::Ast; +use peepmatic_runtime::{ + operator::{Operator, UnquoteOperator}, + r#type::{BitWidth, Type}, +}; +use std::cell::Cell; +use std::hash::{Hash, Hasher}; +use std::marker::PhantomData; +use wast::Id; + +/// A reference to any AST node. +#[derive(Debug, Clone, Copy)] +pub enum DynAstRef<'a> { + /// A reference to an `Optimizations`. + Optimizations(&'a Optimizations<'a>), + + /// A reference to an `Optimization`. + Optimization(&'a Optimization<'a>), + + /// A reference to an `Lhs`. + Lhs(&'a Lhs<'a>), + + /// A reference to an `Rhs`. + Rhs(&'a Rhs<'a>), + + /// A reference to a `Pattern`. + Pattern(&'a Pattern<'a>), + + /// A reference to a `Precondition`. + Precondition(&'a Precondition<'a>), + + /// A reference to a `ConstraintOperand`. + ConstraintOperand(&'a ConstraintOperand<'a>), + + /// A reference to a `ValueLiteral`. + ValueLiteral(&'a ValueLiteral<'a>), + + /// A reference to a `Constant`. + Constant(&'a Constant<'a>), + + /// A reference to a `PatternOperation`. + PatternOperation(&'a Operation<'a, Pattern<'a>>), + + /// A reference to a `Variable`. + Variable(&'a Variable<'a>), + + /// A reference to an `Integer`. + Integer(&'a Integer<'a>), + + /// A reference to a `Boolean`. + Boolean(&'a Boolean<'a>), + + /// A reference to a `ConditionCode`. + ConditionCode(&'a ConditionCode<'a>), + + /// A reference to an `Unquote`. + Unquote(&'a Unquote<'a>), + + /// A reference to an `RhsOperation`. + RhsOperation(&'a Operation<'a, Rhs<'a>>), +} + +impl<'a, 'b> ChildNodes<'a, 'b> for DynAstRef<'a> { + fn child_nodes(&'b self, sink: &mut impl Extend>) { + match self { + Self::Optimizations(x) => x.child_nodes(sink), + Self::Optimization(x) => x.child_nodes(sink), + Self::Lhs(x) => x.child_nodes(sink), + Self::Rhs(x) => x.child_nodes(sink), + Self::Pattern(x) => x.child_nodes(sink), + Self::Precondition(x) => x.child_nodes(sink), + Self::ConstraintOperand(x) => x.child_nodes(sink), + Self::ValueLiteral(x) => x.child_nodes(sink), + Self::Constant(x) => x.child_nodes(sink), + Self::PatternOperation(x) => x.child_nodes(sink), + Self::Variable(x) => x.child_nodes(sink), + Self::Integer(x) => x.child_nodes(sink), + Self::Boolean(x) => x.child_nodes(sink), + Self::ConditionCode(x) => x.child_nodes(sink), + Self::Unquote(x) => x.child_nodes(sink), + Self::RhsOperation(x) => x.child_nodes(sink), + } + } +} + +/// A trait implemented by all AST nodes. +/// +/// All AST nodes can: +/// +/// * Enumerate their children via `ChildNodes`. +/// +/// * Give you the `wast::Span` where they were defined. +/// +/// * Be converted into a `DynAstRef`. +/// +/// This trait is blanked implemented for everything that does those three +/// things, and in practice those three thrings are all implemented by the +/// `derive(Ast)` macro. +pub trait Ast<'a>: 'a + ChildNodes<'a, 'a> + Span +where + DynAstRef<'a>: From<&'a Self>, +{ +} + +impl<'a, T> Ast<'a> for T +where + T: 'a + ?Sized + ChildNodes<'a, 'a> + Span, + DynAstRef<'a>: From<&'a Self>, +{ +} + +/// Enumerate the child AST nodes of a given node. +pub trait ChildNodes<'a, 'b> { + /// Get each of this AST node's children, in order. + fn child_nodes(&'b self, sink: &mut impl Extend>); +} + +/// A trait for getting the span where an AST node was defined. +pub trait Span { + /// Get the span where this AST node was defined. + fn span(&self) -> wast::Span; +} + +/// A set of optimizations. +/// +/// This is the root AST node. +#[derive(Debug, Ast)] +pub struct Optimizations<'a> { + /// Where these `Optimizations` were defined. + #[peepmatic(skip_child)] + pub span: wast::Span, + + /// The optimizations. + #[peepmatic(flatten)] + pub optimizations: Vec>, +} + +/// A complete optimization: a left-hand side to match against and a right-hand +/// side replacement. +#[derive(Debug, Ast)] +pub struct Optimization<'a> { + /// Where this `Optimization` was defined. + #[peepmatic(skip_child)] + pub span: wast::Span, + + /// The left-hand side that matches when this optimization applies. + pub lhs: Lhs<'a>, + + /// The new sequence of instructions to replace an old sequence that matches + /// the left-hand side with. + pub rhs: Rhs<'a>, +} + +/// A left-hand side describes what is required for a particular optimization to +/// apply. +/// +/// A left-hand side has two parts: a structural pattern for describing +/// candidate instruction sequences, and zero or more preconditions that add +/// additional constraints upon instruction sequences matched by the pattern. +#[derive(Debug, Ast)] +pub struct Lhs<'a> { + /// Where this `Lhs` was defined. + #[peepmatic(skip_child)] + pub span: wast::Span, + + /// A pattern that describes sequences of instructions to match. + pub pattern: Pattern<'a>, + + /// Additional constraints that a match must satisfy in addition to + /// structually matching the pattern, e.g. some constant must be a power of + /// two. + #[peepmatic(flatten)] + pub preconditions: Vec>, +} + +/// A structural pattern, potentially with wildcard variables for matching whole +/// subtrees. +#[derive(Debug, Ast)] +pub enum Pattern<'a> { + /// A specific value. These are written as `1234` or `0x1234` or `true` or + /// `false`. + ValueLiteral(ValueLiteral<'a>), + + /// A constant that matches any constant value. This subsumes value + /// patterns. These are upper-case identifiers like `$C`. + Constant(Constant<'a>), + + /// An operation pattern with zero or more operand patterns. These are + /// s-expressions like `(iadd $x $y)`. + Operation(Operation<'a, Pattern<'a>>), + + /// A variable that matches any kind of subexpression. This subsumes all + /// other patterns. These are lower-case identifiers like `$x`. + Variable(Variable<'a>), +} + +/// An integer or boolean value literal. +#[derive(Debug, Ast)] +pub enum ValueLiteral<'a> { + /// An integer value. + Integer(Integer<'a>), + + /// A boolean value: `true` or `false`. + Boolean(Boolean<'a>), + + /// A condition code: `eq`, `ne`, etc... + ConditionCode(ConditionCode<'a>), +} + +/// An integer literal. +#[derive(Debug, PartialEq, Eq, Ast)] +pub struct Integer<'a> { + /// Where this `Integer` was defined. + #[peepmatic(skip_child)] + pub span: wast::Span, + + /// The integer value. + /// + /// Note that although Cranelift allows 128 bits wide values, the widest + /// supported constants as immediates are 64 bits. + #[peepmatic(skip_child)] + pub value: i64, + + /// The bit width of this integer. + /// + /// This is either a fixed bit width, or polymorphic over the width of the + /// optimization. + /// + /// This field is initialized from `None` to `Some` by the type checking + /// pass in `src/verify.rs`. + #[peepmatic(skip_child)] + pub bit_width: Cell>, + + #[allow(missing_docs)] + #[peepmatic(skip_child)] + pub marker: PhantomData<&'a ()>, +} + +impl Hash for Integer<'_> { + fn hash(&self, state: &mut H) + where + H: Hasher, + { + let Integer { + span, + value, + bit_width, + marker: _, + } = self; + span.hash(state); + value.hash(state); + let bit_width = bit_width.get(); + bit_width.hash(state); + } +} + +/// A boolean literal. +#[derive(Debug, PartialEq, Eq, Ast)] +pub struct Boolean<'a> { + /// Where this `Boolean` was defined. + #[peepmatic(skip_child)] + pub span: wast::Span, + + /// The boolean value. + #[peepmatic(skip_child)] + pub value: bool, + + /// The bit width of this boolean. + /// + /// This is either a fixed bit width, or polymorphic over the width of the + /// optimization. + /// + /// This field is initialized from `None` to `Some` by the type checking + /// pass in `src/verify.rs`. + #[peepmatic(skip_child)] + pub bit_width: Cell>, + + #[allow(missing_docs)] + #[peepmatic(skip_child)] + pub marker: PhantomData<&'a ()>, +} + +impl Hash for Boolean<'_> { + fn hash(&self, state: &mut H) + where + H: Hasher, + { + let Boolean { + span, + value, + bit_width, + marker: _, + } = self; + span.hash(state); + value.hash(state); + let bit_width = bit_width.get(); + bit_width.hash(state); + } +} + +/// A condition code. +#[derive(Debug, Ast)] +pub struct ConditionCode<'a> { + /// Where this `ConditionCode` was defined. + #[peepmatic(skip_child)] + pub span: wast::Span, + + /// The actual condition code. + #[peepmatic(skip_child)] + pub cc: peepmatic_runtime::cc::ConditionCode, + + #[allow(missing_docs)] + #[peepmatic(skip_child)] + pub marker: PhantomData<&'a ()>, +} + +/// A symbolic constant. +/// +/// These are identifiers containing uppercase letters: `$C`, `$MY-CONST`, +/// `$CONSTANT1`. +#[derive(Debug, Ast)] +pub struct Constant<'a> { + /// Where this `Constant` was defined. + #[peepmatic(skip_child)] + pub span: wast::Span, + + /// This constant's identifier. + #[peepmatic(skip_child)] + pub id: Id<'a>, +} + +/// A variable that matches any subtree. +/// +/// Duplicate uses of the same variable constrain each occurrence's match to +/// being the same as each other occurrence as well, e.g. `(iadd $x $x)` matches +/// `(iadd 5 5)` but not `(iadd 1 2)`. +#[derive(Debug, Ast)] +pub struct Variable<'a> { + /// Where this `Variable` was defined. + #[peepmatic(skip_child)] + pub span: wast::Span, + + /// This variable's identifier. + #[peepmatic(skip_child)] + pub id: Id<'a>, +} + +/// An operation with an operator, and operands of type `T`. +#[derive(Debug, Ast)] +#[peepmatic(no_into_dyn_node)] +pub struct Operation<'a, T> +where + T: 'a + Ast<'a>, + DynAstRef<'a>: From<&'a T>, +{ + /// The span where this operation was written. + #[peepmatic(skip_child)] + pub span: wast::Span, + + /// The operator for this operation, e.g. `imul` or `iadd`. + #[peepmatic(skip_child)] + pub operator: Operator, + + /// An optional ascribed or inferred type for the operator. + #[peepmatic(skip_child)] + pub r#type: Cell>, + + /// This operation's operands. + /// + /// When `Operation` is used in a pattern, these are the sub-patterns for + /// the operands. When `Operation is used in a right-hand side replacement, + /// these are the sub-replacements for the operands. + #[peepmatic(flatten)] + pub operands: Vec, + + #[allow(missing_docs)] + #[peepmatic(skip_child)] + pub marker: PhantomData<&'a ()>, +} + +impl<'a> From<&'a Operation<'a, Pattern<'a>>> for DynAstRef<'a> { + #[inline] + fn from(o: &'a Operation<'a, Pattern<'a>>) -> DynAstRef<'a> { + DynAstRef::PatternOperation(o) + } +} + +impl<'a> From<&'a Operation<'a, Rhs<'a>>> for DynAstRef<'a> { + #[inline] + fn from(o: &'a Operation<'a, Rhs<'a>>) -> DynAstRef<'a> { + DynAstRef::RhsOperation(o) + } +} + +/// A precondition adds additional constraints to a pattern, such as "$C must be +/// a power of two". +#[derive(Debug, Ast)] +pub struct Precondition<'a> { + /// Where this `Precondition` was defined. + #[peepmatic(skip_child)] + pub span: wast::Span, + + /// The constraint operator. + #[peepmatic(skip_child)] + pub constraint: Constraint, + + /// The operands of the constraint. + #[peepmatic(flatten)] + pub operands: Vec>, +} + +/// Contraint operators. +#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] +pub enum Constraint { + /// Is the operand a power of two? + IsPowerOfTwo, + + /// Check the bit width of a value. + BitWidth, + + /// Does the argument fit within our target architecture's native word size? + FitsInNativeWord, +} + +/// An operand of a precondition's constraint. +#[derive(Debug, Ast)] +pub enum ConstraintOperand<'a> { + /// A value literal operand. + ValueLiteral(ValueLiteral<'a>), + + /// A constant operand. + Constant(Constant<'a>), + + /// A variable operand. + Variable(Variable<'a>), +} + +/// The right-hand side of an optimization that contains the instructions to +/// replace any matched left-hand side with. +#[derive(Debug, Ast)] +pub enum Rhs<'a> { + /// A value literal right-hand side. + ValueLiteral(ValueLiteral<'a>), + + /// A constant right-hand side (the constant must have been matched and + /// bound in the left-hand side's pattern). + Constant(Constant<'a>), + + /// A variable right-hand side (the variable must have been matched and + /// bound in the left-hand side's pattern). + Variable(Variable<'a>), + + /// An unquote expression that is evaluated while replacing the left-hand + /// side with the right-hand side. The result of the evaluation is used in + /// the replacement. + Unquote(Unquote<'a>), + + /// A compound right-hand side consisting of an operation and subsequent + /// right-hand side operands. + Operation(Operation<'a, Rhs<'a>>), +} + +/// An unquote operation. +/// +/// Rather than replaciong a left-hand side, these are evaluated and then the +/// result of the evaluation replaces the left-hand side. This allows for +/// compile-time computation while replacing a matched left-hand side with a +/// right-hand side. +/// +/// For example, given the unqouted right-hand side `$(log2 $C)`, we replace any +/// instructions that match its left-hand side with the compile-time result of +/// `log2($C)` (the left-hand side must match and bind the constant `$C`). +#[derive(Debug, Ast)] +pub struct Unquote<'a> { + /// Where this `Unquote` was defined. + #[peepmatic(skip_child)] + pub span: wast::Span, + + /// The operator for this unquote operation. + #[peepmatic(skip_child)] + pub operator: UnquoteOperator, + + /// The operands for this unquote operation. + #[peepmatic(flatten)] + pub operands: Vec>, +} diff --git a/cranelift/peepmatic/src/automatize.rs b/cranelift/peepmatic/src/automatize.rs new file mode 100644 index 0000000000..3310bef118 --- /dev/null +++ b/cranelift/peepmatic/src/automatize.rs @@ -0,0 +1,31 @@ +//! Compile a set of linear optimizations into an automaton. + +use peepmatic_automata::{Automaton, Builder}; +use peepmatic_runtime::linear; + +/// Construct an automaton from a set of linear optimizations. +pub fn automatize( + opts: &linear::Optimizations, +) -> Automaton, linear::MatchOp, Vec> { + debug_assert!(crate::linear_passes::is_sorted_lexicographically(opts)); + + let mut builder = Builder::, linear::MatchOp, Vec>::new(); + + for opt in &opts.optimizations { + let mut insertion = builder.insert(); + for inc in &opt.increments { + // Ensure that this state's associated data is this increment's + // match operation. + if let Some(op) = insertion.get_state_data() { + assert_eq!(*op, inc.operation); + } else { + insertion.set_state_data(inc.operation); + } + + insertion.next(inc.expected, inc.actions.clone()); + } + insertion.finish(); + } + + builder.finish() +} diff --git a/cranelift/peepmatic/src/dot_fmt.rs b/cranelift/peepmatic/src/dot_fmt.rs new file mode 100644 index 0000000000..a2c75de02c --- /dev/null +++ b/cranelift/peepmatic/src/dot_fmt.rs @@ -0,0 +1,142 @@ +//! Formatting a peephole optimizer's automata for GraphViz Dot. +//! +//! See also `crates/automata/src/dot.rs`. + +use peepmatic_automata::dot::DotFmt; +use peepmatic_runtime::{ + cc::ConditionCode, + integer_interner::{IntegerId, IntegerInterner}, + linear, + operator::Operator, + paths::{PathId, PathInterner}, +}; +use std::convert::TryFrom; +use std::io::{self, Write}; + +#[derive(Debug)] +pub(crate) struct PeepholeDotFmt<'a>(pub(crate) &'a PathInterner, pub(crate) &'a IntegerInterner); + +impl DotFmt, linear::MatchOp, Vec> for PeepholeDotFmt<'_> { + fn fmt_transition( + &self, + w: &mut impl Write, + from: Option<&linear::MatchOp>, + input: &Option, + _to: Option<&linear::MatchOp>, + ) -> io::Result<()> { + let from = from.expect("we should have match op for every state"); + if let Some(x) = input { + match from { + linear::MatchOp::Opcode { .. } => { + let opcode = + Operator::try_from(*x).expect("we shouldn't generate non-opcode edges"); + write!(w, "{}", opcode) + } + linear::MatchOp::ConditionCode { .. } => { + let cc = + ConditionCode::try_from(*x).expect("we shouldn't generate non-CC edges"); + write!(w, "{}", cc) + } + linear::MatchOp::IntegerValue { .. } => { + let x = self.1.lookup(IntegerId(*x)); + write!(w, "{}", x) + } + _ => write!(w, "{}", x), + } + } else { + write!(w, "(else)") + } + } + + fn fmt_state(&self, w: &mut impl Write, op: &linear::MatchOp) -> io::Result<()> { + use linear::MatchOp::*; + + write!(w, r#""#)?; + + let p = p(self.0); + match op { + Opcode { path } => write!(w, "opcode @ {}", p(path))?, + IsConst { path } => write!(w, "is-const? @ {}", p(path))?, + IsPowerOfTwo { path } => write!(w, "is-power-of-two? @ {}", p(path))?, + BitWidth { path } => write!(w, "bit-width @ {}", p(path))?, + FitsInNativeWord { path } => write!(w, "fits-in-native-word @ {}", p(path))?, + Eq { path_a, path_b } => write!(w, "{} == {}", p(path_a), p(path_b))?, + IntegerValue { path } => write!(w, "integer-value @ {}", p(path))?, + BooleanValue { path } => write!(w, "boolean-value @ {}", p(path))?, + ConditionCode { path } => write!(w, "condition-code @ {}", p(path))?, + Nop => write!(w, "nop")?, + } + + writeln!(w, "") + } + + fn fmt_output(&self, w: &mut impl Write, actions: &Vec) -> io::Result<()> { + use linear::Action::*; + + if actions.is_empty() { + return writeln!(w, "(no output)"); + } + + write!(w, r#""#)?; + + let p = p(self.0); + + for a in actions { + match a { + GetLhs { path } => write!(w, "get-lhs @ {}
", p(path))?, + UnaryUnquote { operator, operand } => { + write!(w, "eval {} $rhs{}
", operator, operand.0)? + } + BinaryUnquote { operator, operands } => write!( + w, + "eval {} $rhs{}, $rhs{}
", + operator, operands[0].0, operands[1].0, + )?, + MakeIntegerConst { + value, + bit_width: _, + } => write!(w, "make {}
", self.1.lookup(*value))?, + MakeBooleanConst { + value, + bit_width: _, + } => write!(w, "make {}
", value)?, + MakeConditionCode { cc } => write!(w, "{}
", cc)?, + MakeUnaryInst { + operand, + operator, + r#type: _, + } => write!(w, "make {} $rhs{}
", operator, operand.0,)?, + MakeBinaryInst { + operator, + operands, + r#type: _, + } => write!( + w, + "make {} $rhs{}, $rhs{}
", + operator, operands[0].0, operands[1].0, + )?, + MakeTernaryInst { + operator, + operands, + r#type: _, + } => write!( + w, + "make {} $rhs{}, $rhs{}, $rhs{}
", + operator, operands[0].0, operands[1].0, operands[2].0, + )?, + } + } + + writeln!(w, "
") + } +} + +fn p<'a>(paths: &'a PathInterner) -> impl Fn(&PathId) -> String + 'a { + move |path: &PathId| { + let mut s = vec![]; + for b in paths.lookup(*path).0 { + s.push(b.to_string()); + } + s.join(".") + } +} diff --git a/cranelift/peepmatic/src/lib.rs b/cranelift/peepmatic/src/lib.rs new file mode 100755 index 0000000000..0cf4147db8 --- /dev/null +++ b/cranelift/peepmatic/src/lib.rs @@ -0,0 +1,165 @@ +/*! + +`peepmatic` is a DSL and compiler for generating peephole optimizers. + +The user writes a set of optimizations in the DSL, and then `peepmatic` compiles +the set of optimizations into an efficient peephole optimizer. + + */ + +#![deny(missing_docs)] +#![deny(missing_debug_implementations)] + +mod ast; +mod automatize; +mod dot_fmt; +mod linear_passes; +mod linearize; +mod parser; +mod traversals; +mod verify; +pub use self::{ + ast::*, automatize::*, linear_passes::*, linearize::*, parser::*, traversals::*, verify::*, +}; + +use peepmatic_runtime::PeepholeOptimizations; +use std::fs; +use std::path::Path; + +/// Compile the given DSL file into a compact peephole optimizations automaton! +/// +/// ## Example +/// +/// ```no_run +/// # fn main() -> anyhow::Result<()> { +/// use std::path::Path; +/// +/// let peep_opts = peepmatic::compile_file(Path::new( +/// "path/to/optimizations.peepmatic" +/// ))?; +/// +/// // Use the peephole optimizations or serialize them into bytes here... +/// # Ok(()) +/// # } +/// ``` +/// +/// ## Visualizing the Peephole Optimizer's Automaton +/// +/// To visualize (or debug) the peephole optimizer's automaton, set the +/// `PEEPMATIC_DOT` environment variable to a file path. A [GraphViz +/// Dot]((https://graphviz.gitlab.io/_pages/pdf/dotguide.pdf)) file showing the +/// peephole optimizer's automaton will be written to that file path. +pub fn compile_file(filename: &Path) -> anyhow::Result { + let source = fs::read_to_string(filename)?; + compile_str(&source, filename) +} + +/// Compile the given DSL source text down into a compact peephole optimizations +/// automaton. +/// +/// This is like [compile_file][crate::compile_file] but you bring your own file +/// I/O. +/// +/// The `filename` parameter is used to provide better error messages. +/// +/// ## Example +/// +/// ```no_run +/// # fn main() -> anyhow::Result<()> { +/// use std::path::Path; +/// +/// let peep_opts = peepmatic::compile_str( +/// " +/// (=> (iadd $x 0) $x) +/// (=> (imul $x 0) 0) +/// (=> (imul $x 1) $x) +/// ", +/// Path::new("my-optimizations"), +/// )?; +/// +/// // Use the peephole optimizations or serialize them into bytes here... +/// # Ok(()) +/// # } +/// ``` +/// +/// ## Visualizing the Peephole Optimizer's Automaton +/// +/// To visualize (or debug) the peephole optimizer's automaton, set the +/// `PEEPMATIC_DOT` environment variable to a file path. A [GraphViz +/// Dot]((https://graphviz.gitlab.io/_pages/pdf/dotguide.pdf)) file showing the +/// peephole optimizer's automaton will be written to that file path. +pub fn compile_str(source: &str, filename: &Path) -> anyhow::Result { + let buf = wast::parser::ParseBuffer::new(source).map_err(|mut e| { + e.set_path(filename); + e.set_text(source); + e + })?; + + let opts = wast::parser::parse::(&buf).map_err(|mut e| { + e.set_path(filename); + e.set_text(source); + e + })?; + + verify(&opts).map_err(|mut e| { + e.set_path(filename); + e.set_text(source); + e + })?; + + let mut opts = crate::linearize(&opts); + sort_least_to_most_general(&mut opts); + remove_unnecessary_nops(&mut opts); + match_in_same_order(&mut opts); + sort_lexicographically(&mut opts); + + let automata = automatize(&opts); + let paths = opts.paths; + let integers = opts.integers; + + if let Ok(path) = std::env::var("PEEPMATIC_DOT") { + let f = dot_fmt::PeepholeDotFmt(&paths, &integers); + if let Err(e) = automata.write_dot_file(&f, &path) { + panic!( + "failed to write GraphViz Dot file to PEEPMATIC_DOT={}; error: {}", + path, e + ); + } + } + + Ok(PeepholeOptimizations { + paths, + integers, + automata, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn assert_compiles(path: &str) { + match compile_file(Path::new(path)) { + Ok(_) => return, + Err(e) => { + eprintln!("error: {}", e); + panic!("error: {}", e); + } + } + } + + #[test] + fn compile_redundant_bor() { + assert_compiles("examples/redundant-bor.peepmatic"); + } + + #[test] + fn mul_by_pow2() { + assert_compiles("examples/mul-by-pow2.peepmatic"); + } + + #[test] + fn compile_preopt() { + assert_compiles("examples/preopt.peepmatic"); + } +} diff --git a/cranelift/peepmatic/src/linear_passes.rs b/cranelift/peepmatic/src/linear_passes.rs new file mode 100644 index 0000000000..cac1ca3d19 --- /dev/null +++ b/cranelift/peepmatic/src/linear_passes.rs @@ -0,0 +1,444 @@ +//! Passes over the linear IR. + +use peepmatic_runtime::{ + linear, + paths::{PathId, PathInterner}, +}; +use std::cmp::Ordering; + +/// Sort a set of optimizations from least to most general. +/// +/// This helps us ensure that we always match the least-general (aka +/// most-specific) optimization that we can for a particular instruction +/// sequence. +/// +/// For example, if we have both of these optimizations: +/// +/// ```lisp +/// (=> (imul $C $x) +/// (imul_imm $C $x)) +/// +/// (=> (when (imul $C $x)) +/// (is-power-of-two $C)) +/// (ishl $x $C)) +/// ``` +/// +/// and we are matching `(imul 4 (..))`, then we want to apply the second +/// optimization, because it is more specific than the first. +pub fn sort_least_to_most_general(opts: &mut linear::Optimizations) { + let linear::Optimizations { + ref mut optimizations, + ref paths, + .. + } = opts; + + // NB: we *cannot* use an unstable sort here, because we want deterministic + // compilation of optimizations to automata. + optimizations.sort_by(|a, b| compare_optimization_generality(paths, a, b)); + debug_assert!(is_sorted_by_generality(opts)); +} + +/// Sort the linear optimizations lexicographically. +/// +/// This sort order is required for automata construction. +pub fn sort_lexicographically(opts: &mut linear::Optimizations) { + let linear::Optimizations { + ref mut optimizations, + ref paths, + .. + } = opts; + + // NB: we *cannot* use an unstable sort here, same as above. + optimizations + .sort_by(|a, b| compare_optimizations(paths, a, b, |a_len, b_len| a_len.cmp(&b_len))); +} + +fn compare_optimizations( + paths: &PathInterner, + a: &linear::Optimization, + b: &linear::Optimization, + compare_lengths: impl Fn(usize, usize) -> Ordering, +) -> Ordering { + for (a, b) in a.increments.iter().zip(b.increments.iter()) { + let c = compare_match_op_generality(paths, a.operation, b.operation); + if c != Ordering::Equal { + return c; + } + + let c = a.expected.cmp(&b.expected).reverse(); + if c != Ordering::Equal { + return c; + } + } + + compare_lengths(a.increments.len(), b.increments.len()) +} + +fn compare_optimization_generality( + paths: &PathInterner, + a: &linear::Optimization, + b: &linear::Optimization, +) -> Ordering { + compare_optimizations(paths, a, b, |a_len, b_len| { + // If they shared equivalent prefixes, then compare lengths and invert the + // result because longer patterns are less general than shorter patterns. + a_len.cmp(&b_len).reverse() + }) +} + +fn compare_match_op_generality( + paths: &PathInterner, + a: linear::MatchOp, + b: linear::MatchOp, +) -> Ordering { + use linear::MatchOp::*; + match (a, b) { + (Opcode { path: a }, Opcode { path: b }) => compare_paths(paths, a, b), + (Opcode { .. }, _) => Ordering::Less, + (_, Opcode { .. }) => Ordering::Greater, + + (IntegerValue { path: a }, IntegerValue { path: b }) => compare_paths(paths, a, b), + (IntegerValue { .. }, _) => Ordering::Less, + (_, IntegerValue { .. }) => Ordering::Greater, + + (BooleanValue { path: a }, BooleanValue { path: b }) => compare_paths(paths, a, b), + (BooleanValue { .. }, _) => Ordering::Less, + (_, BooleanValue { .. }) => Ordering::Greater, + + (ConditionCode { path: a }, ConditionCode { path: b }) => compare_paths(paths, a, b), + (ConditionCode { .. }, _) => Ordering::Less, + (_, ConditionCode { .. }) => Ordering::Greater, + + (IsConst { path: a }, IsConst { path: b }) => compare_paths(paths, a, b), + (IsConst { .. }, _) => Ordering::Less, + (_, IsConst { .. }) => Ordering::Greater, + + ( + Eq { + path_a: pa1, + path_b: pb1, + }, + Eq { + path_a: pa2, + path_b: pb2, + }, + ) => compare_paths(paths, pa1, pa2).then(compare_paths(paths, pb1, pb2)), + (Eq { .. }, _) => Ordering::Less, + (_, Eq { .. }) => Ordering::Greater, + + (IsPowerOfTwo { path: a }, IsPowerOfTwo { path: b }) => compare_paths(paths, a, b), + (IsPowerOfTwo { .. }, _) => Ordering::Less, + (_, IsPowerOfTwo { .. }) => Ordering::Greater, + + (BitWidth { path: a }, BitWidth { path: b }) => compare_paths(paths, a, b), + (BitWidth { .. }, _) => Ordering::Less, + (_, BitWidth { .. }) => Ordering::Greater, + + (FitsInNativeWord { path: a }, FitsInNativeWord { path: b }) => compare_paths(paths, a, b), + (FitsInNativeWord { .. }, _) => Ordering::Less, + (_, FitsInNativeWord { .. }) => Ordering::Greater, + + (Nop, Nop) => Ordering::Equal, + } +} + +fn compare_paths(paths: &PathInterner, a: PathId, b: PathId) -> Ordering { + if a == b { + Ordering::Equal + } else { + let a = paths.lookup(a); + let b = paths.lookup(b); + a.0.cmp(&b.0) + } +} + +/// Are the given optimizations sorted from least to most general? +pub(crate) fn is_sorted_by_generality(opts: &linear::Optimizations) -> bool { + opts.optimizations + .windows(2) + .all(|w| compare_optimization_generality(&opts.paths, &w[0], &w[1]) <= Ordering::Equal) +} + +/// Are the given optimizations sorted lexicographically? +pub(crate) fn is_sorted_lexicographically(opts: &linear::Optimizations) -> bool { + opts.optimizations.windows(2).all(|w| { + compare_optimizations(&opts.paths, &w[0], &w[1], |a_len, b_len| a_len.cmp(&b_len)) + <= Ordering::Equal + }) +} + +/// Ensure that we emit match operations in a consistent order. +/// +/// There are many linear optimizations, each of which have their own sequence +/// of match operations that need to be tested. But when interpreting the +/// automata against some instructions, we only perform a single sequence of +/// match operations, and at any given moment, we only want one match operation +/// to interpret next. This means that two optimizations that are next to each +/// other in the sorting must have their shared prefixes diverge on an +/// **expected result edge**, not on which match operation to preform next. And +/// if they have zero shared prefix, then we need to create one, that +/// immediately divereges on the expected result. +/// +/// For example, consider these two patterns that don't have any shared prefix: +/// +/// ```lisp +/// (=> (iadd $x $y) ...) +/// (=> $C ...) +/// ``` +/// +/// These produce the following linear match operations and expected results: +/// +/// ```text +/// opcode @ 0 --iadd--> +/// is-const? @ 0 --true--> +/// ``` +/// +/// In order to ensure that we only have one match operation to interpret at any +/// given time when evaluating the automata, this pass transforms the second +/// optimization so that it shares a prefix match operation, but diverges on the +/// expected result: +/// +/// ```text +/// opcode @ 0 --iadd--> +/// opcode @ 0 --(else)--> is-const? @ 0 --true--> +/// ``` +pub fn match_in_same_order(opts: &mut linear::Optimizations) { + assert!(!opts.optimizations.is_empty()); + + let mut prefix = vec![]; + + for opt in &mut opts.optimizations { + assert!(!opt.increments.is_empty()); + + let mut old_increments = opt.increments.iter().peekable(); + let mut new_increments = vec![]; + + for (last_op, last_expected) in &prefix { + match old_increments.peek() { + None => { + break; + } + Some(inc) if *last_op == inc.operation => { + let inc = old_increments.next().unwrap(); + new_increments.push(inc.clone()); + if inc.expected != *last_expected { + break; + } + } + Some(_) => { + new_increments.push(linear::Increment { + operation: *last_op, + expected: None, + actions: vec![], + }); + if last_expected.is_some() { + break; + } + } + } + } + + new_increments.extend(old_increments.cloned()); + assert!(new_increments.len() >= opt.increments.len()); + opt.increments = new_increments; + + prefix.clear(); + prefix.extend( + opt.increments + .iter() + .map(|inc| (inc.operation, inc.expected)), + ); + } + + // Should still be sorted after this pass. + debug_assert!(is_sorted_by_generality(&opts)); +} + +/// 99.99% of nops are unnecessary; remove them. +/// +/// They're only needed for when a LHS pattern is just a variable, and that's +/// it. However, it is easier to have basically unused nop matching operations +/// for the DSL's edge-cases than it is to try and statically eliminate their +/// existence completely. So we just emit nop match operations for all variable +/// patterns, and then in this post-processing pass, we fuse them and their +/// actions with their preceding increment. +pub fn remove_unnecessary_nops(opts: &mut linear::Optimizations) { + for opt in &mut opts.optimizations { + if opt.increments.len() < 2 { + debug_assert!(!opt.increments.is_empty()); + continue; + } + + for i in (1..opt.increments.len()).rev() { + if let linear::MatchOp::Nop = opt.increments[i].operation { + let nop = opt.increments.remove(i); + opt.increments[i - 1].actions.extend(nop.actions); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ast::*; + use linear::MatchOp::*; + use peepmatic_runtime::{operator::Operator, paths::*}; + + macro_rules! sorts_to { + ($test_name:ident, $source:expr, $make_expected:expr) => { + #[test] + fn $test_name() { + let buf = wast::parser::ParseBuffer::new($source).expect("should lex OK"); + + let opts = match wast::parser::parse::(&buf) { + Ok(opts) => opts, + Err(mut e) => { + e.set_path(std::path::Path::new(stringify!($test_name))); + e.set_text($source); + eprintln!("{}", e); + panic!("should parse OK") + } + }; + + if let Err(mut e) = crate::verify(&opts) { + e.set_path(std::path::Path::new(stringify!($test_name))); + e.set_text($source); + eprintln!("{}", e); + panic!("should verify OK") + } + + let mut opts = crate::linearize(&opts); + sort_least_to_most_general(&mut opts); + + let linear::Optimizations { + mut paths, + mut integers, + optimizations, + } = opts; + + let actual: Vec> = optimizations + .iter() + .map(|o| { + o.increments + .iter() + .map(|i| (i.operation, i.expected)) + .collect() + }) + .collect(); + + let mut p = |p: &[u8]| paths.intern(Path::new(&p)); + let mut i = |i: u64| Some(integers.intern(i).into()); + let expected = $make_expected(&mut p, &mut i); + + assert_eq!(expected, actual); + } + }; + } + + sorts_to!( + test_sort_least_to_most_general, + " +(=> $x 0) +(=> (iadd $x $y) 0) +(=> (iadd $x $x) 0) +(=> (iadd $x $C) 0) +(=> (when (iadd $x $C) (is-power-of-two $C)) 0) +(=> (when (iadd $x $C) (bit-width $x 32)) 0) +(=> (iadd $x 42) 0) +(=> (iadd $x (iadd $y $z)) 0) +", + |p: &mut dyn FnMut(&[u8]) -> PathId, i: &mut dyn FnMut(u64) -> Option| vec![ + vec![ + (Opcode { path: p(&[0]) }, Some(Operator::Iadd as _)), + (Nop, None), + (Opcode { path: p(&[0, 1]) }, Some(Operator::Iadd as _)), + (Nop, None), + (Nop, None), + ], + vec![ + (Opcode { path: p(&[0]) }, Some(Operator::Iadd as _)), + (Nop, None), + (IntegerValue { path: p(&[0, 1]) }, i(42)) + ], + vec![ + (Opcode { path: p(&[0]) }, Some(Operator::Iadd as _)), + (Nop, None), + (IsConst { path: p(&[0, 1]) }, Some(1)), + (IsPowerOfTwo { path: p(&[0, 1]) }, Some(1)) + ], + vec![ + (Opcode { path: p(&[0]) }, Some(Operator::Iadd as _)), + (Nop, None), + (IsConst { path: p(&[0, 1]) }, Some(1)), + (BitWidth { path: p(&[0, 0]) }, Some(32)) + ], + vec![ + (Opcode { path: p(&[0]) }, Some(Operator::Iadd as _)), + (Nop, None), + (IsConst { path: p(&[0, 1]) }, Some(1)) + ], + vec![ + (Opcode { path: p(&[0]) }, Some(Operator::Iadd as _)), + (Nop, None), + ( + Eq { + path_a: p(&[0, 1]), + path_b: p(&[0, 0]), + }, + Some(1) + ) + ], + vec![ + (Opcode { path: p(&[0]) }, Some(Operator::Iadd as _)), + (Nop, None), + (Nop, None), + ], + vec![(Nop, None)] + ] + ); + + sorts_to!( + expected_edges_are_sorted, + " +(=> (iadd 0 $x) $x) +(=> (iadd $x 0) $x) +(=> (imul 1 $x) $x) +(=> (imul $x 1) $x) +(=> (imul 2 $x) (ishl $x 1)) +(=> (imul $x 2) (ishl $x 1)) +", + |p: &mut dyn FnMut(&[u8]) -> PathId, i: &mut dyn FnMut(u64) -> Option| vec![ + vec![ + (Opcode { path: p(&[0]) }, Some(Operator::Imul as _)), + (IntegerValue { path: p(&[0, 0]) }, i(2)), + (Nop, None) + ], + vec![ + (Opcode { path: p(&[0]) }, Some(Operator::Imul as _)), + (IntegerValue { path: p(&[0, 0]) }, i(1)), + (Nop, None) + ], + vec![ + (Opcode { path: p(&[0]) }, Some(Operator::Imul as _)), + (Nop, None), + (IntegerValue { path: p(&[0, 1]) }, i(2)) + ], + vec![ + (Opcode { path: p(&[0]) }, Some(Operator::Imul as _)), + (Nop, None), + (IntegerValue { path: p(&[0, 1]) }, i(1)) + ], + vec![ + (Opcode { path: p(&[0]) }, Some(Operator::Iadd as _)), + (IntegerValue { path: p(&[0, 0]) }, i(0)), + (Nop, None) + ], + vec![ + (Opcode { path: p(&[0]) }, Some(Operator::Iadd as _)), + (Nop, None), + (IntegerValue { path: p(&[0, 1]) }, i(0)) + ] + ] + ); +} diff --git a/cranelift/peepmatic/src/linearize.rs b/cranelift/peepmatic/src/linearize.rs new file mode 100644 index 0000000000..75205c93b9 --- /dev/null +++ b/cranelift/peepmatic/src/linearize.rs @@ -0,0 +1,831 @@ +//! Convert an AST into its linear equivalent. +//! +//! Convert each optimization's left-hand side into a linear series of match +//! operations. This makes it easy to create an automaton, because automatas +//! typically deal with a linear sequence of inputs. The optimization's +//! right-hand side is built incrementally inside actions that are taken on +//! transitions between match operations. +//! +//! See `crates/runtime/src/linear.rs` for the linear datatype definitions. +//! +//! ## Example +//! +//! As an example, if we linearize this optimization: +//! +//! ```lisp +//! (=> (when (imul $x $C) +//! (is-power-of-two $C)) +//! (ishl $x $(log2 C))) +//! ``` +//! +//! Then we should get the following linear chain of "increments": +//! +//! ```ignore +//! [ +//! // ( Match Operation, Expected Value, Actions ) +//! ( Opcode@0, imul, [$x = GetLhs@0.0, $C = GetLhs@0.1, ...] ), +//! ( IsConst(C), true, [] ), +//! ( IsPowerOfTwo(C), true, [] ), +//! ] +//! ``` +//! +//! Each increment will essentially become a state and a transition out of that +//! state in the final automata, along with the actions to perform when taking +//! that transition. The actions record the scope of matches from the left-hand +//! side and also incrementally build the right-hand side's instructions. (Note +//! that we've elided the actions that build up the optimization's right-hand +//! side in this example.) +//! +//! ## General Principles +//! +//! Here are the general principles that linearization should adhere to: +//! +//! * Actions should be pushed as early in the optimization's increment chain as +//! they can be. This means the tail has fewer side effects, and is therefore +//! more likely to be share-able with other optimizations in the automata that +//! we build. +//! +//! * RHS actions cannot reference matches from the LHS until they've been +//! defined. And finally, an RHS operation's operands must be defined before +//! the RHS operation itself. In general, definitions must come before uses! +//! +//! * Shorter increment chains are better! This means fewer tests when matching +//! left-hand sides, and a more-compact, more-cache-friendly automata, and +//! ultimately, a faster automata. +//! +//! * An increment's match operation should be a switch rather than a predicate +//! that returns a boolean. For example, we switch on an instruction's opcode, +//! rather than ask whether this operation is an `imul`. This allows for more +//! prefix sharing in the automata, which (again) makes it more compact and +//! more cache friendly. +//! +//! ## Implementation Overview +//! +//! We emit match operations for a left-hand side's pattern structure, followed +//! by match operations for its preconditions on that structure. This ensures +//! that anything bound in the pattern is defined before it is used in +//! precondition. +//! +//! Within matching the pattern structure, we emit matching operations in a +//! pre-order traversal of the pattern. This ensures that we've already matched +//! an operation before we consider its operands, and therefore we already know +//! the operands exist. See `PatternPreOrder` for details. +//! +//! As we define the match operations for a pattern, we remember the path where +//! each LHS id first occurred. These will later be reused when building the RHS +//! actions. See `LhsIdToPath` for details. +//! +//! After we've generated the match operations and expected result of those +//! match operations, then we generate the right-hand side actions. The +//! right-hand side is built up a post-order traversal, so that operands are +//! defined before they are used. See `RhsPostOrder` and `RhsBuilder` for +//! details. +//! +//! Finally, see `linearize_optimization` for the the main AST optimization into +//! linear optimization translation function. + +use crate::ast::*; +use crate::traversals::Dfs; +use peepmatic_runtime::{ + integer_interner::IntegerInterner, + linear, + paths::{Path, PathId, PathInterner}, +}; +use std::collections::BTreeMap; +use wast::Id; + +/// Translate the given AST optimizations into linear optimizations. +pub fn linearize(opts: &Optimizations) -> linear::Optimizations { + let mut optimizations = vec![]; + let mut paths = PathInterner::new(); + let mut integers = IntegerInterner::new(); + for opt in &opts.optimizations { + let lin_opt = linearize_optimization(&mut paths, &mut integers, opt); + optimizations.push(lin_opt); + } + linear::Optimizations { + optimizations, + paths, + integers, + } +} + +/// Translate an AST optimization into a linear optimization! +fn linearize_optimization( + paths: &mut PathInterner, + integers: &mut IntegerInterner, + opt: &Optimization, +) -> linear::Optimization { + let mut increments: Vec = vec![]; + + let mut lhs_id_to_path = LhsIdToPath::new(); + + // We do a pre-order traversal of the LHS because we don't know whether a + // child actually exists to match on until we've matched its parent, and we + // don't want to emit matching operations on things that might not exist! + let mut patterns = PatternPreOrder::new(&opt.lhs.pattern); + while let Some((path, pattern)) = patterns.next(paths) { + // Create the matching parts of an `Increment` for this part of the + // pattern, without any actions yet. + let (operation, expected) = pattern.to_linear_match_op(integers, &lhs_id_to_path, path); + increments.push(linear::Increment { + operation, + expected, + actions: vec![], + }); + + lhs_id_to_path.remember_path_to_pattern_ids(pattern, path); + + // Some operations require type ascriptions for us to infer the correct + // bit width of their results: `ireduce`, `sextend`, `uextend`, etc. + // When there is such a type ascription in the pattern, insert another + // increment that checks the instruction-being-matched's bit width. + if let Pattern::Operation(Operation { r#type, .. }) = pattern { + if let Some(w) = r#type.get().and_then(|ty| ty.bit_width.fixed_width()) { + increments.push(linear::Increment { + operation: linear::MatchOp::BitWidth { path }, + expected: Some(w as u32), + actions: vec![], + }); + } + } + } + + // Now that we've added all the increments for the LHS pattern, add the + // increments for its preconditions. + for pre in &opt.lhs.preconditions { + increments.push(pre.to_linear_increment(&lhs_id_to_path)); + } + + assert!(!increments.is_empty()); + + // Finally, generate the RHS-building actions and attach them to the first increment. + let mut rhs_builder = RhsBuilder::new(&opt.rhs); + rhs_builder.add_rhs_build_actions(integers, &lhs_id_to_path, &mut increments[0].actions); + + linear::Optimization { increments } +} + +/// A post-order, depth-first traversal of right-hand sides. +/// +/// Does not maintain any extra state about the traversal, such as where in the +/// tree each yielded node comes from. +struct RhsPostOrder<'a> { + dfs: Dfs<'a>, +} + +impl<'a> RhsPostOrder<'a> { + fn new(rhs: &'a Rhs<'a>) -> Self { + Self { dfs: Dfs::new(rhs) } + } +} + +impl<'a> Iterator for RhsPostOrder<'a> { + type Item = &'a Rhs<'a>; + + fn next(&mut self) -> Option<&'a Rhs<'a>> { + use crate::traversals::TraversalEvent as TE; + loop { + match self.dfs.next()? { + (TE::Exit, DynAstRef::Rhs(rhs)) => return Some(rhs), + _ => continue, + } + } + } +} + +/// A pre-order, depth-first traversal of left-hand side patterns. +/// +/// Keeps track of the path to each pattern, and yields it along side the +/// pattern AST node. +struct PatternPreOrder<'a> { + last_child: Option, + path: Vec, + dfs: Dfs<'a>, +} + +impl<'a> PatternPreOrder<'a> { + fn new(pattern: &'a Pattern<'a>) -> Self { + Self { + last_child: None, + path: vec![], + dfs: Dfs::new(pattern), + } + } + + fn next(&mut self, paths: &mut PathInterner) -> Option<(PathId, &'a Pattern<'a>)> { + use crate::traversals::TraversalEvent as TE; + loop { + match self.dfs.next()? { + (TE::Enter, DynAstRef::Pattern(pattern)) => { + let last_child = self.last_child.take(); + self.path.push(match last_child { + None => 0, + Some(c) => { + assert!( + c < std::u8::MAX, + "operators must have less than or equal u8::MAX arity" + ); + c + 1 + } + }); + let path = paths.intern(Path(&self.path)); + return Some((path, pattern)); + } + (TE::Exit, DynAstRef::Pattern(_)) => { + self.last_child = Some( + self.path + .pop() + .expect("should always have a non-empty path during traversal"), + ); + } + _ => {} + } + } + } +} + +/// A map from left-hand side identifiers to the path in the left-hand side +/// where they first occurred. +struct LhsIdToPath<'a> { + id_to_path: BTreeMap<&'a str, PathId>, +} + +impl<'a> LhsIdToPath<'a> { + /// Construct a new, empty `LhsIdToPath`. + fn new() -> Self { + Self { + id_to_path: Default::default(), + } + } + + /// Have we already seen the given identifier? + fn get_first_occurrence(&self, id: &Id) -> Option { + self.id_to_path.get(id.name()).copied() + } + + /// Get the path within the left-hand side pattern where we first saw the + /// given AST id. + /// + /// ## Panics + /// + /// Panics if the given AST id has not already been canonicalized. + fn unwrap_first_occurrence(&self, id: &Id) -> PathId { + self.id_to_path[id.name()] + } + + /// Remember the path to any LHS ids used in the given pattern. + fn remember_path_to_pattern_ids(&mut self, pattern: &'a Pattern<'a>, path: PathId) { + match pattern { + // If this is the first time we've seen an identifier defined on the + // left-hand side, remember it. + Pattern::Variable(Variable { id, .. }) | Pattern::Constant(Constant { id, .. }) => { + self.id_to_path.entry(id.name()).or_insert(path); + } + _ => {} + } + } +} + +/// An `RhsBuilder` emits the actions for building the right-hand side +/// instructions. +struct RhsBuilder<'a> { + // We do a post order traversal of the RHS because an RHS instruction cannot + // be created until after all of its operands are created. + rhs_post_order: RhsPostOrder<'a>, + + // A map from a right-hand side's span to its `linear::RhsId`. This is used + // by RHS-construction actions to reference operands. In practice the + // `RhsId` is roughly equivalent to its index in the post-order traversal of + // the RHS. + rhs_span_to_id: BTreeMap, +} + +impl<'a> RhsBuilder<'a> { + /// Create a new builder for the given right-hand side. + fn new(rhs: &'a Rhs<'a>) -> Self { + let rhs_post_order = RhsPostOrder::new(rhs); + let rhs_span_to_id = Default::default(); + Self { + rhs_post_order, + rhs_span_to_id, + } + } + + /// Get the `linear::RhsId` for the given right-hand side. + /// + /// ## Panics + /// + /// Panics if we haven't already emitted the action for building this RHS's + /// instruction. + fn get_rhs_id(&self, rhs: &Rhs) -> linear::RhsId { + self.rhs_span_to_id[&rhs.span()] + } + + /// Create actions for building up this right-hand side of an optimization. + /// + /// Because we are walking the right-hand side with a post-order traversal, + /// we know that we already created an instruction's operands that are + /// defined in the right-hand side, before we get to the parent instruction. + fn add_rhs_build_actions( + &mut self, + integers: &mut IntegerInterner, + lhs_id_to_path: &LhsIdToPath, + actions: &mut Vec, + ) { + while let Some(rhs) = self.rhs_post_order.next() { + actions.push(self.rhs_to_linear_action(integers, lhs_id_to_path, rhs)); + let id = linear::RhsId(self.rhs_span_to_id.len() as u32); + self.rhs_span_to_id.insert(rhs.span(), id); + } + } + + fn rhs_to_linear_action( + &self, + integers: &mut IntegerInterner, + lhs_id_to_path: &LhsIdToPath, + rhs: &Rhs, + ) -> linear::Action { + match rhs { + Rhs::ValueLiteral(ValueLiteral::Integer(i)) => linear::Action::MakeIntegerConst { + value: integers.intern(i.value as u64), + bit_width: i + .bit_width + .get() + .expect("should be initialized after type checking"), + }, + Rhs::ValueLiteral(ValueLiteral::Boolean(b)) => linear::Action::MakeBooleanConst { + value: b.value, + bit_width: b + .bit_width + .get() + .expect("should be initialized after type checking"), + }, + Rhs::ValueLiteral(ValueLiteral::ConditionCode(ConditionCode { cc, .. })) => { + linear::Action::MakeConditionCode { cc: *cc } + } + Rhs::Variable(Variable { id, .. }) | Rhs::Constant(Constant { id, .. }) => { + let path = lhs_id_to_path.unwrap_first_occurrence(id); + linear::Action::GetLhs { path } + } + Rhs::Unquote(unq) => match unq.operands.len() { + 1 => linear::Action::UnaryUnquote { + operator: unq.operator, + operand: self.get_rhs_id(&unq.operands[0]), + }, + 2 => linear::Action::BinaryUnquote { + operator: unq.operator, + operands: [ + self.get_rhs_id(&unq.operands[0]), + self.get_rhs_id(&unq.operands[1]), + ], + }, + n => unreachable!("no unquote operators of arity {}", n), + }, + Rhs::Operation(op) => match op.operands.len() { + 1 => linear::Action::MakeUnaryInst { + operator: op.operator, + r#type: op + .r#type + .get() + .expect("should be initialized after type checking"), + operand: self.get_rhs_id(&op.operands[0]), + }, + 2 => linear::Action::MakeBinaryInst { + operator: op.operator, + r#type: op + .r#type + .get() + .expect("should be initialized after type checking"), + operands: [ + self.get_rhs_id(&op.operands[0]), + self.get_rhs_id(&op.operands[1]), + ], + }, + 3 => linear::Action::MakeTernaryInst { + operator: op.operator, + r#type: op + .r#type + .get() + .expect("should be initialized after type checking"), + operands: [ + self.get_rhs_id(&op.operands[0]), + self.get_rhs_id(&op.operands[1]), + self.get_rhs_id(&op.operands[2]), + ], + }, + n => unreachable!("no instructions of arity {}", n), + }, + } + } +} + +impl Precondition<'_> { + /// Convert this precondition into a `linear::Increment`. + fn to_linear_increment(&self, lhs_id_to_path: &LhsIdToPath) -> linear::Increment { + match self.constraint { + Constraint::IsPowerOfTwo => { + let id = match &self.operands[0] { + ConstraintOperand::Constant(Constant { id, .. }) => id, + _ => unreachable!("checked in verification"), + }; + let path = lhs_id_to_path.unwrap_first_occurrence(&id); + linear::Increment { + operation: linear::MatchOp::IsPowerOfTwo { path }, + expected: Some(1), + actions: vec![], + } + } + Constraint::BitWidth => { + let id = match &self.operands[0] { + ConstraintOperand::Constant(Constant { id, .. }) + | ConstraintOperand::Variable(Variable { id, .. }) => id, + _ => unreachable!("checked in verification"), + }; + let path = lhs_id_to_path.unwrap_first_occurrence(&id); + + let width = match &self.operands[1] { + ConstraintOperand::ValueLiteral(ValueLiteral::Integer(Integer { + value, + .. + })) => *value, + _ => unreachable!("checked in verification"), + }; + debug_assert!(width <= 128); + debug_assert!((width as u8).is_power_of_two()); + + linear::Increment { + operation: linear::MatchOp::BitWidth { path }, + expected: Some(width as u32), + actions: vec![], + } + } + Constraint::FitsInNativeWord => { + let id = match &self.operands[0] { + ConstraintOperand::Constant(Constant { id, .. }) + | ConstraintOperand::Variable(Variable { id, .. }) => id, + _ => unreachable!("checked in verification"), + }; + let path = lhs_id_to_path.unwrap_first_occurrence(&id); + linear::Increment { + operation: linear::MatchOp::FitsInNativeWord { path }, + expected: Some(1), + actions: vec![], + } + } + } + } +} + +impl Pattern<'_> { + /// Convert this pattern into its linear match operation and the expected + /// result of that operation. + /// + /// NB: these mappings to expected values need to stay sync'd with the + /// runtime! + fn to_linear_match_op( + &self, + integers: &mut IntegerInterner, + lhs_id_to_path: &LhsIdToPath, + path: PathId, + ) -> (linear::MatchOp, Option) { + match self { + Pattern::ValueLiteral(ValueLiteral::Integer(Integer { value, .. })) => ( + linear::MatchOp::IntegerValue { path }, + Some(integers.intern(*value as u64).into()), + ), + Pattern::ValueLiteral(ValueLiteral::Boolean(Boolean { value, .. })) => { + (linear::MatchOp::BooleanValue { path }, Some(*value as u32)) + } + Pattern::ValueLiteral(ValueLiteral::ConditionCode(ConditionCode { cc, .. })) => { + (linear::MatchOp::ConditionCode { path }, Some(*cc as u32)) + } + Pattern::Constant(Constant { id, .. }) => { + if let Some(path_b) = lhs_id_to_path.get_first_occurrence(id) { + debug_assert!(path != path_b); + ( + linear::MatchOp::Eq { + path_a: path, + path_b, + }, + Some(1), + ) + } else { + (linear::MatchOp::IsConst { path }, Some(1)) + } + } + Pattern::Variable(Variable { id, .. }) => { + if let Some(path_b) = lhs_id_to_path.get_first_occurrence(id) { + debug_assert!(path != path_b); + ( + linear::MatchOp::Eq { + path_a: path, + path_b, + }, + Some(1), + ) + } else { + (linear::MatchOp::Nop, None) + } + } + Pattern::Operation(op) => (linear::MatchOp::Opcode { path }, Some(op.operator as u32)), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use peepmatic_runtime::{ + integer_interner::IntegerId, + linear::{Action::*, MatchOp::*}, + operator::Operator, + r#type::{BitWidth, Kind, Type}, + }; + + macro_rules! linearizes_to { + ($name:ident, $source:expr, $make_expected:expr $(,)* ) => { + #[test] + fn $name() { + let buf = wast::parser::ParseBuffer::new($source).expect("should lex OK"); + + let opts = match wast::parser::parse::(&buf) { + Ok(opts) => opts, + Err(mut e) => { + e.set_path(std::path::Path::new(stringify!($name))); + e.set_text($source); + eprintln!("{}", e); + panic!("should parse OK") + } + }; + + assert_eq!( + opts.optimizations.len(), + 1, + "`linearizes_to!` only supports a single optimization; split the big test into \ + multiple small tests" + ); + + if let Err(mut e) = crate::verify(&opts) { + e.set_path(std::path::Path::new(stringify!($name))); + e.set_text($source); + eprintln!("{}", e); + panic!("should verify OK") + } + + let mut paths = PathInterner::new(); + let mut p = |p: &[u8]| paths.intern(Path::new(&p)); + + let mut integers = IntegerInterner::new(); + let mut i = |i: u64| integers.intern(i); + + #[allow(unused_variables)] + let expected = $make_expected(&mut p, &mut i); + dbg!(&expected); + + let actual = linearize_optimization(&mut paths, &mut integers, &opts.optimizations[0]); + dbg!(&actual); + + assert_eq!(expected, actual); + } + }; + } + + linearizes_to!( + mul_by_pow2_into_shift, + " +(=> (when (imul $x $C) + (is-power-of-two $C)) + (ishl $x $C)) + ", + |p: &mut dyn FnMut(&[u8]) -> PathId, i: &mut dyn FnMut(u64) -> IntegerId| { + linear::Optimization { + increments: vec![ + linear::Increment { + operation: Opcode { path: p(&[0]) }, + expected: Some(Operator::Imul as _), + actions: vec![ + GetLhs { path: p(&[0, 0]) }, + GetLhs { path: p(&[0, 1]) }, + MakeBinaryInst { + operator: Operator::Ishl, + r#type: Type { + kind: Kind::Int, + bit_width: BitWidth::Polymorphic, + }, + operands: [linear::RhsId(0), linear::RhsId(1)], + }, + ], + }, + linear::Increment { + operation: Nop, + expected: None, + actions: vec![], + }, + linear::Increment { + operation: IsConst { path: p(&[0, 1]) }, + expected: Some(1), + actions: vec![], + }, + linear::Increment { + operation: IsPowerOfTwo { path: p(&[0, 1]) }, + expected: Some(1), + actions: vec![], + }, + ], + } + }, + ); + + linearizes_to!( + variable_pattern_id_optimization, + "(=> $x $x)", + |p: &mut dyn FnMut(&[u8]) -> PathId, i: &mut dyn FnMut(u64) -> IntegerId| { + linear::Optimization { + increments: vec![linear::Increment { + operation: Nop, + expected: None, + actions: vec![GetLhs { path: p(&[0]) }], + }], + } + }, + ); + + linearizes_to!( + constant_pattern_id_optimization, + "(=> $C $C)", + |p: &mut dyn FnMut(&[u8]) -> PathId, i: &mut dyn FnMut(u64) -> IntegerId| { + linear::Optimization { + increments: vec![linear::Increment { + operation: IsConst { path: p(&[0]) }, + expected: Some(1), + actions: vec![GetLhs { path: p(&[0]) }], + }], + } + }, + ); + + linearizes_to!( + boolean_literal_id_optimization, + "(=> true true)", + |p: &mut dyn FnMut(&[u8]) -> PathId, i: &mut dyn FnMut(u64) -> IntegerId| { + linear::Optimization { + increments: vec![linear::Increment { + operation: BooleanValue { path: p(&[0]) }, + expected: Some(1), + actions: vec![MakeBooleanConst { + value: true, + bit_width: BitWidth::Polymorphic, + }], + }], + } + }, + ); + + linearizes_to!( + number_literal_id_optimization, + "(=> 5 5)", + |p: &mut dyn FnMut(&[u8]) -> PathId, i: &mut dyn FnMut(u64) -> IntegerId| { + linear::Optimization { + increments: vec![linear::Increment { + operation: IntegerValue { path: p(&[0]) }, + expected: Some(i(5).into()), + actions: vec![MakeIntegerConst { + value: i(5), + bit_width: BitWidth::Polymorphic, + }], + }], + } + }, + ); + + linearizes_to!( + operation_id_optimization, + "(=> (iconst $C) (iconst $C))", + |p: &mut dyn FnMut(&[u8]) -> PathId, i: &mut dyn FnMut(u64) -> IntegerId| { + linear::Optimization { + increments: vec![ + linear::Increment { + operation: Opcode { path: p(&[0]) }, + expected: Some(Operator::Iconst as _), + actions: vec![ + GetLhs { path: p(&[0, 0]) }, + MakeUnaryInst { + operator: Operator::Iconst, + r#type: Type { + kind: Kind::Int, + bit_width: BitWidth::Polymorphic, + }, + operand: linear::RhsId(0), + }, + ], + }, + linear::Increment { + operation: IsConst { path: p(&[0, 0]) }, + expected: Some(1), + actions: vec![], + }, + ], + } + }, + ); + + linearizes_to!( + redundant_bor, + "(=> (bor $x (bor $x $y)) (bor $x $y))", + |p: &mut dyn FnMut(&[u8]) -> PathId, i: &mut dyn FnMut(u64) -> IntegerId| { + linear::Optimization { + increments: vec![ + linear::Increment { + operation: Opcode { path: p(&[0]) }, + expected: Some(Operator::Bor as _), + actions: vec![ + GetLhs { path: p(&[0, 0]) }, + GetLhs { + path: p(&[0, 1, 1]), + }, + MakeBinaryInst { + operator: Operator::Bor, + r#type: Type { + kind: Kind::Int, + bit_width: BitWidth::Polymorphic, + }, + operands: [linear::RhsId(0), linear::RhsId(1)], + }, + ], + }, + linear::Increment { + operation: Nop, + expected: None, + actions: vec![], + }, + linear::Increment { + operation: Opcode { path: p(&[0, 1]) }, + expected: Some(Operator::Bor as _), + actions: vec![], + }, + linear::Increment { + operation: Eq { + path_a: p(&[0, 1, 0]), + path_b: p(&[0, 0]), + }, + expected: Some(1), + actions: vec![], + }, + linear::Increment { + operation: Nop, + expected: None, + actions: vec![], + }, + ], + } + }, + ); + + linearizes_to!( + large_integers, + // u64::MAX + "(=> 18446744073709551615 0)", + |p: &mut dyn FnMut(&[u8]) -> PathId, i: &mut dyn FnMut(u64) -> IntegerId| { + linear::Optimization { + increments: vec![linear::Increment { + operation: IntegerValue { path: p(&[0]) }, + expected: Some(i(std::u64::MAX).into()), + actions: vec![MakeIntegerConst { + value: i(0), + bit_width: BitWidth::Polymorphic, + }], + }], + } + } + ); + + linearizes_to!( + ireduce_with_type_ascription, + "(=> (ireduce{i32} $x) 0)", + |p: &mut dyn FnMut(&[u8]) -> PathId, i: &mut dyn FnMut(u64) -> IntegerId| { + linear::Optimization { + increments: vec![ + linear::Increment { + operation: Opcode { path: p(&[0]) }, + expected: Some(Operator::Ireduce as _), + actions: vec![MakeIntegerConst { + value: i(0), + bit_width: BitWidth::ThirtyTwo, + }], + }, + linear::Increment { + operation: linear::MatchOp::BitWidth { path: p(&[0]) }, + expected: Some(32), + actions: vec![], + }, + linear::Increment { + operation: Nop, + expected: None, + actions: vec![], + }, + ], + } + } + ); +} diff --git a/cranelift/peepmatic/src/parser.rs b/cranelift/peepmatic/src/parser.rs new file mode 100644 index 0000000000..19ad49017c --- /dev/null +++ b/cranelift/peepmatic/src/parser.rs @@ -0,0 +1,932 @@ +/*! + +This module implements parsing the DSL text format. It implements the +`wast::Parse` trait for all of our AST types. + +The grammar for the DSL is given below: + +```ebnf + ::= * + + ::= '(' '=>' ')' + + ::= + | '(' 'when' * ')' + + ::= + | + | > + | + + ::= + | + + ::= 'true' | 'false' + +> ::= '(' [] * ')' + + ::= '(' * ')' + + ::= + | + | + + ::= + | + | + | + | > + + ::= '$' '(' * ')' + + ::= + | +``` + + */ + +use crate::ast::*; +use peepmatic_runtime::r#type::Type; +use std::cell::Cell; +use std::marker::PhantomData; +use wast::{ + parser::{Cursor, Parse, Parser, Peek, Result as ParseResult}, + Id, LParen, +}; + +mod tok { + use wast::{custom_keyword, custom_reserved}; + + custom_keyword!(bit_width = "bit-width"); + custom_reserved!(dollar = "$"); + custom_keyword!(r#false = "false"); + custom_keyword!(fits_in_native_word = "fits-in-native-word"); + custom_keyword!(is_power_of_two = "is-power-of-two"); + custom_reserved!(left_curly = "{"); + custom_keyword!(log2); + custom_keyword!(neg); + custom_reserved!(replace = "=>"); + custom_reserved!(right_curly = "}"); + custom_keyword!(r#true = "true"); + custom_keyword!(when); + + custom_keyword!(eq); + custom_keyword!(ne); + custom_keyword!(slt); + custom_keyword!(ult); + custom_keyword!(sge); + custom_keyword!(uge); + custom_keyword!(sgt); + custom_keyword!(ugt); + custom_keyword!(sle); + custom_keyword!(ule); + custom_keyword!(of); + custom_keyword!(nof); +} + +impl<'a> Parse<'a> for Optimizations<'a> { + fn parse(p: Parser<'a>) -> ParseResult { + let span = p.cur_span(); + let mut optimizations = vec![]; + while !p.is_empty() { + optimizations.push(p.parse()?); + } + Ok(Optimizations { + span, + optimizations, + }) + } +} + +impl<'a> Parse<'a> for Optimization<'a> { + fn parse(p: Parser<'a>) -> ParseResult { + let span = p.cur_span(); + p.parens(|p| { + p.parse::()?; + let lhs = p.parse()?; + let rhs = p.parse()?; + Ok(Optimization { span, lhs, rhs }) + }) + } +} + +impl<'a> Parse<'a> for Lhs<'a> { + fn parse(p: Parser<'a>) -> ParseResult { + let span = p.cur_span(); + let mut preconditions = vec![]; + if p.peek::() && p.peek2::() { + p.parens(|p| { + p.parse::()?; + let pattern = p.parse()?; + while p.peek::() { + preconditions.push(p.parse()?); + } + Ok(Lhs { + span, + pattern, + preconditions, + }) + }) + } else { + let span = p.cur_span(); + let pattern = p.parse()?; + Ok(Lhs { + span, + pattern, + preconditions, + }) + } + } +} + +impl<'a> Parse<'a> for Pattern<'a> { + fn parse(p: Parser<'a>) -> ParseResult { + if p.peek::() { + return Ok(Pattern::ValueLiteral(p.parse()?)); + } + if p.peek::() { + return Ok(Pattern::Constant(p.parse()?)); + } + if p.peek::>() { + return Ok(Pattern::Operation(p.parse()?)); + } + if p.peek::() { + return Ok(Pattern::Variable(p.parse()?)); + } + Err(p.error("expected a left-hand side pattern")) + } +} + +impl<'a> Peek for Pattern<'a> { + fn peek(c: Cursor) -> bool { + ValueLiteral::peek(c) + || Constant::peek(c) + || Variable::peek(c) + || Operation::::peek(c) + } + + fn display() -> &'static str { + "left-hand side pattern" + } +} + +impl<'a> Parse<'a> for ValueLiteral<'a> { + fn parse(p: Parser<'a>) -> ParseResult { + if let Ok(b) = p.parse::() { + return Ok(ValueLiteral::Boolean(b)); + } + if let Ok(i) = p.parse::() { + return Ok(ValueLiteral::Integer(i)); + } + if let Ok(cc) = p.parse::() { + return Ok(ValueLiteral::ConditionCode(cc)); + } + Err(p.error("expected an integer or boolean or condition code literal")) + } +} + +impl<'a> Peek for ValueLiteral<'a> { + fn peek(c: Cursor) -> bool { + c.integer().is_some() || Boolean::peek(c) || ConditionCode::peek(c) + } + + fn display() -> &'static str { + "value literal" + } +} + +impl<'a> Parse<'a> for Integer<'a> { + fn parse(p: Parser<'a>) -> ParseResult { + let span = p.cur_span(); + p.step(|c| { + if let Some((i, rest)) = c.integer() { + let (s, base) = i.val(); + let val = i64::from_str_radix(s, base) + .or_else(|_| u128::from_str_radix(s, base).map(|i| i as i64)); + return match val { + Ok(value) => Ok(( + Integer { + span, + value, + bit_width: Default::default(), + marker: PhantomData, + }, + rest, + )), + Err(_) => Err(c.error("invalid integer: out of range")), + }; + } + Err(c.error("expected an integer")) + }) + } +} + +impl<'a> Parse<'a> for Boolean<'a> { + fn parse(p: Parser<'a>) -> ParseResult { + let span = p.cur_span(); + if p.parse::().is_ok() { + return Ok(Boolean { + span, + value: true, + bit_width: Default::default(), + marker: PhantomData, + }); + } + if p.parse::().is_ok() { + return Ok(Boolean { + span, + value: false, + bit_width: Default::default(), + marker: PhantomData, + }); + } + Err(p.error("expected `true` or `false`")) + } +} + +impl<'a> Peek for Boolean<'a> { + fn peek(c: Cursor) -> bool { + ::peek(c) || ::peek(c) + } + + fn display() -> &'static str { + "boolean `true` or `false`" + } +} + +impl<'a> Parse<'a> for ConditionCode<'a> { + fn parse(p: Parser<'a>) -> ParseResult { + let span = p.cur_span(); + + macro_rules! parse_cc { + ( $( $token:ident => $cc:ident, )* ) => { + $( + if p.peek::() { + p.parse::()?; + return Ok(Self { + span, + cc: peepmatic_runtime::cc::ConditionCode::$cc, + marker: PhantomData, + }); + } + )* + } + } + + parse_cc! { + eq => Eq, + ne => Ne, + slt => Slt, + ult => Ult, + sge => Sge, + uge => Uge, + sgt => Sgt, + ugt => Ugt, + sle => Sle, + ule => Ule, + of => Of, + nof => Nof, + } + + Err(p.error("expected a condition code")) + } +} + +impl<'a> Peek for ConditionCode<'a> { + fn peek(c: Cursor) -> bool { + macro_rules! peek_cc { + ( $( $token:ident, )* ) => { + false $( || ::peek(c) )* + } + } + + peek_cc! { + eq, + ne, + slt, + ult, + sge, + uge, + sgt, + ugt, + sle, + ule, + of, + nof, + } + } + + fn display() -> &'static str { + "condition code" + } +} + +impl<'a> Parse<'a> for Constant<'a> { + fn parse(p: Parser<'a>) -> ParseResult { + let span = p.cur_span(); + let id = Id::parse(p)?; + if id + .name() + .chars() + .all(|c| !c.is_alphabetic() || c.is_uppercase()) + { + Ok(Constant { span, id }) + } else { + let upper = id + .name() + .chars() + .flat_map(|c| c.to_uppercase()) + .collect::(); + Err(p.error(format!( + "symbolic constants must start with an upper-case letter like ${}", + upper + ))) + } + } +} + +impl<'a> Peek for Constant<'a> { + fn peek(c: Cursor) -> bool { + if let Some((id, _rest)) = c.id() { + id.chars().all(|c| !c.is_alphabetic() || c.is_uppercase()) + } else { + false + } + } + + fn display() -> &'static str { + "symbolic constant" + } +} + +impl<'a> Parse<'a> for Variable<'a> { + fn parse(p: Parser<'a>) -> ParseResult { + let span = p.cur_span(); + let id = Id::parse(p)?; + if id + .name() + .chars() + .all(|c| !c.is_alphabetic() || c.is_lowercase()) + { + Ok(Variable { span, id }) + } else { + let lower = id + .name() + .chars() + .flat_map(|c| c.to_lowercase()) + .collect::(); + Err(p.error(format!( + "variables must start with an lower-case letter like ${}", + lower + ))) + } + } +} + +impl<'a> Peek for Variable<'a> { + fn peek(c: Cursor) -> bool { + if let Some((id, _rest)) = c.id() { + id.chars().all(|c| !c.is_alphabetic() || c.is_lowercase()) + } else { + false + } + } + + fn display() -> &'static str { + "variable" + } +} + +impl<'a, T> Parse<'a> for Operation<'a, T> +where + T: 'a + Ast<'a> + Peek + Parse<'a>, + DynAstRef<'a>: From<&'a T>, +{ + fn parse(p: Parser<'a>) -> ParseResult { + let span = p.cur_span(); + p.parens(|p| { + let operator = p.parse()?; + + let r#type = Cell::new(if p.peek::() { + p.parse::()?; + let ty = p.parse::()?; + p.parse::()?; + Some(ty) + } else { + None + }); + + let mut operands = vec![]; + while p.peek::() { + operands.push(p.parse()?); + } + Ok(Operation { + span, + operator, + r#type, + operands, + marker: PhantomData, + }) + }) + } +} + +impl<'a, T> Peek for Operation<'a, T> +where + T: 'a + Ast<'a>, + DynAstRef<'a>: From<&'a T>, +{ + fn peek(c: Cursor) -> bool { + wast::LParen::peek(c) + } + + fn display() -> &'static str { + "operation" + } +} + +impl<'a> Parse<'a> for Precondition<'a> { + fn parse(p: Parser<'a>) -> ParseResult { + let span = p.cur_span(); + p.parens(|p| { + let constraint = p.parse()?; + let mut operands = vec![]; + while p.peek::() { + operands.push(p.parse()?); + } + Ok(Precondition { + span, + constraint, + operands, + }) + }) + } +} + +impl<'a> Parse<'a> for Constraint { + fn parse(p: Parser<'a>) -> ParseResult { + if p.peek::() { + p.parse::()?; + return Ok(Constraint::IsPowerOfTwo); + } + if p.peek::() { + p.parse::()?; + return Ok(Constraint::BitWidth); + } + if p.peek::() { + p.parse::()?; + return Ok(Constraint::FitsInNativeWord); + } + Err(p.error("expected a precondition constraint")) + } +} + +impl<'a> Parse<'a> for ConstraintOperand<'a> { + fn parse(p: Parser<'a>) -> ParseResult { + if p.peek::() { + return Ok(ConstraintOperand::ValueLiteral(p.parse()?)); + } + if p.peek::() { + return Ok(ConstraintOperand::Constant(p.parse()?)); + } + if p.peek::() { + return Ok(ConstraintOperand::Variable(p.parse()?)); + } + Err(p.error("expected an operand for precondition constraint")) + } +} + +impl<'a> Peek for ConstraintOperand<'a> { + fn peek(c: Cursor) -> bool { + ValueLiteral::peek(c) || Constant::peek(c) || Variable::peek(c) + } + + fn display() -> &'static str { + "operand for a precondition constraint" + } +} + +impl<'a> Parse<'a> for Rhs<'a> { + fn parse(p: Parser<'a>) -> ParseResult { + if p.peek::() { + return Ok(Rhs::ValueLiteral(p.parse()?)); + } + if p.peek::() { + return Ok(Rhs::Constant(p.parse()?)); + } + if p.peek::() { + return Ok(Rhs::Variable(p.parse()?)); + } + if p.peek::() { + return Ok(Rhs::Unquote(p.parse()?)); + } + if p.peek::>() { + return Ok(Rhs::Operation(p.parse()?)); + } + Err(p.error("expected a right-hand side replacement")) + } +} + +impl<'a> Peek for Rhs<'a> { + fn peek(c: Cursor) -> bool { + ValueLiteral::peek(c) + || Constant::peek(c) + || Variable::peek(c) + || Unquote::peek(c) + || Operation::::peek(c) + } + + fn display() -> &'static str { + "right-hand side replacement" + } +} + +impl<'a> Parse<'a> for Unquote<'a> { + fn parse(p: Parser<'a>) -> ParseResult { + let span = p.cur_span(); + p.parse::()?; + p.parens(|p| { + let operator = p.parse()?; + let mut operands = vec![]; + while p.peek::() { + operands.push(p.parse()?); + } + Ok(Unquote { + span, + operator, + operands, + }) + }) + } +} + +impl<'a> Peek for Unquote<'a> { + fn peek(c: Cursor) -> bool { + tok::dollar::peek(c) + } + + fn display() -> &'static str { + "unquote expression" + } +} + +#[cfg(test)] +mod test { + use super::*; + use peepmatic_runtime::operator::Operator; + + macro_rules! test_parse { + ( + $( + $name:ident < $ast:ty > { + $( ok { $( $ok:expr , )* } )* + $( err { $( $err:expr , )* } )* + } + )* + ) => { + $( + #[test] + #[allow(non_snake_case)] + fn $name() { + $( + $({ + let input = $ok; + let buf = wast::parser::ParseBuffer::new(input).unwrap_or_else(|e| { + panic!("should lex OK, got error:\n\n{}\n\nInput:\n\n{}", e, input) + }); + if let Err(e) = wast::parser::parse::<$ast>(&buf) { + panic!( + "expected to parse OK, got error:\n\n{}\n\nInput:\n\n{}", + e, input + ); + } + })* + )* + + $( + $({ + let input = $err; + let buf = wast::parser::ParseBuffer::new(input).unwrap_or_else(|e| { + panic!("should lex OK, got error:\n\n{}\n\nInput:\n\n{}", e, input) + }); + if let Ok(ast) = wast::parser::parse::<$ast>(&buf) { + panic!( + "expected a parse error, got:\n\n{:?}\n\nInput:\n\n{}", + ast, input + ); + } + })* + )* + } + )* + } + } + + test_parse! { + parse_boolean { + ok { + "true", + "false", + } + err { + "", + "t", + "tr", + "tru", + "truezzz", + "f", + "fa", + "fal", + "fals", + "falsezzz", + } + } + parse_cc { + ok { + "eq", + "ne", + "slt", + "ult", + "sge", + "uge", + "sgt", + "ugt", + "sle", + "ule", + "of", + "nof", + } + err { + "", + "neq", + } + } + parse_constant { + ok { + "$C", + "$C1", + "$C2", + "$X", + "$Y", + "$SOME-CONSTANT", + "$SOME_CONSTANT", + } + err { + "", + "zzz", + "$", + "$variable", + "$Some-Constant", + "$Some_Constant", + "$Some_constant", + } + } + parse_constraint { + ok { + "is-power-of-two", + "bit-width", + "fits-in-native-word", + } + err { + "", + "iadd", + "imul", + } + } + parse_constraint_operand { + ok { + "1234", + "true", + "$CONSTANT", + "$variable", + } + err { + "", + "is-power-of-two", + "(is-power-of-two $C)", + "(iadd 1 2)", + } + } + parse_integer { + ok { + "0", + "1", + "12", + "123", + "1234", + "12345", + "123456", + "1234567", + "12345678", + "123456789", + "1234567890", + "0x0", + "0x1", + "0x12", + "0x123", + "0x1234", + "0x12345", + "0x123456", + "0x1234567", + "0x12345678", + "0x123456789", + "0x123456789a", + "0x123456789ab", + "0x123456789abc", + "0x123456789abcd", + "0x123456789abcde", + "0x123456789abcdef", + "0xffff_ffff_ffff_ffff", + } + err { + "", + "abcdef", + "01234567890abcdef", + "0xgggg", + "0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", + } + } + parse_lhs { + ok { + "(when (imul $C1 $C2) (is-power-of-two $C1) (is-power-of-two $C2))", + "(when (imul $x $C) (is-power-of-two $C))", + "(imul $x $y)", + "(imul $x)", + "(imul)", + "$C", + "$x", + } + err { + "", + "()", + "abc", + } + } + parse_operation_pattern> { + ok { + "(iadd)", + "(iadd 1)", + "(iadd 1 2)", + "(iadd $x $C)", + "(iadd{i32} $x $y)", + "(icmp eq $x $y)", + } + err { + "", + "()", + "$var", + "$CONST", + "(ishl $x $(log2 $C))", + } + } + parse_operation_rhs> { + ok { + "(iadd)", + "(iadd 1)", + "(iadd 1 2)", + "(ishl $x $(log2 $C))", + } + err { + "", + "()", + "$var", + "$CONST", + } + } + parse_operator { + ok { + "bor", + "iadd", + "iadd_imm", + "iconst", + "imul", + "imul_imm", + "ishl", + "sdiv", + "sdiv_imm", + "sshr", + } + err { + "", + "iadd.i32", + "iadd{i32}", + } + } + parse_optimization { + ok { + "(=> (when (iadd $x $C) (is-power-of-two $C) (is-power-of-two $C)) (iadd $C $x))", + "(=> (when (iadd $x $C)) (iadd $C $x))", + "(=> (iadd $x $C) (iadd $C $x))", + } + err { + "", + "()", + "(=>)", + "(=> () ())", + } + } + parse_optimizations { + ok { + "", + r#" + ;; Canonicalize `a + (b + c)` into `(a + b) + c`. + (=> (iadd $a (iadd $b $c)) + (iadd (iadd $a $b) $c)) + + ;; Combine a `const` and an `iadd` into a `iadd_imm`. + (=> (iadd (iconst $C) $x) + (iadd_imm $C $x)) + + ;; When `C` is a power of two, replace `x * C` with `x << log2(C)`. + (=> (when (imul $x $C) + (is-power-of-two $C)) + (ishl $x $(log2 $C))) + "#, + } + } + parse_pattern { + ok { + "1234", + "$C", + "$x", + "(iadd $x $y)", + } + err { + "", + "()", + "abc", + } + } + parse_precondition { + ok { + "(is-power-of-two)", + "(is-power-of-two $C)", + "(is-power-of-two $C1 $C2)", + } + err { + "", + "1234", + "()", + "$var", + "$CONST", + } + } + parse_rhs { + ok { + "5", + "$C", + "$x", + "$(log2 $C)", + "(iadd $x 1)", + } + err { + "", + "()", + } + } + parse_unquote { + ok { + "$(log2)", + "$(log2 $C)", + "$(log2 $C1 1)", + "$(neg)", + "$(neg $C)", + "$(neg $C 1)", + } + err { + "", + "(log2 $C)", + "$()", + } + } + parse_value_literal { + ok { + "12345", + "true", + } + err { + "", + "'c'", + "\"hello\"", + "12.34", + } + } + parse_variable { + ok { + "$v", + "$v1", + "$v2", + "$x", + "$y", + "$some-var", + "$another_var", + } + err { + "zzz", + "$", + "$CONSTANT", + "$fooBar", + } + } + } +} diff --git a/cranelift/peepmatic/src/traversals.rs b/cranelift/peepmatic/src/traversals.rs new file mode 100644 index 0000000000..5eb4101c37 --- /dev/null +++ b/cranelift/peepmatic/src/traversals.rs @@ -0,0 +1,278 @@ +//! Traversals over the AST. + +use crate::ast::*; + +/// A low-level DFS traversal event: either entering or exiting the traversal of +/// an AST node. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum TraversalEvent { + /// Entering traversal of an AST node. + /// + /// Processing an AST node upon this event corresponds to a pre-order + /// DFS traversal. + Enter, + + /// Exiting traversal of an AST node. + /// + /// Processing an AST node upon this event corresponds to a post-order DFS + /// traversal. + Exit, +} + +/// A depth-first traversal of an AST. +/// +/// This is a fairly low-level traversal type, and is intended to be used as a +/// building block for making specific pre-order or post-order traversals for +/// whatever problem is at hand. +/// +/// This implementation is not recursive, and exposes an `Iterator` interface +/// that yields pairs of `(TraversalEvent, DynAstRef)` items. +/// +/// The traversal can walk a whole set of `Optimization`s or just a subtree of +/// the AST, because the `new` constructor takes anything that can convert into +/// a `DynAstRef`. +#[derive(Debug, Clone)] +pub struct Dfs<'a> { + stack: Vec<(TraversalEvent, DynAstRef<'a>)>, +} + +impl<'a> Dfs<'a> { + /// Construct a new `Dfs` traversal starting at the given `start` AST node. + pub fn new(start: impl Into>) -> Self { + let start = start.into(); + Dfs { + stack: vec![ + (TraversalEvent::Exit, start), + (TraversalEvent::Enter, start), + ], + } + } + + /// Peek at the next traversal event and AST node pair, if any. + pub fn peek(&self) -> Option<(TraversalEvent, DynAstRef<'a>)> { + self.stack.last().cloned() + } +} + +impl<'a> Iterator for Dfs<'a> { + type Item = (TraversalEvent, DynAstRef<'a>); + + fn next(&mut self) -> Option<(TraversalEvent, DynAstRef<'a>)> { + let (event, node) = self.stack.pop()?; + if let TraversalEvent::Enter = event { + let mut enqueue_children = EnqueueChildren(self); + node.child_nodes(&mut enqueue_children) + } + return Some((event, node)); + + struct EnqueueChildren<'a, 'b>(&'b mut Dfs<'a>) + where + 'a: 'b; + + impl<'a, 'b> Extend> for EnqueueChildren<'a, 'b> + where + 'a: 'b, + { + fn extend>>(&mut self, iter: T) { + let iter = iter.into_iter(); + + let (min, max) = iter.size_hint(); + self.0.stack.reserve(max.unwrap_or(min) * 2); + + let start = self.0.stack.len(); + + for node in iter { + self.0.stack.push((TraversalEvent::Enter, node)); + self.0.stack.push((TraversalEvent::Exit, node)); + } + + // Reverse to make it so that we visit children in order + // (e.g. operands are visited in order). + self.0.stack[start..].reverse(); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use DynAstRef::*; + + #[test] + fn test_dfs_traversal() { + let input = " +(=> (when (imul $x $C) + (is-power-of-two $C)) + (ishl $x $(log2 $C))) +"; + let buf = wast::parser::ParseBuffer::new(input).expect("input should lex OK"); + let ast = match wast::parser::parse::(&buf) { + Ok(ast) => ast, + Err(e) => panic!("expected to parse OK, got error:\n\n{}", e), + }; + + let mut dfs = Dfs::new(&ast); + assert!(matches!( + dbg!(dfs.next()), + Some((TraversalEvent::Enter, Optimizations(..))) + )); + assert!(matches!( + dbg!(dfs.next()), + Some((TraversalEvent::Enter, Optimization(..))) + )); + assert!(matches!( + dbg!(dfs.next()), + Some((TraversalEvent::Enter, Lhs(..))) + )); + assert!(matches!( + dbg!(dfs.next()), + Some((TraversalEvent::Enter, Pattern(..))) + )); + assert!(matches!( + dbg!(dfs.next()), + Some((TraversalEvent::Enter, PatternOperation(..))) + )); + assert!(matches!( + dbg!(dfs.next()), + Some((TraversalEvent::Enter, Pattern(..))) + )); + assert!(matches!( + dbg!(dfs.next()), + Some((TraversalEvent::Enter, Variable(..))) + )); + assert!(matches!( + dbg!(dfs.next()), + Some((TraversalEvent::Exit, Variable(..))) + )); + assert!(matches!( + dbg!(dfs.next()), + Some((TraversalEvent::Exit, Pattern(..))) + )); + assert!(matches!( + dbg!(dfs.next()), + Some((TraversalEvent::Enter, Pattern(..))) + )); + assert!(matches!( + dbg!(dfs.next()), + Some((TraversalEvent::Enter, Constant(..))) + )); + assert!(matches!( + dbg!(dfs.next()), + Some((TraversalEvent::Exit, Constant(..))) + )); + assert!(matches!( + dbg!(dfs.next()), + Some((TraversalEvent::Exit, Pattern(..))) + )); + assert!(matches!( + dbg!(dfs.next()), + Some((TraversalEvent::Exit, PatternOperation(..))) + )); + assert!(matches!( + dbg!(dfs.next()), + Some((TraversalEvent::Exit, Pattern(..))) + )); + assert!(matches!( + dbg!(dfs.next()), + Some((TraversalEvent::Enter, Precondition(..))) + )); + assert!(matches!( + dbg!(dfs.next()), + Some((TraversalEvent::Enter, ConstraintOperand(..))) + )); + assert!(matches!( + dbg!(dfs.next()), + Some((TraversalEvent::Enter, Constant(..))) + )); + assert!(matches!( + dbg!(dfs.next()), + Some((TraversalEvent::Exit, Constant(..))) + )); + assert!(matches!( + dbg!(dfs.next()), + Some((TraversalEvent::Exit, ConstraintOperand(..))) + )); + assert!(matches!( + dbg!(dfs.next()), + Some((TraversalEvent::Exit, Precondition(..))) + )); + assert!(matches!( + dbg!(dfs.next()), + Some((TraversalEvent::Exit, Lhs(..))) + )); + assert!(matches!( + dbg!(dfs.next()), + Some((TraversalEvent::Enter, Rhs(..))) + )); + assert!(matches!( + dbg!(dfs.next()), + Some((TraversalEvent::Enter, RhsOperation(..))) + )); + assert!(matches!( + dbg!(dfs.next()), + Some((TraversalEvent::Enter, Rhs(..))) + )); + assert!(matches!( + dbg!(dfs.next()), + Some((TraversalEvent::Enter, Variable(..))) + )); + assert!(matches!( + dbg!(dfs.next()), + Some((TraversalEvent::Exit, Variable(..))) + )); + assert!(matches!( + dbg!(dfs.next()), + Some((TraversalEvent::Exit, Rhs(..))) + )); + assert!(matches!( + dbg!(dfs.next()), + Some((TraversalEvent::Enter, Rhs(..))) + )); + assert!(matches!( + dbg!(dfs.next()), + Some((TraversalEvent::Enter, Unquote(..))) + )); + assert!(matches!( + dbg!(dfs.next()), + Some((TraversalEvent::Enter, Rhs(..))) + )); + assert!(matches!( + dbg!(dfs.next()), + Some((TraversalEvent::Enter, Constant(..))) + )); + assert!(matches!( + dbg!(dfs.next()), + Some((TraversalEvent::Exit, Constant(..))) + )); + assert!(matches!( + dbg!(dfs.next()), + Some((TraversalEvent::Exit, Rhs(..))) + )); + assert!(matches!( + dbg!(dfs.next()), + Some((TraversalEvent::Exit, Unquote(..))) + )); + assert!(matches!( + dbg!(dfs.next()), + Some((TraversalEvent::Exit, Rhs(..))) + )); + assert!(matches!( + dbg!(dfs.next()), + Some((TraversalEvent::Exit, RhsOperation(..))) + )); + assert!(matches!( + dbg!(dfs.next()), + Some((TraversalEvent::Exit, Rhs(..))) + )); + assert!(matches!( + dbg!(dfs.next()), + Some((TraversalEvent::Exit, Optimization(..))) + )); + assert!(matches!( + dbg!(dfs.next()), + Some((TraversalEvent::Exit, Optimizations(..))) + )); + assert!(dfs.next().is_none()); + } +} diff --git a/cranelift/peepmatic/src/verify.rs b/cranelift/peepmatic/src/verify.rs new file mode 100644 index 0000000000..03865458ca --- /dev/null +++ b/cranelift/peepmatic/src/verify.rs @@ -0,0 +1,1433 @@ +//! Verification and type checking of optimizations. +//! +//! For type checking, we compile the AST's type constraints down into Z3 +//! variables and assertions. If Z3 finds the assertions satisfiable, then we're +//! done! If it finds them unsatisfiable, we use the `get_unsat_core` method to +//! get the minimal subset of assertions that are in conflict, and report a +//! best-effort type error message with them. These messages aren't perfect, but +//! they're Good Enough when embedded in the source text via our tracking of +//! `wast::Span`s. +//! +//! Verifying that there aren't any counter-examples (inputs for which the LHS +//! and RHS produce different results) for a particular optimization is not +//! implemented yet. + +use crate::ast::{Span as _, *}; +use crate::traversals::{Dfs, TraversalEvent}; +use peepmatic_runtime::{ + operator::{Operator, TypingContext as TypingContextTrait}, + r#type::{BitWidth, Kind, Type}, +}; +use std::borrow::Cow; +use std::collections::HashMap; +use std::convert::{TryFrom, TryInto}; +use std::fmt; +use std::hash::Hash; +use std::iter; +use std::mem; +use std::ops::{Deref, DerefMut}; +use std::path::Path; +use wast::{Error as WastError, Id, Span}; +use z3::ast::Ast; + +/// A verification or type checking error. +#[derive(Debug)] +pub struct VerifyError { + errors: Vec, +} + +impl fmt::Display for VerifyError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + for e in &self.errors { + writeln!(f, "{}\n", e)?; + } + Ok(()) + } +} + +impl std::error::Error for VerifyError {} + +impl From for VerifyError { + fn from(e: WastError) -> Self { + VerifyError { + errors: vec![e.into()], + } + } +} + +impl From for VerifyError { + fn from(e: anyhow::Error) -> Self { + VerifyError { errors: vec![e] } + } +} + +impl VerifyError { + /// To provide a more useful error this function can be used to extract + /// relevant textual information about this error into the error itself. + /// + /// The `contents` here should be the full text of the original file being + /// parsed, and this will extract a sub-slice as necessary to render in the + /// `Display` implementation later on. + pub fn set_text(&mut self, contents: &str) { + for e in &mut self.errors { + if let Some(e) = e.downcast_mut::() { + e.set_text(contents); + } + } + } + + /// To provide a more useful error this function can be used to set + /// the file name that this error is associated with. + /// + /// The `path` here will be stored in this error and later rendered in the + /// `Display` implementation. + pub fn set_path(&mut self, path: &Path) { + for e in &mut self.errors { + if let Some(e) = e.downcast_mut::() { + e.set_path(path); + } + } + } +} + +/// Either `Ok(T)` or `Err(VerifyError)`. +pub type VerifyResult = Result; + +/// Verify and type check a set of optimizations. +pub fn verify(opts: &Optimizations) -> VerifyResult<()> { + if opts.optimizations.is_empty() { + return Err(anyhow::anyhow!("no optimizations").into()); + } + + verify_unique_left_hand_sides(opts)?; + + let z3 = &z3::Context::new(&z3::Config::new()); + for opt in &opts.optimizations { + verify_optimization(z3, opt)?; + } + Ok(()) +} + +/// Check that every LHS in the given optimizations is unique. +/// +/// If there were duplicates, then it would be nondeterministic which one we +/// applied and would make automata construction more difficult. It is better to +/// check for duplicates and reject them if found. +fn verify_unique_left_hand_sides(opts: &Optimizations) -> VerifyResult<()> { + let mut lefts = HashMap::new(); + for opt in &opts.optimizations { + let canon_lhs = canonicalized_lhs_key(&opt.lhs); + let existing = lefts.insert(canon_lhs, opt.lhs.span()); + if let Some(span) = existing { + return Err(VerifyError { + errors: vec![ + anyhow::anyhow!("error: two optimizations cannot have the same left-hand side"), + WastError::new(span, "note: first use of this left-hand side".into()).into(), + WastError::new( + opt.lhs.span(), + "note: second use of this left-hand side".into(), + ) + .into(), + ], + }); + } + } + Ok(()) +} + +/// When checking for duplicate left-hand sides, we need to consider patterns +/// that are duplicates up to renaming identifiers. For example, these LHSes +/// should be considered duplicates of each other: +/// +/// ```lisp +/// (=> (iadd $x $y) ...) +/// (=> (iadd $a $b) ...) +/// ``` +/// +/// This function creates an opaque, canonicalized hash key for left-hand sides +/// that sees through identifier renaming. +fn canonicalized_lhs_key(lhs: &Lhs) -> impl Hash + Eq { + let mut var_to_canon = HashMap::new(); + let mut const_to_canon = HashMap::new(); + let mut canonicalized = vec![]; + + for (event, ast) in Dfs::new(lhs) { + if event != TraversalEvent::Enter { + continue; + } + use CanonicalBit::*; + canonicalized.push(match ast { + DynAstRef::Lhs(_) => Other("Lhs"), + DynAstRef::Pattern(_) => Other("Pattern"), + DynAstRef::ValueLiteral(_) => Other("ValueLiteral"), + DynAstRef::Integer(i) => Integer(i.value), + DynAstRef::Boolean(b) => Boolean(b.value), + DynAstRef::ConditionCode(cc) => ConditionCode(cc.cc), + DynAstRef::PatternOperation(o) => Operation(o.operator, o.r#type.get()), + DynAstRef::Precondition(p) => Precondition(p.constraint), + DynAstRef::ConstraintOperand(_) => Other("ConstraintOperand"), + DynAstRef::Variable(Variable { id, .. }) => { + let new_id = var_to_canon.len() as u32; + let canon_id = var_to_canon.entry(id).or_insert(new_id); + Var(*canon_id) + } + DynAstRef::Constant(Constant { id, .. }) => { + let new_id = const_to_canon.len() as u32; + let canon_id = const_to_canon.entry(id).or_insert(new_id); + Const(*canon_id) + } + other => unreachable!("unreachable ast node: {:?}", other), + }); + } + + return canonicalized; + + #[derive(Hash, PartialEq, Eq)] + enum CanonicalBit { + Var(u32), + Const(u32), + Integer(i64), + Boolean(bool), + ConditionCode(peepmatic_runtime::cc::ConditionCode), + Operation(Operator, Option), + Precondition(Constraint), + Other(&'static str), + } +} + +pub(crate) struct TypingContext<'a> { + z3: &'a z3::Context, + type_kind_sort: z3::DatatypeSort<'a>, + solver: z3::Solver<'a>, + + // The type of the root of the optimization. Initialized when collecting + // type constraints. + root_ty: Option>, + + // See the comments above `enter_operation_scope`. + operation_scope: HashMap<&'static str, TypeVar<'a>>, + + // A map from identifiers to the type variable describing its type. + id_to_type_var: HashMap, TypeVar<'a>>, + + // A list of type constraints, the span of the AST node where the constraint + // originates from, and an optional message to be displayed if the + // constraint is not satisfied. + constraints: Vec<(z3::ast::Bool<'a>, Span, Option>)>, + + // Keep track of AST nodes that need to have their types assigned to + // them. For these AST nodes, we know what bit width to use when + // interpreting peephole optimization actions. + boolean_literals: Vec<(&'a Boolean<'a>, TypeVar<'a>)>, + integer_literals: Vec<(&'a Integer<'a>, TypeVar<'a>)>, + rhs_operations: Vec<(&'a Operation<'a, Rhs<'a>>, TypeVar<'a>)>, +} + +impl<'a> TypingContext<'a> { + fn new(z3: &'a z3::Context) -> Self { + let type_kind_sort = z3::DatatypeBuilder::new(z3) + .variant("int", &[]) + .variant("bool", &[]) + .variant("cpu_flags", &[]) + .variant("cc", &[]) + .variant("void", &[]) + .finish("TypeKind"); + TypingContext { + z3, + solver: z3::Solver::new(z3), + root_ty: None, + operation_scope: Default::default(), + id_to_type_var: Default::default(), + type_kind_sort, + constraints: vec![], + boolean_literals: Default::default(), + integer_literals: Default::default(), + rhs_operations: Default::default(), + } + } + + fn init_root_type(&mut self, span: Span, root_ty: TypeVar<'a>) { + assert!(self.root_ty.is_none()); + + // Make sure the root is a valid kind, i.e. not a condition code. + let is_int = self.is_int(&root_ty); + let is_bool = self.is_bool(&root_ty); + let is_void = self.is_void(&root_ty); + let is_cpu_flags = self.is_cpu_flags(&root_ty); + self.constraints.push(( + is_int.or(&[&is_bool, &is_void, &is_cpu_flags]), + span, + Some( + "the root of an optimization must be an integer, a boolean, void, or CPU flags" + .into(), + ), + )); + + self.root_ty = Some(root_ty); + } + + fn new_type_var(&self) -> TypeVar<'a> { + let kind = + z3::ast::Datatype::fresh_const(self.z3, "type-var-kind", &self.type_kind_sort.sort); + let width = z3::ast::BV::fresh_const(self.z3, "type-var-width", 8); + TypeVar { kind, width } + } + + fn get_or_create_type_var_for_id(&mut self, id: Id<'a>) -> TypeVar<'a> { + if let Some(ty) = self.id_to_type_var.get(&id) { + ty.clone() + } else { + // Note: can't use the entry API because we reborrow `self` here. + let ty = self.new_type_var(); + self.id_to_type_var.insert(id, ty.clone()); + ty + } + } + + fn get_type_var_for_id(&mut self, id: Id<'a>) -> VerifyResult> { + if let Some(ty) = self.id_to_type_var.get(&id) { + Ok(ty.clone()) + } else { + Err(WastError::new(id.span(), format!("unknown identifier: ${}", id.name())).into()) + } + } + + // The `#[peepmatic]` macro for operations allows defining operations' types + // like `(iNN, iNN) -> iNN` where `iNN` all refer to the same integer type + // variable that must have the same bit width. But other operations might + // *also* have that type signature but be instantiated at a different bit + // width. We don't want to mix up which `iNN` variables are and aren't the + // same. We use this method to track scopes within which all uses of `iNN` + // and similar refer to the same type variables. + fn enter_operation_scope<'b>( + &'b mut self, + ) -> impl DerefMut> + Drop + 'b { + assert!(self.operation_scope.is_empty()); + return Scope(self); + + struct Scope<'a, 'b>(&'b mut TypingContext<'a>) + where + 'a: 'b; + + impl<'a, 'b> Deref for Scope<'a, 'b> + where + 'a: 'b, + { + type Target = TypingContext<'a>; + fn deref(&self) -> &TypingContext<'a> { + self.0 + } + } + + impl<'a, 'b> DerefMut for Scope<'a, 'b> + where + 'a: 'b, + { + fn deref_mut(&mut self) -> &mut TypingContext<'a> { + self.0 + } + } + + impl Drop for Scope<'_, '_> { + fn drop(&mut self) { + self.0.operation_scope.clear(); + } + } + } + + fn remember_boolean_literal(&mut self, b: &'a Boolean<'a>, ty: TypeVar<'a>) { + self.assert_is_bool(b.span, &ty); + self.boolean_literals.push((b, ty)); + } + + fn remember_integer_literal(&mut self, i: &'a Integer<'a>, ty: TypeVar<'a>) { + self.assert_is_integer(i.span, &ty); + self.integer_literals.push((i, ty)); + } + + fn remember_rhs_operation(&mut self, op: &'a Operation<'a, Rhs<'a>>, ty: TypeVar<'a>) { + self.rhs_operations.push((op, ty)); + } + + fn is_int(&self, ty: &TypeVar<'a>) -> z3::ast::Bool<'a> { + self.type_kind_sort.variants[0] + .tester + .apply(&[&ty.kind.clone().into()]) + .as_bool() + .unwrap() + } + + fn is_bool(&self, ty: &TypeVar<'a>) -> z3::ast::Bool<'a> { + self.type_kind_sort.variants[1] + .tester + .apply(&[&ty.kind.clone().into()]) + .as_bool() + .unwrap() + } + + fn is_cpu_flags(&self, ty: &TypeVar<'a>) -> z3::ast::Bool<'a> { + self.type_kind_sort.variants[2] + .tester + .apply(&[&ty.kind.clone().into()]) + .as_bool() + .unwrap() + } + + fn is_condition_code(&self, ty: &TypeVar<'a>) -> z3::ast::Bool<'a> { + self.type_kind_sort.variants[3] + .tester + .apply(&[&ty.kind.clone().into()]) + .as_bool() + .unwrap() + } + + fn is_void(&self, ty: &TypeVar<'a>) -> z3::ast::Bool<'a> { + self.type_kind_sort.variants[4] + .tester + .apply(&[&ty.kind.clone().into()]) + .as_bool() + .unwrap() + } + + fn assert_is_integer(&mut self, span: Span, ty: &TypeVar<'a>) { + self.constraints.push(( + self.is_int(ty), + span, + Some("type error: expected integer".into()), + )); + } + + fn assert_is_bool(&mut self, span: Span, ty: &TypeVar<'a>) { + self.constraints.push(( + self.is_bool(ty), + span, + Some("type error: expected bool".into()), + )); + } + + fn assert_is_cpu_flags(&mut self, span: Span, ty: &TypeVar<'a>) { + self.constraints.push(( + self.is_cpu_flags(ty), + span, + Some("type error: expected CPU flags".into()), + )); + } + + fn assert_is_cc(&mut self, span: Span, ty: &TypeVar<'a>) { + self.constraints.push(( + self.is_condition_code(ty), + span, + Some("type error: expected condition code".into()), + )); + } + + fn assert_is_void(&mut self, span: Span, ty: &TypeVar<'a>) { + self.constraints.push(( + self.is_void(ty), + span, + Some("type error: expected void".into()), + )); + } + + fn assert_bit_width(&mut self, span: Span, ty: &TypeVar<'a>, width: u8) { + debug_assert!(width == 0 || width.is_power_of_two()); + let width_var = z3::ast::BV::from_i64(self.z3, width as i64, 8); + let is_width = width_var._eq(&ty.width); + self.constraints.push(( + is_width, + span, + Some(format!("type error: expected bit width = {}", width).into()), + )); + } + + fn assert_bit_width_lt(&mut self, span: Span, a: &TypeVar<'a>, b: &TypeVar<'a>) { + self.constraints.push(( + a.width.bvult(&b.width), + span, + Some("type error: expected narrower bit width".into()), + )); + } + + fn assert_bit_width_gt(&mut self, span: Span, a: &TypeVar<'a>, b: &TypeVar<'a>) { + self.constraints.push(( + a.width.bvugt(&b.width), + span, + Some("type error: expected wider bit width".into()), + )); + } + + fn assert_type_eq( + &mut self, + span: Span, + lhs: &TypeVar<'a>, + rhs: &TypeVar<'a>, + msg: Option>, + ) { + self.constraints + .push((lhs.kind._eq(&rhs.kind), span, msg.clone())); + self.constraints + .push((lhs.width._eq(&rhs.width), span, msg)); + } + + fn type_check(&self, span: Span) -> VerifyResult<()> { + let trackers = iter::repeat_with(|| z3::ast::Bool::fresh_const(self.z3, "type-constraint")) + .take(self.constraints.len()) + .collect::>(); + + let mut tracker_to_diagnostics = HashMap::with_capacity(self.constraints.len()); + + for (constraint_data, tracker) in self.constraints.iter().zip(trackers) { + let (constraint, span, msg) = constraint_data; + self.solver.assert_and_track(constraint, &tracker); + tracker_to_diagnostics.insert(tracker, (*span, msg.clone())); + } + + match self.solver.check() { + z3::SatResult::Sat => Ok(()), + z3::SatResult::Unsat => { + let core = self.solver.get_unsat_core(); + if core.is_empty() { + return Err(WastError::new( + span, + "z3 determined the type constraints for this optimization are \ + unsatisfiable, meaning there is a type error, but z3 did not give us any \ + additional information" + .into(), + ) + .into()); + } + + let mut errors = core + .iter() + .map(|tracker| { + let (span, msg) = &tracker_to_diagnostics[tracker]; + ( + *span, + WastError::new( + *span, + msg.clone().unwrap_or("type error".into()).into(), + ) + .into(), + ) + }) + .collect::>(); + errors.sort_by_key(|(span, _)| *span); + let errors = errors.into_iter().map(|(_, e)| e).collect(); + + Err(VerifyError { errors }) + } + z3::SatResult::Unknown => Err(anyhow::anyhow!( + "z3 returned 'unknown' when evaluating type constraints: {}", + self.solver + .get_reason_unknown() + .unwrap_or_else(|| "".into()) + ) + .into()), + } + } + + fn assign_types(&mut self) -> VerifyResult<()> { + for (int, ty) in mem::replace(&mut self.integer_literals, vec![]) { + let width = self.ty_var_to_width(&ty)?; + int.bit_width.set(Some(width)); + } + + for (b, ty) in mem::replace(&mut self.boolean_literals, vec![]) { + let width = self.ty_var_to_width(&ty)?; + b.bit_width.set(Some(width)); + } + + for (op, ty) in mem::replace(&mut self.rhs_operations, vec![]) { + let kind = self.op_ty_var_to_kind(&ty); + let bit_width = match kind { + Kind::CpuFlags | Kind::Void => BitWidth::One, + Kind::Int | Kind::Bool => self.ty_var_to_width(&ty)?, + }; + debug_assert!(op.r#type.get().is_none()); + op.r#type.set(Some(Type { kind, bit_width })); + } + + Ok(()) + } + + fn ty_var_to_width(&self, ty_var: &TypeVar<'a>) -> VerifyResult { + // Doing solver push/pops apparently clears out the model, so we have to + // re-check each time to ensure that it exists, and Z3 doesn't helpfully + // abort the process for us. This should be fast, since the solver + // remembers inferences from earlier checks. + assert_eq!(self.solver.check(), z3::SatResult::Sat); + + // Check if there is more than one satisfying assignment to + // `ty_var`'s width variable. If so, then it must be polymorphic. If + // not, then it must have a fixed value. + let model = self.solver.get_model(); + let width_var = model.eval(&ty_var.width).unwrap(); + let bit_width: u8 = width_var.as_u64().unwrap().try_into().unwrap(); + + self.solver.push(); + self.solver.assert(&ty_var.width._eq(&width_var).not()); + let is_polymorphic = match self.solver.check() { + z3::SatResult::Sat => true, + z3::SatResult::Unsat => false, + z3::SatResult::Unknown => panic!("Z3 cannot determine bit width of type"), + }; + self.solver.pop(1); + + if is_polymorphic { + // If something is polymorphic over bit widths, it must be + // polymorphic over the same bit width as the whole + // optimization. + // + // TODO: We should have a better model for bit-width + // polymorphism. The current setup works for all the use cases we + // currently care about, and is relatively easy to implement when + // matching and constructing the RHS, but is a bit ad-hoc. Maybe + // allow each LHS variable a polymorphic bit width, augment the AST + // with that info, and later emit match ops as necessary to express + // their relative constraints? *hand waves* + self.solver.push(); + self.solver + .assert(&ty_var.width._eq(&self.root_ty.as_ref().unwrap().width)); + match self.solver.check() { + z3::SatResult::Sat => {} + z3::SatResult::Unsat => { + return Err(anyhow::anyhow!( + "AST node is bit width polymorphic, but not over the optimization's root \ + width" + ) + .into()) + } + z3::SatResult::Unknown => panic!("Z3 cannot determine bit width of type"), + }; + self.solver.pop(1); + + Ok(BitWidth::Polymorphic) + } else { + Ok(BitWidth::try_from(bit_width).unwrap()) + } + } + + fn op_ty_var_to_kind(&self, ty_var: &TypeVar<'a>) -> Kind { + for (predicate, kind) in [ + (Self::is_int as fn(_, _) -> _, Kind::Int), + (Self::is_bool, Kind::Bool), + (Self::is_cpu_flags, Kind::CpuFlags), + (Self::is_void, Kind::Void), + ] + .iter() + { + self.solver.push(); + self.solver.assert(&predicate(self, ty_var)); + match self.solver.check() { + z3::SatResult::Sat => { + self.solver.pop(1); + return *kind; + } + z3::SatResult::Unsat => { + self.solver.pop(1); + continue; + } + z3::SatResult::Unknown => panic!("Z3 cannot determine the type's kind"), + } + } + + // This would only happen if given a `TypeVar` whose kind was a + // condition code, but we only use this function for RHS operations, + // which cannot be condition codes. + panic!("cannot convert type variable's kind to `peepmatic_runtime::type::Kind`") + } +} + +impl<'a> TypingContextTrait<'a> for TypingContext<'a> { + type TypeVariable = TypeVar<'a>; + + fn cc(&mut self, span: Span) -> TypeVar<'a> { + let ty = self.new_type_var(); + self.assert_is_cc(span, &ty); + ty + } + + fn bNN(&mut self, span: Span) -> TypeVar<'a> { + if let Some(ty) = self.operation_scope.get("bNN") { + return ty.clone(); + } + + let ty = self.new_type_var(); + self.assert_is_bool(span, &ty); + self.operation_scope.insert("bNN", ty.clone()); + ty + } + + fn iNN(&mut self, span: Span) -> TypeVar<'a> { + if let Some(ty) = self.operation_scope.get("iNN") { + return ty.clone(); + } + + let ty = self.new_type_var(); + self.assert_is_integer(span, &ty); + self.operation_scope.insert("iNN", ty.clone()); + ty + } + + fn iMM(&mut self, span: Span) -> TypeVar<'a> { + if let Some(ty) = self.operation_scope.get("iMM") { + return ty.clone(); + } + + let ty = self.new_type_var(); + self.assert_is_integer(span, &ty); + self.operation_scope.insert("iMM", ty.clone()); + ty + } + + fn cpu_flags(&mut self, span: Span) -> TypeVar<'a> { + if let Some(ty) = self.operation_scope.get("cpu_flags") { + return ty.clone(); + } + + let ty = self.new_type_var(); + self.assert_is_cpu_flags(span, &ty); + self.assert_bit_width(span, &ty, 1); + self.operation_scope.insert("cpu_flags", ty.clone()); + ty + } + + fn b1(&mut self, span: Span) -> TypeVar<'a> { + let b1 = self.new_type_var(); + self.assert_is_bool(span, &b1); + self.assert_bit_width(span, &b1, 1); + b1 + } + + fn void(&mut self, span: Span) -> TypeVar<'a> { + let void = self.new_type_var(); + self.assert_is_void(span, &void); + self.assert_bit_width(span, &void, 0); + void + } + + fn bool_or_int(&mut self, span: Span) -> TypeVar<'a> { + let ty = self.new_type_var(); + let is_int = self.type_kind_sort.variants[0] + .tester + .apply(&[&ty.kind.clone().into()]) + .as_bool() + .unwrap(); + let is_bool = self.type_kind_sort.variants[1] + .tester + .apply(&[&ty.kind.clone().into()]) + .as_bool() + .unwrap(); + self.constraints.push(( + is_int.or(&[&is_bool]), + span, + Some("type error: must be either an int or a bool type".into()), + )); + ty + } + + fn any_t(&mut self, _span: Span) -> TypeVar<'a> { + if let Some(ty) = self.operation_scope.get("any_t") { + return ty.clone(); + } + + let ty = self.new_type_var(); + self.operation_scope.insert("any_t", ty.clone()); + ty + } +} + +#[derive(Clone)] +pub(crate) struct TypeVar<'a> { + kind: z3::ast::Datatype<'a>, + width: z3::ast::BV<'a>, +} + +fn verify_optimization(z3: &z3::Context, opt: &Optimization) -> VerifyResult<()> { + let mut context = TypingContext::new(z3); + collect_type_constraints(&mut context, opt)?; + context.type_check(opt.span)?; + context.assign_types()?; + + // TODO: add another pass here to check for counter-examples to this + // optimization, i.e. inputs where the LHS and RHS are not equivalent. + + Ok(()) +} + +fn collect_type_constraints<'a>( + context: &mut TypingContext<'a>, + opt: &'a Optimization<'a>, +) -> VerifyResult<()> { + use crate::traversals::TraversalEvent as TE; + + let lhs_ty = context.new_type_var(); + context.init_root_type(opt.lhs.span, lhs_ty.clone()); + + let rhs_ty = context.new_type_var(); + context.assert_type_eq( + opt.span, + &lhs_ty, + &rhs_ty, + Some("type error: the left-hand side and right-hand side must have the same type".into()), + ); + + // A stack of type variables that we are constraining as we traverse the + // AST. Operations push new type variables for their operands' expected + // types, and exiting a `Pattern` in the traversal pops them off. + let mut expected_types = vec![lhs_ty]; + + // Build up the type constraints for the left-hand side. + for (event, node) in Dfs::new(&opt.lhs) { + match (event, node) { + (TE::Enter, DynAstRef::Pattern(Pattern::Constant(Constant { id, span }))) + | (TE::Enter, DynAstRef::Pattern(Pattern::Variable(Variable { id, span }))) => { + let id = context.get_or_create_type_var_for_id(*id); + context.assert_type_eq(*span, expected_types.last().unwrap(), &id, None); + } + (TE::Enter, DynAstRef::Pattern(Pattern::ValueLiteral(ValueLiteral::Integer(i)))) => { + let ty = expected_types.last().unwrap(); + context.remember_integer_literal(i, ty.clone()); + } + (TE::Enter, DynAstRef::Pattern(Pattern::ValueLiteral(ValueLiteral::Boolean(b)))) => { + let ty = expected_types.last().unwrap(); + context.remember_boolean_literal(b, ty.clone()); + } + ( + TE::Enter, + DynAstRef::Pattern(Pattern::ValueLiteral(ValueLiteral::ConditionCode(cc))), + ) => { + let ty = expected_types.last().unwrap(); + context.assert_is_cc(cc.span, ty); + } + (TE::Enter, DynAstRef::PatternOperation(op)) => { + let result_ty; + let mut operand_types = vec![]; + { + let mut scope = context.enter_operation_scope(); + result_ty = op.operator.result_type(&mut *scope, op.span); + op.operator + .immediate_types(&mut *scope, op.span, &mut operand_types); + op.operator + .param_types(&mut *scope, op.span, &mut operand_types); + } + + if op.operands.len() != operand_types.len() { + return Err(WastError::new( + op.span, + format!( + "Expected {} operands but found {}", + operand_types.len(), + op.operands.len() + ), + ) + .into()); + } + + for imm in op + .operands + .iter() + .take(op.operator.immediates_arity() as usize) + { + match imm { + Pattern::ValueLiteral(_) | + Pattern::Constant(_) | + Pattern::Variable(_) => continue, + Pattern::Operation(op) => return Err(WastError::new( + op.span, + "operations are invalid immediates; must be a value literal, constant, \ + or variable".into() + ).into()), + } + } + + match op.operator { + Operator::Ireduce | Operator::Uextend | Operator::Sextend => { + if op.r#type.get().is_none() { + return Err(WastError::new( + op.span, + "`ireduce`, `sextend`, and `uextend` require an ascribed type, \ + like `(sextend{i64} ...)`" + .into(), + ) + .into()); + } + } + _ => {} + } + + match op.operator { + Operator::Uextend | Operator::Sextend => { + context.assert_bit_width_gt(op.span, &result_ty, &operand_types[0]); + } + Operator::Ireduce => { + context.assert_bit_width_lt(op.span, &result_ty, &operand_types[0]); + } + _ => {} + } + + if let Some(ty) = op.r#type.get() { + match ty.kind { + Kind::Bool => context.assert_is_bool(op.span, &result_ty), + Kind::Int => context.assert_is_integer(op.span, &result_ty), + Kind::Void => context.assert_is_void(op.span, &result_ty), + Kind::CpuFlags => { + unreachable!("no syntax for ascribing CPU flags types right now") + } + } + if let Some(w) = ty.bit_width.fixed_width() { + context.assert_bit_width(op.span, &result_ty, w); + } + } + + context.assert_type_eq(op.span, expected_types.last().unwrap(), &result_ty, None); + + operand_types.reverse(); + expected_types.extend(operand_types); + } + (TE::Exit, DynAstRef::Pattern(..)) => { + expected_types.pop().unwrap(); + } + (TE::Enter, DynAstRef::Precondition(pre)) => { + type_constrain_precondition(context, pre)?; + } + _ => continue, + } + } + + // We should have exited exactly as many patterns as we entered: one for the + // root pattern and the initial `lhs_ty`, and then the rest for the operands + // of pattern operations. + assert!(expected_types.is_empty()); + + // Collect the type constraints for the right-hand side. + expected_types.push(rhs_ty); + for (event, node) in Dfs::new(&opt.rhs) { + match (event, node) { + (TE::Enter, DynAstRef::Rhs(Rhs::ValueLiteral(ValueLiteral::Integer(i)))) => { + let ty = expected_types.last().unwrap(); + context.remember_integer_literal(i, ty.clone()); + } + (TE::Enter, DynAstRef::Rhs(Rhs::ValueLiteral(ValueLiteral::Boolean(b)))) => { + let ty = expected_types.last().unwrap(); + context.remember_boolean_literal(b, ty.clone()); + } + (TE::Enter, DynAstRef::Rhs(Rhs::ValueLiteral(ValueLiteral::ConditionCode(cc)))) => { + let ty = expected_types.last().unwrap(); + context.assert_is_cc(cc.span, ty); + } + (TE::Enter, DynAstRef::Rhs(Rhs::Constant(Constant { span, id }))) + | (TE::Enter, DynAstRef::Rhs(Rhs::Variable(Variable { span, id }))) => { + let id_ty = context.get_type_var_for_id(*id)?; + context.assert_type_eq(*span, expected_types.last().unwrap(), &id_ty, None); + } + (TE::Enter, DynAstRef::RhsOperation(op)) => { + let result_ty; + let mut operand_types = vec![]; + { + let mut scope = context.enter_operation_scope(); + result_ty = op.operator.result_type(&mut *scope, op.span); + op.operator + .immediate_types(&mut *scope, op.span, &mut operand_types); + op.operator + .param_types(&mut *scope, op.span, &mut operand_types); + } + + if op.operands.len() != operand_types.len() { + return Err(WastError::new( + op.span, + format!( + "Expected {} operands but found {}", + operand_types.len(), + op.operands.len() + ), + ) + .into()); + } + + for imm in op + .operands + .iter() + .take(op.operator.immediates_arity() as usize) + { + match imm { + Rhs::ValueLiteral(_) + | Rhs::Constant(_) + | Rhs::Variable(_) + | Rhs::Unquote(_) => continue, + Rhs::Operation(op) => return Err(WastError::new( + op.span, + "operations are invalid immediates; must be a value literal, unquote, \ + constant, or variable" + .into(), + ) + .into()), + } + } + + match op.operator { + Operator::Ireduce | Operator::Uextend | Operator::Sextend => { + if op.r#type.get().is_none() { + return Err(WastError::new( + op.span, + "`ireduce`, `sextend`, and `uextend` require an ascribed type, \ + like `(sextend{i64} ...)`" + .into(), + ) + .into()); + } + } + _ => {} + } + + match op.operator { + Operator::Uextend | Operator::Sextend => { + context.assert_bit_width_gt(op.span, &result_ty, &operand_types[0]); + } + Operator::Ireduce => { + context.assert_bit_width_lt(op.span, &result_ty, &operand_types[0]); + } + _ => {} + } + + if let Some(ty) = op.r#type.get() { + match ty.kind { + Kind::Bool => context.assert_is_bool(op.span, &result_ty), + Kind::Int => context.assert_is_integer(op.span, &result_ty), + Kind::Void => context.assert_is_void(op.span, &result_ty), + Kind::CpuFlags => { + unreachable!("no syntax for ascribing CPU flags types right now") + } + } + if let Some(w) = ty.bit_width.fixed_width() { + context.assert_bit_width(op.span, &result_ty, w); + } + } + + context.assert_type_eq(op.span, expected_types.last().unwrap(), &result_ty, None); + if op.r#type.get().is_none() { + context.remember_rhs_operation(op, result_ty); + } + + operand_types.reverse(); + expected_types.extend(operand_types); + } + (TE::Enter, DynAstRef::Unquote(unq)) => { + let result_ty; + let mut operand_types = vec![]; + { + let mut scope = context.enter_operation_scope(); + result_ty = unq.operator.result_type(&mut *scope, unq.span); + unq.operator + .immediate_types(&mut *scope, unq.span, &mut operand_types); + unq.operator + .param_types(&mut *scope, unq.span, &mut operand_types); + } + + if unq.operands.len() != operand_types.len() { + return Err(WastError::new( + unq.span, + format!( + "Expected {} unquote operands but found {}", + operand_types.len(), + unq.operands.len() + ), + ) + .into()); + } + + for operand in &unq.operands { + match operand { + Rhs::ValueLiteral(_) | Rhs::Constant(_) => continue, + Rhs::Variable(_) | Rhs::Unquote(_) | Rhs::Operation(_) => { + return Err(WastError::new( + operand.span(), + "unquote operands must be value literals or constants".into(), + ) + .into()); + } + } + } + + context.assert_type_eq(unq.span, expected_types.last().unwrap(), &result_ty, None); + + operand_types.reverse(); + expected_types.extend(operand_types); + } + (TE::Exit, DynAstRef::Rhs(..)) => { + expected_types.pop().unwrap(); + } + _ => continue, + } + } + + // Again, we should have popped off all the expected types when exiting + // `Rhs` nodes in the traversal. + assert!(expected_types.is_empty()); + + Ok(()) +} + +fn type_constrain_precondition<'a>( + context: &mut TypingContext<'a>, + pre: &Precondition<'a>, +) -> VerifyResult<()> { + match pre.constraint { + Constraint::BitWidth => { + if pre.operands.len() != 2 { + return Err(WastError::new( + pre.span, + format!( + "the `bit-width` precondition requires exactly 2 operands, found \ + {} operands", + pre.operands.len(), + ), + ) + .into()); + } + + let id = match pre.operands[0] { + ConstraintOperand::ValueLiteral(_) => { + return Err(anyhow::anyhow!( + "the `bit-width` precondition requires a constant or variable as \ + its first operand" + ) + .into()) + } + ConstraintOperand::Constant(Constant { id, .. }) + | ConstraintOperand::Variable(Variable { id, .. }) => id, + }; + + let width = match pre.operands[1] { + ConstraintOperand::ValueLiteral(ValueLiteral::Integer(Integer { + value, .. + })) if value == 1 + || value == 8 + || value == 16 + || value == 32 + || value == 64 + || value == 128 => + { + value as u8 + } + ref op => return Err(WastError::new( + op.span(), + "the `bit-width` precondition requires a bit width of 1, 8, 16, 32, 64, or \ + 128" + .into(), + ) + .into()), + }; + + let ty = context.get_type_var_for_id(id)?; + context.assert_bit_width(pre.span, &ty, width); + Ok(()) + } + Constraint::IsPowerOfTwo => { + if pre.operands.len() != 1 { + return Err(WastError::new( + pre.span, + format!( + "the `is-power-of-two` precondition requires exactly 1 operand, found \ + {} operands", + pre.operands.len(), + ), + ) + .into()); + } + match &pre.operands[0] { + ConstraintOperand::Constant(Constant { id, .. }) => { + let ty = context.get_type_var_for_id(*id)?; + context.assert_is_integer(pre.span(), &ty); + Ok(()) + } + op => Err(WastError::new( + op.span(), + "`is-power-of-two` operands must be constant bindings".into(), + ) + .into()), + } + } + Constraint::FitsInNativeWord => { + if pre.operands.len() != 1 { + return Err(WastError::new( + pre.span, + format!( + "the `fits-in-native-word` precondition requires exactly 1 operand, found \ + {} operands", + pre.operands.len(), + ), + ) + .into()); + } + + match pre.operands[0] { + ConstraintOperand::ValueLiteral(_) => { + return Err(anyhow::anyhow!( + "the `fits-in-native-word` precondition requires a constant or variable as \ + its first operand" + ) + .into()) + } + ConstraintOperand::Constant(Constant { id, .. }) + | ConstraintOperand::Variable(Variable { id, .. }) => { + context.get_type_var_for_id(id)?; + Ok(()) + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! verify_ok { + ($name:ident, $src:expr) => { + #[test] + fn $name() { + let buf = wast::parser::ParseBuffer::new($src).expect("should lex OK"); + let opts = match wast::parser::parse::(&buf) { + Ok(opts) => opts, + Err(mut e) => { + e.set_path(Path::new(stringify!($name))); + e.set_text($src); + eprintln!("{}", e); + panic!("should parse OK") + } + }; + match verify(&opts) { + Ok(_) => return, + Err(mut e) => { + e.set_path(Path::new(stringify!($name))); + e.set_text($src); + eprintln!("{}", e); + panic!("should verify OK") + } + } + } + }; + } + + macro_rules! verify_err { + ($name:ident, $src:expr) => { + #[test] + fn $name() { + let buf = wast::parser::ParseBuffer::new($src).expect("should lex OK"); + let opts = match wast::parser::parse::(&buf) { + Ok(opts) => opts, + Err(mut e) => { + e.set_path(Path::new(stringify!($name))); + e.set_text($src); + eprintln!("{}", e); + panic!("should parse OK") + } + }; + match verify(&opts) { + Ok(_) => panic!("expected a verification error, but it verified OK"), + Err(mut e) => { + e.set_path(Path::new(stringify!($name))); + e.set_text($src); + eprintln!("{}", e); + return; + } + } + } + }; + } + + verify_ok!(bool_0, "(=> true true)"); + verify_ok!(bool_1, "(=> false false)"); + verify_ok!(bool_2, "(=> true false)"); + verify_ok!(bool_3, "(=> false true)"); + + verify_err!(bool_is_not_int_0, "(=> true 42)"); + verify_err!(bool_is_not_int_1, "(=> 42 true)"); + + verify_ok!( + bit_width_0, + " +(=> (when (iadd $x $y) + (bit-width $x 32) + (bit-width $y 32)) + (iadd $x $y)) +" + ); + verify_err!( + bit_width_1, + " +(=> (when (iadd $x $y) + (bit-width $x 32) + (bit-width $y 64)) + (iadd $x $y)) +" + ); + verify_err!( + bit_width_2, + " +(=> (when (iconst $C) + (bit-width $C)) + 5) +" + ); + verify_err!( + bit_width_3, + " +(=> (when (iconst $C) + (bit-width 32 32)) + 5) +" + ); + verify_err!( + bit_width_4, + " +(=> (when (iconst $C) + (bit-width $C $C)) + 5) +" + ); + verify_err!( + bit_width_5, + " +(=> (when (iconst $C) + (bit-width $C2 32)) + 5) +" + ); + verify_err!( + bit_width_6, + " +(=> (when (iconst $C) + (bit-width $C2 33)) + 5) +" + ); + + verify_ok!( + is_power_of_two_0, + " +(=> (when (imul $x $C) + (is-power-of-two $C)) + (ishl $x $(log2 $C))) +" + ); + verify_err!( + is_power_of_two_1, + " +(=> (when (imul $x $C) + (is-power-of-two)) + 5) +" + ); + verify_err!( + is_power_of_two_2, + " +(=> (when (imul $x $C) + (is-power-of-two $C $C)) + 5) +" + ); + + verify_ok!(pattern_ops_0, "(=> (iadd $x $C) 5)"); + verify_err!(pattern_ops_1, "(=> (iadd $x) 5)"); + verify_err!(pattern_ops_2, "(=> (iadd $x $y $z) 5)"); + + verify_ok!(unquote_0, "(=> $C $(log2 $C))"); + verify_err!(unquote_1, "(=> (iadd $C $D) $(log2 $C $D))"); + verify_err!(unquote_2, "(=> $x $(log2))"); + verify_ok!(unquote_3, "(=> $C $(neg $C))"); + verify_err!(unquote_4, "(=> $x $(neg))"); + verify_err!(unquote_5, "(=> (iadd $x $y) $(neg $x $y))"); + verify_err!(unquote_6, "(=> $x $(neg $x))"); + + verify_ok!(rhs_0, "(=> $x (iadd $x (iconst 0)))"); + verify_err!(rhs_1, "(=> $x (iadd $x))"); + verify_err!(rhs_2, "(=> $x (iadd $x 0 0))"); + + verify_err!(no_optimizations, ""); + + verify_err!( + duplicate_left_hand_sides, + " +(=> (iadd $x $y) 0) +(=> (iadd $x $y) 1) +" + ); + verify_err!( + canonically_duplicate_left_hand_sides_0, + " +(=> (iadd $x $y) 0) +(=> (iadd $y $x) 1) +" + ); + verify_err!( + canonically_duplicate_left_hand_sides_1, + " +(=> (iadd $X $Y) 0) +(=> (iadd $Y $X) 1) +" + ); + verify_err!( + canonically_duplicate_left_hand_sides_2, + " +(=> (iadd $x $x) 0) +(=> (iadd $y $y) 1) +" + ); + + verify_ok!( + canonically_different_left_hand_sides_0, + " +(=> (iadd $x $C) 0) +(=> (iadd $C $x) 1) +" + ); + verify_ok!( + canonically_different_left_hand_sides_1, + " +(=> (iadd $x $x) 0) +(=> (iadd $x $y) 1) +" + ); + + verify_ok!( + fits_in_native_word_0, + "(=> (when (iadd $x $y) (fits-in-native-word $x)) 0)" + ); + verify_err!( + fits_in_native_word_1, + "(=> (when (iadd $x $y) (fits-in-native-word)) 0)" + ); + verify_err!( + fits_in_native_word_2, + "(=> (when (iadd $x $y) (fits-in-native-word $x $y)) 0)" + ); + verify_err!( + fits_in_native_word_3, + "(=> (when (iadd $x $y) (fits-in-native-word true)) 0)" + ); + + verify_err!(reduce_extend_0, "(=> (sextend (ireduce -1)) 0)"); + verify_err!(reduce_extend_1, "(=> (uextend (ireduce -1)) 0)"); + verify_ok!(reduce_extend_2, "(=> (sextend{i64} (ireduce{i32} -1)) 0)"); + verify_ok!(reduce_extend_3, "(=> (uextend{i64} (ireduce{i32} -1)) 0)"); + verify_err!(reduce_extend_4, "(=> (sextend{i64} (ireduce{i64} -1)) 0)"); + verify_err!(reduce_extend_5, "(=> (uextend{i64} (ireduce{i64} -1)) 0)"); + verify_err!(reduce_extend_6, "(=> (sextend{i32} (ireduce{i64} -1)) 0)"); + verify_err!(reduce_extend_7, "(=> (uextend{i32} (ireduce{i64} -1)) 0)"); + + verify_err!( + using_an_operation_as_an_immediate_in_lhs, + "(=> (iadd_imm (imul $x $y) $z) 0)" + ); + verify_err!( + using_an_operation_as_an_immediate_in_rhs, + "(=> (iadd (imul $x $y) $z) (iadd_imm (imul $x $y) $z))" + ); + + verify_err!( + using_a_condition_code_as_the_root_of_an_optimization, + "(=> eq eq)" + ); +}