Skip to content

Commit a138ec7

Browse files
authored
feat(go/parser): add visited map for getNamedTypes (#59)
* feat(go/parser): add visited map for getNamedTypes * update
1 parent b071e0b commit a138ec7

File tree

3 files changed

+221
-17
lines changed

3 files changed

+221
-17
lines changed

lang/golang/parser/ctx.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,8 @@ func (p *GoParser) collectTypes(ctx *fileContext, typ ast.Expr, st *Type, inline
446446

447447
// get type id and tells if it is std or builtin
448448
func (ctx *fileContext) getTypeinfo(typ types.Type) (ti typeInfo) {
449-
tobjs, isPointer, isNamed := getNamedTypes(typ)
449+
visited := make(map[types.Type]bool)
450+
tobjs, isPointer, isNamed := getNamedTypes(typ, visited)
450451
ti.IsPointer = isPointer
451452
ti.Ty = typ
452453
ti.IsNamed = isNamed

lang/golang/parser/utils.go

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -183,63 +183,69 @@ func getTypeKind(n ast.Expr) TypeKind {
183183
}
184184
}
185185

186-
func getNamedTypes(typ types.Type) (tys []types.Object, isPointer bool, isNamed bool) {
186+
func getNamedTypes(typ types.Type, visited map[types.Type]bool) (tys []types.Object, isPointer bool, isNamed bool) {
187+
if visited[typ] {
188+
return nil, false, false
189+
}
190+
191+
visited[typ] = true
192+
187193
switch t := typ.(type) {
188194
case *types.Pointer:
189195
isPointer = true
190-
typs, _, isNamed2 := getNamedTypes(t.Elem())
196+
var typs []types.Object
197+
typs, _, isNamed = getNamedTypes(t.Elem(), visited)
191198
tys = append(tys, typs...)
192-
isNamed = isNamed2
193199
case *types.Slice:
194-
typs, _, _ := getNamedTypes(t.Elem())
200+
typs, _, _ := getNamedTypes(t.Elem(), visited)
195201
tys = append(tys, typs...)
196202
case *types.Array:
197-
typs, _, _ := getNamedTypes(t.Elem())
203+
typs, _, _ := getNamedTypes(t.Elem(), visited)
198204
tys = append(tys, typs...)
199205
case *types.Chan:
200-
typs, _, _ := getNamedTypes(t.Elem())
206+
typs, _, _ := getNamedTypes(t.Elem(), visited)
201207
tys = append(tys, typs...)
202208
case *types.Tuple:
203209
for i := 0; i < t.Len(); i++ {
204-
typs, _, _ := getNamedTypes(t.At(i).Type())
210+
typs, _, _ := getNamedTypes(t.At(i).Type(), visited)
205211
tys = append(tys, typs...)
206212
}
207213
case *types.Map:
208-
typs2, _, _ := getNamedTypes(t.Elem())
209-
typs1, _, _ := getNamedTypes(t.Key())
214+
typs2, _, _ := getNamedTypes(t.Elem(), visited)
215+
typs1, _, _ := getNamedTypes(t.Key(), visited)
210216
tys = append(tys, typs1...)
211217
tys = append(tys, typs2...)
212218
case *types.Named:
213219
tys = append(tys, t.Obj())
214220
isNamed = true
215221
case *types.Struct:
216222
for i := 0; i < t.NumFields(); i++ {
217-
typs, _, _ := getNamedTypes(t.Field(i).Type())
223+
typs, _, _ := getNamedTypes(t.Field(i).Type(), visited)
218224
tys = append(tys, typs...)
219225
}
220226
case *types.Interface:
221227
for i := 0; i < t.NumEmbeddeds(); i++ {
222-
typs, _, _ := getNamedTypes(t.EmbeddedType(i))
228+
typs, _, _ := getNamedTypes(t.EmbeddedType(i), visited)
223229
tys = append(tys, typs...)
224230
}
225231
for i := 0; i < t.NumExplicitMethods(); i++ {
226-
typs, _, _ := getNamedTypes(t.ExplicitMethod(i).Type())
232+
typs, _, _ := getNamedTypes(t.ExplicitMethod(i).Type(), visited)
227233
tys = append(tys, typs...)
228234
}
229235
case *types.TypeParam:
230-
typs, _, _ := getNamedTypes(t.Constraint())
236+
typs, _, _ := getNamedTypes(t.Constraint(), visited)
231237
tys = append(tys, typs...)
232238
case *types.Alias:
233239
var typs []types.Object
234-
typs, isPointer, isNamed = getNamedTypes(t.Rhs())
240+
typs, isPointer, isNamed = getNamedTypes(t.Rhs(), visited)
235241
tys = append(tys, typs...)
236242
case *types.Signature:
237243
for i := 0; i < t.Params().Len(); i++ {
238-
typs, _, _ := getNamedTypes(t.Params().At(i).Type())
244+
typs, _, _ := getNamedTypes(t.Params().At(i).Type(), visited)
239245
tys = append(tys, typs...)
240246
}
241247
for i := 0; i < t.Results().Len(); i++ {
242-
typs, _, _ := getNamedTypes(t.Results().At(i).Type())
248+
typs, _, _ := getNamedTypes(t.Results().At(i).Type(), visited)
243249
tys = append(tys, typs...)
244250
}
245251
}

lang/golang/parser/utils_test.go

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
// Copyright 2025 CloudWeGo Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package parser
16+
17+
import (
18+
"go/ast"
19+
"go/importer"
20+
"go/parser"
21+
"go/token"
22+
"go/types"
23+
"slices"
24+
"testing"
25+
26+
"github.com/stretchr/testify/require"
27+
)
28+
29+
func getTypeForTest(t *testing.T, src, name string) types.Type {
30+
fset := token.NewFileSet()
31+
f, err := parser.ParseFile(fset, "test.go", src, 0)
32+
require.NoError(t, err, "Failed to parse source code for %s", name)
33+
34+
conf := types.Config{Importer: importer.Default()}
35+
pkg, err := conf.Check("test", fset, []*ast.File{f}, nil)
36+
require.NoError(t, err, "Failed to type-check source code for %s", name)
37+
38+
obj := pkg.Scope().Lookup(name)
39+
require.NotNil(t, obj, "Object '%s' not found in source", name)
40+
41+
return obj.Type()
42+
}
43+
44+
func objectsToNames(objs []types.Object) []string {
45+
names := make([]string, len(objs))
46+
for i, obj := range objs {
47+
if obj.Pkg() != nil {
48+
names[i] = obj.Pkg().Path() + "." + obj.Name()
49+
} else {
50+
names[i] = obj.Name()
51+
}
52+
}
53+
slices.Sort(names)
54+
return names
55+
}
56+
57+
func Test_getNamedTypes(t *testing.T) {
58+
testCases := []struct {
59+
name string
60+
source string
61+
targetVar string
62+
expectedNames []string
63+
expectedIsPointer bool
64+
expectedIsNamed bool
65+
}{
66+
{
67+
name: "Simple Named Type",
68+
source: `package main
69+
type MyInt int`,
70+
targetVar: "MyInt",
71+
expectedNames: []string{"test.MyInt"},
72+
expectedIsPointer: false,
73+
expectedIsNamed: true,
74+
},
75+
{
76+
name: "Pointer to Named Type",
77+
source: `package main
78+
type MyInt int
79+
var p *MyInt`,
80+
targetVar: "p",
81+
expectedNames: []string{"test.MyInt"},
82+
expectedIsPointer: true,
83+
expectedIsNamed: true,
84+
},
85+
{
86+
name: "Slice of Named Type",
87+
source: `package main
88+
type MyStruct struct{}; var s []*MyStruct`,
89+
targetVar: "s",
90+
expectedNames: []string{"test.MyStruct"},
91+
expectedIsPointer: false,
92+
expectedIsNamed: false,
93+
},
94+
{
95+
name: "Array of Named Type",
96+
source: `package main
97+
type MyInt int; var a [5]MyInt`,
98+
targetVar: "a",
99+
expectedNames: []string{"test.MyInt"},
100+
expectedIsPointer: false,
101+
expectedIsNamed: false,
102+
},
103+
{
104+
name: "Map with Named Types",
105+
source: `package main
106+
type KeyType int; type ValueType string; var m map[*KeyType]ValueType`,
107+
targetVar: "m",
108+
expectedNames: []string{"test.KeyType", "test.ValueType"},
109+
expectedIsPointer: false,
110+
expectedIsNamed: false,
111+
},
112+
{
113+
name: "Struct with Named Fields",
114+
source: `package main
115+
type MyInt int
116+
type MyString string
117+
var s struct {
118+
Field1 MyInt
119+
Field2 *MyString
120+
}`,
121+
targetVar: "s",
122+
expectedNames: []string{"test.MyInt", "test.MyString"},
123+
expectedIsPointer: false,
124+
expectedIsNamed: false,
125+
},
126+
{
127+
name: "Interface with Embedded and Explicit Methods",
128+
source: `package main
129+
import "io"
130+
type MyInterface interface{
131+
io.Reader
132+
MyMethod(arg io.Writer)
133+
}`,
134+
targetVar: "MyInterface",
135+
expectedNames: []string{"test.MyInterface"},
136+
expectedIsPointer: false,
137+
expectedIsNamed: true,
138+
},
139+
{
140+
name: "Function Signature",
141+
source: `package main; import "bytes"; type MyInt int; var fn func(a MyInt) *bytes.Buffer`,
142+
targetVar: "fn",
143+
expectedNames: []string{"bytes.Buffer", "test.MyInt"},
144+
expectedIsPointer: false,
145+
expectedIsNamed: false,
146+
},
147+
{
148+
name: "Type Alias",
149+
source: `package main; type MyInt int; type IntAlias = MyInt`,
150+
targetVar: "IntAlias",
151+
expectedNames: []string{"test.MyInt"},
152+
expectedIsPointer: false,
153+
expectedIsNamed: true,
154+
},
155+
{
156+
name: "Recursive Struct (Cycle)",
157+
source: `package main; type Node struct{ Next *Node }`,
158+
targetVar: "Node",
159+
expectedNames: []string{"test.Node"},
160+
expectedIsPointer: false,
161+
expectedIsNamed: true,
162+
},
163+
{
164+
name: "No Named Types",
165+
source: `package main; var i int`,
166+
targetVar: "i",
167+
expectedNames: []string{},
168+
expectedIsPointer: false,
169+
expectedIsNamed: false,
170+
},
171+
{
172+
name: "Tuple from function return",
173+
source: `package main
174+
import "net/http"
175+
var f func() (*http.Request, error)`,
176+
targetVar: "f",
177+
expectedNames: []string{"error", "net/http.Request"}, // error is a builtin interface, not considered a named type object here
178+
expectedIsPointer: false,
179+
expectedIsNamed: false,
180+
},
181+
}
182+
183+
for _, tc := range testCases {
184+
t.Run(tc.name, func(t *testing.T) {
185+
typ := getTypeForTest(t, tc.source, tc.targetVar)
186+
visited := make(map[types.Type]bool)
187+
188+
tys, isPointer, isNamed := getNamedTypes(typ, visited)
189+
190+
actualNames := objectsToNames(tys)
191+
192+
require.Equal(t, tc.expectedNames, actualNames, "Named types mismatch")
193+
require.Equal(t, tc.expectedIsPointer, isPointer, "isPointer mismatch")
194+
require.Equal(t, tc.expectedIsNamed, isNamed, "isNamed mismatch")
195+
})
196+
}
197+
}

0 commit comments

Comments
 (0)