Skip to content

Commit d7be38c

Browse files
authored
Merge branch 'go-gorm:master' into master
2 parents dd8088c + 0a5395f commit d7be38c

File tree

7 files changed

+89
-136
lines changed

7 files changed

+89
-136
lines changed

create.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,18 @@ func Create(db *gorm.DB) {
123123
if db.AddError(err) == nil {
124124
defer rows.Close()
125125
gorm.Scan(rows, db, gorm.ScanUpdate|gorm.ScanOnConflictDoNothing)
126+
if db.Statement.Result != nil {
127+
db.Statement.Result.RowsAffected = db.RowsAffected
128+
}
126129
}
127130
} else {
128131
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
129132
if db.AddError(err) == nil {
130133
db.RowsAffected, _ = result.RowsAffected()
134+
if db.Statement.Result != nil {
135+
db.Statement.Result.Result = result
136+
db.Statement.Result.RowsAffected = db.RowsAffected
137+
}
131138
}
132139
}
133140
}

error_translator.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
// The error codes to map mssql errors to gorm errors, here is a reference about error codes for mssql https://learn.microsoft.com/en-us/sql/relational-databases/errors-events/database-engine-events-and-errors?view=sql-server-ver16
1010
var errCodes = map[int32]error{
1111
2627: gorm.ErrDuplicatedKey,
12+
2601: gorm.ErrDuplicatedKey,
1213
547: gorm.ErrForeignKeyViolated,
1314
}
1415

