From 3a3a284f6fd835339a1b51d9805d2fd2bbf195d4 Mon Sep 17 00:00:00 2001 From: Alexis Engelke Date: Sun, 3 Jan 2021 11:32:28 +0100 Subject: [PATCH] parseinstrs: Improve performance --- parseinstrs.py | 197 ++++++++++++++++++++++++------------------------- 1 file changed, 96 insertions(+), 101 deletions(-) diff --git a/parseinstrs.py b/parseinstrs.py index 299fa52..2b72dfe 100644 --- a/parseinstrs.py +++ b/parseinstrs.py @@ -1,54 +1,45 @@ #!/usr/bin/python3 import argparse -from binascii import unhexlify from collections import OrderedDict, defaultdict, namedtuple, Counter -from copy import copy -from enum import Enum, IntEnum -from itertools import accumulate, product +from enum import Enum +from itertools import product import struct from typing import NamedTuple, FrozenSet, List, Tuple, Union, Optional, ByteString -def bitstruct(name, fields): - names, sizes = zip(*(field.split(":") for field in fields)) - sizes = tuple(map(int, sizes)) - offsets = (0,) + tuple(accumulate(sizes)) - class __class: - def __init__(self, **kwargs): - for name in names: - setattr(self, name, kwargs.get(name, 0)) - def _encode(self, length): - enc = 0 - for name, size, offset in zip(names, sizes, offsets): - enc += (getattr(self, name) & ((1 << size) - 1)) << offset - return enc.to_bytes(length, "little") - __class.__name__ = name - return __class - -InstrFlags = bitstruct("InstrFlags", [ - "modrm_idx:2", - "modreg_idx:2", - "vexreg_idx:2", - "zeroreg_idx:2", - "imm_idx:2", - "zeroreg_val:1", - "lock:1", - "imm_control:3", - "vsib:1", - "op0_size:2", - "op1_size:2", - "op2_size:2", - "op3_size:2", - "opsize:2", - "size_fix1:3", - "size_fix2:2", - "instr_width:1", - "op0_regty:3", - "op1_regty:3", - "op2_regty:3", - "_unused:6", - "ign66:1", -]) +INSTR_FLAGS_FIELDS, INSTR_FLAGS_SIZES = zip(*[ + ("modrm_idx", 2), + ("modreg_idx", 2), + ("vexreg_idx", 2), + ("zeroreg_idx", 2), + ("imm_idx", 2), + ("zeroreg_val", 1), + ("lock", 1), + ("imm_control", 3), + ("vsib", 1), + ("op0_size", 2), + ("op1_size", 2), + ("op2_size", 2), + ("op3_size", 2), + ("opsize", 2), + ("size_fix1", 3), + ("size_fix2", 2), + ("instr_width", 1), + ("op0_regty", 3), + ("op1_regty", 3), + ("op2_regty", 3), + ("unused", 6), + ("ign66", 1), +][::-1]) +class InstrFlags(namedtuple("InstrFlags", INSTR_FLAGS_FIELDS)): + def __new__(cls, **kwargs): + init = {**{f: 0 for f in cls._fields}, **kwargs} + return super(InstrFlags, cls).__new__(cls, **init) + def _encode(self): + enc = 0 + for value, size in zip(self, INSTR_FLAGS_SIZES): + enc = enc << size | (value & ((1 << size) - 1)) + return enc ENCODINGS = { "NP": InstrFlags(), @@ -75,11 +66,11 @@ ENCODINGS = { "TD": InstrFlags(zeroreg_idx=1^3, imm_idx=0^3, imm_control=2), "RVM": InstrFlags(modrm_idx=2^3, modreg_idx=0^3, vexreg_idx=1^3), - "RVMI": InstrFlags(modrm_idx=2^3, modreg_idx=0^3, vexreg_idx=1^3, imm_idx=3^3, imm_control=4, imm_byte=1), - "RVMR": InstrFlags(modrm_idx=2^3, modreg_idx=0^3, vexreg_idx=1^3, imm_idx=3^3, imm_control=3, imm_byte=1), + "RVMI": InstrFlags(modrm_idx=2^3, modreg_idx=0^3, vexreg_idx=1^3, imm_idx=3^3, imm_control=4), + "RVMR": InstrFlags(modrm_idx=2^3, modreg_idx=0^3, vexreg_idx=1^3, imm_idx=3^3, imm_control=3), "RMV": InstrFlags(modrm_idx=1^3, modreg_idx=0^3, vexreg_idx=2^3), "VM": InstrFlags(modrm_idx=1^3, vexreg_idx=0^3), - "VMI": InstrFlags(modrm_idx=1^3, vexreg_idx=0^3, imm_idx=2^3, imm_control=4, imm_byte=1), + "VMI": InstrFlags(modrm_idx=1^3, vexreg_idx=0^3, imm_idx=2^3, imm_control=4), "MVR": InstrFlags(modrm_idx=0^3, modreg_idx=2^3, vexreg_idx=1^3), } @@ -157,7 +148,8 @@ class InstrDesc(NamedTuple): return cls(desc[5], desc[0], operands, frozenset(desc[6:])) def encode(self, ign66): - flags = copy(ENCODINGS[self.encoding]) + flags = ENCODINGS[self.encoding] + extraflags = {} opsz = set(self.OPKIND_SIZES[opkind.size] for opkind in self.operands) @@ -166,37 +158,37 @@ class InstrDesc(NamedTuple): if len(fixed) > 2 or (len(fixed) == 2 and not (1 <= fixed[1] <= 4)): raise Exception("invalid fixed operand sizes: %r"%fixed) sizes = (fixed + [1, 1])[:2] + [-2, -3] # See operand_sizes in decode.c. - flags.size_fix1 = sizes[0] - flags.size_fix2 = sizes[1] - 1 + extraflags["size_fix1"] = sizes[0] + extraflags["size_fix2"] = sizes[1] - 1 for i, opkind in enumerate(self.operands): sz = self.OPKIND_SIZES[opkind.size] reg_type = self.OPKIND_REGTYS.get(opkind.kind, 7) - setattr(flags, "op%d_size"%i, sizes.index(sz)) + extraflags["op%d_size"%i] = sizes.index(sz) if i < 3: - setattr(flags, "op%d_regty"%i, reg_type) + extraflags["op%d_regty"%i] = reg_type elif reg_type not in (7, 2): raise Exception("invalid regty for op 3, must be VEC") # Miscellaneous Flags - if "SIZE_8" in self.flags: flags.opsize = 1 - if "DEF64" in self.flags: flags.opsize = 2 - if "FORCE64" in self.flags: flags.opsize = 3 - if "INSTR_WIDTH" in self.flags: flags.instr_width = 1 - if "LOCK" in self.flags: flags.lock = 1 - if "VSIB" in self.flags: flags.vsib = 1 + if "SIZE_8" in self.flags: extraflags["opsize"] = 1 + if "DEF64" in self.flags: extraflags["opsize"] = 2 + if "FORCE64" in self.flags: extraflags["opsize"] = 3 + if "INSTR_WIDTH" in self.flags: extraflags["instr_width"] = 1 + if "LOCK" in self.flags: extraflags["lock"] = 1 + if "VSIB" in self.flags: extraflags["vsib"] = 1 if "USE66" not in self.flags and (ign66 or "IGN66" in self.flags): - flags.ign66 = 1 + extraflags["ign66"] = 1 if flags.imm_control >= 4: imm_op = next(op for op in self.operands if op.kind == OpKind.K_IMM) if ("IMM_8" in self.flags or imm_op.size == 1 or (imm_op.size == OpKind.SZ_OP and "SIZE_8" in self.flags)): - flags.imm_control |= 1 + extraflags["imm_control"] = flags.imm_control | 1 - enc = flags._encode(6) - enc = tuple(int.from_bytes(enc[i:i+2], "little") for i in range(0, 6, 2)) + enc = flags._replace(**extraflags)._encode() + enc = tuple((enc >> i) & 0xffff for i in range(0, 48, 16)) # First 2 bytes are the mnemonic, last 6 bytes are the encoding. return ("FDI_"+self.mnemonic,) + enc @@ -230,19 +222,6 @@ class TrieEntry(NamedTuple): def instr(cls, descidx): return cls(EntryKind.INSTR, (), descidx) - @property - def encode_length(self): - return len(self.items) - def encode(self, encode_item) -> Tuple[Union[int, str]]: - enc_items = (encode_item(item) if item else 0 for item in self.items) - return tuple(enc_items) - - def map(self, map_func): - mapped_items = (map_func(i, v) for i, v in enumerate(self.items)) - return TrieEntry(self.kind, tuple(mapped_items), self.descidx) - def update(self, idx, new_val): - return self.map(lambda i, v: new_val if i == idx else v) - import re opcode_regex = re.compile( r"^(?:(?P(?PVEX\.)?(?PNP|66|F2|F3|NFx)\." + @@ -348,16 +327,19 @@ class Table: for i in range(root_count): self.data["root%d"%i] = TrieEntry.table(EntryKind.TABLE_ROOT) self.descs = [] + self.descs_map = {} self.offsets = {} self.annotations = {} def _update_table(self, name, idx, entry_name, entry_val): + old = self.data[name] # Don't override existing entries. This only happens on invalid input, # e.g. when an opcode is specified twice. - if self.data[name].items[idx]: + if old.items[idx]: raise Exception("{}/{} set, not overriding to {}".format(name, idx, entry_name)) self.data[entry_name] = entry_val - self.data[name] = self.data[name].update(idx, entry_name) + new_items = old.items[:idx] + (entry_name,) + old.items[idx+1:] + self.data[name] = TrieEntry(old.kind, new_items, None) def add_opcode(self, opcode, instr_encoding, root_idx=0): name = "t{},{}".format(root_idx, format_opcode(opcode)) @@ -376,25 +358,38 @@ class Table: raise Exception("{}, have {}, want {}".format( name, self.data[tn].kind, kind)) - if instr_encoding not in self.descs: + desc_idx = self.descs_map.get(instr_encoding) + if desc_idx is None: + desc_idx = self.descs_map[instr_encoding] = len(self.descs) self.descs.append(instr_encoding) - desc_idx = self.descs.index(instr_encoding) self._update_table(tn, opcode[-1][1], name, TrieEntry.instr(desc_idx)) def deduplicate(self): - synonyms = True - while synonyms: - entries = {} # Mapping from entry to name + parents = defaultdict(set) + for name, entry in self.data.items(): + for child in entry.items: + parents[child].add(name) + + queue = list(self.data.keys()) + entries = {} # Mapping from entry to name + while queue: + # First find new synonyms synonyms = {} # Mapping from name to unique name - for name, entry in self.data.items(): - if entry in entries: - synonyms[name] = entries[entry] + for name in queue: + if self.data[name] in entries: + synonyms[name] = entries[self.data[name]] + del self.data[name] else: - entries[entry] = name - for name, entry in self.data.items(): - self.data[name] = entry.map(lambda _, v: synonyms.get(v, v)) - for key in synonyms: - del self.data[key] + entries[self.data[name]] = name + queue = set.union(set(), *(parents[n] for n in synonyms)) + # Update parents of found synonyms; parents will need to be checked + # again for synonyms in the next iteration. + for name in queue: + entry = self.data[name] + items = tuple(synonyms.get(v, v) for v in entry.items) + self.data[name] = entry._replace(items=items) + for child in items: + parents[child].add(name) def calc_offsets(self): current = 0 @@ -404,24 +399,26 @@ class Table: else: self.annotations[current] = "%s(%d)" % (name, entry.kind.value) self.offsets[name] = current - current += (entry.encode_length + 3) & ~3 + current += (len(entry.items) + 3) & ~3 if current >= 0x8000: raise Exception("maximum table size exceeded: {:x}".format(current)) - def encode_item(self, name): + def _encode_item(self, name): return (self.offsets[name] << 1) | self.data[name].kind.value def compile(self): self.calc_offsets() - ordered = sorted((off, self.data[k]) for k, off in self.offsets.items() if self.data[k].encode_length) + ordered = sorted((off, self.data[k]) for k, off in self.offsets.items() if self.data[k].items) - data = () + data = [0] * (ordered[-1][0] + len(ordered[-1][1].items)) for off, entry in ordered: - data += (0,) * (off - len(data)) + entry.encode(self.encode_item) + for i, item in enumerate(entry.items, start=off): + if item is not None: + data[i] = self._encode_item(item) stats = dict(Counter(entry.kind for entry in self.data.values())) print("%d bytes" % (2*len(data)), stats) - return data, self.annotations, [self.offsets[k] for k in self.roots], self.descs + return tuple(data), self.annotations, [self.offsets[k] for k in self.roots], self.descs def bytes_to_table(data, notes): strdata = tuple(d+"," if type(d) == str else "%#04x,"%d for d in data) @@ -586,7 +583,7 @@ def encode_table(entries): descs = "" alt_index = 0 - for mnem, variants in sorted(mnemonics.items()): + for mnem, variants in mnemonics.items(): dedup = [] for variant in variants: if not any(x[:3] == variant[:3] for x in dedup): @@ -601,8 +598,7 @@ def encode_table(entries): for idx, alt, (enc, immsz, tys_i, opc_s) in zip(indices, alt_list, dedup): descs += f"[{idx}] = {{ .enc = ENC_{enc}, .immsz = {immsz}, .tys = {tys_i:#x}, .opc = {opc_s}, .alt = {alt} }},\n" - mnemonics_list = sorted(mnemonics.keys()) - mnemonics_lut = {mnem: mnemonics_list.index(mnem) for mnem in mnemonics_list} + mnemonics_lut = {mnem: i for i, mnem in enumerate(sorted(mnemonics.keys()))} mnemonics_tab = "\n".join("FE_MNEMONIC(%s,%d)"%entry for entry in mnemonics_lut.items()) return mnemonics_tab, descs @@ -624,9 +620,8 @@ if __name__ == "__main__": entries.append((Opcode.parse(opcode_string), InstrDesc.parse(desc))) mnemonics = sorted({desc.mnemonic for _, desc in entries}) - mnemonics_lut = {name: mnemonics.index(name) for name in mnemonics} - decode_mnems_lines = ["FD_MNEMONIC(%s,%d)"%e for e in mnemonics_lut.items()] + decode_mnems_lines = ["FD_MNEMONIC(%s,%d)"%e[::-1] for e in enumerate(mnemonics)] args.decode_mnems.write("\n".join(decode_mnems_lines)) modes = [32, 64]