Skip to content

Fix db engine #32351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 67 additions & 47 deletions models/db/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ package db
import (
"context"
"database/sql"
"errors"
"runtime"
"slices"
"sync"

"code.gitea.io/gitea/modules/setting"

"xorm.io/builder"
"xorm.io/xorm"
Expand All @@ -15,76 +21,90 @@ import (
// will be overwritten by Init with HammerContext
var DefaultContext context.Context

// contextKey is a value for use with context.WithValue.
type contextKey struct {
name string
}
type engineContextKeyType struct{}

// enginedContextKey is a context key. It is used with context.Value() to get the current Engined for the context
var (
enginedContextKey = &contextKey{"engined"}
_ Engined = &Context{}
)
var engineContextKey = engineContextKeyType{}

// Context represents a db context
type Context struct {
context.Context
e Engine
transaction bool
}

func newContext(ctx context.Context, e Engine, transaction bool) *Context {
return &Context{
Context: ctx,
e: e,
transaction: transaction,
}
}

// InTransaction if context is in a transaction
func (ctx *Context) InTransaction() bool {
return ctx.transaction
engine Engine
}

// Engine returns db engine
func (ctx *Context) Engine() Engine {
return ctx.e
func newContext(ctx context.Context, e Engine) *Context {
return &Context{Context: ctx, engine: e}
}

// Value shadows Value for context.Context but allows us to get ourselves and an Engined object
func (ctx *Context) Value(key any) any {
if key == enginedContextKey {
if key == engineContextKey {
return ctx
}
return ctx.Context.Value(key)
}

// WithContext returns this engine tied to this context
func (ctx *Context) WithContext(other context.Context) *Context {
return newContext(ctx, ctx.e.Context(other), ctx.transaction)
return newContext(ctx, ctx.engine.Context(other))
}

// Engined structs provide an Engine
type Engined interface {
Engine() Engine
var (
contextSafetyOnce sync.Once
contextSafetyDeniedFuncPCs []uintptr
)

func contextSafetyCheck(e Engine) {
if setting.IsProd && !setting.IsInTesting {
return
}
if e == nil {
return
}
// Only do this check for non-end-users. If the problem could be fixed in the future, this code could be removed.
contextSafetyOnce.Do(func() {
// try to figure out the bad functions to deny
type m struct{}
_ = e.SQL("SELECT 1").Iterate(&m{}, func(int, any) error {
callers := make([]uintptr, 32)
callerNum := runtime.Callers(1, callers)
for i := 0; i < callerNum; i++ {
if funcName := runtime.FuncForPC(callers[i]).Name(); funcName == "xorm.io/xorm.(*Session).Iterate" {
contextSafetyDeniedFuncPCs = append(contextSafetyDeniedFuncPCs, callers[i])
}
}
return nil
})
if len(contextSafetyDeniedFuncPCs) != 1 {
panic(errors.New("unable to determine the functions to deny"))
}
})

// it should be very fast: xxxx ns/op
callers := make([]uintptr, 32)
callerNum := runtime.Callers(3, callers) // skip 3: runtime.Callers, contextSafetyCheck, GetEngine
for i := 0; i < callerNum; i++ {
if slices.Contains(contextSafetyDeniedFuncPCs, callers[i]) {
panic(errors.New("using database context in an iterator would cause corrupted results"))
}
}
}

// GetEngine will get a db Engine from this context or return an Engine restricted to this context
// GetEngine gets an existing db Engine/Statement or creates a new Session
func GetEngine(ctx context.Context) Engine {
if e := getEngine(ctx); e != nil {
if e := getExistingEngine(ctx); e != nil {
return e
}
return x.Context(ctx)
}

// getEngine will get a db Engine from this context or return nil
func getEngine(ctx context.Context) Engine {
if engined, ok := ctx.(Engined); ok {
return engined.Engine()
// getExistingEngine gets an existing db Engine/Statement from this context or returns nil
func getExistingEngine(ctx context.Context) (e Engine) {
defer func() { contextSafetyCheck(e) }()
if engined, ok := ctx.(*Context); ok {
return engined.engine
}
enginedInterface := ctx.Value(enginedContextKey)
if enginedInterface != nil {
return enginedInterface.(Engined).Engine()
if engined, ok := ctx.Value(engineContextKey).(*Context); ok {
return engined.engine
}
return nil
}
Expand Down Expand Up @@ -132,23 +152,23 @@ func (c *halfCommitter) Close() error {
// d. It doesn't mean rollback is forbidden, but always do it only when there is an error, and you do want to rollback.
func TxContext(parentCtx context.Context) (*Context, Committer, error) {
if sess, ok := inTransaction(parentCtx); ok {
return newContext(parentCtx, sess, true), &halfCommitter{committer: sess}, nil
return newContext(parentCtx, sess), &halfCommitter{committer: sess}, nil
}

sess := x.NewSession()
if err := sess.Begin(); err != nil {
sess.Close()
_ = sess.Close()
return nil, nil, err
}

return newContext(DefaultContext, sess, true), sess, nil
return newContext(DefaultContext, sess), sess, nil
}

// WithTx represents executing database operations on a transaction, if the transaction exist,
// this function will reuse it otherwise will create a new one and close it when finished.
func WithTx(parentCtx context.Context, f func(ctx context.Context) error) error {
if sess, ok := inTransaction(parentCtx); ok {
err := f(newContext(parentCtx, sess, true))
err := f(newContext(parentCtx, sess))
if err != nil {
// rollback immediately, in case the caller ignores returned error and tries to commit the transaction.
_ = sess.Close()
Expand All @@ -165,7 +185,7 @@ func txWithNoCheck(parentCtx context.Context, f func(ctx context.Context) error)
return err
}

if err := f(newContext(parentCtx, sess, true)); err != nil {
if err := f(newContext(parentCtx, sess)); err != nil {
return err
}

Expand Down Expand Up @@ -312,7 +332,7 @@ func InTransaction(ctx context.Context) bool {
}

func inTransaction(ctx context.Context) (*xorm.Session, bool) {
e := getEngine(ctx)
e := getExistingEngine(ctx)
if e == nil {
return nil, false
}
Expand Down
44 changes: 44 additions & 0 deletions models/db/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,47 @@ func TestTxContext(t *testing.T) {
}))
}
}

func TestContextSafety(t *testing.T) {
type TestModel1 struct {
ID int64
}
type TestModel2 struct {
ID int64
}
assert.NoError(t, unittest.GetXORMEngine().Sync(&TestModel1{}, &TestModel2{}))
assert.NoError(t, db.TruncateBeans(db.DefaultContext, &TestModel1{}, &TestModel2{}))
testCount := 10
for i := 1; i <= testCount; i++ {
assert.NoError(t, db.Insert(db.DefaultContext, &TestModel1{ID: int64(i)}))
assert.NoError(t, db.Insert(db.DefaultContext, &TestModel2{ID: int64(-i)}))
}

actualCount := 0
// here: db.GetEngine(db.DefaultContext) is a new *Session created from *Engine
_ = db.WithTx(db.DefaultContext, func(ctx context.Context) error {
_ = db.GetEngine(ctx).Iterate(&TestModel1{}, func(i int, bean any) error {
// here: db.GetEngine(ctx) is always the unclosed "Iterate" *Session with autoResetStatement=false,
// and the internal states (including "cond" and others) are always there and not be reset in this callback.
m1 := bean.(*TestModel1)
assert.EqualValues(t, i+1, m1.ID)

// here: XORM bug, it fails because the SQL becomes "WHERE id=-1", "WHERE id=-1 AND id=-2", "WHERE id=-1 AND id=-2 AND id=-3" ...
// and it conflicts with the "Iterate"'s internal states.
// has, err := db.GetEngine(ctx).Get(&TestModel2{ID: -m1.ID})

actualCount++
return nil
})
return nil
})
assert.EqualValues(t, testCount, actualCount)

// deny the bad usages
assert.PanicsWithError(t, "using database context in an iterator would cause corrupted results", func() {
_ = unittest.GetXORMEngine().Iterate(&TestModel1{}, func(i int, bean any) error {
_ = db.GetEngine(db.DefaultContext)
return nil
})
})
}
5 changes: 1 addition & 4 deletions models/db/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,7 @@ func InitEngine(ctx context.Context) error {
// SetDefaultEngine sets the default engine for db
func SetDefaultEngine(ctx context.Context, eng *xorm.Engine) {
x = eng
DefaultContext = &Context{
Context: ctx,
e: x,
}
DefaultContext = &Context{Context: ctx, engine: x}
}

// UnsetDefaultEngine closes and unsets the default engine
Expand Down
2 changes: 1 addition & 1 deletion models/db/install/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
)

func getXORMEngine() *xorm.Engine {
return db.DefaultContext.(*db.Context).Engine().(*xorm.Engine)
return db.GetEngine(db.DefaultContext).(*xorm.Engine)
}

// CheckDatabaseConnection checks the database connection
Expand Down
2 changes: 1 addition & 1 deletion models/db/iterate.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
"xorm.io/builder"
)

// Iterate iterate all the Bean object
// Iterate iterates all the Bean object
func Iterate[Bean any](ctx context.Context, cond builder.Cond, f func(ctx context.Context, bean *Bean) error) error {
var start int
batchSize := setting.Database.IterateBufferSize
Expand Down
31 changes: 16 additions & 15 deletions models/packages/debian/search.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,26 +75,27 @@ func ExistPackages(ctx context.Context, opts *PackageSearchOptions) (bool, error
}

// SearchPackages gets the packages matching the search options
func SearchPackages(ctx context.Context, opts *PackageSearchOptions, iter func(*packages.PackageFileDescriptor)) error {
return db.GetEngine(ctx).
func SearchPackages(ctx context.Context, opts *PackageSearchOptions) ([]*packages.PackageFileDescriptor, error) {
var pkgFiles []*packages.PackageFile
err := db.GetEngine(ctx).
Table("package_file").
Select("package_file.*").
Join("INNER", "package_version", "package_version.id = package_file.version_id").
Join("INNER", "package", "package.id = package_version.package_id").
Where(opts.toCond()).
Asc("package.lower_name", "package_version.created_unix").
Iterate(new(packages.PackageFile), func(_ int, bean any) error {
pf := bean.(*packages.PackageFile)

pfd, err := packages.GetPackageFileDescriptor(ctx, pf)
if err != nil {
return err
}

iter(pfd)

return nil
})
Asc("package.lower_name", "package_version.created_unix").Find(&pkgFiles)
if err != nil {
return nil, err
}
pfds := make([]*packages.PackageFileDescriptor, 0, len(pkgFiles))
for _, pf := range pkgFiles {
pfd, err := packages.GetPackageFileDescriptor(ctx, pf)
if err != nil {
return nil, err
}
pfds = append(pfds, pfd)
}
return pfds, nil
}

// GetDistributions gets all available distributions
Expand Down
2 changes: 1 addition & 1 deletion models/unittest/fixtures.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func GetXORMEngine(engine ...*xorm.Engine) (x *xorm.Engine) {
if len(engine) == 1 {
return engine[0]
}
return db.DefaultContext.(*db.Context).Engine().(*xorm.Engine)
return db.GetEngine(db.DefaultContext).(*xorm.Engine)
}

// InitFixtures initialize test fixtures for a test database
Expand Down
2 changes: 1 addition & 1 deletion services/packages/cleanup/cleanup.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (
rpm_service "code.gitea.io/gitea/services/packages/rpm"
)

// Task method to execute cleanup rules and cleanup expired package data
// CleanupTask executes cleanup rules and cleanup expired package data
func CleanupTask(ctx context.Context, olderThan time.Duration) error {
if err := ExecuteCleanupRules(ctx); err != nil {
return err
Expand Down
9 changes: 5 additions & 4 deletions services/packages/debian/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,11 @@ func buildPackagesIndices(ctx context.Context, ownerID int64, repoVersion *packa
w := io.MultiWriter(packagesContent, gzw, xzw)

addSeparator := false
if err := debian_model.SearchPackages(ctx, opts, func(pfd *packages_model.PackageFileDescriptor) {
pfds, err := debian_model.SearchPackages(ctx, opts)
if err != nil {
return err
}
for _, pfd := range pfds {
if addSeparator {
fmt.Fprintln(w)
}
Expand All @@ -220,10 +224,7 @@ func buildPackagesIndices(ctx context.Context, ownerID int64, repoVersion *packa
fmt.Fprintf(w, "SHA1: %s\n", pfd.Blob.HashSHA1)
fmt.Fprintf(w, "SHA256: %s\n", pfd.Blob.HashSHA256)
fmt.Fprintf(w, "SHA512: %s\n", pfd.Blob.HashSHA512)
}); err != nil {
return err
}

gzw.Close()
xzw.Close()

Expand Down
Loading
Loading