parseinstrs: Improve performance
This commit is contained in:
197
parseinstrs.py
197
parseinstrs.py
@@ -1,54 +1,45 @@
|
||||
#!/usr/bin/python3
|
||||
|
||||
import argparse
|
||||
from binascii import unhexlify
|
||||
from collections import OrderedDict, defaultdict, namedtuple, Counter
|
||||
from copy import copy
|
||||
from enum import Enum, IntEnum
|
||||
from itertools import accumulate, product
|
||||
from enum import Enum
|
||||
from itertools import product
|
||||
import struct
|
||||
from typing import NamedTuple, FrozenSet, List, Tuple, Union, Optional, ByteString
|
||||
|
||||
def bitstruct(name, fields):
|
||||
names, sizes = zip(*(field.split(":") for field in fields))
|
||||
sizes = tuple(map(int, sizes))
|
||||
offsets = (0,) + tuple(accumulate(sizes))
|
||||
class __class:
|
||||
def __init__(self, **kwargs):
|
||||
for name in names:
|
||||
setattr(self, name, kwargs.get(name, 0))
|
||||
def _encode(self, length):
|
||||
enc = 0
|
||||
for name, size, offset in zip(names, sizes, offsets):
|
||||
enc += (getattr(self, name) & ((1 << size) - 1)) << offset
|
||||
return enc.to_bytes(length, "little")
|
||||
__class.__name__ = name
|
||||
return __class
|
||||
|
||||
InstrFlags = bitstruct("InstrFlags", [
|
||||
"modrm_idx:2",
|
||||
"modreg_idx:2",
|
||||
"vexreg_idx:2",
|
||||
"zeroreg_idx:2",
|
||||
"imm_idx:2",
|
||||
"zeroreg_val:1",
|
||||
"lock:1",
|
||||
"imm_control:3",
|
||||
"vsib:1",
|
||||
"op0_size:2",
|
||||
"op1_size:2",
|
||||
"op2_size:2",
|
||||
"op3_size:2",
|
||||
"opsize:2",
|
||||
"size_fix1:3",
|
||||
"size_fix2:2",
|
||||
"instr_width:1",
|
||||
"op0_regty:3",
|
||||
"op1_regty:3",
|
||||
"op2_regty:3",
|
||||
"_unused:6",
|
||||
"ign66:1",
|
||||
])
|
||||
INSTR_FLAGS_FIELDS, INSTR_FLAGS_SIZES = zip(*[
|
||||
("modrm_idx", 2),
|
||||
("modreg_idx", 2),
|
||||
("vexreg_idx", 2),
|
||||
("zeroreg_idx", 2),
|
||||
("imm_idx", 2),
|
||||
("zeroreg_val", 1),
|
||||
("lock", 1),
|
||||
("imm_control", 3),
|
||||
("vsib", 1),
|
||||
("op0_size", 2),
|
||||
("op1_size", 2),
|
||||
("op2_size", 2),
|
||||
("op3_size", 2),
|
||||
("opsize", 2),
|
||||
("size_fix1", 3),
|
||||
("size_fix2", 2),
|
||||
("instr_width", 1),
|
||||
("op0_regty", 3),
|
||||
("op1_regty", 3),
|
||||
("op2_regty", 3),
|
||||
("unused", 6),
|
||||
("ign66", 1),
|
||||
][::-1])
|
||||
class InstrFlags(namedtuple("InstrFlags", INSTR_FLAGS_FIELDS)):
|
||||
def __new__(cls, **kwargs):
|
||||
init = {**{f: 0 for f in cls._fields}, **kwargs}
|
||||
return super(InstrFlags, cls).__new__(cls, **init)
|
||||
def _encode(self):
|
||||
enc = 0
|
||||
for value, size in zip(self, INSTR_FLAGS_SIZES):
|
||||
enc = enc << size | (value & ((1 << size) - 1))
|
||||
return enc
|
||||
|
||||
ENCODINGS = {
|
||||
"NP": InstrFlags(),
|
||||
@@ -75,11 +66,11 @@ ENCODINGS = {
|
||||
"TD": InstrFlags(zeroreg_idx=1^3, imm_idx=0^3, imm_control=2),
|
||||
|
||||
"RVM": InstrFlags(modrm_idx=2^3, modreg_idx=0^3, vexreg_idx=1^3),
|
||||
"RVMI": InstrFlags(modrm_idx=2^3, modreg_idx=0^3, vexreg_idx=1^3, imm_idx=3^3, imm_control=4, imm_byte=1),
|
||||
"RVMR": InstrFlags(modrm_idx=2^3, modreg_idx=0^3, vexreg_idx=1^3, imm_idx=3^3, imm_control=3, imm_byte=1),
|
||||
"RVMI": InstrFlags(modrm_idx=2^3, modreg_idx=0^3, vexreg_idx=1^3, imm_idx=3^3, imm_control=4),
|
||||
"RVMR": InstrFlags(modrm_idx=2^3, modreg_idx=0^3, vexreg_idx=1^3, imm_idx=3^3, imm_control=3),
|
||||
"RMV": InstrFlags(modrm_idx=1^3, modreg_idx=0^3, vexreg_idx=2^3),
|
||||
"VM": InstrFlags(modrm_idx=1^3, vexreg_idx=0^3),
|
||||
"VMI": InstrFlags(modrm_idx=1^3, vexreg_idx=0^3, imm_idx=2^3, imm_control=4, imm_byte=1),
|
||||
"VMI": InstrFlags(modrm_idx=1^3, vexreg_idx=0^3, imm_idx=2^3, imm_control=4),
|
||||
"MVR": InstrFlags(modrm_idx=0^3, modreg_idx=2^3, vexreg_idx=1^3),
|
||||
}
|
||||
|
||||
@@ -157,7 +148,8 @@ class InstrDesc(NamedTuple):
|
||||
return cls(desc[5], desc[0], operands, frozenset(desc[6:]))
|
||||
|
||||
def encode(self, ign66):
|
||||
flags = copy(ENCODINGS[self.encoding])
|
||||
flags = ENCODINGS[self.encoding]
|
||||
extraflags = {}
|
||||
|
||||
opsz = set(self.OPKIND_SIZES[opkind.size] for opkind in self.operands)
|
||||
|
||||
@@ -166,37 +158,37 @@ class InstrDesc(NamedTuple):
|
||||
if len(fixed) > 2 or (len(fixed) == 2 and not (1 <= fixed[1] <= 4)):
|
||||
raise Exception("invalid fixed operand sizes: %r"%fixed)
|
||||
sizes = (fixed + [1, 1])[:2] + [-2, -3] # See operand_sizes in decode.c.
|
||||
flags.size_fix1 = sizes[0]
|
||||
flags.size_fix2 = sizes[1] - 1
|
||||
extraflags["size_fix1"] = sizes[0]
|
||||
extraflags["size_fix2"] = sizes[1] - 1
|
||||
|
||||
for i, opkind in enumerate(self.operands):
|
||||
sz = self.OPKIND_SIZES[opkind.size]
|
||||
reg_type = self.OPKIND_REGTYS.get(opkind.kind, 7)
|
||||
setattr(flags, "op%d_size"%i, sizes.index(sz))
|
||||
extraflags["op%d_size"%i] = sizes.index(sz)
|
||||
if i < 3:
|
||||
setattr(flags, "op%d_regty"%i, reg_type)
|
||||
extraflags["op%d_regty"%i] = reg_type
|
||||
elif reg_type not in (7, 2):
|
||||
raise Exception("invalid regty for op 3, must be VEC")
|
||||
|
||||
# Miscellaneous Flags
|
||||
if "SIZE_8" in self.flags: flags.opsize = 1
|
||||
if "DEF64" in self.flags: flags.opsize = 2
|
||||
if "FORCE64" in self.flags: flags.opsize = 3
|
||||
if "INSTR_WIDTH" in self.flags: flags.instr_width = 1
|
||||
if "LOCK" in self.flags: flags.lock = 1
|
||||
if "VSIB" in self.flags: flags.vsib = 1
|
||||
if "SIZE_8" in self.flags: extraflags["opsize"] = 1
|
||||
if "DEF64" in self.flags: extraflags["opsize"] = 2
|
||||
if "FORCE64" in self.flags: extraflags["opsize"] = 3
|
||||
if "INSTR_WIDTH" in self.flags: extraflags["instr_width"] = 1
|
||||
if "LOCK" in self.flags: extraflags["lock"] = 1
|
||||
if "VSIB" in self.flags: extraflags["vsib"] = 1
|
||||
|
||||
if "USE66" not in self.flags and (ign66 or "IGN66" in self.flags):
|
||||
flags.ign66 = 1
|
||||
extraflags["ign66"] = 1
|
||||
|
||||
if flags.imm_control >= 4:
|
||||
imm_op = next(op for op in self.operands if op.kind == OpKind.K_IMM)
|
||||
if ("IMM_8" in self.flags or imm_op.size == 1 or
|
||||
(imm_op.size == OpKind.SZ_OP and "SIZE_8" in self.flags)):
|
||||
flags.imm_control |= 1
|
||||
extraflags["imm_control"] = flags.imm_control | 1
|
||||
|
||||
enc = flags._encode(6)
|
||||
enc = tuple(int.from_bytes(enc[i:i+2], "little") for i in range(0, 6, 2))
|
||||
enc = flags._replace(**extraflags)._encode()
|
||||
enc = tuple((enc >> i) & 0xffff for i in range(0, 48, 16))
|
||||
# First 2 bytes are the mnemonic, last 6 bytes are the encoding.
|
||||
return ("FDI_"+self.mnemonic,) + enc
|
||||
|
||||
@@ -230,19 +222,6 @@ class TrieEntry(NamedTuple):
|
||||
def instr(cls, descidx):
|
||||
return cls(EntryKind.INSTR, (), descidx)
|
||||
|
||||
@property
|
||||
def encode_length(self):
|
||||
return len(self.items)
|
||||
def encode(self, encode_item) -> Tuple[Union[int, str]]:
|
||||
enc_items = (encode_item(item) if item else 0 for item in self.items)
|
||||
return tuple(enc_items)
|
||||
|
||||
def map(self, map_func):
|
||||
mapped_items = (map_func(i, v) for i, v in enumerate(self.items))
|
||||
return TrieEntry(self.kind, tuple(mapped_items), self.descidx)
|
||||
def update(self, idx, new_val):
|
||||
return self.map(lambda i, v: new_val if i == idx else v)
|
||||
|
||||
import re
|
||||
opcode_regex = re.compile(
|
||||
r"^(?:(?P<prefixes>(?P<vex>VEX\.)?(?P<legacy>NP|66|F2|F3|NFx)\." +
|
||||
@@ -348,16 +327,19 @@ class Table:
|
||||
for i in range(root_count):
|
||||
self.data["root%d"%i] = TrieEntry.table(EntryKind.TABLE_ROOT)
|
||||
self.descs = []
|
||||
self.descs_map = {}
|
||||
self.offsets = {}
|
||||
self.annotations = {}
|
||||
|
||||
def _update_table(self, name, idx, entry_name, entry_val):
|
||||
old = self.data[name]
|
||||
# Don't override existing entries. This only happens on invalid input,
|
||||
# e.g. when an opcode is specified twice.
|
||||
if self.data[name].items[idx]:
|
||||
if old.items[idx]:
|
||||
raise Exception("{}/{} set, not overriding to {}".format(name, idx, entry_name))
|
||||
self.data[entry_name] = entry_val
|
||||
self.data[name] = self.data[name].update(idx, entry_name)
|
||||
new_items = old.items[:idx] + (entry_name,) + old.items[idx+1:]
|
||||
self.data[name] = TrieEntry(old.kind, new_items, None)
|
||||
|
||||
def add_opcode(self, opcode, instr_encoding, root_idx=0):
|
||||
name = "t{},{}".format(root_idx, format_opcode(opcode))
|
||||
@@ -376,25 +358,38 @@ class Table:
|
||||
raise Exception("{}, have {}, want {}".format(
|
||||
name, self.data[tn].kind, kind))
|
||||
|
||||
if instr_encoding not in self.descs:
|
||||
desc_idx = self.descs_map.get(instr_encoding)
|
||||
if desc_idx is None:
|
||||
desc_idx = self.descs_map[instr_encoding] = len(self.descs)
|
||||
self.descs.append(instr_encoding)
|
||||
desc_idx = self.descs.index(instr_encoding)
|
||||
self._update_table(tn, opcode[-1][1], name, TrieEntry.instr(desc_idx))
|
||||
|
||||
def deduplicate(self):
|
||||
synonyms = True
|
||||
while synonyms:
|
||||
entries = {} # Mapping from entry to name
|
||||
parents = defaultdict(set)
|
||||
for name, entry in self.data.items():
|
||||
for child in entry.items:
|
||||
parents[child].add(name)
|
||||
|
||||
queue = list(self.data.keys())
|
||||
entries = {} # Mapping from entry to name
|
||||
while queue:
|
||||
# First find new synonyms
|
||||
synonyms = {} # Mapping from name to unique name
|
||||
for name, entry in self.data.items():
|
||||
if entry in entries:
|
||||
synonyms[name] = entries[entry]
|
||||
for name in queue:
|
||||
if self.data[name] in entries:
|
||||
synonyms[name] = entries[self.data[name]]
|
||||
del self.data[name]
|
||||
else:
|
||||
entries[entry] = name
|
||||
for name, entry in self.data.items():
|
||||
self.data[name] = entry.map(lambda _, v: synonyms.get(v, v))
|
||||
for key in synonyms:
|
||||
del self.data[key]
|
||||
entries[self.data[name]] = name
|
||||
queue = set.union(set(), *(parents[n] for n in synonyms))
|
||||
# Update parents of found synonyms; parents will need to be checked
|
||||
# again for synonyms in the next iteration.
|
||||
for name in queue:
|
||||
entry = self.data[name]
|
||||
items = tuple(synonyms.get(v, v) for v in entry.items)
|
||||
self.data[name] = entry._replace(items=items)
|
||||
for child in items:
|
||||
parents[child].add(name)
|
||||
|
||||
def calc_offsets(self):
|
||||
current = 0
|
||||
@@ -404,24 +399,26 @@ class Table:
|
||||
else:
|
||||
self.annotations[current] = "%s(%d)" % (name, entry.kind.value)
|
||||
self.offsets[name] = current
|
||||
current += (entry.encode_length + 3) & ~3
|
||||
current += (len(entry.items) + 3) & ~3
|
||||
if current >= 0x8000:
|
||||
raise Exception("maximum table size exceeded: {:x}".format(current))
|
||||
|
||||
def encode_item(self, name):
|
||||
def _encode_item(self, name):
|
||||
return (self.offsets[name] << 1) | self.data[name].kind.value
|
||||
|
||||
def compile(self):
|
||||
self.calc_offsets()
|
||||
ordered = sorted((off, self.data[k]) for k, off in self.offsets.items() if self.data[k].encode_length)
|
||||
ordered = sorted((off, self.data[k]) for k, off in self.offsets.items() if self.data[k].items)
|
||||
|
||||
data = ()
|
||||
data = [0] * (ordered[-1][0] + len(ordered[-1][1].items))
|
||||
for off, entry in ordered:
|
||||
data += (0,) * (off - len(data)) + entry.encode(self.encode_item)
|
||||
for i, item in enumerate(entry.items, start=off):
|
||||
if item is not None:
|
||||
data[i] = self._encode_item(item)
|
||||
|
||||
stats = dict(Counter(entry.kind for entry in self.data.values()))
|
||||
print("%d bytes" % (2*len(data)), stats)
|
||||
return data, self.annotations, [self.offsets[k] for k in self.roots], self.descs
|
||||
return tuple(data), self.annotations, [self.offsets[k] for k in self.roots], self.descs
|
||||
|
||||
def bytes_to_table(data, notes):
|
||||
strdata = tuple(d+"," if type(d) == str else "%#04x,"%d for d in data)
|
||||
@@ -586,7 +583,7 @@ def encode_table(entries):
|
||||
|
||||
descs = ""
|
||||
alt_index = 0
|
||||
for mnem, variants in sorted(mnemonics.items()):
|
||||
for mnem, variants in mnemonics.items():
|
||||
dedup = []
|
||||
for variant in variants:
|
||||
if not any(x[:3] == variant[:3] for x in dedup):
|
||||
@@ -601,8 +598,7 @@ def encode_table(entries):
|
||||
for idx, alt, (enc, immsz, tys_i, opc_s) in zip(indices, alt_list, dedup):
|
||||
descs += f"[{idx}] = {{ .enc = ENC_{enc}, .immsz = {immsz}, .tys = {tys_i:#x}, .opc = {opc_s}, .alt = {alt} }},\n"
|
||||
|
||||
mnemonics_list = sorted(mnemonics.keys())
|
||||
mnemonics_lut = {mnem: mnemonics_list.index(mnem) for mnem in mnemonics_list}
|
||||
mnemonics_lut = {mnem: i for i, mnem in enumerate(sorted(mnemonics.keys()))}
|
||||
mnemonics_tab = "\n".join("FE_MNEMONIC(%s,%d)"%entry for entry in mnemonics_lut.items())
|
||||
return mnemonics_tab, descs
|
||||
|
||||
@@ -624,9 +620,8 @@ if __name__ == "__main__":
|
||||
entries.append((Opcode.parse(opcode_string), InstrDesc.parse(desc)))
|
||||
|
||||
mnemonics = sorted({desc.mnemonic for _, desc in entries})
|
||||
mnemonics_lut = {name: mnemonics.index(name) for name in mnemonics}
|
||||
|
||||
decode_mnems_lines = ["FD_MNEMONIC(%s,%d)"%e for e in mnemonics_lut.items()]
|
||||
decode_mnems_lines = ["FD_MNEMONIC(%s,%d)"%e[::-1] for e in enumerate(mnemonics)]
|
||||
args.decode_mnems.write("\n".join(decode_mnems_lines))
|
||||
|
||||
modes = [32, 64]
|
||||
|
||||
Reference in New Issue
Block a user