Skip to content
This repository was archived by the owner on Mar 23, 2023. It is now read-only.

fix for #17 #265

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 35 additions & 5 deletions compiler/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,8 @@ def resolve_name(self, writer, name):
if var:
if var.type == Var.TYPE_GLOBAL:
return self._resolve_global(writer, name)
if var.type == Var.TYPE_TUPLE_PARAM:
return expr.GeneratedLocalVar(name)
writer.write_checked_call1('πg.CheckLocal(πF, {}, {})',
util.adjust_local_name(name),
util.go_str(name))
Expand All @@ -263,6 +265,7 @@ class Var(object):
TYPE_LOCAL = 0
TYPE_PARAM = 1
TYPE_GLOBAL = 2
TYPE_TUPLE_PARAM = 3

def __init__(self, name, var_type, arg_index=None):
self.name = name
Expand All @@ -273,6 +276,9 @@ def __init__(self, name, var_type, arg_index=None):
elif var_type == Var.TYPE_PARAM:
assert arg_index is not None
self.init_expr = 'πArgs[{}]'.format(arg_index)
elif var_type == Var.TYPE_TUPLE_PARAM:
assert arg_index is None
self.init_expr = 'nil'
else:
assert arg_index is None
self.init_expr = None
Expand Down Expand Up @@ -364,16 +370,40 @@ def __init__(self, node):
BlockVisitor.__init__(self)
self.is_generator = False
node_args = node.args
args = [a.arg for a in node_args.args]
args = []
for arg in node_args.args:
if isinstance(arg, ast.Tuple):
args.append(arg.elts)
else:
args.append(arg.arg)
# args = [a.arg for a in node_args.args]

if node_args.vararg:
args.append(node_args.vararg.arg)
if node_args.kwarg:
args.append(node_args.kwarg.arg)
for i, name in enumerate(args):
if name in self.vars:
msg = "duplicate argument '{}' in function definition".format(name)
raise util.ParseError(node, msg)
self.vars[name] = Var(name, Var.TYPE_PARAM, arg_index=i)
if isinstance(name, list):
arg_name = 'τ{}'.format(id(name))
for el in name:
self._parse_tuple(el, node)
self.vars[arg_name] = Var(arg_name, Var.TYPE_PARAM, i)
else:
self._check_duplicate_args(name, node)
self.vars[name] = Var(name, Var.TYPE_PARAM, arg_index=i)

def _parse_tuple(self, el, node):
if isinstance(el, ast.Tuple):
for x in el.elts:
self._parse_tuple(x, node)
else:
self._check_duplicate_args(el.arg, node)
self.vars[el.arg] = Var(el.arg, Var.TYPE_TUPLE_PARAM)

def _check_duplicate_args(self, name, node):
if name in self.vars:
msg = "duplicate argument '{}' in function definition".format(name)
raise util.ParseError(node, msg)

def visit_Yield(self, unused_node): # pylint: disable=unused-argument
self.is_generator = True
8 changes: 8 additions & 0 deletions compiler/block_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from grumpy.compiler import util
from grumpy import pythonparser


class PackageTest(unittest.TestCase):

def testCreate(self):
Expand Down Expand Up @@ -224,6 +225,13 @@ def testYieldExpr(self):
self.assertEqual(sorted(visitor.vars.keys()), ['foo'])
self.assertRegexpMatches(visitor.vars['foo'].init_expr, r'UnboundLocal')

def testTupleArgs(self):
func = _ParseStmt('def foo((bar, baz)): pass')
visitor = block.FunctionBlockVisitor(func)
self.assertEqual(len(visitor.vars), 3)
self.assertEqual(len([v for v in visitor.vars if visitor.vars[v].type == block.Var.TYPE_TUPLE_PARAM]), 2)
self.assertIn('bar', visitor.vars)
self.assertIn('baz', visitor.vars)

def _MakeModuleBlock():
importer = imputil.Importer(None, '__main__', '/tmp/foo.py', False)
Expand Down
15 changes: 14 additions & 1 deletion compiler/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,13 @@ def visit_function_inline(self, node):
func_block = block.FunctionBlock(self.block, node.name, func_visitor.vars,
func_visitor.is_generator)
visitor = StatementVisitor(func_block, self.future_node)

for arg in node.args.args:
if isinstance(arg, ast.Tuple):
arg_name = 'τ{}'.format(id(arg.elts))
with visitor.writer.indent_block():
visitor._tie_target(arg, util.adjust_local_name(arg_name)) # pylint: disable=protected-access

# Indent so that the function body is aligned with the goto labels.
with visitor.writer.indent_block():
visitor._visit_each(node.body) # pylint: disable=protected-access
Expand All @@ -519,9 +526,13 @@ def visit_function_inline(self, node):
defaults = [None] * (argc - len(args.defaults)) + args.defaults
for i, (a, d) in enumerate(zip(args.args, defaults)):
with self.visit_expr(d) if d else expr.nil_expr as default:
if isinstance(a, ast.Tuple):
name = util.go_str('τ{}'.format(id(a.elts)))
else:
name = util.go_str(a.arg)
tmpl = '$args[$i] = πg.Param{Name: $name, Def: $default}'
self.writer.write_tmpl(tmpl, args=func_args.expr, i=i,
name=util.go_str(a.arg), default=default.expr)
name=name, default=default.expr)
flags = []
if args.vararg:
flags.append('πg.CodeFlagVarArg')
Expand Down Expand Up @@ -583,6 +594,8 @@ def visit_function_inline(self, node):
def _assign_target(self, target, value):
if isinstance(target, ast.Name):
self.block.bind_var(self.writer, target.id, value)
elif isinstance(target, ast.arg):
self.block.bind_var(self.writer, target.arg, value)
elif isinstance(target, ast.Attribute):
with self.visit_expr(target.value) as obj:
self.writer.write_checked_call1(
Expand Down
24 changes: 24 additions & 0 deletions compiler/stmt_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,30 @@ def foo(a, b):
print a, b
foo('bar', 'baz')""")))

def testFunctionDefWithTupleArgs(self):
self.assertEqual((0, "('bar', 'baz')\n"), _GrumpRun(textwrap.dedent("""\
def foo((a, b)):
print(a, b)
foo(('bar', 'baz'))""")))

def testFunctionDefWithNestedTupleArgs(self):
self.assertEqual((0, "('bar', 'baz', 'qux')\n"), _GrumpRun(textwrap.dedent("""\
def foo(((a, b), c)):
print(a, b, c)
foo((('bar', 'baz'), 'qux'))""")))

def testFunctionDefWithMultipleTupleArgs(self):
self.assertEqual((0, "('bar', 'baz')\n"), _GrumpRun(textwrap.dedent("""\
def foo(((a, ), (b, ))):
print(a, b)
foo((('bar',), ('baz', )))""")))

def testFunctionDefTupleArgsInLambda(self):
self.assertEqual((0, "[(3, 2), (4, 3), (12, 1)]\n"), _GrumpRun(textwrap.dedent("""\
c = {12: 1, 3: 2, 4: 3}
top = sorted(c.items(), key=lambda (k,v): v)
print (top)""")))

def testFunctionDefGenerator(self):
self.assertEqual((0, "['foo', 'bar']\n"), _GrumpRun(textwrap.dedent("""\
def gen():
Expand Down