Add more mypy annotations.
This commit is contained in:
@@ -47,9 +47,11 @@ class Def(object):
|
|||||||
self.expr = expr
|
self.expr = expr
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
|
# type: () -> str
|
||||||
return "{} << {!r}".format(self.defs, self.expr)
|
return "{} << {!r}".format(self.defs, self.expr)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
# type: () -> str
|
||||||
if len(self.defs) == 1:
|
if len(self.defs) == 1:
|
||||||
return "{!s} << {!s}".format(self.defs[0], self.expr)
|
return "{!s} << {!s}".format(self.defs[0], self.expr)
|
||||||
else:
|
else:
|
||||||
@@ -379,15 +381,18 @@ class Apply(Expr):
|
|||||||
return Def(other, self)
|
return Def(other, self)
|
||||||
|
|
||||||
def instname(self):
|
def instname(self):
|
||||||
|
# type: () -> str
|
||||||
i = self.inst.name
|
i = self.inst.name
|
||||||
for t in self.typevars:
|
for t in self.typevars:
|
||||||
i += '.{}'.format(t)
|
i += '.{}'.format(t)
|
||||||
return i
|
return i
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
|
# type: () -> str
|
||||||
return "Apply({}, {})".format(self.instname(), self.args)
|
return "Apply({}, {})".format(self.instname(), self.args)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
# type: () -> str
|
||||||
args = ', '.join(map(str, self.args))
|
args = ', '.join(map(str, self.args))
|
||||||
return '{}({})'.format(self.instname(), args)
|
return '{}({})'.format(self.instname(), args)
|
||||||
|
|
||||||
|
|||||||
@@ -193,6 +193,7 @@ class InstructionFormat(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def extract_names(globs):
|
def extract_names(globs):
|
||||||
|
# type: (Dict[str, Any]) -> None
|
||||||
"""
|
"""
|
||||||
Given a dict mapping name -> object as returned by `globals()`, find
|
Given a dict mapping name -> object as returned by `globals()`, find
|
||||||
all the InstructionFormat objects and set their name from the dict key.
|
all the InstructionFormat objects and set their name from the dict key.
|
||||||
|
|||||||
@@ -6,11 +6,13 @@ from .operands import Operand
|
|||||||
from .formats import InstructionFormat
|
from .formats import InstructionFormat
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from typing import Union, Sequence, List # noqa
|
from typing import Union, Sequence, List, Tuple, Any, TYPE_CHECKING # noqa
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .ast import Expr, Apply # noqa
|
||||||
|
from .typevar import TypeVar # noqa
|
||||||
# List of operands for ins/outs:
|
# List of operands for ins/outs:
|
||||||
OpList = Union[Sequence[Operand], Operand]
|
OpList = Union[Sequence[Operand], Operand]
|
||||||
MaybeBoundInst = Union['Instruction', 'BoundInstruction']
|
MaybeBoundInst = Union['Instruction', 'BoundInstruction']
|
||||||
from typing import Tuple, Any # noqa
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -122,6 +124,7 @@ class Instruction(object):
|
|||||||
InstructionGroup.append(self)
|
InstructionGroup.append(self)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
# type: () -> str
|
||||||
prefix = ', '.join(o.name for o in self.outs)
|
prefix = ', '.join(o.name for o in self.outs)
|
||||||
if prefix:
|
if prefix:
|
||||||
prefix = prefix + ' = '
|
prefix = prefix + ' = '
|
||||||
@@ -141,6 +144,7 @@ class Instruction(object):
|
|||||||
return self.name
|
return self.name
|
||||||
|
|
||||||
def blurb(self):
|
def blurb(self):
|
||||||
|
# type: () -> str
|
||||||
"""Get the first line of the doc comment"""
|
"""Get the first line of the doc comment"""
|
||||||
for line in self.__doc__.split('\n'):
|
for line in self.__doc__.split('\n'):
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
@@ -149,6 +153,7 @@ class Instruction(object):
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
def _verify_polymorphic(self):
|
def _verify_polymorphic(self):
|
||||||
|
# type: () -> None
|
||||||
"""
|
"""
|
||||||
Check if this instruction is polymorphic, and verify its use of type
|
Check if this instruction is polymorphic, and verify its use of type
|
||||||
variables.
|
variables.
|
||||||
@@ -193,6 +198,7 @@ class Instruction(object):
|
|||||||
self.ctrl_typevar = tv
|
self.ctrl_typevar = tv
|
||||||
|
|
||||||
def _verify_ctrl_typevar(self, ctrl_typevar):
|
def _verify_ctrl_typevar(self, ctrl_typevar):
|
||||||
|
# type: (TypeVar) -> List[TypeVar]
|
||||||
"""
|
"""
|
||||||
Verify that the use of TypeVars is consistent with `ctrl_typevar` as
|
Verify that the use of TypeVars is consistent with `ctrl_typevar` as
|
||||||
the controlling type variable.
|
the controlling type variable.
|
||||||
@@ -204,7 +210,7 @@ class Instruction(object):
|
|||||||
|
|
||||||
Return list of other type variables used, or raise an error.
|
Return list of other type variables used, or raise an error.
|
||||||
"""
|
"""
|
||||||
other_tvs = []
|
other_tvs = [] # type: List[TypeVar]
|
||||||
# Check value inputs.
|
# Check value inputs.
|
||||||
for opnum in self.value_opnums:
|
for opnum in self.value_opnums:
|
||||||
typ = self.ins[opnum].typevar
|
typ = self.ins[opnum].typevar
|
||||||
@@ -283,11 +289,12 @@ class Instruction(object):
|
|||||||
return (self, ())
|
return (self, ())
|
||||||
|
|
||||||
def __call__(self, *args):
|
def __call__(self, *args):
|
||||||
|
# type: (*Expr) -> Apply
|
||||||
"""
|
"""
|
||||||
Create an `ast.Apply` AST node representing the application of this
|
Create an `ast.Apply` AST node representing the application of this
|
||||||
instruction to the arguments.
|
instruction to the arguments.
|
||||||
"""
|
"""
|
||||||
from .ast import Apply
|
from .ast import Apply # noqa
|
||||||
return Apply(self, args)
|
return Apply(self, args)
|
||||||
|
|
||||||
|
|
||||||
@@ -303,6 +310,7 @@ class BoundInstruction(object):
|
|||||||
assert len(typevars) <= 1 + len(inst.other_typevars)
|
assert len(typevars) <= 1 + len(inst.other_typevars)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
# type: () -> str
|
||||||
return '.'.join([self.inst.name, ] + list(map(str, self.typevars)))
|
return '.'.join([self.inst.name, ] + list(map(str, self.typevars)))
|
||||||
|
|
||||||
def bind(self, *args):
|
def bind(self, *args):
|
||||||
@@ -336,9 +344,10 @@ class BoundInstruction(object):
|
|||||||
return (self.inst, self.typevars)
|
return (self.inst, self.typevars)
|
||||||
|
|
||||||
def __call__(self, *args):
|
def __call__(self, *args):
|
||||||
|
# type: (*Expr) -> Apply
|
||||||
"""
|
"""
|
||||||
Create an `ast.Apply` AST node representing the application of this
|
Create an `ast.Apply` AST node representing the application of this
|
||||||
instruction to the arguments.
|
instruction to the arguments.
|
||||||
"""
|
"""
|
||||||
from .ast import Apply
|
from .ast import Apply # noqa
|
||||||
return Apply(self, args)
|
return Apply(self, args)
|
||||||
|
|||||||
@@ -97,6 +97,7 @@ class TargetISA(object):
|
|||||||
self.settings.number_predicate(enc.isap)
|
self.settings.number_predicate(enc.isap)
|
||||||
|
|
||||||
def _collect_regclasses(self):
|
def _collect_regclasses(self):
|
||||||
|
# type: () -> None
|
||||||
"""
|
"""
|
||||||
Collect and number register classes.
|
Collect and number register classes.
|
||||||
|
|
||||||
@@ -132,6 +133,7 @@ class CPUMode(object):
|
|||||||
return self.name
|
return self.name
|
||||||
|
|
||||||
def enc(self, *args, **kwargs):
|
def enc(self, *args, **kwargs):
|
||||||
|
# type: (*Any, **Any) -> None
|
||||||
"""
|
"""
|
||||||
Add a new encoding to this CPU mode.
|
Add a new encoding to this CPU mode.
|
||||||
|
|
||||||
@@ -186,7 +188,7 @@ class EncRecipe(object):
|
|||||||
return self.name
|
return self.name
|
||||||
|
|
||||||
def _verify_constraints(self, seq):
|
def _verify_constraints(self, seq):
|
||||||
# (ConstraintSeq) -> Sequence[OperandConstraint]
|
# type: (ConstraintSeq) -> Sequence[OperandConstraint]
|
||||||
if not isinstance(seq, tuple):
|
if not isinstance(seq, tuple):
|
||||||
seq = (seq,)
|
seq = (seq,)
|
||||||
for c in seq:
|
for c in seq:
|
||||||
@@ -194,7 +196,7 @@ class EncRecipe(object):
|
|||||||
# An integer constraint is bound to a value operand.
|
# An integer constraint is bound to a value operand.
|
||||||
# Check that it is in range.
|
# Check that it is in range.
|
||||||
assert c >= 0
|
assert c >= 0
|
||||||
if not format.has_value_list:
|
if not self.format.has_value_list:
|
||||||
assert c < self.format.num_value_operands
|
assert c < self.format.num_value_operands
|
||||||
else:
|
else:
|
||||||
assert isinstance(c, RegClass) or isinstance(c, Register)
|
assert isinstance(c, RegClass) or isinstance(c, Register)
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ except ImportError:
|
|||||||
|
|
||||||
|
|
||||||
def _is_parent(a, b):
|
def _is_parent(a, b):
|
||||||
|
# type: (PredContext, PredContext) -> bool
|
||||||
"""
|
"""
|
||||||
Return true if a is a parent of b, or equal to it.
|
Return true if a is a parent of b, or equal to it.
|
||||||
"""
|
"""
|
||||||
@@ -46,6 +47,7 @@ def _is_parent(a, b):
|
|||||||
|
|
||||||
|
|
||||||
def _descendant(a, b):
|
def _descendant(a, b):
|
||||||
|
# type: (PredContext, PredContext) -> PredContext
|
||||||
"""
|
"""
|
||||||
If a is a parent of b or b is a parent of a, return the descendant of the
|
If a is a parent of b or b is a parent of a, return the descendant of the
|
||||||
two.
|
two.
|
||||||
|
|||||||
@@ -26,7 +26,8 @@ from . import is_power_of_two, next_power_of_two
|
|||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from typing import Sequence, Tuple, List, Dict # noqa
|
from typing import Sequence, Tuple, List, Dict, Any, TYPE_CHECKING # noqa
|
||||||
|
if TYPE_CHECKING:
|
||||||
from .isa import TargetISA # noqa
|
from .isa import TargetISA # noqa
|
||||||
# A tuple uniquely identifying a register class inside a register bank.
|
# A tuple uniquely identifying a register class inside a register bank.
|
||||||
# (count, width, start)
|
# (count, width, start)
|
||||||
@@ -189,6 +190,7 @@ class RegClass(object):
|
|||||||
bank.classes.append(self)
|
bank.classes.append(self)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
# type: () -> str
|
||||||
return self.name
|
return self.name
|
||||||
|
|
||||||
def rctup(self):
|
def rctup(self):
|
||||||
@@ -223,6 +225,7 @@ class RegClass(object):
|
|||||||
return (count, self.width, start)
|
return (count, self.width, start)
|
||||||
|
|
||||||
def __getitem__(self, sliced):
|
def __getitem__(self, sliced):
|
||||||
|
# type: (slice) -> RegClass
|
||||||
"""
|
"""
|
||||||
Create a sub-class of a register class using slice notation. The slice
|
Create a sub-class of a register class using slice notation. The slice
|
||||||
indexes refer to allocations in the parent register class, not register
|
indexes refer to allocations in the parent register class, not register
|
||||||
@@ -273,6 +276,7 @@ class RegClass(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def extract_names(globs):
|
def extract_names(globs):
|
||||||
|
# type: (Dict[str, Any]) -> None
|
||||||
"""
|
"""
|
||||||
Given a dict mapping name -> object as returned by `globals()`, find
|
Given a dict mapping name -> object as returned by `globals()`, find
|
||||||
all the RegClass objects and set their name from the dict key.
|
all the RegClass objects and set their name from the dict key.
|
||||||
|
|||||||
@@ -21,11 +21,12 @@ class ValueType(object):
|
|||||||
_registry = dict() # type: Dict[str, ValueType]
|
_registry = dict() # type: Dict[str, ValueType]
|
||||||
|
|
||||||
# List of all the scalar types.
|
# List of all the scalar types.
|
||||||
all_scalars = list() # type: List[ValueType]
|
all_scalars = list() # type: List[ScalarType]
|
||||||
|
|
||||||
def __init__(self, name, membytes, doc):
|
def __init__(self, name, membytes, doc):
|
||||||
# type: (str, int, str) -> None
|
# type: (str, int, str) -> None
|
||||||
self.name = name
|
self.name = name
|
||||||
|
self.number = None # type: int
|
||||||
self.membytes = membytes
|
self.membytes = membytes
|
||||||
self.__doc__ = doc
|
self.__doc__ = doc
|
||||||
assert name not in ValueType._registry
|
assert name not in ValueType._registry
|
||||||
|
|||||||
@@ -9,7 +9,9 @@ import math
|
|||||||
from . import types, is_power_of_two
|
from . import types, is_power_of_two
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from typing import Tuple, Union # noqa
|
from typing import Tuple, Union, TYPE_CHECKING # noqa
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from srcgen import Formatter # noqa
|
||||||
Interval = Tuple[int, int]
|
Interval = Tuple[int, int]
|
||||||
# An Interval where `True` means 'everything'
|
# An Interval where `True` means 'everything'
|
||||||
BoolInterval = Union[bool, Interval]
|
BoolInterval = Union[bool, Interval]
|
||||||
@@ -143,7 +145,11 @@ class TypeSet(object):
|
|||||||
return h
|
return h
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
|
# type: (object) -> bool
|
||||||
|
if isinstance(other, TypeSet):
|
||||||
return self.typeset_key() == other.typeset_key()
|
return self.typeset_key() == other.typeset_key()
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
# type: () -> str
|
# type: () -> str
|
||||||
@@ -157,6 +163,7 @@ class TypeSet(object):
|
|||||||
return s + ')'
|
return s + ')'
|
||||||
|
|
||||||
def emit_fields(self, fmt):
|
def emit_fields(self, fmt):
|
||||||
|
# type: (Formatter) -> None
|
||||||
"""Emit field initializers for this typeset."""
|
"""Emit field initializers for this typeset."""
|
||||||
fmt.comment(repr(self))
|
fmt.comment(repr(self))
|
||||||
fields = ('lanes', 'int', 'float', 'bool')
|
fields = ('lanes', 'int', 'float', 'bool')
|
||||||
@@ -299,6 +306,9 @@ class TypeVar(object):
|
|||||||
.format(self.name, self.type_set))
|
.format(self.name, self.type_set))
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
|
# type: (object) -> bool
|
||||||
|
if not isinstance(other, TypeVar):
|
||||||
|
return False
|
||||||
if self.is_derived and other.is_derived:
|
if self.is_derived and other.is_derived:
|
||||||
return (
|
return (
|
||||||
self.derived_func == other.derived_func and
|
self.derived_func == other.derived_func and
|
||||||
|
|||||||
@@ -40,10 +40,6 @@ class Rtl(object):
|
|||||||
# type: (*DefApply) -> None
|
# type: (*DefApply) -> None
|
||||||
self.rtl = tuple(map(canonicalize_defapply, args))
|
self.rtl = tuple(map(canonicalize_defapply, args))
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
# type: () -> Iterator[Def]
|
|
||||||
return iter(self.rtl)
|
|
||||||
|
|
||||||
|
|
||||||
class XForm(object):
|
class XForm(object):
|
||||||
"""
|
"""
|
||||||
@@ -105,10 +101,11 @@ class XForm(object):
|
|||||||
self._collect_typevars()
|
self._collect_typevars()
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
|
# type: () -> str
|
||||||
s = "XForm(inputs={}, defs={},\n ".format(self.inputs, self.defs)
|
s = "XForm(inputs={}, defs={},\n ".format(self.inputs, self.defs)
|
||||||
s += '\n '.join(str(n) for n in self.src)
|
s += '\n '.join(str(n) for n in self.src.rtl)
|
||||||
s += '\n=>\n '
|
s += '\n=>\n '
|
||||||
s += '\n '.join(str(n) for n in self.dst)
|
s += '\n '.join(str(n) for n in self.dst.rtl)
|
||||||
s += '\n)'
|
s += '\n)'
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|||||||
@@ -8,8 +8,14 @@ quadratically probed hash table.
|
|||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from cdsl import next_power_of_two
|
from cdsl import next_power_of_two
|
||||||
|
|
||||||
|
try:
|
||||||
|
from typing import Any, List, Sequence, Callable # noqa
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def simple_hash(s):
|
def simple_hash(s):
|
||||||
|
# type: (str) -> int
|
||||||
"""
|
"""
|
||||||
Compute a primitive hash of a string.
|
Compute a primitive hash of a string.
|
||||||
|
|
||||||
@@ -26,6 +32,7 @@ def simple_hash(s):
|
|||||||
|
|
||||||
|
|
||||||
def compute_quadratic(items, hash_function):
|
def compute_quadratic(items, hash_function):
|
||||||
|
# type: (Sequence[Any], Callable[[Any], int]) -> List[Any]
|
||||||
"""
|
"""
|
||||||
Compute an open addressed, quadratically probed hash table containing
|
Compute an open addressed, quadratically probed hash table containing
|
||||||
`items`. The returned table is a list containing the elements of the
|
`items`. The returned table is a list containing the elements of the
|
||||||
@@ -43,7 +50,7 @@ def compute_quadratic(items, hash_function):
|
|||||||
items = list(items)
|
items = list(items)
|
||||||
# Table size must be a power of two. Aim for >20% unused slots.
|
# Table size must be a power of two. Aim for >20% unused slots.
|
||||||
size = next_power_of_two(int(1.20*len(items)))
|
size = next_power_of_two(int(1.20*len(items)))
|
||||||
table = [None] * size
|
table = [None] * size # type: List[Any]
|
||||||
|
|
||||||
for i in items:
|
for i in items:
|
||||||
h = hash_function(i) % size
|
h = hash_function(i) % size
|
||||||
|
|||||||
@@ -16,8 +16,14 @@ from __future__ import absolute_import, print_function
|
|||||||
import os
|
import os
|
||||||
from os.path import dirname, abspath, join
|
from os.path import dirname, abspath, join
|
||||||
|
|
||||||
|
try:
|
||||||
|
from typing import Iterable # noqa
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def source_files(top):
|
def source_files(top):
|
||||||
|
# type: (str) -> Iterable[str]
|
||||||
"""
|
"""
|
||||||
Recursively find all interesting source files and directories in the
|
Recursively find all interesting source files and directories in the
|
||||||
directory tree starting at top. Yield a path to each file.
|
directory tree starting at top. Yield a path to each file.
|
||||||
@@ -30,6 +36,7 @@ def source_files(top):
|
|||||||
|
|
||||||
|
|
||||||
def generate():
|
def generate():
|
||||||
|
# type: () -> None
|
||||||
print("Dependencies from meta language directory:")
|
print("Dependencies from meta language directory:")
|
||||||
meta = dirname(abspath(__file__))
|
meta = dirname(abspath(__file__))
|
||||||
for path in source_files(meta):
|
for path in source_files(meta):
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from cdsl.ast import Var
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from typing import Sequence # noqa
|
from typing import Sequence # noqa
|
||||||
|
from cdsl.isa import TargetISA # noqa
|
||||||
from cdsl.ast import Def # noqa
|
from cdsl.ast import Def # noqa
|
||||||
from cdsl.xform import XForm, XFormGroup # noqa
|
from cdsl.xform import XForm, XFormGroup # noqa
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -240,6 +241,7 @@ def gen_xform_group(xgrp, fmt):
|
|||||||
|
|
||||||
|
|
||||||
def generate(isas, out_dir):
|
def generate(isas, out_dir):
|
||||||
|
# type: (Sequence[TargetISA], str) -> None
|
||||||
fmt = Formatter()
|
fmt = Formatter()
|
||||||
gen_xform_group(legalize.narrow, fmt)
|
gen_xform_group(legalize.narrow, fmt)
|
||||||
gen_xform_group(legalize.expand, fmt)
|
gen_xform_group(legalize.expand, fmt)
|
||||||
|
|||||||
@@ -12,8 +12,14 @@ import srcgen
|
|||||||
from cdsl.types import ValueType
|
from cdsl.types import ValueType
|
||||||
import base.types # noqa
|
import base.types # noqa
|
||||||
|
|
||||||
|
try:
|
||||||
|
from typing import Iterable # noqa
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def emit_type(ty, fmt):
|
def emit_type(ty, fmt):
|
||||||
|
# type: (ValueType, srcgen.Formatter) -> None
|
||||||
"""
|
"""
|
||||||
Emit a constant definition of a single value type.
|
Emit a constant definition of a single value type.
|
||||||
"""
|
"""
|
||||||
@@ -25,6 +31,7 @@ def emit_type(ty, fmt):
|
|||||||
|
|
||||||
|
|
||||||
def emit_vectors(bits, fmt):
|
def emit_vectors(bits, fmt):
|
||||||
|
# type: (int, srcgen.Formatter) -> None
|
||||||
"""
|
"""
|
||||||
Emit definition for all vector types with `bits` total size.
|
Emit definition for all vector types with `bits` total size.
|
||||||
"""
|
"""
|
||||||
@@ -37,6 +44,7 @@ def emit_vectors(bits, fmt):
|
|||||||
|
|
||||||
|
|
||||||
def emit_types(fmt):
|
def emit_types(fmt):
|
||||||
|
# type: (srcgen.Formatter) -> None
|
||||||
for ty in ValueType.all_scalars:
|
for ty in ValueType.all_scalars:
|
||||||
emit_type(ty, fmt)
|
emit_type(ty, fmt)
|
||||||
# Emit vector definitions for common SIMD sizes.
|
# Emit vector definitions for common SIMD sizes.
|
||||||
@@ -47,6 +55,7 @@ def emit_types(fmt):
|
|||||||
|
|
||||||
|
|
||||||
def generate(out_dir):
|
def generate(out_dir):
|
||||||
|
# type: (str) -> None
|
||||||
fmt = srcgen.Formatter()
|
fmt = srcgen.Formatter()
|
||||||
emit_types(fmt)
|
emit_types(fmt)
|
||||||
fmt.update_file('types.rs', out_dir)
|
fmt.update_file('types.rs', out_dir)
|
||||||
|
|||||||
@@ -101,6 +101,7 @@ class Formatter(object):
|
|||||||
self.fmt.indent_push()
|
self.fmt.indent_push()
|
||||||
|
|
||||||
def __exit__(self, t, v, tb):
|
def __exit__(self, t, v, tb):
|
||||||
|
# type: (object, object, object) -> None
|
||||||
self.fmt.indent_pop()
|
self.fmt.indent_pop()
|
||||||
if self.after:
|
if self.after:
|
||||||
self.fmt.line(self.after)
|
self.fmt.line(self.after)
|
||||||
@@ -126,13 +127,16 @@ class Formatter(object):
|
|||||||
return Formatter._IndentedScope(self, after)
|
return Formatter._IndentedScope(self, after)
|
||||||
|
|
||||||
def format(self, fmt, *args):
|
def format(self, fmt, *args):
|
||||||
|
# type: (str, *Any) -> None
|
||||||
self.line(fmt.format(*args))
|
self.line(fmt.format(*args))
|
||||||
|
|
||||||
def comment(self, s):
|
def comment(self, s):
|
||||||
|
# type: (str) -> None
|
||||||
"""Add a comment line."""
|
"""Add a comment line."""
|
||||||
self.line('// ' + s)
|
self.line('// ' + s)
|
||||||
|
|
||||||
def doc_comment(self, s):
|
def doc_comment(self, s):
|
||||||
|
# type: (str) -> None
|
||||||
"""Add a (multi-line) documentation comment."""
|
"""Add a (multi-line) documentation comment."""
|
||||||
s = re.sub('^', self.indent + '/// ', s, flags=re.M) + '\n'
|
s = re.sub('^', self.indent + '/// ', s, flags=re.M) + '\n'
|
||||||
self.lines.append(s)
|
self.lines.append(s)
|
||||||
|
|||||||
@@ -7,18 +7,25 @@ item is mapped to its offset in the final array.
|
|||||||
This is a compression technique for compile-time generated tables.
|
This is a compression technique for compile-time generated tables.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
from typing import Any, List, Dict, Tuple, Sequence # noqa
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class UniqueTable:
|
class UniqueTable:
|
||||||
"""
|
"""
|
||||||
Collect items into the `table` list, removing duplicates.
|
Collect items into the `table` list, removing duplicates.
|
||||||
"""
|
"""
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
# type: () -> None
|
||||||
# List of items added in order.
|
# List of items added in order.
|
||||||
self.table = list()
|
self.table = list() # type: List[Any]
|
||||||
# Map item -> index.
|
# Map item -> index.
|
||||||
self.index = dict()
|
self.index = dict() # type: Dict[Any, int]
|
||||||
|
|
||||||
def add(self, item):
|
def add(self, item):
|
||||||
|
# type: (Any) -> int
|
||||||
"""
|
"""
|
||||||
Add a single item to the table if it isn't already there.
|
Add a single item to the table if it isn't already there.
|
||||||
|
|
||||||
@@ -40,11 +47,13 @@ class UniqueSeqTable:
|
|||||||
Sequences don't have to be of the same length.
|
Sequences don't have to be of the same length.
|
||||||
"""
|
"""
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.table = list()
|
# type: () -> None
|
||||||
|
self.table = list() # type: List[Any]
|
||||||
# Map seq -> index.
|
# Map seq -> index.
|
||||||
self.index = dict()
|
self.index = dict() # type: Dict[Tuple[Any, ...], int]
|
||||||
|
|
||||||
def add(self, seq):
|
def add(self, seq):
|
||||||
|
# type: (Sequence[Any]) -> int
|
||||||
"""
|
"""
|
||||||
Add a sequence of items to the table. If the table already contains the
|
Add a sequence of items to the table. If the table already contains the
|
||||||
items in `seq` in the same order, use those instead.
|
items in `seq` in the same order, use those instead.
|
||||||
|
|||||||
Reference in New Issue
Block a user