error_translator_test.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ func TestDialector_Translate(t *testing.T) {
2727
args: args{err: mssql.Error{Number: 2627}},
2828
want: gorm.ErrDuplicatedKey,
2929
},
30+
{
31+
name: "it should return ErrDuplicatedKey error if the error number is 2601",
32+
args: args{err: mssql.Error{Number: 2601}},
33+
want: gorm.ErrDuplicatedKey,
34+
},
3035
{
3136
name: "it should return ErrForeignKeyViolated the error number is 547",
3237
args: args{err: mssql.Error{Number: 547}},

go.mod

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module gorm.io/driver/sqlserver
33
go 1.14
44

55
require (
6-
github.com/microsoft/go-mssqldb v1.8.0
7-
gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde
6+
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
7+
github.com/microsoft/go-mssqldb v0.19.0
8+
gorm.io/gorm v1.30.0
89
)

go.sum

Lines changed: 41 additions & 100 deletions
Large diffs are not rendered by default.

migrator.go

Lines changed: 27 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ func (m Migrator) CreateTable(values ...interface{}) (err error) {
4848
}
4949
for _, fieldName := range stmt.Schema.DBNames {
5050
field := stmt.Schema.FieldsByDBName[fieldName]
51-
if field.Comment == "" {
51+
if _, ok := field.TagSettings["COMMENT"]; !ok {
5252
continue
5353
}
5454
if err = m.setColumnComment(stmt, field, true); err != nil {
@@ -65,17 +65,18 @@ func (m Migrator) CreateTable(values ...interface{}) (err error) {
6565

6666
func (m Migrator) setColumnComment(stmt *gorm.Statement, field *schema.Field, add bool) error {
6767
schemaName := m.getTableSchemaName(stmt.Schema)
68+
commentExpr := gorm.Expr(strings.ReplaceAll(field.Comment, "'", "''"))
6869
// add field comment
6970
if add {
7071
return m.DB.Exec(
71-
"EXEC sp_addextendedproperty 'MS_Description', ?, 'SCHEMA', ?, 'TABLE', ?, 'COLUMN', ?",
72-
field.Comment, schemaName, stmt.Table, field.DBName,
72+
"EXEC sp_addextendedproperty 'MS_Description', N'?', 'SCHEMA', ?, 'TABLE', ?, 'COLUMN', ?",
73+
commentExpr, schemaName, stmt.Table, field.DBName,
7374
).Error
7475
}
7576
// update field comment
7677
return m.DB.Exec(
77-
"EXEC sp_updateextendedproperty 'MS_Description', ?, 'SCHEMA', ?, 'TABLE', ?, 'COLUMN', ?",
78-
field.Comment, schemaName, stmt.Table, field.DBName,
78+
"EXEC sp_updateextendedproperty 'MS_Description', N'?', 'SCHEMA', ?, 'TABLE', ?, 'COLUMN', ?",
79+
commentExpr, schemaName, stmt.Table, field.DBName,
7980
).Error
8081
}
8182

@@ -121,7 +122,7 @@ func getFullQualifiedTableName(stmt *gorm.Statement) string {
121122

122123
func (m Migrator) HasTable(value interface{}) bool {
123124
var count int
124-
m.RunWithValue(value, func(stmt *gorm.Statement) error {
125+
_ = m.RunWithValue(value, func(stmt *gorm.Statement) error {
125126
schemaName := getTableSchemaName(stmt.Schema)
126127
if schemaName == "" {
127128
schemaName = "%"
@@ -202,7 +203,7 @@ func (m Migrator) AddColumn(value interface{}, name string) error {
202203
return m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
203204
if stmt.Schema != nil {
204205
if field := stmt.Schema.LookUpField(name); field != nil {
205-
if field.Comment == "" {
206+
if _, ok := field.TagSettings["COMMENT"]; !ok {
206207
return
207208
}
208209
if err = m.setColumnComment(stmt, field, true); err != nil {
@@ -216,7 +217,7 @@ func (m Migrator) AddColumn(value interface{}, name string) error {
216217

217218
func (m Migrator) HasColumn(value interface{}, field string) bool {
218219
var count int64
219-
m.RunWithValue(value, func(stmt *gorm.Statement) error {
220+
_ = m.RunWithValue(value, func(stmt *gorm.Statement) error {
220221
currentDatabase := m.DB.Migrator().CurrentDatabase()
221222
name := field
222223
if stmt.Schema != nil {
@@ -273,17 +274,13 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error
273274
})
274275
}
275276

276-
func (m Migrator) GetColumnComment(stmt *gorm.Statement, fieldDBName string) (description string) {
277+
func (m Migrator) GetColumnComment(stmt *gorm.Statement, fieldDBName string) (comment sql.NullString) {
277278
queryTx := m.DB.Session(&gorm.Session{Logger: m.DB.Logger.LogMode(logger.Warn)})
278279
if m.DB.DryRun {
279280
queryTx.DryRun = false
280281
}
281-
var comment sql.NullString
282282
queryTx.Raw("SELECT value FROM ?.sys.fn_listextendedproperty('MS_Description', 'SCHEMA', ?, 'TABLE', ?, 'COLUMN', ?)",
283283
gorm.Expr(m.CurrentDatabase()), m.getTableSchemaName(stmt.Schema), stmt.Table, fieldDBName).Scan(&comment)
284-
if comment.Valid {
285-
description = comment.String
286-
}
287284
return
288285
}
289286

@@ -293,12 +290,12 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
293290
}
294291

295292
return m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
296-
description := m.GetColumnComment(stmt, field.DBName)
297-
if field.Comment != description {
298-
if description == "" {
299-
err = m.setColumnComment(stmt, field, true)
300-
} else {
293+
comment := m.GetColumnComment(stmt, field.DBName)
294+
if field.Comment != comment.String {
295+
if comment.Valid {
301296
err = m.setColumnComment(stmt, field, false)
297+
} else {
298+
err = m.setColumnComment(stmt, field, true)
302299
}
303300
}
304301
return
@@ -317,7 +314,7 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
317314
}
318315

319316
rawColumnTypes, _ := rows.ColumnTypes()
320-
rows.Close()
317+
_ = rows.Close()
321318

322319
{
323320
_, schemaName, tableName := splitFullQualifiedName(stmt.Table)
@@ -394,7 +391,7 @@ WHERE TABLE_CATALOG = ? AND TABLE_NAME = ?`)
394391
columnTypes = append(columnTypes, column)
395392
}
396393

397-
columns.Close()
394+
_ = columns.Close()
398395
}
399396

400397
{
@@ -415,7 +412,7 @@ WHERE TABLE_CATALOG = ? AND TABLE_NAME = ?`)
415412

416413
for columnTypeRows.Next() {
417414
var name, columnType string
418-
columnTypeRows.Scan(&name, &columnType)
415+
_ = columnTypeRows.Scan(&name, &columnType)
419416
for idx, c := range columnTypes {
420417
mc := c.(migrator.ColumnType)
421418
if mc.NameValue.String == name {
@@ -431,7 +428,7 @@ WHERE TABLE_CATALOG = ? AND TABLE_NAME = ?`)
431428
}
432429
}
433430

434-
columnTypeRows.Close()
431+
_ = columnTypeRows.Close()
435432
}
436433

437434
return
@@ -473,7 +470,7 @@ func (m Migrator) CreateIndex(value interface{}, name string) error {
473470

474471
func (m Migrator) HasIndex(value interface{}, name string) bool {
475472
var count int
476-
m.RunWithValue(value, func(stmt *gorm.Statement) error {
473+
_ = m.RunWithValue(value, func(stmt *gorm.Statement) error {
477474
if stmt.Schema != nil {
478475
if idx := stmt.Schema.LookIndex(name); idx != nil {
479476
name = idx.Name
@@ -537,34 +534,34 @@ func (m Migrator) GetIndexes(value interface{}) ([]gorm.Index, error) {
537534

538535
func (m Migrator) HasConstraint(value interface{}, name string) bool {
539536
var count int64
540-
m.RunWithValue(value, func(stmt *gorm.Statement) error {
537+
_ = m.RunWithValue(value, func(stmt *gorm.Statement) error {
541538
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
542539
if constraint != nil {
543540
name = constraint.GetName()
544541
}
545542

546-
tableCatalog, schema, tableName := splitFullQualifiedName(table)
543+
tableCatalog, tableSchema, tableName := splitFullQualifiedName(table)
547544
if tableCatalog == "" {
548545
tableCatalog = m.CurrentDatabase()
549546
}
550-
if schema == "" {
551-
schema = "%"
547+
if tableSchema == "" {
548+
tableSchema = "%"
552549
}
553550

554551
return m.DB.Raw(
555552
`SELECT count(*) FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id inner join INFORMATION_SCHEMA.TABLES as I on I.TABLE_NAME = T.name WHERE F.name = ? AND I.TABLE_NAME = ? AND I.TABLE_SCHEMA like ? AND I.TABLE_CATALOG = ?;`,
556-
name, tableName, schema, tableCatalog,
553+
name, tableName, tableSchema, tableCatalog,
557554
).Row().Scan(&count)
558555
})
559556
return count > 0
560557
}
561558

562559
func (m Migrator) CurrentDatabase() (name string) {
563-
m.DB.Raw("SELECT DB_NAME() AS [Current Database]").Row().Scan(&name)
560+
_ = m.DB.Raw("SELECT DB_NAME() AS [Current Database]").Row().Scan(&name)
564561
return
565562
}
566563

567564
func (m Migrator) DefaultSchema() (name string) {
568-
m.DB.Raw("SELECT SCHEMA_NAME() AS [Default Schema]").Row().Scan(&name)
565+
_ = m.DB.Raw("SELECT SCHEMA_NAME() AS [Default Schema]").Row().Scan(&name)
569566
return
570567
}

migrator_test.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -191,18 +191,19 @@ func testGetMigrateColumns(db *gorm.DB, dst interface{}) (columnsWithDefault, co
191191
}
192192

193193
type TestTableFieldComment struct {
194-
ID string `gorm:"column:id;primaryKey"`
194+
ID string `gorm:"column:id;primaryKey;comment:"` // field comment is an empty string
195195
Name string `gorm:"column:name;comment:姓名"`
196196
Age uint `gorm:"column:age;comment:年龄"`
197197
}
198198

199199
func (*TestTableFieldComment) TableName() string { return "test_table_field_comment" }
200200

201201
type TestTableFieldCommentUpdate struct {
202-
ID string `gorm:"column:id;primaryKey"`
202+
ID string `gorm:"column:id;primaryKey;comment:ID"`
203203
Name string `gorm:"column:name;comment:姓名"`
204204
Age uint `gorm:"column:age;comment:周岁"`
205205
Birthday *time.Time `gorm:"column:birthday;comment:生日"`
206+
Quote string `gorm:"column:quote;comment:注释中包含'单引号'和特殊符号❤️"`
206207
}
207208

208209
func (*TestTableFieldCommentUpdate) TableName() string { return "test_table_field_comment" }
@@ -235,12 +236,12 @@ func TestMigrator_MigrateColumnComment(t *testing.T) {
235236
t.Fatal("expected Statement.Schema, got nil")
236237
}
237238

238-
wantComments := []string{"", "姓名", "周岁", "生日"}
239+
wantComments := []string{"ID", "姓名", "周岁", "生日", "注释中包含'单引号'和特殊符号❤️"}
239240
gotComments := make([]string, len(stmt.Schema.DBNames))
240241

241242
for i, fieldDBName := range stmt.Schema.DBNames {
242243
comment := m.GetColumnComment(stmt, fieldDBName)
243-
gotComments[i] = comment
244+
gotComments[i] = comment.String
244245
}
245246

246247
if !reflect.DeepEqual(wantComments, gotComments) {

0 commit comments

Comments
 (0)