@@ -268,8 +268,16 @@ def VisitClassType(self, node):
268268 return node
269269
270270
271+ class ClearClassTypePointers (object ):
272+ """For ClassType nodes: Set their cls pointer to None."""
273+
274+ def EnterClassType (self , node ):
275+ node .cls = None
276+
277+
271278class NamedTypeToClassType (object ):
272- """Change all NamedType objects to ClassType objects."""
279+ """Change all NamedType objects to ClassType objects.
280+ """
273281
274282 def VisitNamedType (self , node ):
275283 """Converts a named type to a class type, to be filled in later.
@@ -286,11 +294,18 @@ def VisitNamedType(self, node):
286294def FillInClasses (target , global_module = None ):
287295 """Fill in class pointers in ClassType nodes for a PyTD object.
288296
297+ This will adjust the "cls" pointer for existing ClassType nodes so that they
298+ point to their named class. It will only do this for cls pointers that are
299+ None, otherwise it will keep the old value. Use the NamedTypeToClassType
300+ visitor to create the ClassType nodes in the first place. Use the
301+ ClearClassTypePointers visitor to set the "cls" pointers for already existing
302+ ClassType nodes back to None.
303+
289304 Args:
290305 target: The PyTD object to operate on. Changes will happen in-place. If this
291- is a TypeDeclUnit it will also be used for lookups.
306+ is a TypeDeclUnit it will also be used for lookups.
292307 global_module: Global symbols. Tried if a name doesn't exist locally. This
293- is required if target is not a TypeDeclUnit.
308+ is required if target is not a TypeDeclUnit.
294309 """
295310 if global_module is None :
296311 global_module = target
@@ -308,12 +323,15 @@ def FillInClasses(target, global_module=None):
308323 target .Visit (_FillInClasses (global_module , global_module ))
309324
310325
311- def LookupClasses (module , global_module = None ):
326+ def LookupClasses (module , global_module = None , overwrite = False ):
312327 """Converts a module from one using NamedType to ClassType.
313328
314329 Args:
315330 module: The module to process.
316331 global_module: The global (builtins) module for name lookup. Can be None.
332+ overwrite: If we should overwrite the "cls" pointer of existing ClassType
333+ nodes. Otherwise, "cls" pointers of existing ClassType nodes will only
334+ be written if they are None.
317335
318336 Returns:
319337 A new module that only uses ClassType. All ClassType instances will point
@@ -323,6 +341,9 @@ def LookupClasses(module, global_module=None):
323341 KeyError: If we can't find a class.
324342 """
325343 module = module .Visit (NamedTypeToClassType ())
344+ if overwrite :
345+ # Set cls pointers to None so that FillInClasses is allowed to set them.
346+ module = module .Visit (ClearClassTypePointers ())
326347 FillInClasses (module , global_module )
327348 module .Visit (VerifyLookup ())
328349 return module
0 commit comments