parseinstrs: Improve performance

This commit is contained in:
Alexis Engelke
2021-01-03 11:32:28 +01:00
parent 5a77c0e6eb
commit 3a3a284f6f

View File

@@ -1,54 +1,45 @@
#!/usr/bin/python3
import argparse
from binascii import unhexlify
from collections import OrderedDict, defaultdict, namedtuple, Counter
from copy import copy
from enum import Enum, IntEnum
from itertools import accumulate, product
from enum import Enum
from itertools import product
import struct
from typing import NamedTuple, FrozenSet, List, Tuple, Union, Optional, ByteString
def bitstruct(name, fields):
names, sizes = zip(*(field.split(":") for field in fields))
sizes = tuple(map(int, sizes))
offsets = (0,) + tuple(accumulate(sizes))
class __class:
def __init__(self, **kwargs):
for name in names:
setattr(self, name, kwargs.get(name, 0))
def _encode(self, length):
enc = 0
for name, size, offset in zip(names, sizes, offsets):
enc += (getattr(self, name) & ((1 << size) - 1)) << offset
return enc.to_bytes(length, "little")
__class.__name__ = name
return __class
InstrFlags = bitstruct("InstrFlags", [
"modrm_idx:2",
"modreg_idx:2",
"vexreg_idx:2",
"zeroreg_idx:2",
"imm_idx:2",
"zeroreg_val:1",
"lock:1",
"imm_control:3",
"vsib:1",
"op0_size:2",
"op1_size:2",
"op2_size:2",
"op3_size:2",
"opsize:2",
"size_fix1:3",
"size_fix2:2",
"instr_width:1",
"op0_regty:3",
"op1_regty:3",
"op2_regty:3",
"_unused:6",
"ign66:1",
])
INSTR_FLAGS_FIELDS, INSTR_FLAGS_SIZES = zip(*[
("modrm_idx", 2),
("modreg_idx", 2),
("vexreg_idx", 2),
("zeroreg_idx", 2),
("imm_idx", 2),
("zeroreg_val", 1),
("lock", 1),
("imm_control", 3),
("vsib", 1),
("op0_size", 2),
("op1_size", 2),
("op2_size", 2),
("op3_size", 2),
("opsize", 2),
("size_fix1", 3),
("size_fix2", 2),
("instr_width", 1),
("op0_regty", 3),
("op1_regty", 3),
("op2_regty", 3),
("unused", 6),
("ign66", 1),
][::-1])
class InstrFlags(namedtuple("InstrFlags", INSTR_FLAGS_FIELDS)):
def __new__(cls, **kwargs):
init = {**{f: 0 for f in cls._fields}, **kwargs}
return super(InstrFlags, cls).__new__(cls, **init)
def _encode(self):
enc = 0
for value, size in zip(self, INSTR_FLAGS_SIZES):
enc = enc << size | (value & ((1 << size) - 1))
return enc
ENCODINGS = {
"NP": InstrFlags(),
@@ -75,11 +66,11 @@ ENCODINGS = {
"TD": InstrFlags(zeroreg_idx=1^3, imm_idx=0^3, imm_control=2),
"RVM": InstrFlags(modrm_idx=2^3, modreg_idx=0^3, vexreg_idx=1^3),
"RVMI": InstrFlags(modrm_idx=2^3, modreg_idx=0^3, vexreg_idx=1^3, imm_idx=3^3, imm_control=4, imm_byte=1),
"RVMR": InstrFlags(modrm_idx=2^3, modreg_idx=0^3, vexreg_idx=1^3, imm_idx=3^3, imm_control=3, imm_byte=1),
"RVMI": InstrFlags(modrm_idx=2^3, modreg_idx=0^3, vexreg_idx=1^3, imm_idx=3^3, imm_control=4),
"RVMR": InstrFlags(modrm_idx=2^3, modreg_idx=0^3, vexreg_idx=1^3, imm_idx=3^3, imm_control=3),
"RMV": InstrFlags(modrm_idx=1^3, modreg_idx=0^3, vexreg_idx=2^3),
"VM": InstrFlags(modrm_idx=1^3, vexreg_idx=0^3),
"VMI": InstrFlags(modrm_idx=1^3, vexreg_idx=0^3, imm_idx=2^3, imm_control=4, imm_byte=1),
"VMI": InstrFlags(modrm_idx=1^3, vexreg_idx=0^3, imm_idx=2^3, imm_control=4),
"MVR": InstrFlags(modrm_idx=0^3, modreg_idx=2^3, vexreg_idx=1^3),
}
@@ -157,7 +148,8 @@ class InstrDesc(NamedTuple):
return cls(desc[5], desc[0], operands, frozenset(desc[6:]))
def encode(self, ign66):
flags = copy(ENCODINGS[self.encoding])
flags = ENCODINGS[self.encoding]
extraflags = {}
opsz = set(self.OPKIND_SIZES[opkind.size] for opkind in self.operands)
@@ -166,37 +158,37 @@ class InstrDesc(NamedTuple):
if len(fixed) > 2 or (len(fixed) == 2 and not (1 <= fixed[1] <= 4)):
raise Exception("invalid fixed operand sizes: %r"%fixed)
sizes = (fixed + [1, 1])[:2] + [-2, -3] # See operand_sizes in decode.c.
flags.size_fix1 = sizes[0]
flags.size_fix2 = sizes[1] - 1
extraflags["size_fix1"] = sizes[0]
extraflags["size_fix2"] = sizes[1] - 1
for i, opkind in enumerate(self.operands):
sz = self.OPKIND_SIZES[opkind.size]
reg_type = self.OPKIND_REGTYS.get(opkind.kind, 7)
setattr(flags, "op%d_size"%i, sizes.index(sz))
extraflags["op%d_size"%i] = sizes.index(sz)
if i < 3:
setattr(flags, "op%d_regty"%i, reg_type)
extraflags["op%d_regty"%i] = reg_type
elif reg_type not in (7, 2):
raise Exception("invalid regty for op 3, must be VEC")
# Miscellaneous Flags
if "SIZE_8" in self.flags: flags.opsize = 1
if "DEF64" in self.flags: flags.opsize = 2
if "FORCE64" in self.flags: flags.opsize = 3
if "INSTR_WIDTH" in self.flags: flags.instr_width = 1
if "LOCK" in self.flags: flags.lock = 1
if "VSIB" in self.flags: flags.vsib = 1
if "SIZE_8" in self.flags: extraflags["opsize"] = 1
if "DEF64" in self.flags: extraflags["opsize"] = 2
if "FORCE64" in self.flags: extraflags["opsize"] = 3
if "INSTR_WIDTH" in self.flags: extraflags["instr_width"] = 1
if "LOCK" in self.flags: extraflags["lock"] = 1
if "VSIB" in self.flags: extraflags["vsib"] = 1
if "USE66" not in self.flags and (ign66 or "IGN66" in self.flags):
flags.ign66 = 1
extraflags["ign66"] = 1
if flags.imm_control >= 4:
imm_op = next(op for op in self.operands if op.kind == OpKind.K_IMM)
if ("IMM_8" in self.flags or imm_op.size == 1 or
(imm_op.size == OpKind.SZ_OP and "SIZE_8" in self.flags)):
flags.imm_control |= 1
extraflags["imm_control"] = flags.imm_control | 1
enc = flags._encode(6)
enc = tuple(int.from_bytes(enc[i:i+2], "little") for i in range(0, 6, 2))
enc = flags._replace(**extraflags)._encode()
enc = tuple((enc >> i) & 0xffff for i in range(0, 48, 16))
# First 2 bytes are the mnemonic, last 6 bytes are the encoding.
return ("FDI_"+self.mnemonic,) + enc
@@ -230,19 +222,6 @@ class TrieEntry(NamedTuple):
def instr(cls, descidx):
return cls(EntryKind.INSTR, (), descidx)
@property
def encode_length(self):
return len(self.items)
def encode(self, encode_item) -> Tuple[Union[int, str]]:
enc_items = (encode_item(item) if item else 0 for item in self.items)
return tuple(enc_items)
def map(self, map_func):
mapped_items = (map_func(i, v) for i, v in enumerate(self.items))
return TrieEntry(self.kind, tuple(mapped_items), self.descidx)
def update(self, idx, new_val):
return self.map(lambda i, v: new_val if i == idx else v)
import re
opcode_regex = re.compile(
r"^(?:(?P<prefixes>(?P<vex>VEX\.)?(?P<legacy>NP|66|F2|F3|NFx)\." +
@@ -348,16 +327,19 @@ class Table:
for i in range(root_count):
self.data["root%d"%i] = TrieEntry.table(EntryKind.TABLE_ROOT)
self.descs = []
self.descs_map = {}
self.offsets = {}
self.annotations = {}
def _update_table(self, name, idx, entry_name, entry_val):
old = self.data[name]
# Don't override existing entries. This only happens on invalid input,
# e.g. when an opcode is specified twice.
if self.data[name].items[idx]:
if old.items[idx]:
raise Exception("{}/{} set, not overriding to {}".format(name, idx, entry_name))
self.data[entry_name] = entry_val
self.data[name] = self.data[name].update(idx, entry_name)
new_items = old.items[:idx] + (entry_name,) + old.items[idx+1:]
self.data[name] = TrieEntry(old.kind, new_items, None)
def add_opcode(self, opcode, instr_encoding, root_idx=0):
name = "t{},{}".format(root_idx, format_opcode(opcode))
@@ -376,25 +358,38 @@ class Table:
raise Exception("{}, have {}, want {}".format(
name, self.data[tn].kind, kind))
if instr_encoding not in self.descs:
desc_idx = self.descs_map.get(instr_encoding)
if desc_idx is None:
desc_idx = self.descs_map[instr_encoding] = len(self.descs)
self.descs.append(instr_encoding)
desc_idx = self.descs.index(instr_encoding)
self._update_table(tn, opcode[-1][1], name, TrieEntry.instr(desc_idx))
def deduplicate(self):
synonyms = True
while synonyms:
entries = {} # Mapping from entry to name
parents = defaultdict(set)
for name, entry in self.data.items():
for child in entry.items:
parents[child].add(name)
queue = list(self.data.keys())
entries = {} # Mapping from entry to name
while queue:
# First find new synonyms
synonyms = {} # Mapping from name to unique name
for name, entry in self.data.items():
if entry in entries:
synonyms[name] = entries[entry]
for name in queue:
if self.data[name] in entries:
synonyms[name] = entries[self.data[name]]
del self.data[name]
else:
entries[entry] = name
for name, entry in self.data.items():
self.data[name] = entry.map(lambda _, v: synonyms.get(v, v))
for key in synonyms:
del self.data[key]
entries[self.data[name]] = name
queue = set.union(set(), *(parents[n] for n in synonyms))
# Update parents of found synonyms; parents will need to be checked
# again for synonyms in the next iteration.
for name in queue:
entry = self.data[name]
items = tuple(synonyms.get(v, v) for v in entry.items)
self.data[name] = entry._replace(items=items)
for child in items:
parents[child].add(name)
def calc_offsets(self):
current = 0
@@ -404,24 +399,26 @@ class Table:
else:
self.annotations[current] = "%s(%d)" % (name, entry.kind.value)
self.offsets[name] = current
current += (entry.encode_length + 3) & ~3
current += (len(entry.items) + 3) & ~3
if current >= 0x8000:
raise Exception("maximum table size exceeded: {:x}".format(current))
def encode_item(self, name):
def _encode_item(self, name):
return (self.offsets[name] << 1) | self.data[name].kind.value
def compile(self):
self.calc_offsets()
ordered = sorted((off, self.data[k]) for k, off in self.offsets.items() if self.data[k].encode_length)
ordered = sorted((off, self.data[k]) for k, off in self.offsets.items() if self.data[k].items)
data = ()
data = [0] * (ordered[-1][0] + len(ordered[-1][1].items))
for off, entry in ordered:
data += (0,) * (off - len(data)) + entry.encode(self.encode_item)
for i, item in enumerate(entry.items, start=off):
if item is not None:
data[i] = self._encode_item(item)
stats = dict(Counter(entry.kind for entry in self.data.values()))
print("%d bytes" % (2*len(data)), stats)
return data, self.annotations, [self.offsets[k] for k in self.roots], self.descs
return tuple(data), self.annotations, [self.offsets[k] for k in self.roots], self.descs
def bytes_to_table(data, notes):
strdata = tuple(d+"," if type(d) == str else "%#04x,"%d for d in data)
@@ -586,7 +583,7 @@ def encode_table(entries):
descs = ""
alt_index = 0
for mnem, variants in sorted(mnemonics.items()):
for mnem, variants in mnemonics.items():
dedup = []
for variant in variants:
if not any(x[:3] == variant[:3] for x in dedup):
@@ -601,8 +598,7 @@ def encode_table(entries):
for idx, alt, (enc, immsz, tys_i, opc_s) in zip(indices, alt_list, dedup):
descs += f"[{idx}] = {{ .enc = ENC_{enc}, .immsz = {immsz}, .tys = {tys_i:#x}, .opc = {opc_s}, .alt = {alt} }},\n"
mnemonics_list = sorted(mnemonics.keys())
mnemonics_lut = {mnem: mnemonics_list.index(mnem) for mnem in mnemonics_list}
mnemonics_lut = {mnem: i for i, mnem in enumerate(sorted(mnemonics.keys()))}
mnemonics_tab = "\n".join("FE_MNEMONIC(%s,%d)"%entry for entry in mnemonics_lut.items())
return mnemonics_tab, descs
@@ -624,9 +620,8 @@ if __name__ == "__main__":
entries.append((Opcode.parse(opcode_string), InstrDesc.parse(desc)))
mnemonics = sorted({desc.mnemonic for _, desc in entries})
mnemonics_lut = {name: mnemonics.index(name) for name in mnemonics}
decode_mnems_lines = ["FD_MNEMONIC(%s,%d)"%e for e in mnemonics_lut.items()]
decode_mnems_lines = ["FD_MNEMONIC(%s,%d)"%e[::-1] for e in enumerate(mnemonics)]
args.decode_mnems.write("\n".join(decode_mnems_lines))
modes = [32, 64]