Skip to content
This repository was archived by the owner on May 11, 2022. It is now read-only.

Commit ec497a4

Browse files
committed
Add "overwrite" argument to LookupClasses().
1 parent bc24654 commit ec497a4

File tree

1 file changed

+25
-4
lines changed

1 file changed

+25
-4
lines changed

parse/visitors.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -268,8 +268,16 @@ def VisitClassType(self, node):
268268
return node
269269

270270

271+
class ClearClassTypePointers(object):
272+
"""For ClassType nodes: Set their cls pointer to None."""
273+
274+
def EnterClassType(self, node):
275+
node.cls = None
276+
277+
271278
class NamedTypeToClassType(object):
272-
"""Change all NamedType objects to ClassType objects."""
279+
"""Change all NamedType objects to ClassType objects.
280+
"""
273281

274282
def VisitNamedType(self, node):
275283
"""Converts a named type to a class type, to be filled in later.
@@ -286,11 +294,18 @@ def VisitNamedType(self, node):
286294
def FillInClasses(target, global_module=None):
287295
"""Fill in class pointers in ClassType nodes for a PyTD object.
288296
297+
This will adjust the "cls" pointer for existing ClassType nodes so that they
298+
point to their named class. It will only do this for cls pointers that are
299+
None, otherwise it will keep the old value. Use the NamedTypeToClassType
300+
visitor to create the ClassType nodes in the first place. Use the
301+
ClearClassTypePointers visitor to set the "cls" pointers for already existing
302+
ClassType nodes back to None.
303+
289304
Args:
290305
target: The PyTD object to operate on. Changes will happen in-place. If this
291-
is a TypeDeclUnit it will also be used for lookups.
306+
is a TypeDeclUnit it will also be used for lookups.
292307
global_module: Global symbols. Tried if a name doesn't exist locally. This
293-
is required if target is not a TypeDeclUnit.
308+
is required if target is not a TypeDeclUnit.
294309
"""
295310
if global_module is None:
296311
global_module = target
@@ -308,12 +323,15 @@ def FillInClasses(target, global_module=None):
308323
target.Visit(_FillInClasses(global_module, global_module))
309324

310325

311-
def LookupClasses(module, global_module=None):
326+
def LookupClasses(module, global_module=None, overwrite=False):
312327
"""Converts a module from one using NamedType to ClassType.
313328
314329
Args:
315330
module: The module to process.
316331
global_module: The global (builtins) module for name lookup. Can be None.
332+
overwrite: If we should overwrite the "cls" pointer of existing ClassType
333+
nodes. Otherwise, "cls" pointers of existing ClassType nodes will only
334+
be written if they are None.
317335
318336
Returns:
319337
A new module that only uses ClassType. All ClassType instances will point
@@ -323,6 +341,9 @@ def LookupClasses(module, global_module=None):
323341
KeyError: If we can't find a class.
324342
"""
325343
module = module.Visit(NamedTypeToClassType())
344+
if overwrite:
345+
# Set cls pointers to None so that FillInClasses is allowed to set them.
346+
module = module.Visit(ClearClassTypePointers())
326347
FillInClasses(module, global_module)
327348
module.Visit(VerifyLookup())
328349
return module

0 commit comments

Comments
 (0)