parseinstrs: Simplify trie implementation

This commit is contained in:
Alexis Engelke
2021-01-23 12:32:22 +01:00
parent 43910a6227
commit 13a2456458

View File

@@ -210,32 +210,16 @@ class InstrDesc(NamedTuple):
class EntryKind(Enum):
NONE = 0
INSTR = 1
WEAKINSTR = 9
TABLE256 = 2
TABLE16 = 3
TABLE8E = 4
TABLE_PREFIX = 5
TABLE_VEX = 6
TABLE_ROOT = -1
class TrieEntry(NamedTuple):
kind: EntryKind
items: Tuple[Optional[str]]
descidx: Optional[int]
TABLE_LENGTH = {
EntryKind.TABLE256: 256,
EntryKind.TABLE16: 16,
EntryKind.TABLE8E: 8,
EntryKind.TABLE_PREFIX: 4,
EntryKind.TABLE_VEX: 4,
EntryKind.TABLE_ROOT: 8,
}
@classmethod
def table(cls, kind):
return cls(kind, (None,) * cls.TABLE_LENGTH[kind], ())
@classmethod
def instr(cls, descidx):
return cls(EntryKind.INSTR, (), descidx)
@property
def is_instr(self):
return self == EntryKind.INSTR or self == EntryKind.WEAKINSTR
opcode_regex = re.compile(
r"^(?:(?P<prefixes>(?P<vex>VEX\.)?(?P<legacy>NP|66|F2|F3|NFx)\." +
@@ -307,127 +291,118 @@ class Opcode(NamedTuple):
vexl = {"0": [0], "1": [1<<1], "IG": [0, 1<<1]}[self.vexl or "IG"]
entries = list(map(sum, product(rexw, vexl)))
opcode.append((EntryKind.TABLE_VEX, entries))
return opcode
kinds, values = zip(*opcode)
return [tuple(zip(kinds, prod)) for prod in product(*values)]
class Trie:
KIND_ORDER = (EntryKind.TABLE_ROOT, EntryKind.TABLE256,
EntryKind.TABLE_PREFIX, EntryKind.TABLE16,
EntryKind.TABLE8E, EntryKind.TABLE_VEX)
TABLE_LENGTH = {
EntryKind.TABLE_ROOT: 8,
EntryKind.TABLE256: 256,
EntryKind.TABLE_PREFIX: 4,
EntryKind.TABLE16: 16,
EntryKind.TABLE8E: 8,
EntryKind.TABLE_VEX: 4,
}
class Table:
def __init__(self, root_count=1):
self.data = OrderedDict()
self.roots = [(i,) for i in range(root_count)]
for i in range(root_count):
self.data[i,] = TrieEntry.table(EntryKind.TABLE_ROOT)
self.descs = []
self.descs_map = {}
self.offsets = {}
self.annotations = {}
def __init__(self, root_count):
self.trie = []
self.trie.append([None] * root_count)
self.kindmap = defaultdict(list)
def _update_table(self, name, idx, entry_name, entry_val):
old = self.data[name]
# Don't override existing entries. This only happens on invalid input,
# e.g. when an opcode is specified twice.
if old.items[idx]:
raise Exception("{}/{} set, not overriding to {}".format(name, idx, entry_name))
self.data[entry_name] = entry_val
new_items = old.items[:idx] + (entry_name,) + old.items[idx+1:]
self.data[name] = TrieEntry(old.kind, new_items, None)
def _add_table(self, kind):
self.trie.append([None] * self.TABLE_LENGTH[kind])
self.kindmap[kind].append(len(self.trie) - 1)
return len(self.trie) - 1
def _walk_opcode(self, opcode, root_idx):
tn = root_idx,
for i in range(len(opcode) - 1):
# kind is the table kind that we want to point to in the _next_.
kind, byte = opcode[i+1][0], opcode[i][1]
# Retain prev_tn name so that we can update it.
prev_tn, tn = tn, self.data[tn].items[byte]
if tn is None:
tn = prev_tn + (byte,)
self._update_table(prev_tn, byte, tn, TrieEntry.table(kind))
def _transform_opcode(self, opcode):
vals = {k: v for k, v in opcode.for_trie()}
return [vals.get(kind) for kind in self.KIND_ORDER]
if self.data[tn].kind != kind:
raise Exception("{}, have {}, want {}".format(
opcode, self.data[tn].kind, kind))
return tn
def _clone(self, elem):
if not elem or elem[0].is_instr:
return elem
new_num = self._add_table(elem[0])
self.trie[new_num] = [self._clone(e) for e in self.trie[elem[1]]]
return elem[0], new_num
def _add_encoding(self, instr_encoding):
desc_idx = self.descs_map.get(instr_encoding)
if desc_idx is None:
desc_idx = self.descs_map[instr_encoding] = len(self.descs)
self.descs.append(instr_encoding)
return TrieEntry.instr(desc_idx)
def add_opcode(self, opcode, instr_encoding, root_idx=0):
tn = self._walk_opcode(opcode, root_idx)
desc_entry = self._add_encoding(instr_encoding)
self._update_table(tn, opcode[-1][1], desc_entry.descidx, desc_entry)
def fill_free(self, opcode, instr_encoding, root_idx=0):
desc_entry = self._add_encoding(instr_encoding)
tn = self._walk_opcode(opcode, root_idx)
queue = [(tn, opcode[-1][1])]
while queue:
tn, idx = queue.pop()
entry = self.data[tn].items[idx]
if not entry:
self._update_table(tn, idx, desc_entry.descidx, desc_entry)
else:
for i in range(len(self.data[entry].items)):
queue.append((entry, i))
def add_opcode(self, opcode, descidx, root_idx, weak=False):
opcode = self._transform_opcode(opcode)
frontier = [(0, root_idx)]
for elem_kind, elem in zip(self.KIND_ORDER, opcode):
new_frontier = []
for entry_num, entry_idx in frontier:
entry = self.trie[entry_num]
if elem is None:
if entry[entry_idx] is None or entry[entry_idx][0] != elem_kind:
new_frontier.append((entry_num, entry_idx))
continue
elem = list(range(self.TABLE_LENGTH[elem_kind]))
if entry[entry_idx] is None:
new_num = self._add_table(elem_kind)
entry[entry_idx] = elem_kind, new_num
elif entry[entry_idx][0] != elem_kind:
# Need to add a new node here and copy entry one level below
new_num = self._add_table(elem_kind)
# Keep original entry, but clone others recursively
self.trie[new_num][0] = entry[entry_idx]
for i in range(1, len(self.trie[new_num])):
self.trie[new_num][i] = self._clone(entry[entry_idx])
entry[entry_idx] = elem_kind, new_num
for elem_idx in elem:
new_frontier.append((entry[entry_idx][1], elem_idx))
frontier = new_frontier
for entry_num, entry_idx in frontier:
entry = self.trie[entry_num]
if not entry[entry_idx] or entry[entry_idx][0] == EntryKind.WEAKINSTR:
kind = EntryKind.INSTR if not weak else EntryKind.WEAKINSTR
entry[entry_idx] = kind, descidx
elif not weak:
raise Exception(f"redundant non-weak {opcode}")
def deduplicate(self):
parents = defaultdict(set)
for name, entry in self.data.items():
for child in entry.items:
parents[child].add(name)
synonyms = {}
for kind in self.KIND_ORDER[::-1]:
entries = {}
for num in self.kindmap[kind]:
# Replace previous synonyms
entry = self.trie[num]
for i, elem in enumerate(entry):
if elem and not elem[0].is_instr and elem[1] in synonyms:
entry[i] = elem[0], synonyms[elem[1]]
queue = list(self.data.keys())
entries = {} # Mapping from entry to name
while queue:
# First find new synonyms
synonyms = {} # Mapping from name to unique name
for name in queue:
if self.data[name] in entries:
synonyms[name] = entries[self.data[name]]
del self.data[name]
# And deduplicate all entries of this kind
unique_entry = tuple(entry)
if unique_entry in entries:
synonyms[num] = entries[unique_entry]
self.trie[num] = None
else:
entries[self.data[name]] = name
queue = set.union(set(), *(parents[n] for n in synonyms))
# Update parents of found synonyms; parents will need to be checked
# again for synonyms in the next iteration.
for name in queue:
entry = self.data[name]
items = tuple(synonyms.get(v, v) for v in entry.items)
self.data[name] = entry._replace(items=items)
for child in items:
parents[child].add(name)
def calc_offsets(self):
current = 0
for name, entry in self.data.items():
if entry.kind == EntryKind.INSTR:
self.offsets[name] = entry.descidx << 2
else:
self.annotations[current] = "%s(%d)" % (name, entry.kind.value)
self.offsets[name] = current
current += (len(entry.items) + 3) & ~3
if current >= 0x8000:
raise Exception("maximum table size exceeded: {:x}".format(current))
def _encode_item(self, name):
return (self.offsets[name] << 1) | self.data[name].kind.value
entries[unique_entry] = num
def compile(self):
self.calc_offsets()
ordered = sorted((off, self.data[k]) for k, off in self.offsets.items() if self.data[k].items)
offsets = [None] * len(self.trie)
last_off = 0
for num, entry in enumerate(self.trie[1:], start=1):
if not entry:
continue
offsets[num] = last_off
last_off += (len(entry) + 3) & ~3
if last_off >= 0x8000:
raise Exception(f"maximum table size exceeded: {last_off:#x}")
data = [0] * (ordered[-1][0] + len(ordered[-1][1].items))
for off, entry in ordered:
for i, item in enumerate(entry.items, start=off):
if item is not None:
data[i] = self._encode_item(item)
data = [0] * last_off
for off, entry in zip(offsets, self.trie):
if off is None:
continue
for i, elem in enumerate(entry, start=off):
if elem is not None:
value = elem[1] << 2 if elem[0].is_instr else offsets[elem[1]]
data[i] = (value << 1) | (elem[0].value & 7)
stats = dict(Counter(entry.kind for entry in self.data.values()))
stats = {k: len(v) for k, v in self.kindmap.items()}
print("%d bytes" % (2*len(data)), stats)
return tuple(data), self.annotations, [self.offsets[k] for k in self.roots], self.descs
return tuple(data), [offsets[v] for _, v in self.trie[0]]
def bytes_to_table(data, notes):
strdata = tuple(d+"," if type(d) == str else "%#04x,"%d for d in data)
@@ -611,24 +586,22 @@ if __name__ == "__main__":
decode_mnems_lines = [f"FD_MNEMONIC({m},{i})\n" for i, m in enumerate(mnemonics)]
args.decode_mnems.write("".join(decode_mnems_lines))
table = Table(root_count=len(args.modes))
weak_opcodes = []
trie = Trie(root_count=len(args.modes))
descs, desc_map = [], {}
for weak, opcode, desc in entries:
ign66 = opcode.prefix in ("NP", "66", "F2", "F3")
modrm = opcode.modreg or opcode.opcext
descenc = desc.encode(ign66, modrm)
desc_idx = desc_map.get(descenc)
if desc_idx is None:
desc_idx = desc_map[descenc] = len(descs)
descs.append(descenc)
for i, mode in enumerate(args.modes):
if "ONLY%d"%(96-mode) not in desc.flags:
for opcode_path in opcode.for_trie():
if weak:
weak_opcodes.append((opcode_path, descenc, i))
else:
table.add_opcode(opcode_path, descenc, i)
for k in weak_opcodes:
table.fill_free(*k)
trie.add_opcode(opcode, desc_idx, i, weak)
table.deduplicate()
table_data, annotations, root_offsets, descs = table.compile()
trie.deduplicate()
table_data, root_offsets = trie.compile()
mnemonics_intel = [m.replace("SSE_", "").replace("MMX_", "")
.replace("MOVABS", "MOV").replace("RESERVED_", "")
@@ -640,7 +613,7 @@ if __name__ == "__main__":
defines = ["FD_TABLE_OFFSET_%d %d"%k for k in zip(args.modes, root_offsets)]
decode_table = template.format(
hex_table=bytes_to_table(table_data, annotations),
hex_table=bytes_to_table(table_data, {}),
descs="\n".join("{{{0},{1},{2},{3}}},".format(*desc) for desc in descs),
mnemonics=parse_mnemonics(mnemonics_intel),
defines="\n".join("#define " + line for line in defines),