Add PEP 484 type annotations to a bunch of Python code.

Along with the mypy tool, this helps find bugs in the Python code
handling the instruction definition data structures.
This commit is contained in:
Jakob Stoklund Olesen
2016-10-26 12:19:55 -07:00
parent b6ff2621f9
commit 6748817985
8 changed files with 140 additions and 42 deletions

View File

@@ -15,7 +15,7 @@ parser = argparse.ArgumentParser(description='Generate sources for Cretonne.')
parser.add_argument('--out-dir', help='set output directory') parser.add_argument('--out-dir', help='set output directory')
args = parser.parse_args() args = parser.parse_args()
out_dir = args.out_dir out_dir = args.out_dir # type: ignore
isas = isa.all_isas() isas = isa.all_isas()

View File

@@ -9,14 +9,23 @@ import re
import math import math
import importlib import importlib
from collections import OrderedDict from collections import OrderedDict
from .predicates import And from .predicates import And, Predicate, FieldPredicate # noqa
from .ast import Apply
# The typing module is only required by mypy, and we don't use these imports
# outside type comments.
try:
from typing import Tuple, Union, Any, Iterable, Sequence # noqa
MaybeBoundInst = Union['Instruction', 'BoundInstruction']
AnyPredicate = Union['Predicate', 'FieldPredicate']
except ImportError:
pass
camel_re = re.compile('(^|_)([a-z])') camel_re = re.compile('(^|_)([a-z])')
def camel_case(s): def camel_case(s):
# type: (str) -> str
"""Convert the string s to CamelCase""" """Convert the string s to CamelCase"""
return camel_re.sub(lambda m: m.group(2).upper(), s) return camel_re.sub(lambda m: m.group(2).upper(), s)
@@ -133,7 +142,7 @@ class SettingGroup(object):
""" """
# The currently open setting group. # The currently open setting group.
_current = None _current = None # type: SettingGroup
def __init__(self, name, parent=None): def __init__(self, name, parent=None):
self.name = name self.name = name
@@ -175,7 +184,6 @@ class SettingGroup(object):
.format(self, SettingGroup._current)) .format(self, SettingGroup._current))
SettingGroup._current = None SettingGroup._current = None
if globs: if globs:
from .predicates import Predicate
for name, obj in globs.iteritems(): for name, obj in globs.iteritems():
if isinstance(obj, Setting): if isinstance(obj, Setting):
assert obj.name is None, obj.name assert obj.name is None, obj.name
@@ -381,10 +389,10 @@ class ValueType(object):
""" """
# Map name -> ValueType. # Map name -> ValueType.
_registry = dict() _registry = dict() # type: Dict[str, ValueType]
# List of all the scalar types. # List of all the scalar types.
all_scalars = list() all_scalars = list() # type: List[ValueType]
def __init__(self, name, membytes, doc): def __init__(self, name, membytes, doc):
self.name = name self.name = name
@@ -534,7 +542,7 @@ class InstructionGroup(object):
""" """
# The currently open instruction group. # The currently open instruction group.
_current = None _current = None # type: InstructionGroup
def open(self): def open(self):
""" """
@@ -644,15 +652,17 @@ class InstructionFormat(object):
""" """
# Map (multiple_results, kind, kind, ...) -> InstructionFormat # Map (multiple_results, kind, kind, ...) -> InstructionFormat
_registry = dict() _registry = dict() # type: Dict[Tuple, InstructionFormat]
# All existing formats. # All existing formats.
all_formats = list() all_formats = list() # type: List[InstructionFormat]
def __init__(self, *kinds, **kwargs): def __init__(self, *kinds, **kwargs):
self.name = kwargs.get('name', None) # type: (*Union[OperandKind, Tuple[str, OperandKind]], **Any) -> None # noqa
self.name = kwargs.get('name', None) # type: str
self.multiple_results = kwargs.get('multiple_results', False) self.multiple_results = kwargs.get('multiple_results', False)
self.boxed_storage = kwargs.get('boxed_storage', False) self.boxed_storage = kwargs.get('boxed_storage', False)
self.members = list() # type: List[str]
self.kinds = tuple(self._process_member_names(kinds)) self.kinds = tuple(self._process_member_names(kinds))
# Which of self.kinds are `value`? # Which of self.kinds are `value`?
@@ -660,7 +670,7 @@ class InstructionFormat(object):
i for i, k in enumerate(self.kinds) if k is value) i for i, k in enumerate(self.kinds) if k is value)
# The typevar_operand argument must point to a 'value' operand. # The typevar_operand argument must point to a 'value' operand.
self.typevar_operand = kwargs.get('typevar_operand', None) self.typevar_operand = kwargs.get('typevar_operand', None) # type: int
if self.typevar_operand is not None: if self.typevar_operand is not None:
assert self.kinds[self.typevar_operand] is value, \ assert self.kinds[self.typevar_operand] is value, \
"typevar_operand must indicate a 'value' operand" "typevar_operand must indicate a 'value' operand"
@@ -678,6 +688,7 @@ class InstructionFormat(object):
InstructionFormat.all_formats.append(self) InstructionFormat.all_formats.append(self)
def _process_member_names(self, kinds): def _process_member_names(self, kinds):
# type: (Sequence[Union[OperandKind, Tuple[str, OperandKind]]]) -> Iterable[OperandKind] # noqa
""" """
Extract names of all the immediate operands in the kinds tuple. Extract names of all the immediate operands in the kinds tuple.
@@ -687,14 +698,14 @@ class InstructionFormat(object):
Yields the operand kinds. Yields the operand kinds.
""" """
self.members = list() for arg in kinds:
for i, k in enumerate(kinds): if isinstance(arg, OperandKind):
if isinstance(k, tuple): member = arg.default_member
member, k = k k = arg
else: else:
member = k.default_member member, k = arg
yield k
self.members.append(member) self.members.append(member)
yield k
# Create `FormatField` instances for the immediates. # Create `FormatField` instances for the immediates.
if isinstance(k, ImmediateKind): if isinstance(k, ImmediateKind):
@@ -704,6 +715,7 @@ class InstructionFormat(object):
@staticmethod @staticmethod
def lookup(ins, outs): def lookup(ins, outs):
# type: (Sequence[Operand], Sequence[Operand]) -> InstructionFormat
""" """
Find an existing instruction format that matches the given lists of Find an existing instruction format that matches the given lists of
instruction inputs and outputs. instruction inputs and outputs.
@@ -750,6 +762,7 @@ class FormatField(object):
""" """
def __init__(self, format, operand, name): def __init__(self, format, operand, name):
# type: (InstructionFormat, int, str) -> None
self.format = format self.format = format
self.operand = operand self.operand = operand
self.name = name self.name = name
@@ -758,6 +771,7 @@ class FormatField(object):
return '{}.{}'.format(self.format.name, self.name) return '{}.{}'.format(self.format.name, self.name)
def rust_name(self): def rust_name(self):
# type: () -> str
if self.format.boxed_storage: if self.format.boxed_storage:
return 'data.' + self.name return 'data.' + self.name
else: else:
@@ -782,6 +796,7 @@ class Instruction(object):
""" """
def __init__(self, name, doc, ins=(), outs=(), **kwargs): def __init__(self, name, doc, ins=(), outs=(), **kwargs):
# type: (str, str, Union[Sequence[Operand], Operand], Union[Sequence[Operand], Operand], **Any) -> None # noqa
self.name = name self.name = name
self.camel_name = camel_case(name) self.camel_name = camel_case(name)
self.__doc__ = doc self.__doc__ = doc
@@ -898,6 +913,7 @@ class Instruction(object):
@staticmethod @staticmethod
def _to_operand_tuple(x): def _to_operand_tuple(x):
# type: (Union[Sequence[Operand], Operand]) -> Tuple[Operand, ...]
# Allow a single Operand instance instead of the awkward singleton # Allow a single Operand instance instead of the awkward singleton
# tuple syntax. # tuple syntax.
if isinstance(x, Operand): if isinstance(x, Operand):
@@ -909,6 +925,7 @@ class Instruction(object):
return x return x
def bind(self, *args): def bind(self, *args):
# type: (*ValueType) -> BoundInstruction
""" """
Bind a polymorphic instruction to a concrete list of type variable Bind a polymorphic instruction to a concrete list of type variable
values. values.
@@ -917,6 +934,7 @@ class Instruction(object):
return BoundInstruction(self, args) return BoundInstruction(self, args)
def __getattr__(self, name): def __getattr__(self, name):
# type: (str) -> BoundInstruction
""" """
Bind a polymorphic instruction to a single type variable with dot Bind a polymorphic instruction to a single type variable with dot
syntax: syntax:
@@ -926,6 +944,7 @@ class Instruction(object):
return self.bind(ValueType.by_name(name)) return self.bind(ValueType.by_name(name))
def fully_bound(self): def fully_bound(self):
# type: () -> Tuple[Instruction, Tuple[ValueType, ...]]
""" """
Verify that all typevars have been bound, and return a Verify that all typevars have been bound, and return a
`(inst, typevars)` pair. `(inst, typevars)` pair.
@@ -941,6 +960,7 @@ class Instruction(object):
Create an `ast.Apply` AST node representing the application of this Create an `ast.Apply` AST node representing the application of this
instruction to the arguments. instruction to the arguments.
""" """
from .ast import Apply
return Apply(self, args) return Apply(self, args)
@@ -950,6 +970,7 @@ class BoundInstruction(object):
""" """
def __init__(self, inst, typevars): def __init__(self, inst, typevars):
# type: (Instruction, Tuple[ValueType, ...]) -> None
self.inst = inst self.inst = inst
self.typevars = typevars self.typevars = typevars
assert len(typevars) <= 1 + len(inst.other_typevars) assert len(typevars) <= 1 + len(inst.other_typevars)
@@ -958,12 +979,14 @@ class BoundInstruction(object):
return '.'.join([self.inst.name, ] + list(map(str, self.typevars))) return '.'.join([self.inst.name, ] + list(map(str, self.typevars)))
def bind(self, *args): def bind(self, *args):
# type: (*ValueType) -> BoundInstruction
""" """
Bind additional typevars. Bind additional typevars.
""" """
return BoundInstruction(self.inst, self.typevars + args) return BoundInstruction(self.inst, self.typevars + args)
def __getattr__(self, name): def __getattr__(self, name):
# type: (str) -> BoundInstruction
""" """
Bind an additional typevar dot syntax: Bind an additional typevar dot syntax:
@@ -972,6 +995,7 @@ class BoundInstruction(object):
return self.bind(ValueType.by_name(name)) return self.bind(ValueType.by_name(name))
def fully_bound(self): def fully_bound(self):
# type: () -> Tuple[Instruction, Tuple[ValueType, ...]]
""" """
Verify that all typevars have been bound, and return a Verify that all typevars have been bound, and return a
`(inst, typevars)` pair. `(inst, typevars)` pair.
@@ -989,6 +1013,7 @@ class BoundInstruction(object):
Create an `ast.Apply` AST node representing the application of this Create an `ast.Apply` AST node representing the application of this
instruction to the arguments. instruction to the arguments.
""" """
from .ast import Apply
return Apply(self, args) return Apply(self, args)
@@ -1139,6 +1164,7 @@ class Encoding(object):
""" """
def __init__(self, cpumode, inst, recipe, encbits, instp=None, isap=None): def __init__(self, cpumode, inst, recipe, encbits, instp=None, isap=None):
# type: (CPUMode, MaybeBoundInst, EncRecipe, int, AnyPredicate, AnyPredicate) -> None # noqa
assert isinstance(cpumode, CPUMode) assert isinstance(cpumode, CPUMode)
assert isinstance(recipe, EncRecipe) assert isinstance(recipe, EncRecipe)
self.inst, self.typevars = inst.fully_bound() self.inst, self.typevars = inst.fully_bound()

