Add more mypy annotations.

This commit is contained in:
Jakob Stoklund Olesen
2017-03-30 15:15:53 -07:00
parent 02051c4764
commit cfe2c7f46f
15 changed files with 101 additions and 32 deletions

View File

@@ -47,9 +47,11 @@ class Def(object):
self.expr = expr
def __repr__(self):
# type: () -> str
return "{} << {!r}".format(self.defs, self.expr)
def __str__(self):
# type: () -> str
if len(self.defs) == 1:
return "{!s} << {!s}".format(self.defs[0], self.expr)
else:
@@ -379,15 +381,18 @@ class Apply(Expr):
return Def(other, self)
def instname(self):
# type: () -> str
i = self.inst.name
for t in self.typevars:
i += '.{}'.format(t)
return i
def __repr__(self):
# type: () -> str
return "Apply({}, {})".format(self.instname(), self.args)
def __str__(self):
# type: () -> str
args = ', '.join(map(str, self.args))
return '{}({})'.format(self.instname(), args)

View File

@@ -193,6 +193,7 @@ class InstructionFormat(object):
@staticmethod
def extract_names(globs):
# type: (Dict[str, Any]) -> None
"""
Given a dict mapping name -> object as returned by `globals()`, find
all the InstructionFormat objects and set their name from the dict key.

View File

@@ -6,11 +6,13 @@ from .operands import Operand
from .formats import InstructionFormat
try:
from typing import Union, Sequence, List # noqa
from typing import Union, Sequence, List, Tuple, Any, TYPE_CHECKING # noqa
if TYPE_CHECKING:
from .ast import Expr, Apply # noqa
from .typevar import TypeVar # noqa
# List of operands for ins/outs:
OpList = Union[Sequence[Operand], Operand]
MaybeBoundInst = Union['Instruction', 'BoundInstruction']
from typing import Tuple, Any # noqa
except ImportError:
pass
@@ -122,6 +124,7 @@ class Instruction(object):
InstructionGroup.append(self)
def __str__(self):
# type: () -> str
prefix = ', '.join(o.name for o in self.outs)
if prefix:
prefix = prefix + ' = '
@@ -141,6 +144,7 @@ class Instruction(object):
return self.name
def blurb(self):
# type: () -> str
"""Get the first line of the doc comment"""
for line in self.__doc__.split('\n'):
line = line.strip()
@@ -149,6 +153,7 @@ class Instruction(object):
return ""
def _verify_polymorphic(self):
# type: () -> None
"""
Check if this instruction is polymorphic, and verify its use of type
variables.
@@ -193,6 +198,7 @@ class Instruction(object):
self.ctrl_typevar = tv
def _verify_ctrl_typevar(self, ctrl_typevar):
# type: (TypeVar) -> List[TypeVar]
"""
Verify that the use of TypeVars is consistent with `ctrl_typevar` as
the controlling type variable.
@@ -204,7 +210,7 @@ class Instruction(object):
Return list of other type variables used, or raise an error.
"""
other_tvs = []
other_tvs = [] # type: List[TypeVar]
# Check value inputs.
for opnum in self.value_opnums:
typ = self.ins[opnum].typevar
@@ -283,11 +289,12 @@ class Instruction(object):
return (self, ())
def __call__(self, *args):
# type: (*Expr) -> Apply
"""
Create an `ast.Apply` AST node representing the application of this
instruction to the arguments.
"""
from .ast import Apply
from .ast import Apply # noqa
return Apply(self, args)
@@ -303,6 +310,7 @@ class BoundInstruction(object):
assert len(typevars) <= 1 + len(inst.other_typevars)
def __str__(self):
# type: () -> str
return '.'.join([self.inst.name, ] + list(map(str, self.typevars)))
def bind(self, *args):
@@ -336,9 +344,10 @@ class BoundInstruction(object):
return (self.inst, self.typevars)
def __call__(self, *args):
# type: (*Expr) -> Apply
"""
Create an `ast.Apply` AST node representing the application of this
instruction to the arguments.
"""
from .ast import Apply
from .ast import Apply # noqa
return Apply(self, args)

View File

@@ -97,6 +97,7 @@ class TargetISA(object):
self.settings.number_predicate(enc.isap)
def _collect_regclasses(self):
# type: () -> None
"""
Collect and number register classes.
@@ -132,6 +133,7 @@ class CPUMode(object):
return self.name
def enc(self, *args, **kwargs):
# type: (*Any, **Any) -> None
"""
Add a new encoding to this CPU mode.
@@ -186,7 +188,7 @@ class EncRecipe(object):
return self.name
def _verify_constraints(self, seq):
# (ConstraintSeq) -> Sequence[OperandConstraint]
# type: (ConstraintSeq) -> Sequence[OperandConstraint]
if not isinstance(seq, tuple):
seq = (seq,)
for c in seq:
@@ -194,7 +196,7 @@ class EncRecipe(object):
# An integer constraint is bound to a value operand.
# Check that it is in range.
assert c >= 0
if not format.has_value_list:
if not self.format.has_value_list:
assert c < self.format.num_value_operands
else:
assert isinstance(c, RegClass) or isinstance(c, Register)

