diff --git a/parseinstrs.py b/parseinstrs.py index 7528f2e..4c76892 100644 --- a/parseinstrs.py +++ b/parseinstrs.py @@ -210,32 +210,16 @@ class InstrDesc(NamedTuple): class EntryKind(Enum): NONE = 0 INSTR = 1 + WEAKINSTR = 9 TABLE256 = 2 TABLE16 = 3 TABLE8E = 4 TABLE_PREFIX = 5 TABLE_VEX = 6 TABLE_ROOT = -1 - -class TrieEntry(NamedTuple): - kind: EntryKind - items: Tuple[Optional[str]] - descidx: Optional[int] - - TABLE_LENGTH = { - EntryKind.TABLE256: 256, - EntryKind.TABLE16: 16, - EntryKind.TABLE8E: 8, - EntryKind.TABLE_PREFIX: 4, - EntryKind.TABLE_VEX: 4, - EntryKind.TABLE_ROOT: 8, - } - @classmethod - def table(cls, kind): - return cls(kind, (None,) * cls.TABLE_LENGTH[kind], ()) - @classmethod - def instr(cls, descidx): - return cls(EntryKind.INSTR, (), descidx) + @property + def is_instr(self): + return self == EntryKind.INSTR or self == EntryKind.WEAKINSTR opcode_regex = re.compile( r"^(?:(?P(?PVEX\.)?(?PNP|66|F2|F3|NFx)\." + @@ -307,127 +291,118 @@ class Opcode(NamedTuple): vexl = {"0": [0], "1": [1<<1], "IG": [0, 1<<1]}[self.vexl or "IG"] entries = list(map(sum, product(rexw, vexl))) opcode.append((EntryKind.TABLE_VEX, entries)) + return opcode - kinds, values = zip(*opcode) - return [tuple(zip(kinds, prod)) for prod in product(*values)] +class Trie: + KIND_ORDER = (EntryKind.TABLE_ROOT, EntryKind.TABLE256, + EntryKind.TABLE_PREFIX, EntryKind.TABLE16, + EntryKind.TABLE8E, EntryKind.TABLE_VEX) + TABLE_LENGTH = { + EntryKind.TABLE_ROOT: 8, + EntryKind.TABLE256: 256, + EntryKind.TABLE_PREFIX: 4, + EntryKind.TABLE16: 16, + EntryKind.TABLE8E: 8, + EntryKind.TABLE_VEX: 4, + } -class Table: - def __init__(self, root_count=1): - self.data = OrderedDict() - self.roots = [(i,) for i in range(root_count)] - for i in range(root_count): - self.data[i,] = TrieEntry.table(EntryKind.TABLE_ROOT) - self.descs = [] - self.descs_map = {} - self.offsets = {} - self.annotations = {} + def __init__(self, root_count): + self.trie = [] + self.trie.append([None] * root_count) + self.kindmap = defaultdict(list) - def _update_table(self, name, idx, entry_name, entry_val): - old = self.data[name] - # Don't override existing entries. This only happens on invalid input, - # e.g. when an opcode is specified twice. - if old.items[idx]: - raise Exception("{}/{} set, not overriding to {}".format(name, idx, entry_name)) - self.data[entry_name] = entry_val - new_items = old.items[:idx] + (entry_name,) + old.items[idx+1:] - self.data[name] = TrieEntry(old.kind, new_items, None) + def _add_table(self, kind): + self.trie.append([None] * self.TABLE_LENGTH[kind]) + self.kindmap[kind].append(len(self.trie) - 1) + return len(self.trie) - 1 - def _walk_opcode(self, opcode, root_idx): - tn = root_idx, - for i in range(len(opcode) - 1): - # kind is the table kind that we want to point to in the _next_. - kind, byte = opcode[i+1][0], opcode[i][1] - # Retain prev_tn name so that we can update it. - prev_tn, tn = tn, self.data[tn].items[byte] - if tn is None: - tn = prev_tn + (byte,) - self._update_table(prev_tn, byte, tn, TrieEntry.table(kind)) + def _transform_opcode(self, opcode): + vals = {k: v for k, v in opcode.for_trie()} + return [vals.get(kind) for kind in self.KIND_ORDER] - if self.data[tn].kind != kind: - raise Exception("{}, have {}, want {}".format( - opcode, self.data[tn].kind, kind)) - return tn + def _clone(self, elem): + if not elem or elem[0].is_instr: + return elem + new_num = self._add_table(elem[0]) + self.trie[new_num] = [self._clone(e) for e in self.trie[elem[1]]] + return elem[0], new_num - def _add_encoding(self, instr_encoding): - desc_idx = self.descs_map.get(instr_encoding) - if desc_idx is None: - desc_idx = self.descs_map[instr_encoding] = len(self.descs) - self.descs.append(instr_encoding) - return TrieEntry.instr(desc_idx) - - def add_opcode(self, opcode, instr_encoding, root_idx=0): - tn = self._walk_opcode(opcode, root_idx) - desc_entry = self._add_encoding(instr_encoding) - self._update_table(tn, opcode[-1][1], desc_entry.descidx, desc_entry) - - def fill_free(self, opcode, instr_encoding, root_idx=0): - desc_entry = self._add_encoding(instr_encoding) - tn = self._walk_opcode(opcode, root_idx) - queue = [(tn, opcode[-1][1])] - while queue: - tn, idx = queue.pop() - entry = self.data[tn].items[idx] - if not entry: - self._update_table(tn, idx, desc_entry.descidx, desc_entry) - else: - for i in range(len(self.data[entry].items)): - queue.append((entry, i)) + def add_opcode(self, opcode, descidx, root_idx, weak=False): + opcode = self._transform_opcode(opcode) + frontier = [(0, root_idx)] + for elem_kind, elem in zip(self.KIND_ORDER, opcode): + new_frontier = [] + for entry_num, entry_idx in frontier: + entry = self.trie[entry_num] + if elem is None: + if entry[entry_idx] is None or entry[entry_idx][0] != elem_kind: + new_frontier.append((entry_num, entry_idx)) + continue + elem = list(range(self.TABLE_LENGTH[elem_kind])) + if entry[entry_idx] is None: + new_num = self._add_table(elem_kind) + entry[entry_idx] = elem_kind, new_num + elif entry[entry_idx][0] != elem_kind: + # Need to add a new node here and copy entry one level below + new_num = self._add_table(elem_kind) + # Keep original entry, but clone others recursively + self.trie[new_num][0] = entry[entry_idx] + for i in range(1, len(self.trie[new_num])): + self.trie[new_num][i] = self._clone(entry[entry_idx]) + entry[entry_idx] = elem_kind, new_num + for elem_idx in elem: + new_frontier.append((entry[entry_idx][1], elem_idx)) + frontier = new_frontier + for entry_num, entry_idx in frontier: + entry = self.trie[entry_num] + if not entry[entry_idx] or entry[entry_idx][0] == EntryKind.WEAKINSTR: + kind = EntryKind.INSTR if not weak else EntryKind.WEAKINSTR + entry[entry_idx] = kind, descidx + elif not weak: + raise Exception(f"redundant non-weak {opcode}") def deduplicate(self): - parents = defaultdict(set) - for name, entry in self.data.items(): - for child in entry.items: - parents[child].add(name) + synonyms = {} + for kind in self.KIND_ORDER[::-1]: + entries = {} + for num in self.kindmap[kind]: + # Replace previous synonyms + entry = self.trie[num] + for i, elem in enumerate(entry): + if elem and not elem[0].is_instr and elem[1] in synonyms: + entry[i] = elem[0], synonyms[elem[1]] - queue = list(self.data.keys()) - entries = {} # Mapping from entry to name - while queue: - # First find new synonyms - synonyms = {} # Mapping from name to unique name - for name in queue: - if self.data[name] in entries: - synonyms[name] = entries[self.data[name]] - del self.data[name] + # And deduplicate all entries of this kind + unique_entry = tuple(entry) + if unique_entry in entries: + synonyms[num] = entries[unique_entry] + self.trie[num] = None else: - entries[self.data[name]] = name - queue = set.union(set(), *(parents[n] for n in synonyms)) - # Update parents of found synonyms; parents will need to be checked - # again for synonyms in the next iteration. - for name in queue: - entry = self.data[name] - items = tuple(synonyms.get(v, v) for v in entry.items) - self.data[name] = entry._replace(items=items) - for child in items: - parents[child].add(name) - - def calc_offsets(self): - current = 0 - for name, entry in self.data.items(): - if entry.kind == EntryKind.INSTR: - self.offsets[name] = entry.descidx << 2 - else: - self.annotations[current] = "%s(%d)" % (name, entry.kind.value) - self.offsets[name] = current - current += (len(entry.items) + 3) & ~3 - if current >= 0x8000: - raise Exception("maximum table size exceeded: {:x}".format(current)) - - def _encode_item(self, name): - return (self.offsets[name] << 1) | self.data[name].kind.value + entries[unique_entry] = num def compile(self): - self.calc_offsets() - ordered = sorted((off, self.data[k]) for k, off in self.offsets.items() if self.data[k].items) + offsets = [None] * len(self.trie) + last_off = 0 + for num, entry in enumerate(self.trie[1:], start=1): + if not entry: + continue + offsets[num] = last_off + last_off += (len(entry) + 3) & ~3 + if last_off >= 0x8000: + raise Exception(f"maximum table size exceeded: {last_off:#x}") - data = [0] * (ordered[-1][0] + len(ordered[-1][1].items)) - for off, entry in ordered: - for i, item in enumerate(entry.items, start=off): - if item is not None: - data[i] = self._encode_item(item) + data = [0] * last_off + for off, entry in zip(offsets, self.trie): + if off is None: + continue + for i, elem in enumerate(entry, start=off): + if elem is not None: + value = elem[1] << 2 if elem[0].is_instr else offsets[elem[1]] + data[i] = (value << 1) | (elem[0].value & 7) - stats = dict(Counter(entry.kind for entry in self.data.values())) + stats = {k: len(v) for k, v in self.kindmap.items()} print("%d bytes" % (2*len(data)), stats) - return tuple(data), self.annotations, [self.offsets[k] for k in self.roots], self.descs + return tuple(data), [offsets[v] for _, v in self.trie[0]] def bytes_to_table(data, notes): strdata = tuple(d+"," if type(d) == str else "%#04x,"%d for d in data) @@ -611,24 +586,22 @@ if __name__ == "__main__": decode_mnems_lines = [f"FD_MNEMONIC({m},{i})\n" for i, m in enumerate(mnemonics)] args.decode_mnems.write("".join(decode_mnems_lines)) - table = Table(root_count=len(args.modes)) - weak_opcodes = [] + trie = Trie(root_count=len(args.modes)) + descs, desc_map = [], {} for weak, opcode, desc in entries: ign66 = opcode.prefix in ("NP", "66", "F2", "F3") modrm = opcode.modreg or opcode.opcext descenc = desc.encode(ign66, modrm) + desc_idx = desc_map.get(descenc) + if desc_idx is None: + desc_idx = desc_map[descenc] = len(descs) + descs.append(descenc) for i, mode in enumerate(args.modes): if "ONLY%d"%(96-mode) not in desc.flags: - for opcode_path in opcode.for_trie(): - if weak: - weak_opcodes.append((opcode_path, descenc, i)) - else: - table.add_opcode(opcode_path, descenc, i) - for k in weak_opcodes: - table.fill_free(*k) + trie.add_opcode(opcode, desc_idx, i, weak) - table.deduplicate() - table_data, annotations, root_offsets, descs = table.compile() + trie.deduplicate() + table_data, root_offsets = trie.compile() mnemonics_intel = [m.replace("SSE_", "").replace("MMX_", "") .replace("MOVABS", "MOV").replace("RESERVED_", "") @@ -640,7 +613,7 @@ if __name__ == "__main__": defines = ["FD_TABLE_OFFSET_%d %d"%k for k in zip(args.modes, root_offsets)] decode_table = template.format( - hex_table=bytes_to_table(table_data, annotations), + hex_table=bytes_to_table(table_data, {}), descs="\n".join("{{{0},{1},{2},{3}}},".format(*desc) for desc in descs), mnemonics=parse_mnemonics(mnemonics_intel), defines="\n".join("#define " + line for line in defines),