parseinstrs: Simplify trie implementation

This commit is contained in:
Alexis Engelke
2021-01-23 12:32:22 +01:00
parent 43910a6227
commit 13a2456458

View File

@@ -210,32 +210,16 @@ class InstrDesc(NamedTuple):
class EntryKind(Enum): class EntryKind(Enum):
NONE = 0 NONE = 0
INSTR = 1 INSTR = 1
WEAKINSTR = 9
TABLE256 = 2 TABLE256 = 2
TABLE16 = 3 TABLE16 = 3
TABLE8E = 4 TABLE8E = 4
TABLE_PREFIX = 5 TABLE_PREFIX = 5
TABLE_VEX = 6 TABLE_VEX = 6
TABLE_ROOT = -1 TABLE_ROOT = -1
@property
class TrieEntry(NamedTuple): def is_instr(self):
kind: EntryKind return self == EntryKind.INSTR or self == EntryKind.WEAKINSTR
items: Tuple[Optional[str]]
descidx: Optional[int]
TABLE_LENGTH = {
EntryKind.TABLE256: 256,
EntryKind.TABLE16: 16,
EntryKind.TABLE8E: 8,
EntryKind.TABLE_PREFIX: 4,
EntryKind.TABLE_VEX: 4,
EntryKind.TABLE_ROOT: 8,
}
@classmethod
def table(cls, kind):
return cls(kind, (None,) * cls.TABLE_LENGTH[kind], ())
@classmethod
def instr(cls, descidx):
return cls(EntryKind.INSTR, (), descidx)
opcode_regex = re.compile( opcode_regex = re.compile(
r"^(?:(?P<prefixes>(?P<vex>VEX\.)?(?P<legacy>NP|66|F2|F3|NFx)\." + r"^(?:(?P<prefixes>(?P<vex>VEX\.)?(?P<legacy>NP|66|F2|F3|NFx)\." +
@@ -307,127 +291,118 @@ class Opcode(NamedTuple):
vexl = {"0": [0], "1": [1<<1], "IG": [0, 1<<1]}[self.vexl or "IG"] vexl = {"0": [0], "1": [1<<1], "IG": [0, 1<<1]}[self.vexl or "IG"]
entries = list(map(sum, product(rexw, vexl))) entries = list(map(sum, product(rexw, vexl)))
opcode.append((EntryKind.TABLE_VEX, entries)) opcode.append((EntryKind.TABLE_VEX, entries))
return opcode
kinds, values = zip(*opcode) class Trie:
return [tuple(zip(kinds, prod)) for prod in product(*values)] KIND_ORDER = (EntryKind.TABLE_ROOT, EntryKind.TABLE256,
EntryKind.TABLE_PREFIX, EntryKind.TABLE16,
EntryKind.TABLE8E, EntryKind.TABLE_VEX)
TABLE_LENGTH = {
EntryKind.TABLE_ROOT: 8,
EntryKind.TABLE256: 256,
EntryKind.TABLE_PREFIX: 4,
EntryKind.TABLE16: 16,
EntryKind.TABLE8E: 8,
EntryKind.TABLE_VEX: 4,
}
class Table: def __init__(self, root_count):
def __init__(self, root_count=1): self.trie = []
self.data = OrderedDict() self.trie.append([None] * root_count)
self.roots = [(i,) for i in range(root_count)] self.kindmap = defaultdict(list)
for i in range(root_count):
self.data[i,] = TrieEntry.table(EntryKind.TABLE_ROOT)
self.descs = []
self.descs_map = {}
self.offsets = {}
self.annotations = {}
def _update_table(self, name, idx, entry_name, entry_val): def _add_table(self, kind):
old = self.data[name] self.trie.append([None] * self.TABLE_LENGTH[kind])
# Don't override existing entries. This only happens on invalid input, self.kindmap[kind].append(len(self.trie) - 1)
# e.g. when an opcode is specified twice. return len(self.trie) - 1
if old.items[idx]:
raise Exception("{}/{} set, not overriding to {}".format(name, idx, entry_name))
self.data[entry_name] = entry_val
new_items = old.items[:idx] + (entry_name,) + old.items[idx+1:]
self.data[name] = TrieEntry(old.kind, new_items, None)
def _walk_opcode(self, opcode, root_idx): def _transform_opcode(self, opcode):
tn = root_idx, vals = {k: v for k, v in opcode.for_trie()}
for i in range(len(opcode) - 1): return [vals.get(kind) for kind in self.KIND_ORDER]
# kind is the table kind that we want to point to in the _next_.
kind, byte = opcode[i+1][0], opcode[i][1]
# Retain prev_tn name so that we can update it.
prev_tn, tn = tn, self.data[tn].items[byte]
if tn is None:
tn = prev_tn + (byte,)
self._update_table(prev_tn, byte, tn, TrieEntry.table(kind))
if self.data[tn].kind != kind: def _clone(self, elem):
raise Exception("{}, have {}, want {}".format( if not elem or elem[0].is_instr:
opcode, self.data[tn].kind, kind)) return elem
return tn new_num = self._add_table(elem[0])
self.trie[new_num] = [self._clone(e) for e in self.trie[elem[1]]]
return elem[0], new_num
def _add_encoding(self, instr_encoding): def add_opcode(self, opcode, descidx, root_idx, weak=False):
desc_idx = self.descs_map.get(instr_encoding) opcode = self._transform_opcode(opcode)
if desc_idx is None: frontier = [(0, root_idx)]
desc_idx = self.descs_map[instr_encoding] = len(self.descs) for elem_kind, elem in zip(self.KIND_ORDER, opcode):
self.descs.append(instr_encoding) new_frontier = []
return TrieEntry.instr(desc_idx) for entry_num, entry_idx in frontier:
entry = self.trie[entry_num]
def add_opcode(self, opcode, instr_encoding, root_idx=0): if elem is None:
tn = self._walk_opcode(opcode, root_idx) if entry[entry_idx] is None or entry[entry_idx][0] != elem_kind:
desc_entry = self._add_encoding(instr_encoding) new_frontier.append((entry_num, entry_idx))
self._update_table(tn, opcode[-1][1], desc_entry.descidx, desc_entry) continue
elem = list(range(self.TABLE_LENGTH[elem_kind]))
def fill_free(self, opcode, instr_encoding, root_idx=0): if entry[entry_idx] is None:
desc_entry = self._add_encoding(instr_encoding) new_num = self._add_table(elem_kind)
tn = self._walk_opcode(opcode, root_idx) entry[entry_idx] = elem_kind, new_num
queue = [(tn, opcode[-1][1])] elif entry[entry_idx][0] != elem_kind:
while queue: # Need to add a new node here and copy entry one level below
tn, idx = queue.pop() new_num = self._add_table(elem_kind)
entry = self.data[tn].items[idx] # Keep original entry, but clone others recursively
if not entry: self.trie[new_num][0] = entry[entry_idx]
self._update_table(tn, idx, desc_entry.descidx, desc_entry) for i in range(1, len(self.trie[new_num])):
else: self.trie[new_num][i] = self._clone(entry[entry_idx])
for i in range(len(self.data[entry].items)): entry[entry_idx] = elem_kind, new_num
queue.append((entry, i)) for elem_idx in elem:
new_frontier.append((entry[entry_idx][1], elem_idx))
frontier = new_frontier
for entry_num, entry_idx in frontier:
entry = self.trie[entry_num]
if not entry[entry_idx] or entry[entry_idx][0] == EntryKind.WEAKINSTR:
kind = EntryKind.INSTR if not weak else EntryKind.WEAKINSTR
entry[entry_idx] = kind, descidx
elif not weak:
raise Exception(f"redundant non-weak {opcode}")
def deduplicate(self): def deduplicate(self):
parents = defaultdict(set) synonyms = {}
for name, entry in self.data.items(): for kind in self.KIND_ORDER[::-1]:
for child in entry.items: entries = {}
parents[child].add(name) for num in self.kindmap[kind]:
# Replace previous synonyms
entry = self.trie[num]
for i, elem in enumerate(entry):
if elem and not elem[0].is_instr and elem[1] in synonyms:
entry[i] = elem[0], synonyms[elem[1]]
queue = list(self.data.keys()) # And deduplicate all entries of this kind
entries = {} # Mapping from entry to name unique_entry = tuple(entry)
while queue: if unique_entry in entries:
# First find new synonyms synonyms[num] = entries[unique_entry]
synonyms = {} # Mapping from name to unique name self.trie[num] = None
for name in queue:
if self.data[name] in entries:
synonyms[name] = entries[self.data[name]]
del self.data[name]
else: else:
entries[self.data[name]] = name entries[unique_entry] = num
queue = set.union(set(), *(parents[n] for n in synonyms))
# Update parents of found synonyms; parents will need to be checked
# again for synonyms in the next iteration.
for name in queue:
entry = self.data[name]
items = tuple(synonyms.get(v, v) for v in entry.items)
self.data[name] = entry._replace(items=items)
for child in items:
parents[child].add(name)
def calc_offsets(self):
current = 0
for name, entry in self.data.items():
if entry.kind == EntryKind.INSTR:
self.offsets[name] = entry.descidx << 2
else:
self.annotations[current] = "%s(%d)" % (name, entry.kind.value)
self.offsets[name] = current
current += (len(entry.items) + 3) & ~3
if current >= 0x8000:
raise Exception("maximum table size exceeded: {:x}".format(current))
def _encode_item(self, name):
return (self.offsets[name] << 1) | self.data[name].kind.value
def compile(self): def compile(self):
self.calc_offsets() offsets = [None] * len(self.trie)
ordered = sorted((off, self.data[k]) for k, off in self.offsets.items() if self.data[k].items) last_off = 0
for num, entry in enumerate(self.trie[1:], start=1):
if not entry:
continue
offsets[num] = last_off
last_off += (len(entry) + 3) & ~3
if last_off >= 0x8000:
raise Exception(f"maximum table size exceeded: {last_off:#x}")
data = [0] * (ordered[-1][0] + len(ordered[-1][1].items)) data = [0] * last_off
for off, entry in ordered: for off, entry in zip(offsets, self.trie):
for i, item in enumerate(entry.items, start=off): if off is None:
if item is not None: continue
data[i] = self._encode_item(item) for i, elem in enumerate(entry, start=off):
if elem is not None:
value = elem[1] << 2 if elem[0].is_instr else offsets[elem[1]]
data[i] = (value << 1) | (elem[0].value & 7)
stats = dict(Counter(entry.kind for entry in self.data.values())) stats = {k: len(v) for k, v in self.kindmap.items()}
print("%d bytes" % (2*len(data)), stats) print("%d bytes" % (2*len(data)), stats)
return tuple(data), self.annotations, [self.offsets[k] for k in self.roots], self.descs return tuple(data), [offsets[v] for _, v in self.trie[0]]
def bytes_to_table(data, notes): def bytes_to_table(data, notes):
strdata = tuple(d+"," if type(d) == str else "%#04x,"%d for d in data) strdata = tuple(d+"," if type(d) == str else "%#04x,"%d for d in data)
@@ -611,24 +586,22 @@ if __name__ == "__main__":
decode_mnems_lines = [f"FD_MNEMONIC({m},{i})\n" for i, m in enumerate(mnemonics)] decode_mnems_lines = [f"FD_MNEMONIC({m},{i})\n" for i, m in enumerate(mnemonics)]
args.decode_mnems.write("".join(decode_mnems_lines)) args.decode_mnems.write("".join(decode_mnems_lines))
table = Table(root_count=len(args.modes)) trie = Trie(root_count=len(args.modes))
weak_opcodes = [] descs, desc_map = [], {}
for weak, opcode, desc in entries: for weak, opcode, desc in entries:
ign66 = opcode.prefix in ("NP", "66", "F2", "F3") ign66 = opcode.prefix in ("NP", "66", "F2", "F3")
modrm = opcode.modreg or opcode.opcext modrm = opcode.modreg or opcode.opcext
descenc = desc.encode(ign66, modrm) descenc = desc.encode(ign66, modrm)
desc_idx = desc_map.get(descenc)
if desc_idx is None:
desc_idx = desc_map[descenc] = len(descs)
descs.append(descenc)
for i, mode in enumerate(args.modes): for i, mode in enumerate(args.modes):
if "ONLY%d"%(96-mode) not in desc.flags: if "ONLY%d"%(96-mode) not in desc.flags:
for opcode_path in opcode.for_trie(): trie.add_opcode(opcode, desc_idx, i, weak)
if weak:
weak_opcodes.append((opcode_path, descenc, i))
else:
table.add_opcode(opcode_path, descenc, i)
for k in weak_opcodes:
table.fill_free(*k)
table.deduplicate() trie.deduplicate()
table_data, annotations, root_offsets, descs = table.compile() table_data, root_offsets = trie.compile()
mnemonics_intel = [m.replace("SSE_", "").replace("MMX_", "") mnemonics_intel = [m.replace("SSE_", "").replace("MMX_", "")
.replace("MOVABS", "MOV").replace("RESERVED_", "") .replace("MOVABS", "MOV").replace("RESERVED_", "")
@@ -640,7 +613,7 @@ if __name__ == "__main__":
defines = ["FD_TABLE_OFFSET_%d %d"%k for k in zip(args.modes, root_offsets)] defines = ["FD_TABLE_OFFSET_%d %d"%k for k in zip(args.modes, root_offsets)]
decode_table = template.format( decode_table = template.format(
hex_table=bytes_to_table(table_data, annotations), hex_table=bytes_to_table(table_data, {}),
descs="\n".join("{{{0},{1},{2},{3}}},".format(*desc) for desc in descs), descs="\n".join("{{{0},{1},{2},{3}}},".format(*desc) for desc in descs),
mnemonics=parse_mnemonics(mnemonics_intel), mnemonics=parse_mnemonics(mnemonics_intel),
defines="\n".join("#define " + line for line in defines), defines="\n".join("#define " + line for line in defines),