View File

@@ -37,6 +37,7 @@ except ImportError:
def _is_parent(a, b):
# type: (PredContext, PredContext) -> bool
"""
Return true if a is a parent of b, or equal to it.
"""
@@ -46,6 +47,7 @@ def _is_parent(a, b):
def _descendant(a, b):
# type: (PredContext, PredContext) -> PredContext
"""
If a is a parent of b or b is a parent of a, return the descendant of the
two.

View File

@@ -26,7 +26,8 @@ from . import is_power_of_two, next_power_of_two
try:
from typing import Sequence, Tuple, List, Dict # noqa
from typing import Sequence, Tuple, List, Dict, Any, TYPE_CHECKING # noqa
if TYPE_CHECKING:
from .isa import TargetISA # noqa
# A tuple uniquely identifying a register class inside a register bank.
# (count, width, start)
@@ -189,6 +190,7 @@ class RegClass(object):
bank.classes.append(self)
def __str__(self):
# type: () -> str
return self.name
def rctup(self):
@@ -223,6 +225,7 @@ class RegClass(object):
return (count, self.width, start)
def __getitem__(self, sliced):
# type: (slice) -> RegClass
"""
Create a sub-class of a register class using slice notation. The slice
indexes refer to allocations in the parent register class, not register
@@ -273,6 +276,7 @@ class RegClass(object):
@staticmethod
def extract_names(globs):
# type: (Dict[str, Any]) -> None
"""
Given a dict mapping name -> object as returned by `globals()`, find
all the RegClass objects and set their name from the dict key.

View File

@@ -21,11 +21,12 @@ class ValueType(object):
_registry = dict() # type: Dict[str, ValueType]
# List of all the scalar types.
all_scalars = list() # type: List[ValueType]
all_scalars = list() # type: List[ScalarType]
def __init__(self, name, membytes, doc):
# type: (str, int, str) -> None
self.name = name
self.number = None # type: int
self.membytes = membytes
self.__doc__ = doc
assert name not in ValueType._registry

View File

@@ -9,7 +9,9 @@ import math
from . import types, is_power_of_two
try:
from typing import Tuple, Union # noqa
from typing import Tuple, Union, TYPE_CHECKING # noqa
if TYPE_CHECKING:
from srcgen import Formatter # noqa
Interval = Tuple[int, int]
# An Interval where `True` means 'everything'
BoolInterval = Union[bool, Interval]
@@ -143,7 +145,11 @@ class TypeSet(object):
return h
def __eq__(self, other):
# type: (object) -> bool
if isinstance(other, TypeSet):
return self.typeset_key() == other.typeset_key()
else:
return False
def __repr__(self):
# type: () -> str
@@ -157,6 +163,7 @@ class TypeSet(object):
return s + ')'
def emit_fields(self, fmt):
# type: (Formatter) -> None
"""Emit field initializers for this typeset."""
fmt.comment(repr(self))
fields = ('lanes', 'int', 'float', 'bool')
@@ -299,6 +306,9 @@ class TypeVar(object):
.format(self.name, self.type_set))
def __eq__(self, other):
# type: (object) -> bool
if not isinstance(other, TypeVar):
return False
if self.is_derived and other.is_derived:
return (
self.derived_func == other.derived_func and

View File

@@ -40,10 +40,6 @@ class Rtl(object):
# type: (*DefApply) -> None
self.rtl = tuple(map(canonicalize_defapply, args))
def __iter__(self):
# type: () -> Iterator[Def]
return iter(self.rtl)
class XForm(object):
"""
@@ -105,10 +101,11 @@ class XForm(object):
self._collect_typevars()
def __repr__(self):
# type: () -> str
s = "XForm(inputs={}, defs={},\n ".format(self.inputs, self.defs)
s += '\n '.join(str(n) for n in self.src)
s += '\n '.join(str(n) for n in self.src.rtl)
s += '\n=>\n '
s += '\n '.join(str(n) for n in self.dst)
s += '\n '.join(str(n) for n in self.dst.rtl)
s += '\n)'
return s

View File

