Skip to content

Commit e353540

Browse files
authored
Merge pull request google#104 from grumpyhome/jamdagni86-fix-17
Jamdagni86 fix for google#17 (Support tuple args in lambda)
2 parents 42009a5 + 4f1c2bc commit e353540

File tree

4 files changed

+81
-6
lines changed

4 files changed

+81
-6
lines changed

grumpy-tools-src/grumpy_tools/compiler/block.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,8 @@ def resolve_name(self, writer, name):
249249
if var:
250250
if var.type == Var.TYPE_GLOBAL:
251251
return self._resolve_global(writer, name)
252+
if var.type == Var.TYPE_TUPLE_PARAM:
253+
return expr.GeneratedLocalVar(name)
252254
writer.write_checked_call1('πg.CheckLocal(πF, {}, {})',
253255
util.adjust_local_name(name),
254256
util.go_str(name))
@@ -263,6 +265,7 @@ class Var(object):
263265
TYPE_LOCAL = 0
264266
TYPE_PARAM = 1
265267
TYPE_GLOBAL = 2
268+
TYPE_TUPLE_PARAM = 3
266269

267270
def __init__(self, name, var_type, arg_index=None):
268271
self.name = name
@@ -273,6 +276,9 @@ def __init__(self, name, var_type, arg_index=None):
273276
elif var_type == Var.TYPE_PARAM:
274277
assert arg_index is not None
275278
self.init_expr = 'πArgs[{}]'.format(arg_index)
279+
elif var_type == Var.TYPE_TUPLE_PARAM:
280+
assert arg_index is None
281+
self.init_expr = 'nil'
276282
else:
277283
assert arg_index is None
278284
self.init_expr = None
@@ -364,16 +370,40 @@ def __init__(self, node):
364370
BlockVisitor.__init__(self)
365371
self.is_generator = False
366372
node_args = node.args
367-
args = [a.arg for a in node_args.args]
373+
args = []
374+
for arg in node_args.args:
375+
if isinstance(arg, ast.Tuple):
376+
args.append(arg.elts)
377+
else:
378+
args.append(arg.arg)
379+
# args = [a.arg for a in node_args.args]
380+
368381
if node_args.vararg:
369382
args.append(node_args.vararg.arg)
370383
if node_args.kwarg:
371384
args.append(node_args.kwarg.arg)
372385
for i, name in enumerate(args):
373-
if name in self.vars:
374-
msg = "duplicate argument '{}' in function definition".format(name)
375-
raise util.ParseError(node, msg)
376-
self.vars[name] = Var(name, Var.TYPE_PARAM, arg_index=i)
386+
if isinstance(name, list):
387+
arg_name = 'τ{}'.format(id(name))
388+
for el in name:
389+
self._parse_tuple(el, node)
390+
self.vars[arg_name] = Var(arg_name, Var.TYPE_PARAM, i)
391+
else:
392+
self._check_duplicate_args(name, node)
393+
self.vars[name] = Var(name, Var.TYPE_PARAM, arg_index=i)
394+
395+
def _parse_tuple(self, el, node):
396+
if isinstance(el, ast.Tuple):
397+
for x in el.elts:
398+
self._parse_tuple(x, node)
399+
else:
400+
self._check_duplicate_args(el.arg, node)
401+
self.vars[el.arg] = Var(el.arg, Var.TYPE_TUPLE_PARAM)
402+
403+
def _check_duplicate_args(self, name, node):
404+
if name in self.vars:
405+
msg = "duplicate argument '{}' in function definition".format(name)
406+
raise util.ParseError(node, msg)
377407

378408
def visit_Yield(self, unused_node): # pylint: disable=unused-argument
379409
self.is_generator = True

grumpy-tools-src/grumpy_tools/compiler/block_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from grumpy_tools.compiler import util
2727
from grumpy_tools.vendor import pythonparser
2828

29+
2930
class PackageTest(unittest.TestCase):
3031

3132
def testCreate(self):
@@ -224,6 +225,13 @@ def testYieldExpr(self):
224225
self.assertEqual(sorted(visitor.vars.keys()), ['foo'])
225226
self.assertRegexpMatches(visitor.vars['foo'].init_expr, r'UnboundLocal')
226227

