Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 80 additions & 1 deletion lang/golang/parser/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@ package parser
import (
"bufio"
"bytes"
"container/list"
"fmt"
"go/ast"
"go/build"
"go/types"
"io"
"os"
"path"
"regexp"
"strings"
"sync"

"github.com/Knetic/govaluate"
. "github.com/cloudwego/abcoder/lang/uniast"
Expand All @@ -49,8 +52,84 @@ func (c cache) Visited(val interface{}) bool {
return ok
}

type cacheEntry struct {
key string
value bool
}

// PackageCache 缓存 importPath 是否是 system package
type PackageCache struct {
lock sync.Mutex
cache map[string]*list.Element
lru *list.List
lruCapacity int
}

func NewPackageCache(lruCapacity int) *PackageCache {
return &PackageCache{
cache: make(map[string]*list.Element),
lru: list.New(),
lruCapacity: lruCapacity,
}
}

// get retrieves a value from the cache.
func (pc *PackageCache) get(key string) (bool, bool) {
pc.lock.Lock()
defer pc.lock.Unlock()
if elem, ok := pc.cache[key]; ok {
pc.lru.MoveToFront(elem)
return elem.Value.(*cacheEntry).value, true
}
return false, false
}

// set adds a value to the cache.
func (pc *PackageCache) set(key string, value bool) {
pc.lock.Lock()
defer pc.lock.Unlock()

if elem, ok := pc.cache[key]; ok {
pc.lru.MoveToFront(elem)
elem.Value.(*cacheEntry).value = value
return
}

if pc.lru.Len() >= pc.lruCapacity {
oldest := pc.lru.Back()
if oldest != nil {
pc.lru.Remove(oldest)
delete(pc.cache, oldest.Value.(*cacheEntry).key)
}
}

elem := pc.lru.PushFront(&cacheEntry{key: key, value: value})
pc.cache[key] = elem
}

// IsStandardPackage 检查一个包是否为标准库,并使用内部缓存。
func (pc *PackageCache) IsStandardPackage(path string) bool {
if isStd, found := pc.get(path); found {
return isStd
}

pkg, err := build.Import(path, "", build.FindOnly)
if err != nil {
// Cannot find the package, assume it's not a standard package
pc.set(path, false)
return false
}

isStd := pkg.Goroot
pc.set(path, isStd)
return isStd
}

// stdlibCache 缓存 importPath 是否是 system package, 10000 个缓存
var stdlibCache = NewPackageCache(10000)

func isSysPkg(importPath string) bool {
return !strings.Contains(strings.Split(importPath, "/")[0], ".")
return stdlibCache.IsStandardPackage(importPath)
}

var (
Expand Down
90 changes: 90 additions & 0 deletions lang/golang/parser/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@ import (
"go/token"
"go/types"
"slices"
"sync"
"testing"

"github.com/stretchr/testify/assert"

"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -195,3 +198,90 @@ var f func() (*http.Request, error)`,
})
}
}

func resetGlobals() {
// 重置包缓存
stdlibCache = NewPackageCache(10000)
}

func Test_isSysPkg(t *testing.T) {
// 测试在 `go env GOROOT` 可以成功执行时的行为
t.Run("Group: Happy Path - GOROOT is found", func(t *testing.T) {
resetGlobals()

testCases := []struct {
name string
importPath string
want bool
}{
{"standard library package", "fmt", true},
{"nested standard library package", "net/http", true},
{"third-party package", "github.com/google/uuid", false},
{"extended library package", "golang.org/x/sync/errgroup", false},
{"local-like package name", "myproject/utils", false},
{"non-existent package", "non/existent/package", false},
{"root-level package with dot", "gopkg.in/yaml.v2", false},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if got := isSysPkg(tc.importPath); got != tc.want {
t.Errorf("isSysPkg(%q) = %v, want %v", tc.importPath, got, tc.want)
}
})
}
})

// 测试并发调用时的行为
t.Run("Group: Concurrency Test", func(t *testing.T) {
resetGlobals()
var wg sync.WaitGroup
numGoroutines := 50
numOpsPerGoroutine := 100

for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < numOpsPerGoroutine; j++ {
isSysPkg("fmt")
isSysPkg("github.com/cloudwego/abcoder")
isSysPkg("net/http")
isSysPkg("a/b/c")
}
}()
}
wg.Wait()
})

// 测试 LRU 缓存的驱逐策略
t.Run("Group: LRU Eviction Test", func(t *testing.T) {
resetGlobals()
stdlibCache.lruCapacity = 2

// 1. 填满 Cache
isSysPkg("fmt")
isSysPkg("os")
assert.Equal(t, 2, stdlibCache.lru.Len(), "Cache should be full")

// 2. 访问 "fmt" 使它最近被使用
isSysPkg("fmt")
assert.Equal(t, "fmt", stdlibCache.lru.Front().Value.(*cacheEntry).key, "fmt should be the most recently used")

// 3. 访问 "net" 使它最近被使用
isSysPkg("net") // "os" should be evicted
assert.Equal(t, 2, stdlibCache.lru.Len(), "Cache size should remain at capacity")

// 4. "fmt" 应该在 Cache 中
_, foundFmt := stdlibCache.get("fmt")
assert.True(t, foundFmt, "fmt should still be in the cache")

// 5. "net" 应该在 Cache 中
_, foundNet := stdlibCache.get("net")
assert.True(t, foundNet, "net should be in the cache")

// 6. "os" 不应该在 Cache 中
_, foundOs := stdlibCache.get("os")
assert.False(t, foundOs, "os should have been evicted from the cache")
})
}
21 changes: 20 additions & 1 deletion lang/golang/writer/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
)

var _ uniast.Writer = (*Writer)(nil)
var testPkgPathRegex = regexp.MustCompile(`^(.+?) \[(.+)\]$`)

type Options struct {
// RepoDir string
Expand Down Expand Up @@ -81,6 +82,22 @@ func (w *Writer) WriteRepo(repo *uniast.Repository, outDir string) error {
return nil
}

// sanitizePkgPath sanitize the package path, remove the suffix in brackets
func sanitizePkgPath(pkgPath string) string {
matches := testPkgPathRegex.FindStringSubmatch(pkgPath)
// matches should be 3 elements:
// 1. The full string
// 2. The package name
// 3. The content inside the brackets
if len(matches) == 3 {
packageName := matches[1]
testName := matches[2]
if testName == packageName+".test" {
return packageName
}
}
return pkgPath
}
func (w *Writer) WriteModule(repo *uniast.Repository, modPath string, outDir string) error {
mod := repo.Modules[modPath]
if mod == nil {
Expand All @@ -94,7 +111,9 @@ func (w *Writer) WriteModule(repo *uniast.Repository, modPath string, outDir str

outdir := filepath.Join(outDir, mod.Dir)
for dir, pkg := range w.visited {
rel := strings.TrimPrefix(dir, mod.Name)
// sanitize the package path
cleanDir := sanitizePkgPath(dir)
rel := strings.TrimPrefix(cleanDir, mod.Name)
pkgDir := filepath.Join(outdir, rel)
if err := os.MkdirAll(pkgDir, 0755); err != nil {
return fmt.Errorf("mkdir %s failed: %v", pkgDir, err)
Expand Down
Loading