Skip to content

Commit aef868a

Browse files
committed
fix: failed to modify field comments when empty due to incorrect conditions
Ref: go-gorm#140
1 parent 4937266 commit aef868a

File tree

2 files changed

+38
-40
lines changed

2 files changed

+38
-40
lines changed

migrator.go

Lines changed: 27 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ func (m Migrator) CreateTable(values ...interface{}) (err error) {
4848
}
4949
for _, fieldName := range stmt.Schema.DBNames {
5050
field := stmt.Schema.FieldsByDBName[fieldName]
51-
if field.Comment == "" {
51+
if _, ok := field.TagSettings["COMMENT"]; !ok {
5252
continue
5353
}
5454
if err = m.setColumnComment(stmt, field, true); err != nil {
@@ -65,17 +65,18 @@ func (m Migrator) CreateTable(values ...interface{}) (err error) {
6565

6666
func (m Migrator) setColumnComment(stmt *gorm.Statement, field *schema.Field, add bool) error {
6767
schemaName := m.getTableSchemaName(stmt.Schema)
68+
commentExpr := gorm.Expr(strings.ReplaceAll(field.Comment, "'", "''"))
6869
// add field comment
6970
if add {
7071
return m.DB.Exec(
71-
"EXEC sp_addextendedproperty 'MS_Description', ?, 'SCHEMA', ?, 'TABLE', ?, 'COLUMN', ?",
72-
field.Comment, schemaName, stmt.Table, field.DBName,
72+
"EXEC sp_addextendedproperty 'MS_Description', N'?', 'SCHEMA', ?, 'TABLE', ?, 'COLUMN', ?",
73+
commentExpr, schemaName, stmt.Table, field.DBName,
7374
).Error
7475
}
7576
// update field comment
7677
return m.DB.Exec(
77-
"EXEC sp_updateextendedproperty 'MS_Description', ?, 'SCHEMA', ?, 'TABLE', ?, 'COLUMN', ?",
78-
field.Comment, schemaName, stmt.Table, field.DBName,
78+
"EXEC sp_updateextendedproperty 'MS_Description', N'?', 'SCHEMA', ?, 'TABLE', ?, 'COLUMN', ?",
79+
commentExpr, schemaName, stmt.Table, field.DBName,
7980
).Error
8081
}
8182

@@ -121,7 +122,7 @@ func getFullQualifiedTableName(stmt *gorm.Statement) string {
121122

122123
func (m Migrator) HasTable(value interface{}) bool {
123124
var count int
124-
m.RunWithValue(value, func(stmt *gorm.Statement) error {
125+
_ = m.RunWithValue(value, func(stmt *gorm.Statement) error {
125126
schemaName := getTableSchemaName(stmt.Schema)
126127
if schemaName == "" {
127128
schemaName = "%"
@@ -202,7 +203,7 @@ func (m Migrator) AddColumn(value interface{}, name string) error {
202203
return m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
203204
if stmt.Schema != nil {
204205
if field := stmt.Schema.LookUpField(name); field != nil {
205-
if field.Comment == "" {
206+
if _, ok := field.TagSettings["COMMENT"]; !ok {
206207
return
207208
}
208209
if err = m.setColumnComment(stmt, field, true); err != nil {
@@ -216,7 +217,7 @@ func (m Migrator) AddColumn(value interface{}, name string) error {
216217

217218
func (m Migrator) HasColumn(value interface{}, field string) bool {
218219
var count int64
219-
m.RunWithValue(value, func(stmt *gorm.Statement) error {
220+
_ = m.RunWithValue(value, func(stmt *gorm.Statement) error {
220221
currentDatabase := m.DB.Migrator().CurrentDatabase()
221222
name := field
222223
if stmt.Schema != nil {
@@ -273,17 +274,13 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error
273274
})
274275
}
275276

276-
func (m Migrator) GetColumnComment(stmt *gorm.Statement, fieldDBName string) (description string) {
277+
func (m Migrator) GetColumnComment(stmt *gorm.Statement, fieldDBName string) (comment sql.NullString) {
277278
queryTx := m.DB.Session(&gorm.Session{Logger: m.DB.Logger.LogMode(logger.Warn)})
278279
if m.DB.DryRun {
279280
queryTx.DryRun = false
280281
}
281-
var comment sql.NullString
282282
queryTx.Raw("SELECT value FROM ?.sys.fn_listextendedproperty('MS_Description', 'SCHEMA', ?, 'TABLE', ?, 'COLUMN', ?)",
283283
gorm.Expr(m.CurrentDatabase()), m.getTableSchemaName(stmt.Schema), stmt.Table, fieldDBName).Scan(&comment)
284-
if comment.Valid {
285-
description = comment.String
286-
}
287284
return
288285
}
289286

@@ -293,12 +290,12 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
293290
}
294291

295292
return m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
296-
description := m.GetColumnComment(stmt, field.DBName)
297-
if field.Comment != description {
298-
if description == "" {
299-
err = m.setColumnComment(stmt, field, true)
300-
} else {
293+
comment := m.GetColumnComment(stmt, field.DBName)
294+
if field.Comment != comment.String {
295+
if comment.Valid {
301296
err = m.setColumnComment(stmt, field, false)
297+
} else {
298+
err = m.setColumnComment(stmt, field, true)
302299
}
303300
}
304301
return
@@ -317,7 +314,7 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
317314
}
318315

319316
rawColumnTypes, _ := rows.ColumnTypes()
320-
rows.Close()
317+
_ = rows.Close()
321318

322319
{
323320
_, schemaName, tableName := splitFullQualifiedName(stmt.Table)
@@ -394,7 +391,7 @@ WHERE TABLE_CATALOG = ? AND TABLE_NAME = ?`)
394391
columnTypes = append(columnTypes, column)
395392
}
396393

397-
columns.Close()
394+
_ = columns.Close()
398395
}
399396

400397
{
@@ -415,7 +412,7 @@ WHERE TABLE_CATALOG = ? AND TABLE_NAME = ?`)
415412

416413
for columnTypeRows.Next() {
417414
var name, columnType string
418-
columnTypeRows.Scan(&name, &columnType)
415+
_ = columnTypeRows.Scan(&name, &columnType)
419416
for idx, c := range columnTypes {
420417
mc := c.(migrator.ColumnType)
421418
if mc.NameValue.String == name {
@@ -431,7 +428,7 @@ WHERE TABLE_CATALOG = ? AND TABLE_NAME = ?`)
431428
}
432429
}
433430

434-
columnTypeRows.Close()
431+
_ = columnTypeRows.Close()
435432
}
436433

437434
return
@@ -473,7 +470,7 @@ func (m Migrator) CreateIndex(value interface{}, name string) error {
473470

474471
func (m Migrator) HasIndex(value interface{}, name string) bool {
475472
var count int
476-
m.RunWithValue(value, func(stmt *gorm.Statement) error {
473+
_ = m.RunWithValue(value, func(stmt *gorm.Statement) error {
477474
if stmt.Schema != nil {
478475
if idx := stmt.Schema.LookIndex(name); idx != nil {
479476
name = idx.Name
@@ -538,34 +535,34 @@ func (m Migrator) GetIndexes(value interface{}) ([]gorm.Index, error) {
538535

539536
func (m Migrator) HasConstraint(value interface{}, name string) bool {
540537
var count int64
541-
m.RunWithValue(value, func(stmt *gorm.Statement) error {
538+
_ = m.RunWithValue(value, func(stmt *gorm.Statement) error {
542539
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
543540
if constraint != nil {
544541
name = constraint.GetName()
545542
}
546543

547-
tableCatalog, schema, tableName := splitFullQualifiedName(table)
544+
tableCatalog, tableSchema, tableName := splitFullQualifiedName(table)
548545
if tableCatalog == "" {
549546
tableCatalog = m.CurrentDatabase()
550547
}
551-
if schema == "" {
552-
schema = "%"
548+
if tableSchema == "" {
549+
tableSchema = "%"
553550
}
554551

555552
return m.DB.Raw(
556553
`SELECT count(*) FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id inner join INFORMATION_SCHEMA.TABLES as I on I.TABLE_NAME = T.name WHERE F.name = ? AND I.TABLE_NAME = ? AND I.TABLE_SCHEMA like ? AND I.TABLE_CATALOG = ?;`,
557-
name, tableName, schema, tableCatalog,
554+
name, tableName, tableSchema, tableCatalog,
558555
).Row().Scan(&count)
559556
})
560557
return count > 0
561558
}
562559

563560
func (m Migrator) CurrentDatabase() (name string) {
564-
m.DB.Raw("SELECT DB_NAME() AS [Current Database]").Row().Scan(&name)
561+
_ = m.DB.Raw("SELECT DB_NAME() AS [Current Database]").Row().Scan(&name)
565562
return
566563
}
567564

568565
func (m Migrator) DefaultSchema() (name string) {
569-
m.DB.Raw("SELECT SCHEMA_NAME() AS [Default Schema]").Row().Scan(&name)
566+
_ = m.DB.Raw("SELECT SCHEMA_NAME() AS [Default Schema]").Row().Scan(&name)
570567
return
571568
}

migrator_test.go

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -190,56 +190,57 @@ func testGetMigrateColumns(db *gorm.DB, dst interface{}) (columnsWithDefault, co
190190
}
191191

192192
type TestTableFieldComment struct {
193-
ID string `gorm:"column:id;primaryKey"`
193+
ID string `gorm:"column:id;primaryKey;comment:"` // field comment is an empty string
194194
Name string `gorm:"column:name;comment:姓名"`
195195
Age uint `gorm:"column:age;comment:年龄"`
196196
}
197197

198198
func (*TestTableFieldComment) TableName() string { return "test_table_field_comment" }
199199

200200
type TestTableFieldCommentUpdate struct {
201-
ID string `gorm:"column:id;primaryKey"`
201+
ID string `gorm:"column:id;primaryKey;comment:ID"`
202202
Name string `gorm:"column:name;comment:姓名"`
203203
Age uint `gorm:"column:age;comment:周岁"`
204204
Birthday *time.Time `gorm:"column:birthday;comment:生日"`
205+
Quote string `gorm:"column:quote;comment:注释中包含'单引号'和特殊符号❤️"`
205206
}
206207

207208
func (*TestTableFieldCommentUpdate) TableName() string { return "test_table_field_comment" }
208209

209210
func TestMigrator_MigrateColumnComment(t *testing.T) {
210211
db, err := gorm.Open(sqlserver.Open(sqlserverDSN))
211212
if err != nil {
212-
t.Error(err)
213+
t.Fatal(err)
213214
}
214-
migrator := db.Debug().Migrator()
215+
dm := db.Debug().Migrator()
215216

216217
tableModel := new(TestTableFieldComment)
217218
defer func() {
218-
if err = migrator.DropTable(tableModel); err != nil {
219+
if err = dm.DropTable(tableModel); err != nil {
219220
t.Errorf("couldn't drop table %q, got error: %v", tableModel.TableName(), err)
220221
}
221222
}()
222223

223-
if err = migrator.AutoMigrate(tableModel); err != nil {
224+
if err = dm.AutoMigrate(tableModel); err != nil {
224225
t.Fatal(err)
225226
}
226227
tableModelUpdate := new(TestTableFieldCommentUpdate)
227-
if err = migrator.AutoMigrate(tableModelUpdate); err != nil {
228+
if err = dm.AutoMigrate(tableModelUpdate); err != nil {
228229
t.Error(err)
229230
}
230231

231-
if m, ok := migrator.(sqlserver.Migrator); ok {
232+
if m, ok := dm.(sqlserver.Migrator); ok {
232233
stmt := db.Model(tableModelUpdate).Find(nil).Statement
233234
if stmt == nil || stmt.Schema == nil {
234235
t.Fatal("expected Statement.Schema, got nil")
235236
}
236237

237-
wantComments := []string{"", "姓名", "周岁", "生日"}
238+
wantComments := []string{"ID", "姓名", "周岁", "生日", "注释中包含'单引号'和特殊符号❤️"}
238239
gotComments := make([]string, len(stmt.Schema.DBNames))
239240

240241
for i, fieldDBName := range stmt.Schema.DBNames {
241242
comment := m.GetColumnComment(stmt, fieldDBName)
242-
gotComments[i] = comment
243+
gotComments[i] = comment.String
243244
}
244245

245246
if !reflect.DeepEqual(wantComments, gotComments) {

0 commit comments

Comments
 (0)