228+
def testTupleArgs(self):
229+
func = _ParseStmt('def foo((bar, baz)): pass')
230+
visitor = block.FunctionBlockVisitor(func)
231+
self.assertEqual(len(visitor.vars), 3)
232+
self.assertEqual(len([v for v in visitor.vars if visitor.vars[v].type == block.Var.TYPE_TUPLE_PARAM]), 2)
233+
self.assertIn('bar', visitor.vars)
234+
self.assertIn('baz', visitor.vars)
227235

228236
def _MakeModuleBlock():
229237
importer = imputil.Importer(None, '__main__', '/tmp/foo.py', False)

grumpy-tools-src/grumpy_tools/compiler/stmt.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,13 @@ def visit_function_inline(self, node):
504504
func_block = block.FunctionBlock(self.block, node.name, func_visitor.vars,
505505
func_visitor.is_generator)
506506
visitor = StatementVisitor(func_block, self.future_node)
507+
508+
for arg in node.args.args:
509+
if isinstance(arg, ast.Tuple):
510+
arg_name = 'τ{}'.format(id(arg.elts))
511+
with visitor.writer.indent_block():
512+
visitor._tie_target(arg, util.adjust_local_name(arg_name)) # pylint: disable=protected-access
513+
507514
# Indent so that the function body is aligned with the goto labels.
508515
with visitor.writer.indent_block():
509516
visitor._visit_each(node.body) # pylint: disable=protected-access
@@ -519,9 +526,13 @@ def visit_function_inline(self, node):
519526
defaults = [None] * (argc - len(args.defaults)) + args.defaults
520527
for i, (a, d) in enumerate(zip(args.args, defaults)):
521528
with self.visit_expr(d) if d else expr.nil_expr as default:
529+
if isinstance(a, ast.Tuple):
530+
name = util.go_str('τ{}'.format(id(a.elts)))
531+
else:
532+
name = util.go_str(a.arg)
522533
tmpl = '$args[$i] = πg.Param{Name: $name, Def: $default}'
523534
self.writer.write_tmpl(tmpl, args=func_args.expr, i=i,
524-
name=util.go_str(a.arg), default=default.expr)
535+
name=name, default=default.expr)
525536
flags = []
526537
if args.vararg:
527538
flags.append('πg.CodeFlagVarArg')
@@ -583,6 +594,8 @@ def visit_function_inline(self, node):
583594
def _assign_target(self, target, value):
584595
if isinstance(target, ast.Name):
585596
self.block.bind_var(self.writer, target.id, value)
597+
elif isinstance(target, ast.arg):
598+
self.block.bind_var(self.writer, target.arg, value)
586599
elif isinstance(target, ast.Attribute):
587600
with self.visit_expr(target.value) as obj:
588601
self.writer.write_checked_call1(

grumpy-tools-src/grumpy_tools/compiler/stmt_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,30 @@ def foo(a, b):
243243
print a, b
244244
foo('bar', 'baz')""")))
245245

246+
def testFunctionDefWithTupleArgs(self):
247+
self.assertEqual((0, "('bar', 'baz')\n"), _GrumpRun(textwrap.dedent("""\
248+
def foo((a, b)):
249+
print(a, b)
250+
foo(('bar', 'baz'))""")))
251+
252+
def testFunctionDefWithNestedTupleArgs(self):
253+
self.assertEqual((0, "('bar', 'baz', 'qux')\n"), _GrumpRun(textwrap.dedent("""\
254+
def foo(((a, b), c)):
255+
print(a, b, c)
256+
foo((('bar', 'baz'), 'qux'))""")))
257+
258+
def testFunctionDefWithMultipleTupleArgs(self):
259+
self.assertEqual((0, "('bar', 'baz')\n"), _GrumpRun(textwrap.dedent("""\
260+
def foo(((a, ), (b, ))):
261+
print(a, b)
262+
foo((('bar',), ('baz', )))""")))
263+
264+
def testFunctionDefTupleArgsInLambda(self):
265+
self.assertEqual((0, "[(3, 2), (4, 3), (12, 1)]\n"), _GrumpRun(textwrap.dedent("""\
266+
c = {12: 1, 3: 2, 4: 3}
267+
top = sorted(c.items(), key=lambda (k,v): v)
268+
print (top)""")))
269+
246270
def testFunctionDefGenerator(self):
247271
self.assertEqual((0, "['foo', 'bar']\n"), _GrumpRun(textwrap.dedent("""\
248272
def gen():

0 commit comments

Comments
 (0)