Skip to content

Commit c44b182

Browse files
authored
Adding a provider which can tell what accessor to use to go from the parent to that child node (Instagram#807)
1 parent bd4f541 commit c44b182

File tree

3 files changed

+89
-0
lines changed

3 files changed

+89
-0
lines changed

libcst/metadata/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66

77
from libcst._position import CodePosition, CodeRange
8+
from libcst.metadata.accessor_provider import AccessorProvider
89
from libcst.metadata.base_provider import (
910
BaseMetadataProvider,
1011
BatchableMetadataProvider,
@@ -86,6 +87,7 @@
8687
"Accesses",
8788
"TypeInferenceProvider",
8889
"FullRepoManager",
90+
"AccessorProvider",
8991
# Experimental APIs:
9092
"ExperimentalReentrantCodegenProvider",
9193
"CodegenPartial",

libcst/metadata/accessor_provider.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
import dataclasses
8+
9+
import libcst as cst
10+
11+
from libcst.metadata.base_provider import VisitorMetadataProvider
12+
13+
14+
class AccessorProvider(VisitorMetadataProvider[str]):
15+
def on_visit(self, node: cst.CSTNode) -> bool:
16+
for f in dataclasses.fields(node):
17+
child = getattr(node, f.name)
18+
self.set_metadata(child, f.name)
19+
return True
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import dataclasses
7+
8+
from textwrap import dedent
9+
10+
import libcst as cst
11+
from libcst.metadata import AccessorProvider, MetadataWrapper
12+
from libcst.testing.utils import data_provider, UnitTest
13+
14+
15+
class DependentVisitor(cst.CSTVisitor):
16+
METADATA_DEPENDENCIES = (AccessorProvider,)
17+
18+
def __init__(self, *, test: UnitTest) -> None:
19+
self.test = test
20+
21+
def on_visit(self, node: cst.CSTNode) -> bool:
22+
for f in dataclasses.fields(node):
23+
child = getattr(node, f.name)
24+
if type(child) is cst.CSTNode:
25+
accessor = self.get_metadata(AccessorProvider, child)
26+
self.test.assertEqual(accessor, f.name)
27+
28+
return True
29+
30+
31+
class AccessorProviderTest(UnitTest):
32+
@data_provider(
33+
(
34+
(
35+
"""
36+
foo = 'toplevel'
37+
fn1(foo)
38+
fn2(foo)
39+
def fn_def():
40+
foo = 'shadow'
41+
fn3(foo)
42+
""",
43+
),
44+
(
45+
"""
46+
global_var = None
47+
@cls_attr
48+
class Cls(cls_attr, kwarg=cls_attr):
49+
cls_attr = 5
50+
def f():
51+
pass
52+
""",
53+
),
54+
(
55+
"""
56+
iterator = None
57+
condition = None
58+
[elt for target in iterator if condition]
59+
{elt for target in iterator if condition}
60+
{elt: target for target in iterator if condition}
61+
(elt for target in iterator if condition)
62+
""",
63+
),
64+
)
65+
)
66+
def test_accessor_provier(self, code: str) -> None:
67+
wrapper = MetadataWrapper(cst.parse_module(dedent(code)))
68+
wrapper.visit(DependentVisitor(test=self))

0 commit comments

Comments
 (0)