parseinstrs: Improve performance

This commit is contained in:
Alexis Engelke
2021-01-03 11:32:28 +01:00
parent 5a77c0e6eb
commit 3a3a284f6f

View File

@@ -1,54 +1,45 @@
#!/usr/bin/python3 #!/usr/bin/python3
import argparse import argparse
from binascii import unhexlify
from collections import OrderedDict, defaultdict, namedtuple, Counter from collections import OrderedDict, defaultdict, namedtuple, Counter
from copy import copy from enum import Enum
from enum import Enum, IntEnum from itertools import product
from itertools import accumulate, product
import struct import struct
from typing import NamedTuple, FrozenSet, List, Tuple, Union, Optional, ByteString from typing import NamedTuple, FrozenSet, List, Tuple, Union, Optional, ByteString
def bitstruct(name, fields): INSTR_FLAGS_FIELDS, INSTR_FLAGS_SIZES = zip(*[
names, sizes = zip(*(field.split(":") for field in fields)) ("modrm_idx", 2),
sizes = tuple(map(int, sizes)) ("modreg_idx", 2),
offsets = (0,) + tuple(accumulate(sizes)) ("vexreg_idx", 2),
class __class: ("zeroreg_idx", 2),
def __init__(self, **kwargs): ("imm_idx", 2),
for name in names: ("zeroreg_val", 1),
setattr(self, name, kwargs.get(name, 0)) ("lock", 1),
def _encode(self, length): ("imm_control", 3),
("vsib", 1),
("op0_size", 2),
("op1_size", 2),
("op2_size", 2),
("op3_size", 2),
("opsize", 2),
("size_fix1", 3),
("size_fix2", 2),
("instr_width", 1),
("op0_regty", 3),
("op1_regty", 3),
("op2_regty", 3),
("unused", 6),
("ign66", 1),
][::-1])
class InstrFlags(namedtuple("InstrFlags", INSTR_FLAGS_FIELDS)):
def __new__(cls, **kwargs):
init = {**{f: 0 for f in cls._fields}, **kwargs}
return super(InstrFlags, cls).__new__(cls, **init)
def _encode(self):
enc = 0 enc = 0
for name, size, offset in zip(names, sizes, offsets): for value, size in zip(self, INSTR_FLAGS_SIZES):
enc += (getattr(self, name) & ((1 << size) - 1)) << offset enc = enc << size | (value & ((1 << size) - 1))
return enc.to_bytes(length, "little") return enc
__class.__name__ = name
return __class
InstrFlags = bitstruct("InstrFlags", [
"modrm_idx:2",
"modreg_idx:2",
"vexreg_idx:2",
"zeroreg_idx:2",
"imm_idx:2",
"zeroreg_val:1",
"lock:1",
"imm_control:3",
"vsib:1",
"op0_size:2",
"op1_size:2",
"op2_size:2",
"op3_size:2",
"opsize:2",
"size_fix1:3",
"size_fix2:2",
"instr_width:1",
"op0_regty:3",
"op1_regty:3",
"op2_regty:3",
"_unused:6",
"ign66:1",
])
ENCODINGS = { ENCODINGS = {
"NP": InstrFlags(), "NP": InstrFlags(),
@@ -75,11 +66,11 @@ ENCODINGS = {
"TD": InstrFlags(zeroreg_idx=1^3, imm_idx=0^3, imm_control=2), "TD": InstrFlags(zeroreg_idx=1^3, imm_idx=0^3, imm_control=2),
"RVM": InstrFlags(modrm_idx=2^3, modreg_idx=0^3, vexreg_idx=1^3), "RVM": InstrFlags(modrm_idx=2^3, modreg_idx=0^3, vexreg_idx=1^3),
"RVMI": InstrFlags(modrm_idx=2^3, modreg_idx=0^3, vexreg_idx=1^3, imm_idx=3^3, imm_control=4, imm_byte=1), "RVMI": InstrFlags(modrm_idx=2^3, modreg_idx=0^3, vexreg_idx=1^3, imm_idx=3^3, imm_control=4),
"RVMR": InstrFlags(modrm_idx=2^3, modreg_idx=0^3, vexreg_idx=1^3, imm_idx=3^3, imm_control=3, imm_byte=1), "RVMR": InstrFlags(modrm_idx=2^3, modreg_idx=0^3, vexreg_idx=1^3, imm_idx=3^3, imm_control=3),
"RMV": InstrFlags(modrm_idx=1^3, modreg_idx=0^3, vexreg_idx=2^3), "RMV": InstrFlags(modrm_idx=1^3, modreg_idx=0^3, vexreg_idx=2^3),
"VM": InstrFlags(modrm_idx=1^3, vexreg_idx=0^3), "VM": InstrFlags(modrm_idx=1^3, vexreg_idx=0^3),
"VMI": InstrFlags(modrm_idx=1^3, vexreg_idx=0^3, imm_idx=2^3, imm_control=4, imm_byte=1), "VMI": InstrFlags(modrm_idx=1^3, vexreg_idx=0^3, imm_idx=2^3, imm_control=4),
"MVR": InstrFlags(modrm_idx=0^3, modreg_idx=2^3, vexreg_idx=1^3), "MVR": InstrFlags(modrm_idx=0^3, modreg_idx=2^3, vexreg_idx=1^3),
} }
@@ -157,7 +148,8 @@ class InstrDesc(NamedTuple):
return cls(desc[5], desc[0], operands, frozenset(desc[6:])) return cls(desc[5], desc[0], operands, frozenset(desc[6:]))
def encode(self, ign66): def encode(self, ign66):
flags = copy(ENCODINGS[self.encoding]) flags = ENCODINGS[self.encoding]
extraflags = {}
opsz = set(self.OPKIND_SIZES[opkind.size] for opkind in self.operands) opsz = set(self.OPKIND_SIZES[opkind.size] for opkind in self.operands)
@@ -166,37 +158,37 @@ class InstrDesc(NamedTuple):
if len(fixed) > 2 or (len(fixed) == 2 and not (1 <= fixed[1] <= 4)): if len(fixed) > 2 or (len(fixed) == 2 and not (1 <= fixed[1] <= 4)):
raise Exception("invalid fixed operand sizes: %r"%fixed) raise Exception("invalid fixed operand sizes: %r"%fixed)
sizes = (fixed + [1, 1])[:2] + [-2, -3] # See operand_sizes in decode.c. sizes = (fixed + [1, 1])[:2] + [-2, -3] # See operand_sizes in decode.c.
flags.size_fix1 = sizes[0] extraflags["size_fix1"] = sizes[0]
flags.size_fix2 = sizes[1] - 1 extraflags["size_fix2"] = sizes[1] - 1
for i, opkind in enumerate(self.operands): for i, opkind in enumerate(self.operands):
sz = self.OPKIND_SIZES[opkind.size] sz = self.OPKIND_SIZES[opkind.size]
reg_type = self.OPKIND_REGTYS.get(opkind.kind, 7) reg_type = self.OPKIND_REGTYS.get(opkind.kind, 7)
setattr(flags, "op%d_size"%i, sizes.index(sz)) extraflags["op%d_size"%i] = sizes.index(sz)
if i < 3: if i < 3:
setattr(flags, "op%d_regty"%i, reg_type) extraflags["op%d_regty"%i] = reg_type
elif reg_type not in (7, 2): elif reg_type not in (7, 2):
raise Exception("invalid regty for op 3, must be VEC") raise Exception("invalid regty for op 3, must be VEC")
# Miscellaneous Flags # Miscellaneous Flags
if "SIZE_8" in self.flags: flags.opsize = 1 if "SIZE_8" in self.flags: extraflags["opsize"] = 1
if "DEF64" in self.flags: flags.opsize = 2 if "DEF64" in self.flags: extraflags["opsize"] = 2
if "FORCE64" in self.flags: flags.opsize = 3 if "FORCE64" in self.flags: extraflags["opsize"] = 3
if "INSTR_WIDTH" in self.flags: flags.instr_width = 1 if "INSTR_WIDTH" in self.flags: extraflags["instr_width"] = 1
if "LOCK" in self.flags: flags.lock = 1 if "LOCK" in self.flags: extraflags["lock"] = 1
if "VSIB" in self.flags: flags.vsib = 1 if "VSIB" in self.flags: extraflags["vsib"] = 1
if "USE66" not in self.flags and (ign66 or "IGN66" in self.flags): if "USE66" not in self.flags and (ign66 or "IGN66" in self.flags):
flags.ign66 = 1 extraflags["ign66"] = 1
if flags.imm_control >= 4: if flags.imm_control >= 4:
imm_op = next(op for op in self.operands if op.kind == OpKind.K_IMM) imm_op = next(op for op in self.operands if op.kind == OpKind.K_IMM)
if ("IMM_8" in self.flags or imm_op.size == 1 or if ("IMM_8" in self.flags or imm_op.size == 1 or
(imm_op.size == OpKind.SZ_OP and "SIZE_8" in self.flags)): (imm_op.size == OpKind.SZ_OP and "SIZE_8" in self.flags)):
flags.imm_control |= 1 extraflags["imm_control"] = flags.imm_control | 1
enc = flags._encode(6) enc = flags._replace(**extraflags)._encode()
enc = tuple(int.from_bytes(enc[i:i+2], "little") for i in range(0, 6, 2)) enc = tuple((enc >> i) & 0xffff for i in range(0, 48, 16))
# First 2 bytes are the mnemonic, last 6 bytes are the encoding. # First 2 bytes are the mnemonic, last 6 bytes are the encoding.
return ("FDI_"+self.mnemonic,) + enc return ("FDI_"+self.mnemonic,) + enc
@@ -230,19 +222,6 @@ class TrieEntry(NamedTuple):
def instr(cls, descidx): def instr(cls, descidx):
return cls(EntryKind.INSTR, (), descidx) return cls(EntryKind.INSTR, (), descidx)
@property
def encode_length(self):
return len(self.items)
def encode(self, encode_item) -> Tuple[Union[int, str]]:
enc_items = (encode_item(item) if item else 0 for item in self.items)
return tuple(enc_items)
def map(self, map_func):
mapped_items = (map_func(i, v) for i, v in enumerate(self.items))
return TrieEntry(self.kind, tuple(mapped_items), self.descidx)
def update(self, idx, new_val):
return self.map(lambda i, v: new_val if i == idx else v)
import re import re
opcode_regex = re.compile( opcode_regex = re.compile(
r"^(?:(?P<prefixes>(?P<vex>VEX\.)?(?P<legacy>NP|66|F2|F3|NFx)\." + r"^(?:(?P<prefixes>(?P<vex>VEX\.)?(?P<legacy>NP|66|F2|F3|NFx)\." +
@@ -348,16 +327,19 @@ class Table:
for i in range(root_count): for i in range(root_count):
self.data["root%d"%i] = TrieEntry.table(EntryKind.TABLE_ROOT) self.data["root%d"%i] = TrieEntry.table(EntryKind.TABLE_ROOT)
self.descs = [] self.descs = []
self.descs_map = {}
self.offsets = {} self.offsets = {}
self.annotations = {} self.annotations = {}
def _update_table(self, name, idx, entry_name, entry_val): def _update_table(self, name, idx, entry_name, entry_val):
old = self.data[name]
# Don't override existing entries. This only happens on invalid input, # Don't override existing entries. This only happens on invalid input,
# e.g. when an opcode is specified twice. # e.g. when an opcode is specified twice.
if self.data[name].items[idx]: if old.items[idx]:
raise Exception("{}/{} set, not overriding to {}".format(name, idx, entry_name)) raise Exception("{}/{} set, not overriding to {}".format(name, idx, entry_name))
self.data[entry_name] = entry_val self.data[entry_name] = entry_val
self.data[name] = self.data[name].update(idx, entry_name) new_items = old.items[:idx] + (entry_name,) + old.items[idx+1:]
self.data[name] = TrieEntry(old.kind, new_items, None)
def add_opcode(self, opcode, instr_encoding, root_idx=0): def add_opcode(self, opcode, instr_encoding, root_idx=0):
name = "t{},{}".format(root_idx, format_opcode(opcode)) name = "t{},{}".format(root_idx, format_opcode(opcode))
@@ -376,25 +358,38 @@ class Table:
raise Exception("{}, have {}, want {}".format( raise Exception("{}, have {}, want {}".format(
name, self.data[tn].kind, kind)) name, self.data[tn].kind, kind))
if instr_encoding not in self.descs: desc_idx = self.descs_map.get(instr_encoding)
if desc_idx is None:
desc_idx = self.descs_map[instr_encoding] = len(self.descs)
self.descs.append(instr_encoding) self.descs.append(instr_encoding)
desc_idx = self.descs.index(instr_encoding)
self._update_table(tn, opcode[-1][1], name, TrieEntry.instr(desc_idx)) self._update_table(tn, opcode[-1][1], name, TrieEntry.instr(desc_idx))
def deduplicate(self): def deduplicate(self):
synonyms = True parents = defaultdict(set)
while synonyms: for name, entry in self.data.items():
for child in entry.items:
parents[child].add(name)
queue = list(self.data.keys())
entries = {} # Mapping from entry to name entries = {} # Mapping from entry to name
while queue:
# First find new synonyms
synonyms = {} # Mapping from name to unique name synonyms = {} # Mapping from name to unique name
for name, entry in self.data.items(): for name in queue:
if entry in entries: if self.data[name] in entries:
synonyms[name] = entries[entry] synonyms[name] = entries[self.data[name]]
del self.data[name]
else: else:
entries[entry] = name entries[self.data[name]] = name
for name, entry in self.data.items(): queue = set.union(set(), *(parents[n] for n in synonyms))
self.data[name] = entry.map(lambda _, v: synonyms.get(v, v)) # Update parents of found synonyms; parents will need to be checked
for key in synonyms: # again for synonyms in the next iteration.
del self.data[key] for name in queue:
entry = self.data[name]
items = tuple(synonyms.get(v, v) for v in entry.items)
self.data[name] = entry._replace(items=items)
for child in items:
parents[child].add(name)
def calc_offsets(self): def calc_offsets(self):
current = 0 current = 0
@@ -404,24 +399,26 @@ class Table:
else: else:
self.annotations[current] = "%s(%d)" % (name, entry.kind.value) self.annotations[current] = "%s(%d)" % (name, entry.kind.value)
self.offsets[name] = current self.offsets[name] = current
current += (entry.encode_length + 3) & ~3 current += (len(entry.items) + 3) & ~3
if current >= 0x8000: if current >= 0x8000:
raise Exception("maximum table size exceeded: {:x}".format(current)) raise Exception("maximum table size exceeded: {:x}".format(current))
def encode_item(self, name): def _encode_item(self, name):
return (self.offsets[name] << 1) | self.data[name].kind.value return (self.offsets[name] << 1) | self.data[name].kind.value
def compile(self): def compile(self):
self.calc_offsets() self.calc_offsets()
ordered = sorted((off, self.data[k]) for k, off in self.offsets.items() if self.data[k].encode_length) ordered = sorted((off, self.data[k]) for k, off in self.offsets.items() if self.data[k].items)
data = () data = [0] * (ordered[-1][0] + len(ordered[-1][1].items))
for off, entry in ordered: for off, entry in ordered:
data += (0,) * (off - len(data)) + entry.encode(self.encode_item) for i, item in enumerate(entry.items, start=off):
if item is not None:
data[i] = self._encode_item(item)
stats = dict(Counter(entry.kind for entry in self.data.values())) stats = dict(Counter(entry.kind for entry in self.data.values()))
print("%d bytes" % (2*len(data)), stats) print("%d bytes" % (2*len(data)), stats)
return data, self.annotations, [self.offsets[k] for k in self.roots], self.descs return tuple(data), self.annotations, [self.offsets[k] for k in self.roots], self.descs
def bytes_to_table(data, notes): def bytes_to_table(data, notes):
strdata = tuple(d+"," if type(d) == str else "%#04x,"%d for d in data) strdata = tuple(d+"," if type(d) == str else "%#04x,"%d for d in data)
@@ -586,7 +583,7 @@ def encode_table(entries):
descs = "" descs = ""
alt_index = 0 alt_index = 0
for mnem, variants in sorted(mnemonics.items()): for mnem, variants in mnemonics.items():
dedup = [] dedup = []
for variant in variants: for variant in variants:
if not any(x[:3] == variant[:3] for x in dedup): if not any(x[:3] == variant[:3] for x in dedup):
@@ -601,8 +598,7 @@ def encode_table(entries):
for idx, alt, (enc, immsz, tys_i, opc_s) in zip(indices, alt_list, dedup): for idx, alt, (enc, immsz, tys_i, opc_s) in zip(indices, alt_list, dedup):
descs += f"[{idx}] = {{ .enc = ENC_{enc}, .immsz = {immsz}, .tys = {tys_i:#x}, .opc = {opc_s}, .alt = {alt} }},\n" descs += f"[{idx}] = {{ .enc = ENC_{enc}, .immsz = {immsz}, .tys = {tys_i:#x}, .opc = {opc_s}, .alt = {alt} }},\n"
mnemonics_list = sorted(mnemonics.keys()) mnemonics_lut = {mnem: i for i, mnem in enumerate(sorted(mnemonics.keys()))}
mnemonics_lut = {mnem: mnemonics_list.index(mnem) for mnem in mnemonics_list}
mnemonics_tab = "\n".join("FE_MNEMONIC(%s,%d)"%entry for entry in mnemonics_lut.items()) mnemonics_tab = "\n".join("FE_MNEMONIC(%s,%d)"%entry for entry in mnemonics_lut.items())
return mnemonics_tab, descs return mnemonics_tab, descs
@@ -624,9 +620,8 @@ if __name__ == "__main__":
entries.append((Opcode.parse(opcode_string), InstrDesc.parse(desc))) entries.append((Opcode.parse(opcode_string), InstrDesc.parse(desc)))
mnemonics = sorted({desc.mnemonic for _, desc in entries}) mnemonics = sorted({desc.mnemonic for _, desc in entries})
mnemonics_lut = {name: mnemonics.index(name) for name in mnemonics}
decode_mnems_lines = ["FD_MNEMONIC(%s,%d)"%e for e in mnemonics_lut.items()] decode_mnems_lines = ["FD_MNEMONIC(%s,%d)"%e[::-1] for e in enumerate(mnemonics)]
args.decode_mnems.write("\n".join(decode_mnems_lines)) args.decode_mnems.write("\n".join(decode_mnems_lines))
modes = [32, 64] modes = [32, 64]