@@ -8,8 +8,14 @@ quadratically probed hash table.
from __future__ import absolute_import
from cdsl import next_power_of_two
try:
from typing import Any, List, Sequence, Callable # noqa
except ImportError:
pass
def simple_hash(s):
# type: (str) -> int
"""
Compute a primitive hash of a string.
@@ -26,6 +32,7 @@ def simple_hash(s):
def compute_quadratic(items, hash_function):
# type: (Sequence[Any], Callable[[Any], int]) -> List[Any]
"""
Compute an open addressed, quadratically probed hash table containing
`items`. The returned table is a list containing the elements of the
@@ -43,7 +50,7 @@ def compute_quadratic(items, hash_function):
items = list(items)
# Table size must be a power of two. Aim for >20% unused slots.
size = next_power_of_two(int(1.20*len(items)))
table = [None] * size
table = [None] * size # type: List[Any]
for i in items:
h = hash_function(i) % size

View File

@@ -16,8 +16,14 @@ from __future__ import absolute_import, print_function
import os
from os.path import dirname, abspath, join
try:
from typing import Iterable # noqa
except ImportError:
pass
def source_files(top):
# type: (str) -> Iterable[str]
"""
Recursively find all interesting source files and directories in the
directory tree starting at top. Yield a path to each file.
@@ -30,6 +36,7 @@ def source_files(top):
def generate():
# type: () -> None
print("Dependencies from meta language directory:")
meta = dirname(abspath(__file__))
for path in source_files(meta):

View File

@@ -14,6 +14,7 @@ from cdsl.ast import Var
try:
from typing import Sequence # noqa
from cdsl.isa import TargetISA # noqa
from cdsl.ast import Def # noqa
from cdsl.xform import XForm, XFormGroup # noqa
except ImportError:
@@ -240,6 +241,7 @@ def gen_xform_group(xgrp, fmt):
def generate(isas, out_dir):
# type: (Sequence[TargetISA], str) -> None
fmt = Formatter()
gen_xform_group(legalize.narrow, fmt)
gen_xform_group(legalize.expand, fmt)

View File

@@ -12,8 +12,14 @@ import srcgen
from cdsl.types import ValueType
import base.types # noqa
try:
from typing import Iterable # noqa
except ImportError:
pass
def emit_type(ty, fmt):
# type: (ValueType, srcgen.Formatter) -> None
"""
Emit a constant definition of a single value type.
"""
@@ -25,6 +31,7 @@ def emit_type(ty, fmt):
def emit_vectors(bits, fmt):
# type: (int, srcgen.Formatter) -> None
"""
Emit definition for all vector types with `bits` total size.
"""
@@ -37,6 +44,7 @@ def emit_vectors(bits, fmt):
def emit_types(fmt):
# type: (srcgen.Formatter) -> None
for ty in ValueType.all_scalars:
emit_type(ty, fmt)
# Emit vector definitions for common SIMD sizes.
@@ -47,6 +55,7 @@ def emit_types(fmt):
def generate(out_dir):
# type: (str) -> None
fmt = srcgen.Formatter()
emit_types(fmt)
fmt.update_file('types.rs', out_dir)

View File

@@ -101,6 +101,7 @@ class Formatter(object):
self.fmt.indent_push()
def __exit__(self, t, v, tb):
# type: (object, object, object) -> None
self.fmt.indent_pop()
if self.after:
self.fmt.line(self.after)
@@ -126,13 +127,16 @@ class Formatter(object):
return Formatter._IndentedScope(self, after)
def format(self, fmt, *args):
# type: (str, *Any) -> None
self.line(fmt.format(*args))
def comment(self, s):
# type: (str) -> None
"""Add a comment line."""
self.line('// ' + s)
def doc_comment(self, s):
# type: (str) -> None
"""Add a (multi-line) documentation comment."""
s = re.sub('^', self.indent + '/// ', s, flags=re.M) + '\n'
self.lines.append(s)

View File

@@ -7,18 +7,25 @@ item is mapped to its offset in the final array.
This is a compression technique for compile-time generated tables.
"""
try:
from typing import Any, List, Dict, Tuple, Sequence # noqa
except ImportError:
pass
class UniqueTable:
"""
Collect items into the `table` list, removing duplicates.
"""
def __init__(self):
# type: () -> None
# List of items added in order.
self.table = list()
self.table = list() # type: List[Any]
# Map item -> index.
self.index = dict()
self.index = dict() # type: Dict[Any, int]
def add(self, item):
# type: (Any) -> int
"""
Add a single item to the table if it isn't already there.
@@ -40,11 +47,13 @@ class UniqueSeqTable:
Sequences don't have to be of the same length.
"""
def __init__(self):
self.table = list()
# type: () -> None
self.table = list() # type: List[Any]
# Map seq -> index.
self.index = dict()
self.index = dict() # type: Dict[Tuple[Any, ...], int]
def add(self, seq):
# type: (Sequence[Any]) -> int
"""
Add a sequence of items to the table. If the table already contains the
items in `seq` in the same order, use those instead.