View File

@@ -5,6 +5,12 @@ This module defines classes that can be used to create abstract syntax trees
for patern matching an rewriting of cretonne instructions. for patern matching an rewriting of cretonne instructions.
""" """
from __future__ import absolute_import from __future__ import absolute_import
from . import Instruction, BoundInstruction
try:
from typing import Union, Tuple # noqa
except ImportError:
pass
class Def(object): class Def(object):
@@ -29,10 +35,12 @@ class Def(object):
""" """
def __init__(self, defs, expr): def __init__(self, defs, expr):
# type: (Union[Var, Tuple[Var, ...]], Apply) -> None
if not isinstance(defs, tuple): if not isinstance(defs, tuple):
defs = (defs,) self.defs = (defs,) # type: Tuple[Var, ...]
assert isinstance(expr, Expr) else:
self.defs = defs self.defs = defs
assert isinstance(expr, Apply)
self.expr = expr self.expr = expr
def __repr__(self): def __repr__(self):
@@ -42,7 +50,18 @@ class Def(object):
if len(self.defs) == 1: if len(self.defs) == 1:
return "{!s} << {!s}".format(self.defs[0], self.expr) return "{!s} << {!s}".format(self.defs[0], self.expr)
else: else:
return "({}) << {!s}".format(", ".join(self.defs), self.expr) return "({}) << {!s}".format(
', '.join(map(str, self.defs)), self.expr)
def root_inst(self):
# type: () -> Instruction
"""Get the instruction at the root of this tree."""
return self.expr.root_inst()
def defs_expr(self):
# type: () -> Tuple[Tuple[Var, ...], Apply]
"""Split into a defs tuple and an Apply expr."""
return (self.defs, self.expr)
class Expr(object): class Expr(object):
@@ -50,12 +69,6 @@ class Expr(object):
An AST expression. An AST expression.
""" """
def __rlshift__(self, other):
"""
Define variables using `var << expr` or `(v1, v2) << expr`.
"""
return Def(other, self)
class Var(Expr): class Var(Expr):
""" """
@@ -63,6 +76,7 @@ class Var(Expr):
""" """
def __init__(self, name): def __init__(self, name):
# type: (str) -> None
self.name = name self.name = name
# Bitmask of contexts where this variable is defined. # Bitmask of contexts where this variable is defined.
# See XForm._rewrite_defs(). # See XForm._rewrite_defs().
@@ -98,16 +112,24 @@ class Apply(Expr):
""" """
def __init__(self, inst, args): def __init__(self, inst, args):
from . import BoundInstruction # type: (Union[Instruction, BoundInstruction], Tuple[Expr, ...]) -> None # noqa
if isinstance(inst, BoundInstruction): if isinstance(inst, BoundInstruction):
self.inst = inst.inst self.inst = inst.inst
self.typevars = inst.typevars self.typevars = inst.typevars
else: else:
assert isinstance(inst, Instruction)
self.inst = inst self.inst = inst
self.typevars = () self.typevars = ()
self.args = args self.args = args
assert len(self.inst.ins) == len(args) assert len(self.inst.ins) == len(args)
def __rlshift__(self, other):
# type: (Union[Var, Tuple[Var, ...]]) -> Def
"""
Define variables using `var << expr` or `(v1, v2) << expr`.
"""
return Def(other, self)
def instname(self): def instname(self):
i = self.inst.name i = self.inst.name
for t in self.typevars: for t in self.typevars:
@@ -120,3 +142,13 @@ class Apply(Expr):
def __str__(self): def __str__(self):
args = ', '.join(map(str, self.args)) args = ', '.join(map(str, self.args))
return '{}({})'.format(self.instname(), args) return '{}({})'.format(self.instname(), args)
def root_inst(self):
# type: () -> Instruction
"""Get the instruction at the root of this tree."""
return self.inst
def defs_expr(self):
# type: () -> Tuple[Tuple[Var, ...], Apply]
"""Split into a defs tuple and an Apply expr."""
return ((), self)

