Skip to content

Commit 8cfde9f

Browse files
authored
feat(go): support collect type parameters (cloudwego#71)
1 parent b106c84 commit 8cfde9f

File tree

3 files changed

+87
-80
lines changed

3 files changed

+87
-80
lines changed

lang/golang/parser/file.go

Lines changed: 36 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ func (p *GoParser) parseVar(ctx *fileContext, vspec *ast.ValueSpec, isConst bool
120120
v = p.newVar(ctx.module.Name, ctx.pkgPath, name.Name, isConst)
121121
v.FileLine = ctx.FileLine(vspec)
122122

123-
// always collect value's dependencies
123+
// collect func value dependencies, in case of var a = func() {...}
124124
if val != nil && !isConst {
125125
collects := collectInfos{}
126126
ast.Inspect(*val, func(n ast.Node) bool {
@@ -159,38 +159,13 @@ func (p *GoParser) parseVar(ctx *fileContext, vspec *ast.ValueSpec, isConst bool
159159
if isConst && v.Type == nil {
160160
v.Type = lastType
161161
}
162-
var varType string
163-
if v.Type != nil {
164-
if v.Type.PkgPath == ctx.pkgPath {
165-
varType = v.Type.Name
166-
} else {
167-
varType = v.Type.CallName()
168-
}
169-
if v.IsPointer {
170-
varType = "*" + varType
171-
}
172-
}
173162

174163
if !isConst {
175-
v.Content = fmt.Sprintf("var %s %s", name.Name, varType)
164+
v.Content = "var " + string(ctx.GetRawContent(vspec))
176165
} else {
177-
if varType != "" {
178-
v.Content = fmt.Sprintf("const %s %s", name.Name, varType)
179-
} else {
180-
v.Content = fmt.Sprintf("const %s", name.Name)
181-
}
166+
v.Content = "const " + string(ctx.GetRawContent(vspec))
182167
}
183168

184-
var comment string
185-
if ctx.collectComment && doc != nil {
186-
comment += string(ctx.GetRawContent(doc)) + "\n"
187-
}
188-
if ctx.collectComment && vspec.Doc != nil {
189-
comment += string(ctx.GetRawContent(vspec.Doc)) + "\n"
190-
v.FileLine.StartOffset = ctx.fset.Position(vspec.Pos()).Offset
191-
}
192-
v.Content = comment + v.Content
193-
194169
var finalVal string
195170
if val != nil {
196171
// refer codes
@@ -229,11 +204,20 @@ func (p *GoParser) parseVar(ctx *fileContext, vspec *ast.ValueSpec, isConst bool
229204
lastValue = &tmp
230205
finalVal = strconv.FormatFloat(tmp, 'f', -1, 64)
231206
}
232-
233-
if finalVal != "" {
207+
if finalVal != "" && !strings.Contains(v.Content, " = ") {
234208
v.Content += " = " + finalVal
235209
}
236210

211+
var comment string
212+
if ctx.collectComment && doc != nil {
213+
comment += string(ctx.GetRawContent(doc)) + "\n"
214+
}
215+
if ctx.collectComment && vspec.Doc != nil {
216+
comment += string(ctx.GetRawContent(vspec.Doc)) + "\n"
217+
v.FileLine.StartOffset = ctx.fset.Position(vspec.Pos()).Offset
218+
}
219+
v.Content = comment + v.Content
220+
237221
typ = v.Type
238222
}
239223
return typ, v, lastValue
@@ -441,22 +425,23 @@ func (p *GoParser) parseASTNode(ctx *fileContext, node ast.Node, collect *collec
441425
func (p *GoParser) parseFunc(ctx *fileContext, funcDecl *ast.FuncDecl) (*Function, bool) {
442426
// method receiver
443427
var receiver *Receiver
444-
isMethod := funcDecl.Recv != nil
445-
if strings.HasSuffix(ctx.filePath, "cmds/life_stat/main.go") && funcDecl.Name.Name == "init" {
446-
447-
}
428+
var tparams []Dependency
429+
isMethod := funcDecl.Recv != nil && len(funcDecl.Recv.List) > 0
448430
if isMethod {
449-
// TODO: reserve the pointer message?
450-
ti := ctx.GetTypeInfo(funcDecl.Recv.List[0].Type)
451-
// name := "self"
452-
// if len(funcDecl.Recv.List[0].Names) > 0 {
453-
// name = funcDecl.Recv.List[0].Names[0].Name
454-
// }
431+
rt := funcDecl.Recv.List[0].Type
432+
ti := ctx.GetTypeInfo(rt)
455433
receiver = &Receiver{
456434
Type: ti.Id,
457435
IsPointer: ti.IsPointer,
458436
// Name: name,
459437
}
438+
// collect receiver's type params
439+
for _, d := range ti.Deps {
440+
tparams = append(tparams, Dependency{
441+
Identity: d,
442+
FileLine: ctx.FileLine(rt), // FIXME: location is not accurate, try parse Index AST to get it.
443+
})
444+
}
460445
}
461446

462447
fname := funcDecl.Name.Name
@@ -474,6 +459,10 @@ func (p *GoParser) parseFunc(ctx *fileContext, funcDecl *ast.FuncDecl) (*Functio
474459
if funcDecl.Type.Results != nil {
475460
ctx.collectFields(funcDecl.Type.Results.List, &results)
476461
}
462+
// collect type params
463+
if funcDecl.Type.TypeParams != nil {
464+
ctx.collectFields(funcDecl.Type.TypeParams.List, &tparams)
465+
}
477466

478467
// collect signature
479468
sig := ctx.GetRawContent(funcDecl.Type)
@@ -510,6 +499,9 @@ set_func:
510499
f.Results = results
511500
f.GlobalVars = collects.globalVars
512501
f.Types = collects.tys
502+
for _, t := range tparams {
503+
f.Types = InsertDependency(f.Types, t)
504+
}
513505
f.Signature = string(sig)
514506
return f, false
515507
}
@@ -534,6 +526,10 @@ func (p *GoParser) parseType(ctx *fileContext, typDecl *ast.TypeSpec, doc *ast.C
534526
}
535527
}
536528

529+
if typDecl.TypeParams != nil {
530+
ctx.collectFields(typDecl.TypeParams.List, &st.SubStruct)
531+
}
532+
537533
st.FileLine = ctx.FileLine(typDecl)
538534
st.Content = string(ctx.GetRawContent(typDecl))
539535
if ctx.collectComment && doc != nil {

lang/golang/parser/utils.go

Lines changed: 12 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@ import (
1919
"bytes"
2020
"fmt"
2121
"go/ast"
22-
"go/parser"
23-
"go/token"
2422
"go/types"
2523
"io"
2624
"os"
@@ -51,35 +49,12 @@ func (c cache) Visited(val interface{}) bool {
5149
return ok
5250
}
5351

54-
func hasMain(file []byte) bool {
55-
if !bytes.Contains(file, []byte("package main")) || !bytes.Contains(file, []byte("func main()")) {
56-
return false
57-
}
58-
fset := token.NewFileSet()
59-
f, err := parser.ParseFile(fset, "any.go", file, parser.SkipObjectResolution)
60-
if err != nil {
61-
return false
62-
}
63-
if f.Name.Name != "main" {
64-
return false
65-
}
66-
for _, decl := range f.Decls {
67-
if funcDecl, ok := decl.(*ast.FuncDecl); ok {
68-
if funcDecl.Name.Name == "main" {
69-
return true
70-
}
71-
}
72-
}
73-
return false
74-
}
75-
7652
func isSysPkg(importPath string) bool {
7753
return !strings.Contains(strings.Split(importPath, "/")[0], ".")
7854
}
7955

8056
var (
8157
verReg = regexp.MustCompile(`/v\d+$`)
82-
litReg = regexp.MustCompile(`[^a-zA-Z0-9_]`)
8358
)
8459

8560
func getPackageAlias(importPath string) string {
@@ -98,14 +73,6 @@ func getPackageAlias(importPath string) string {
9873
return alias
9974
}
10075

101-
func splitVersion(module string) (string, string) {
102-
if strings.Contains(module, "@") {
103-
parts := strings.Split(module, "@")
104-
return parts[0], parts[1]
105-
}
106-
return module, ""
107-
}
108-
10976
func getModuleName(modFilePath string) (string, []byte, error) {
11077
file, err := os.Open(modFilePath)
11178
if err != nil {
@@ -218,6 +185,18 @@ func getNamedTypes(typ types.Type, visited map[types.Type]bool) (tys []types.Obj
218185
case *types.Named:
219186
tys = append(tys, t.Obj())
220187
isNamed = true
188+
if targs := t.TypeArgs(); targs != nil {
189+
for i := 0; i < targs.Len(); i++ {
190+
typs, _, _ := getNamedTypes(targs.At(i), visited)
191+
tys = append(tys, typs...)
192+
}
193+
}
194+
if tparams := t.TypeParams(); tparams != nil {
195+
for i := 0; i < tparams.Len(); i++ {
196+
typs, _, _ := getNamedTypes(tparams.At(i), visited)
197+
tys = append(tys, typs...)
198+
}
199+
}
221200
case *types.Struct:
222201
for i := 0; i < t.NumFields(); i++ {
223202
typs, _, _ := getNamedTypes(t.Field(i).Type(), visited)
@@ -252,13 +231,6 @@ func getNamedTypes(typ types.Type, visited map[types.Type]bool) (tys []types.Obj
252231
return
253232
}
254233

255-
func extractName(typ string) string {
256-
if strings.Contains(typ, ".") {
257-
return strings.Split(typ, ".")[1]
258-
}
259-
return typ
260-
}
261-
262234
func parseExpr(expr string) (interface{}, error) {
263235
// Create a map of parameters to pass to the expression evaluator.
264236
parameters := map[string]interface{}{
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/**
2+
* Copyright 2025 ByteDance Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package pkg
18+
19+
import (
20+
"fmt"
21+
22+
"a.b/c/pkg/entity"
23+
)
24+
25+
type CaseGenericStruct[T entity.InterfaceB, U InterfaceA, V any] struct {
26+
Prefix T
27+
Subfix U
28+
Data V
29+
}
30+
31+
func (s *CaseGenericStruct[_, _, _]) String() string {
32+
return s.Prefix.String() + fmt.Sprintf("%v", s.Data) + s.Subfix.String()
33+
}
34+
35+
func CaseGenericFunc[U InterfaceA, T entity.InterfaceB, V any](a T, b U, c V) string {
36+
return a.String() + fmt.Sprintf("%v", c) + b.String()
37+
}
38+
39+
var CaseGenericVar CaseGenericStruct[entity.InterfaceB, InterfaceA, int]

0 commit comments

Comments
 (0)