Add more mypy annotations.
This commit is contained in:
@@ -47,9 +47,11 @@ class Def(object):
|
||||
self.expr = expr
|
||||
|
||||
def __repr__(self):
|
||||
# type: () -> str
|
||||
return "{} << {!r}".format(self.defs, self.expr)
|
||||
|
||||
def __str__(self):
|
||||
# type: () -> str
|
||||
if len(self.defs) == 1:
|
||||
return "{!s} << {!s}".format(self.defs[0], self.expr)
|
||||
else:
|
||||
@@ -379,15 +381,18 @@ class Apply(Expr):
|
||||
return Def(other, self)
|
||||
|
||||
def instname(self):
|
||||
# type: () -> str
|
||||
i = self.inst.name
|
||||
for t in self.typevars:
|
||||
i += '.{}'.format(t)
|
||||
return i
|
||||
|
||||
def __repr__(self):
|
||||
# type: () -> str
|
||||
return "Apply({}, {})".format(self.instname(), self.args)
|
||||
|
||||
def __str__(self):
|
||||
# type: () -> str
|
||||
args = ', '.join(map(str, self.args))
|
||||
return '{}({})'.format(self.instname(), args)
|
||||
|
||||
|
||||
@@ -193,6 +193,7 @@ class InstructionFormat(object):
|
||||
|
||||
@staticmethod
|
||||
def extract_names(globs):
|
||||
# type: (Dict[str, Any]) -> None
|
||||
"""
|
||||
Given a dict mapping name -> object as returned by `globals()`, find
|
||||
all the InstructionFormat objects and set their name from the dict key.
|
||||
|
||||
@@ -6,11 +6,13 @@ from .operands import Operand
|
||||
from .formats import InstructionFormat
|
||||
|
||||
try:
|
||||
from typing import Union, Sequence, List # noqa
|
||||
# List of operands for ins/outs:
|
||||
OpList = Union[Sequence[Operand], Operand]
|
||||
MaybeBoundInst = Union['Instruction', 'BoundInstruction']
|
||||
from typing import Tuple, Any # noqa
|
||||
from typing import Union, Sequence, List, Tuple, Any, TYPE_CHECKING # noqa
|
||||
if TYPE_CHECKING:
|
||||
from .ast import Expr, Apply # noqa
|
||||
from .typevar import TypeVar # noqa
|
||||
# List of operands for ins/outs:
|
||||
OpList = Union[Sequence[Operand], Operand]
|
||||
MaybeBoundInst = Union['Instruction', 'BoundInstruction']
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@@ -122,6 +124,7 @@ class Instruction(object):
|
||||
InstructionGroup.append(self)
|
||||
|
||||
def __str__(self):
|
||||
# type: () -> str
|
||||
prefix = ', '.join(o.name for o in self.outs)
|
||||
if prefix:
|
||||
prefix = prefix + ' = '
|
||||
@@ -141,6 +144,7 @@ class Instruction(object):
|
||||
return self.name
|
||||
|
||||
def blurb(self):
|
||||
# type: () -> str
|
||||
"""Get the first line of the doc comment"""
|
||||
for line in self.__doc__.split('\n'):
|
||||
line = line.strip()
|
||||
@@ -149,6 +153,7 @@ class Instruction(object):
|
||||
return ""
|
||||
|
||||
def _verify_polymorphic(self):
|
||||
# type: () -> None
|
||||
"""
|
||||
Check if this instruction is polymorphic, and verify its use of type
|
||||
variables.
|
||||
@@ -193,6 +198,7 @@ class Instruction(object):
|
||||
self.ctrl_typevar = tv
|
||||
|
||||
def _verify_ctrl_typevar(self, ctrl_typevar):
|
||||
# type: (TypeVar) -> List[TypeVar]
|
||||
"""
|
||||
Verify that the use of TypeVars is consistent with `ctrl_typevar` as
|
||||
the controlling type variable.
|
||||
@@ -204,7 +210,7 @@ class Instruction(object):
|
||||
|
||||
Return list of other type variables used, or raise an error.
|
||||
"""
|
||||
other_tvs = []
|
||||
other_tvs = [] # type: List[TypeVar]
|
||||
# Check value inputs.
|
||||
for opnum in self.value_opnums:
|
||||
typ = self.ins[opnum].typevar
|
||||
@@ -283,11 +289,12 @@ class Instruction(object):
|
||||
return (self, ())
|
||||
|
||||
def __call__(self, *args):
|
||||
# type: (*Expr) -> Apply
|
||||
"""
|
||||
Create an `ast.Apply` AST node representing the application of this
|
||||
instruction to the arguments.
|
||||
"""
|
||||
from .ast import Apply
|
||||
from .ast import Apply # noqa
|
||||
return Apply(self, args)
|
||||
|
||||
|
||||
@@ -303,6 +310,7 @@ class BoundInstruction(object):
|
||||
assert len(typevars) <= 1 + len(inst.other_typevars)
|
||||
|
||||
def __str__(self):
|
||||
# type: () -> str
|
||||
return '.'.join([self.inst.name, ] + list(map(str, self.typevars)))
|
||||
|
||||
def bind(self, *args):
|
||||
@@ -336,9 +344,10 @@ class BoundInstruction(object):
|
||||
return (self.inst, self.typevars)
|
||||
|
||||
def __call__(self, *args):
|
||||
# type: (*Expr) -> Apply
|
||||
"""
|
||||
Create an `ast.Apply` AST node representing the application of this
|
||||
instruction to the arguments.
|
||||
"""
|
||||
from .ast import Apply
|
||||
from .ast import Apply # noqa
|
||||
return Apply(self, args)
|
||||
|
||||
@@ -97,6 +97,7 @@ class TargetISA(object):
|
||||
self.settings.number_predicate(enc.isap)
|
||||
|
||||
def _collect_regclasses(self):
|
||||
# type: () -> None
|
||||
"""
|
||||
Collect and number register classes.
|
||||
|
||||
@@ -132,6 +133,7 @@ class CPUMode(object):
|
||||
return self.name
|
||||
|
||||
def enc(self, *args, **kwargs):
|
||||
# type: (*Any, **Any) -> None
|
||||
"""
|
||||
Add a new encoding to this CPU mode.
|
||||
|
||||
@@ -186,7 +188,7 @@ class EncRecipe(object):
|
||||
return self.name
|
||||
|
||||
def _verify_constraints(self, seq):
|
||||
# (ConstraintSeq) -> Sequence[OperandConstraint]
|
||||
# type: (ConstraintSeq) -> Sequence[OperandConstraint]
|
||||
if not isinstance(seq, tuple):
|
||||
seq = (seq,)
|
||||
for c in seq:
|
||||
@@ -194,7 +196,7 @@ class EncRecipe(object):
|
||||
# An integer constraint is bound to a value operand.
|
||||
# Check that it is in range.
|
||||
assert c >= 0
|
||||
if not format.has_value_list:
|
||||
if not self.format.has_value_list:
|
||||
assert c < self.format.num_value_operands
|
||||
else:
|
||||
assert isinstance(c, RegClass) or isinstance(c, Register)
|
||||
|
||||
@@ -37,6 +37,7 @@ except ImportError:
|
||||
|
||||
|
||||
def _is_parent(a, b):
|
||||
# type: (PredContext, PredContext) -> bool
|
||||
"""
|
||||
Return true if a is a parent of b, or equal to it.
|
||||
"""
|
||||
@@ -46,6 +47,7 @@ def _is_parent(a, b):
|
||||
|
||||
|
||||
def _descendant(a, b):
|
||||
# type: (PredContext, PredContext) -> PredContext
|
||||
"""
|
||||
If a is a parent of b or b is a parent of a, return the descendant of the
|
||||
two.
|
||||
|
||||
@@ -26,11 +26,12 @@ from . import is_power_of_two, next_power_of_two
|
||||
|
||||
|
||||
try:
|
||||
from typing import Sequence, Tuple, List, Dict # noqa
|
||||
from .isa import TargetISA # noqa
|
||||
# A tuple uniquely identifying a register class inside a register bank.
|
||||
# (count, width, start)
|
||||
RCTup = Tuple[int, int, int]
|
||||
from typing import Sequence, Tuple, List, Dict, Any, TYPE_CHECKING # noqa
|
||||
if TYPE_CHECKING:
|
||||
from .isa import TargetISA # noqa
|
||||
# A tuple uniquely identifying a register class inside a register bank.
|
||||
# (count, width, start)
|
||||
RCTup = Tuple[int, int, int]
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@@ -189,6 +190,7 @@ class RegClass(object):
|
||||
bank.classes.append(self)
|
||||
|
||||
def __str__(self):
|
||||
# type: () -> str
|
||||
return self.name
|
||||
|
||||
def rctup(self):
|
||||
@@ -223,6 +225,7 @@ class RegClass(object):
|
||||
return (count, self.width, start)
|
||||
|
||||
def __getitem__(self, sliced):
|
||||
# type: (slice) -> RegClass
|
||||
"""
|
||||
Create a sub-class of a register class using slice notation. The slice
|
||||
indexes refer to allocations in the parent register class, not register
|
||||
@@ -273,6 +276,7 @@ class RegClass(object):
|
||||
|
||||
@staticmethod
|
||||
def extract_names(globs):
|
||||
# type: (Dict[str, Any]) -> None
|
||||
"""
|
||||
Given a dict mapping name -> object as returned by `globals()`, find
|
||||
all the RegClass objects and set their name from the dict key.
|
||||
|
||||
@@ -21,11 +21,12 @@ class ValueType(object):
|
||||
_registry = dict() # type: Dict[str, ValueType]
|
||||
|
||||
# List of all the scalar types.
|
||||
all_scalars = list() # type: List[ValueType]
|
||||
all_scalars = list() # type: List[ScalarType]
|
||||
|
||||
def __init__(self, name, membytes, doc):
|
||||
# type: (str, int, str) -> None
|
||||
self.name = name
|
||||
self.number = None # type: int
|
||||
self.membytes = membytes
|
||||
self.__doc__ = doc
|
||||
assert name not in ValueType._registry
|
||||
|
||||
@@ -9,10 +9,12 @@ import math
|
||||
from . import types, is_power_of_two
|
||||
|
||||
try:
|
||||
from typing import Tuple, Union # noqa
|
||||
Interval = Tuple[int, int]
|
||||
# An Interval where `True` means 'everything'
|
||||
BoolInterval = Union[bool, Interval]
|
||||
from typing import Tuple, Union, TYPE_CHECKING # noqa
|
||||
if TYPE_CHECKING:
|
||||
from srcgen import Formatter # noqa
|
||||
Interval = Tuple[int, int]
|
||||
# An Interval where `True` means 'everything'
|
||||
BoolInterval = Union[bool, Interval]
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@@ -143,7 +145,11 @@ class TypeSet(object):
|
||||
return h
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.typeset_key() == other.typeset_key()
|
||||
# type: (object) -> bool
|
||||
if isinstance(other, TypeSet):
|
||||
return self.typeset_key() == other.typeset_key()
|
||||
else:
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
# type: () -> str
|
||||
@@ -157,6 +163,7 @@ class TypeSet(object):
|
||||
return s + ')'
|
||||
|
||||
def emit_fields(self, fmt):
|
||||
# type: (Formatter) -> None
|
||||
"""Emit field initializers for this typeset."""
|
||||
fmt.comment(repr(self))
|
||||
fields = ('lanes', 'int', 'float', 'bool')
|
||||
@@ -299,6 +306,9 @@ class TypeVar(object):
|
||||
.format(self.name, self.type_set))
|
||||
|
||||
def __eq__(self, other):
|
||||
# type: (object) -> bool
|
||||
if not isinstance(other, TypeVar):
|
||||
return False
|
||||
if self.is_derived and other.is_derived:
|
||||
return (
|
||||
self.derived_func == other.derived_func and
|
||||
|
||||
@@ -40,10 +40,6 @@ class Rtl(object):
|
||||
# type: (*DefApply) -> None
|
||||
self.rtl = tuple(map(canonicalize_defapply, args))
|
||||
|
||||
def __iter__(self):
|
||||
# type: () -> Iterator[Def]
|
||||
return iter(self.rtl)
|
||||
|
||||
|
||||
class XForm(object):
|
||||
"""
|
||||
@@ -105,10 +101,11 @@ class XForm(object):
|
||||
self._collect_typevars()
|
||||
|
||||
def __repr__(self):
|
||||
# type: () -> str
|
||||
s = "XForm(inputs={}, defs={},\n ".format(self.inputs, self.defs)
|
||||
s += '\n '.join(str(n) for n in self.src)
|
||||
s += '\n '.join(str(n) for n in self.src.rtl)
|
||||
s += '\n=>\n '
|
||||
s += '\n '.join(str(n) for n in self.dst)
|
||||
s += '\n '.join(str(n) for n in self.dst.rtl)
|
||||
s += '\n)'
|
||||
return s
|
||||
|
||||
|
||||
@@ -8,8 +8,14 @@ quadratically probed hash table.
|
||||
from __future__ import absolute_import
|
||||
from cdsl import next_power_of_two
|
||||
|
||||
try:
|
||||
from typing import Any, List, Sequence, Callable # noqa
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def simple_hash(s):
|
||||
# type: (str) -> int
|
||||
"""
|
||||
Compute a primitive hash of a string.
|
||||
|
||||
@@ -26,6 +32,7 @@ def simple_hash(s):
|
||||
|
||||
|
||||
def compute_quadratic(items, hash_function):
|
||||
# type: (Sequence[Any], Callable[[Any], int]) -> List[Any]
|
||||
"""
|
||||
Compute an open addressed, quadratically probed hash table containing
|
||||
`items`. The returned table is a list containing the elements of the
|
||||
@@ -43,7 +50,7 @@ def compute_quadratic(items, hash_function):
|
||||
items = list(items)
|
||||
# Table size must be a power of two. Aim for >20% unused slots.
|
||||
size = next_power_of_two(int(1.20*len(items)))
|
||||
table = [None] * size
|
||||
table = [None] * size # type: List[Any]
|
||||
|
||||
for i in items:
|
||||
h = hash_function(i) % size
|
||||
|
||||
@@ -16,8 +16,14 @@ from __future__ import absolute_import, print_function
|
||||
import os
|
||||
from os.path import dirname, abspath, join
|
||||
|
||||
try:
|
||||
from typing import Iterable # noqa
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def source_files(top):
|
||||
# type: (str) -> Iterable[str]
|
||||
"""
|
||||
Recursively find all interesting source files and directories in the
|
||||
directory tree starting at top. Yield a path to each file.
|
||||
@@ -30,6 +36,7 @@ def source_files(top):
|
||||
|
||||
|
||||
def generate():
|
||||
# type: () -> None
|
||||
print("Dependencies from meta language directory:")
|
||||
meta = dirname(abspath(__file__))
|
||||
for path in source_files(meta):
|
||||
|
||||
@@ -14,6 +14,7 @@ from cdsl.ast import Var
|
||||
|
||||
try:
|
||||
from typing import Sequence # noqa
|
||||
from cdsl.isa import TargetISA # noqa
|
||||
from cdsl.ast import Def # noqa
|
||||
from cdsl.xform import XForm, XFormGroup # noqa
|
||||
except ImportError:
|
||||
@@ -240,6 +241,7 @@ def gen_xform_group(xgrp, fmt):
|
||||
|
||||
|
||||
def generate(isas, out_dir):
|
||||
# type: (Sequence[TargetISA], str) -> None
|
||||
fmt = Formatter()
|
||||
gen_xform_group(legalize.narrow, fmt)
|
||||
gen_xform_group(legalize.expand, fmt)
|
||||
|
||||
@@ -12,8 +12,14 @@ import srcgen
|
||||
from cdsl.types import ValueType
|
||||
import base.types # noqa
|
||||
|
||||
try:
|
||||
from typing import Iterable # noqa
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def emit_type(ty, fmt):
|
||||
# type: (ValueType, srcgen.Formatter) -> None
|
||||
"""
|
||||
Emit a constant definition of a single value type.
|
||||
"""
|
||||
@@ -25,6 +31,7 @@ def emit_type(ty, fmt):
|
||||
|
||||
|
||||
def emit_vectors(bits, fmt):
|
||||
# type: (int, srcgen.Formatter) -> None
|
||||
"""
|
||||
Emit definition for all vector types with `bits` total size.
|
||||
"""
|
||||
@@ -37,6 +44,7 @@ def emit_vectors(bits, fmt):
|
||||
|
||||
|
||||
def emit_types(fmt):
|
||||
# type: (srcgen.Formatter) -> None
|
||||
for ty in ValueType.all_scalars:
|
||||
emit_type(ty, fmt)
|
||||
# Emit vector definitions for common SIMD sizes.
|
||||
@@ -47,6 +55,7 @@ def emit_types(fmt):
|
||||
|
||||
|
||||
def generate(out_dir):
|
||||
# type: (str) -> None
|
||||
fmt = srcgen.Formatter()
|
||||
emit_types(fmt)
|
||||
fmt.update_file('types.rs', out_dir)
|
||||
|
||||
@@ -101,6 +101,7 @@ class Formatter(object):
|
||||
self.fmt.indent_push()
|
||||
|
||||
def __exit__(self, t, v, tb):
|
||||
# type: (object, object, object) -> None
|
||||
self.fmt.indent_pop()
|
||||
if self.after:
|
||||
self.fmt.line(self.after)
|
||||
@@ -126,13 +127,16 @@ class Formatter(object):
|
||||
return Formatter._IndentedScope(self, after)
|
||||
|
||||
def format(self, fmt, *args):
|
||||
# type: (str, *Any) -> None
|
||||
self.line(fmt.format(*args))
|
||||
|
||||
def comment(self, s):
|
||||
# type: (str) -> None
|
||||
"""Add a comment line."""
|
||||
self.line('// ' + s)
|
||||
|
||||
def doc_comment(self, s):
|
||||
# type: (str) -> None
|
||||
"""Add a (multi-line) documentation comment."""
|
||||
s = re.sub('^', self.indent + '/// ', s, flags=re.M) + '\n'
|
||||
self.lines.append(s)
|
||||
|
||||
@@ -7,18 +7,25 @@ item is mapped to its offset in the final array.
|
||||
This is a compression technique for compile-time generated tables.
|
||||
"""
|
||||
|
||||
try:
|
||||
from typing import Any, List, Dict, Tuple, Sequence # noqa
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class UniqueTable:
|
||||
"""
|
||||
Collect items into the `table` list, removing duplicates.
|
||||
"""
|
||||
def __init__(self):
|
||||
# type: () -> None
|
||||
# List of items added in order.
|
||||
self.table = list()
|
||||
self.table = list() # type: List[Any]
|
||||
# Map item -> index.
|
||||
self.index = dict()
|
||||
self.index = dict() # type: Dict[Any, int]
|
||||
|
||||
def add(self, item):
|
||||
# type: (Any) -> int
|
||||
"""
|
||||
Add a single item to the table if it isn't already there.
|
||||
|
||||
@@ -40,11 +47,13 @@ class UniqueSeqTable:
|
||||
Sequences don't have to be of the same length.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.table = list()
|
||||
# type: () -> None
|
||||
self.table = list() # type: List[Any]
|
||||
# Map seq -> index.
|
||||
self.index = dict()
|
||||
self.index = dict() # type: Dict[Tuple[Any, ...], int]
|
||||
|
||||
def add(self, seq):
|
||||
# type: (Sequence[Any]) -> int
|
||||
"""
|
||||
Add a sequence of items to the table. If the table already contains the
|
||||
items in `seq` in the same order, use those instead.
|
||||
|
||||
Reference in New Issue
Block a user