View File

@@ -15,7 +15,7 @@ from .ast import Var
from .xform import Rtl, XFormGroup from .xform import Rtl, XFormGroup
narrow = XFormGroup(""" narrow = XFormGroup('narrow', """
Legalize instructions by narrowing. Legalize instructions by narrowing.
The transformations in the 'narrow' group work by expressing The transformations in the 'narrow' group work by expressing
@@ -24,7 +24,7 @@ narrow = XFormGroup("""
operations are expressed in terms of smaller integer types. operations are expressed in terms of smaller integer types.
""") """)
expand = XFormGroup(""" expand = XFormGroup('expand', """
Legalize instructions by expansion. Legalize instructions by expansion.
Rewrite instructions in terms of other instructions, generally Rewrite instructions in terms of other instructions, generally
@@ -114,5 +114,5 @@ expand.legalize(
Rtl( Rtl(
(a1, b1) << isub_bout(x, y), (a1, b1) << isub_bout(x, y),
(a, b2) << isub_bout(a1, b_in), (a, b2) << isub_bout(a1, b_in),
c << bor(c1, c2) c << bor(b1, b2)
)) ))

View File

@@ -2,7 +2,13 @@
Instruction transformations. Instruction transformations.
""" """
from __future__ import absolute_import from __future__ import absolute_import
from .ast import Def, Var, Apply from .ast import Def, Var, Apply, Expr # noqa
try:
from typing import Union, Iterator, Sequence, Iterable # noqa
DefApply = Union[Def, Apply]
except ImportError:
pass
SRCCTX = 1 SRCCTX = 1
@@ -21,9 +27,11 @@ class Rtl(object):
""" """
def __init__(self, *args): def __init__(self, *args):
# type: (*DefApply) -> None
self.rtl = args self.rtl = args
def __iter__(self): def __iter__(self):
# type: () -> Iterator[DefApply]
return iter(self.rtl) return iter(self.rtl)
@@ -55,16 +63,17 @@ class XForm(object):
""" """
def __init__(self, src, dst): def __init__(self, src, dst):
# type: (Rtl, Rtl) -> None
self.src = src self.src = src
self.dst = dst self.dst = dst
# Variables that are inputs to the source pattern. # Variables that are inputs to the source pattern.
self.inputs = list() self.inputs = list() # type: List[Var]
# Variables defined in either src or dst. # Variables defined in either src or dst.
self.defs = list() self.defs = list() # type: List[Var]
# Rewrite variables in src and dst RTL lists to our own copies. # Rewrite variables in src and dst RTL lists to our own copies.
# Map name -> private Var. # Map name -> private Var.
symtab = dict() symtab = dict() # type: Dict[str, Var]
self._rewrite_rtl(src, symtab, SRCCTX) self._rewrite_rtl(src, symtab, SRCCTX)
num_src_inputs = len(self.inputs) num_src_inputs = len(self.inputs)
self._rewrite_rtl(dst, symtab, DSTCTX) self._rewrite_rtl(dst, symtab, DSTCTX)
@@ -90,7 +99,8 @@ class XForm(object):
return s return s
def _rewrite_rtl(self, rtl, symtab, context): def _rewrite_rtl(self, rtl, symtab, context):
for line in rtl: # type: (Rtl, Dict[str, Var], int) -> None
for line in rtl.rtl:
if isinstance(line, Def): if isinstance(line, Def):
line.defs = tuple( line.defs = tuple(
self._rewrite_defs(line.defs, symtab, context)) self._rewrite_defs(line.defs, symtab, context))
@@ -100,6 +110,7 @@ class XForm(object):
self._rewrite_expr(expr, symtab, context) self._rewrite_expr(expr, symtab, context)
def _rewrite_expr(self, expr, symtab, context): def _rewrite_expr(self, expr, symtab, context):
# type: (Apply, Dict[str, Var], int) -> None
""" """
Find all uses of variables in `expr` and replace them with our own Find all uses of variables in `expr` and replace them with our own
local symbols. local symbols.
@@ -113,6 +124,7 @@ class XForm(object):
self._rewrite_uses(expr, stack, symtab, context)) self._rewrite_uses(expr, stack, symtab, context))
def _rewrite_defs(self, defs, symtab, context): def _rewrite_defs(self, defs, symtab, context):
# type: (Sequence[Var], Dict[str, Var], int) -> Iterable[Var]
""" """
Given a tuple of symbols defined in a Def, rewrite them to local Given a tuple of symbols defined in a Def, rewrite them to local
symbols. Yield the new locals. symbols. Yield the new locals.
@@ -131,6 +143,7 @@ class XForm(object):
yield var yield var
def _rewrite_uses(self, expr, stack, symtab, context): def _rewrite_uses(self, expr, stack, symtab, context):
# type: (Apply, List[Apply], Dict[str, Var], int) -> Iterable[Expr]
""" """
Given an `Apply` expr, rewrite all uses in its arguments to local Given an `Apply` expr, rewrite all uses in its arguments to local
variables. Yield a sequence of new arguments. variables. Yield a sequence of new arguments.
@@ -140,7 +153,7 @@ class XForm(object):
for arg, operand in zip(expr.args, expr.inst.ins): for arg, operand in zip(expr.args, expr.inst.ins):
# Nested instructions are allowed. Visit recursively. # Nested instructions are allowed. Visit recursively.
if isinstance(arg, Apply): if isinstance(arg, Apply):
stack.push(arg) stack.append(arg)
yield arg yield arg
continue continue
if not isinstance(arg, Var): if not isinstance(arg, Var):
@@ -169,11 +182,14 @@ class XFormGroup(object):
A group of related transformations. A group of related transformations.
""" """
def __init__(self, doc): def __init__(self, name, doc):
self.xforms = list() # type: (str, str) -> None
self.xforms = list() # type: List[XForm]
self.name = name
self.__doc__ = doc self.__doc__ = doc
def legalize(self, src, dst): def legalize(self, src, dst):
# type: (Union[Def, Apply], Rtl) -> None
""" """
Add a legalization pattern to this group. Add a legalization pattern to this group.

View File

@@ -7,9 +7,11 @@ architecture supported by Cretonne.
""" """
from __future__ import absolute_import from __future__ import absolute_import
from . import riscv from . import riscv
from cretonne import TargetISA # noqa
def all_isas(): def all_isas():
# type: () -> List[TargetISA]
""" """
Get a list of all the supported target ISAs. Each target ISA is represented Get a list of all the supported target ISAs. Each target ISA is represented
as a :py:class:`cretonne.TargetISA` instance. as a :py:class:`cretonne.TargetISA` instance.

View File

@@ -22,37 +22,44 @@ from cretonne.predicates import IsSignedInt
def LOAD(funct3): def LOAD(funct3):
# type: (int) -> int
assert funct3 <= 0b111 assert funct3 <= 0b111
return 0b00000 | (funct3 << 5) return 0b00000 | (funct3 << 5)
def STORE(funct3): def STORE(funct3):
# type: (int) -> int
assert funct3 <= 0b111 assert funct3 <= 0b111
return 0b01000 | (funct3 << 5) return 0b01000 | (funct3 << 5)
def BRANCH(funct3): def BRANCH(funct3):
# type: (int) -> int
assert funct3 <= 0b111 assert funct3 <= 0b111
return 0b11000 | (funct3 << 5) return 0b11000 | (funct3 << 5)
def OPIMM(funct3, funct7=0): def OPIMM(funct3, funct7=0):
# type: (int, int) -> int
assert funct3 <= 0b111 assert funct3 <= 0b111
return 0b00100 | (funct3 << 5) | (funct7 << 8) return 0b00100 | (funct3 << 5) | (funct7 << 8)
def OPIMM32(funct3, funct7=0): def OPIMM32(funct3, funct7=0):
# type: (int, int) -> int
assert funct3 <= 0b111 assert funct3 <= 0b111
return 0b00110 | (funct3 << 5) | (funct7 << 8) return 0b00110 | (funct3 << 5) | (funct7 << 8)
def OP(funct3, funct7): def OP(funct3, funct7):
# type: (int, int) -> int
assert funct3 <= 0b111 assert funct3 <= 0b111
assert funct7 <= 0b1111111 assert funct7 <= 0b1111111
return 0b01100 | (funct3 << 5) | (funct7 << 8) return 0b01100 | (funct3 << 5) | (funct7 << 8)
def OP32(funct3, funct7): def OP32(funct3, funct7):
# type: (int, int) -> int
assert funct3 <= 0b111 assert funct3 <= 0b111
assert funct7 <= 0b1111111 assert funct7 <= 0b1111111
return 0b01110 | (funct3 << 5) | (funct7 << 8) return 0b01110 | (funct3 << 5) | (funct7 << 8)

View File

@@ -10,6 +10,11 @@ import sys
import os import os
import re import re
try:
from typing import Any # noqa
except ImportError:
pass
class Formatter(object): class Formatter(object):
""" """
@@ -38,19 +43,23 @@ class Formatter(object):
shiftwidth = 4 shiftwidth = 4
def __init__(self): def __init__(self):
# type: () -> None
self.indent = '' self.indent = ''
self.lines = [] self.lines = [] # type: List[str]
def indent_push(self): def indent_push(self):
# type: () -> None
"""Increase current indentation level by one.""" """Increase current indentation level by one."""
self.indent += ' ' * self.shiftwidth self.indent += ' ' * self.shiftwidth
def indent_pop(self): def indent_pop(self):
# type: () -> None
"""Decrease indentation by one level.""" """Decrease indentation by one level."""
assert self.indent != '', 'Already at top level indentation' assert self.indent != '', 'Already at top level indentation'
self.indent = self.indent[0:-self.shiftwidth] self.indent = self.indent[0:-self.shiftwidth]
def line(self, s=None): def line(self, s=None):
# type: (str) -> None
"""Add an indented line.""" """Add an indented line."""
if s: if s:
self.lines.append('{}{}\n'.format(self.indent, s)) self.lines.append('{}{}\n'.format(self.indent, s))
@@ -58,6 +67,7 @@ class Formatter(object):
self.lines.append('\n') self.lines.append('\n')
def outdented_line(self, s): def outdented_line(self, s):
# type: (str) -> None
""" """
Emit a line outdented one level. Emit a line outdented one level.
@@ -67,12 +77,14 @@ class Formatter(object):
self.lines.append('{}{}\n'.format(self.indent[0:-self.shiftwidth], s)) self.lines.append('{}{}\n'.format(self.indent[0:-self.shiftwidth], s))
def writelines(self, f=None): def writelines(self, f=None):
# type: (Any) -> None
"""Write all lines to `f`.""" """Write all lines to `f`."""
if not f: if not f:
f = sys.stdout f = sys.stdout
f.writelines(self.lines) f.writelines(self.lines)
def update_file(self, filename, directory): def update_file(self, filename, directory):
# type: (str, str) -> None
if directory is not None: if directory is not None:
filename = os.path.join(directory, filename) filename = os.path.join(directory, filename)
with open(filename, 'w') as f: with open(filename, 'w') as f:
@@ -80,10 +92,12 @@ class Formatter(object):
class _IndentedScope(object): class _IndentedScope(object):
def __init__(self, fmt, after): def __init__(self, fmt, after):
# type: (Formatter, str) -> None
self.fmt = fmt self.fmt = fmt
self.after = after self.after = after
def __enter__(self): def __enter__(self):
# type: () -> None
self.fmt.indent_push() self.fmt.indent_push()
def __exit__(self, t, v, tb): def __exit__(self, t, v, tb):
@@ -92,6 +106,7 @@ class Formatter(object):
self.fmt.line(self.after) self.fmt.line(self.after)
def indented(self, before=None, after=None): def indented(self, before=None, after=None):
# type: (str, str) -> Formatter._IndentedScope
""" """
Return a scope object for use with a `with` statement: Return a scope object for use with a `with` statement:
@@ -108,7 +123,7 @@ class Formatter(object):
""" """
if before: if before:
self.line(before) self.line(before)
return self._IndentedScope(self, after) return Formatter._IndentedScope(self, after)
def format(self, fmt, *args): def format(self, fmt, *args):
self.line(fmt.format(*args)) self.line(fmt.format(*args))