Skip to content

Commit 64af933

Browse files
authored
fix: migrator force modification of fields with no default value (go-gorm#134)
* fix: migrator force modification of fields with no default value * test: re-migrate table fields with or without default value
1 parent 15fe45b commit 64af933

File tree

2 files changed

+73
-2
lines changed

2 files changed

+73
-2
lines changed

migrator.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,8 +256,6 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
256256
column.DefaultValueValue.String = matches[1]
257257
matches = defaultValueTrimRegexp.FindStringSubmatch(column.DefaultValueValue.String)
258258
}
259-
} else {
260-
column.DefaultValueValue.Valid = true
261259
}
262260

263261
for _, c := range rawColumnTypes {

migrator_test.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ package sqlserver_test
22

33
import (
44
"os"
5+
"reflect"
56
"testing"
7+
"time"
68

79
"gorm.io/driver/sqlserver"
810
"gorm.io/gorm"
@@ -115,3 +117,74 @@ func TestCreateIndex(t *testing.T) {
115117
t.Error("couldn't drop table testtable", tx.Error)
116118
}
117119
}
120+
121+
type TestTableDefaultValue struct {
122+
ID string `gorm:"column:id;primaryKey"`
123+
Name string `gorm:"column:name"`
124+
Age uint `gorm:"column:age"`
125+
Birthday *time.Time `gorm:"column:birthday"`
126+
CompanyID *int `gorm:"column:company_id;default:0"`
127+
ManagerID *uint `gorm:"column:manager_id;default:0"`
128+
Active bool `gorm:"column:active;default:1"`
129+
}
130+
131+
func (*TestTableDefaultValue) TableName() string { return "test_table_default_value" }
132+
133+
func TestReMigrateTableFieldsWithoutDefaultValue(t *testing.T) {
134+
db, err := gorm.Open(sqlserver.Open(sqlserverDSN))
135+
if err != nil {
136+
t.Error(err)
137+
}
138+
139+
var (
140+
migrator = db.Migrator()
141+
tableModel = new(TestTableDefaultValue)
142+
fieldsWithDefault = []string{"company_id", "manager_id", "active"}
143+
fieldsWithoutDefault = []string{"id", "name", "age", "birthday"}
144+
145+
columnsWithDefault []string
146+
columnsWithoutDefault []string
147+
)
148+
149+
defer func() {
150+
if err = migrator.DropTable(tableModel); err != nil {
151+
t.Errorf("couldn't drop table %q, got error: %v", tableModel.TableName(), err)
152+
}
153+
}()
154+
if !migrator.HasTable(tableModel) {
155+
if err = migrator.AutoMigrate(tableModel); err != nil {
156+
t.Errorf("couldn't auto migrate table %q, got error: %v", tableModel.TableName(), err)
157+
}
158+
}
159+
// If in the `Migrator.ColumnTypes` method `column.DefaultValueValue.Valid = true`,
160+
// re-migrate the table will alter all fields without default value except for the primary key.
161+
if err = db.Debug().Migrator().AutoMigrate(tableModel); err != nil {
162+
t.Errorf("couldn't re-migrate table %q, got error: %v", tableModel.TableName(), err)
163+
}
164+
165+
columnsWithDefault, columnsWithoutDefault, err = testGetMigrateColumns(db, tableModel)
166+
if !reflect.DeepEqual(columnsWithDefault, fieldsWithDefault) {
167+
// If in the `Migrator.ColumnTypes` method `column.DefaultValueValue.Valid = true`,
168+
// fields with default value will include all fields: `[id name age birthday company_id manager_id active]`.
169+
t.Errorf("expected columns with default value %v, got %v", fieldsWithDefault, columnsWithDefault)
170+
}
171+
if !reflect.DeepEqual(columnsWithoutDefault, fieldsWithoutDefault) {
172+
t.Errorf("expected columns without default value %v, got %v", fieldsWithoutDefault, columnsWithoutDefault)
173+
}
174+
}
175+
176+
func testGetMigrateColumns(db *gorm.DB, dst interface{}) (columnsWithDefault, columnsWithoutDefault []string, err error) {
177+
migrator := db.Migrator()
178+
var columnTypes []gorm.ColumnType
179+
if columnTypes, err = migrator.ColumnTypes(dst); err != nil {
180+
return
181+
}
182+
for _, columnType := range columnTypes {
183+
if _, ok := columnType.DefaultValue(); ok {
184+
columnsWithDefault = append(columnsWithDefault, columnType.Name())
185+
} else {
186+
columnsWithoutDefault = append(columnsWithoutDefault, columnType.Name())
187+
}
188+
}
189+
return
190+
}

0 commit comments

Comments
 (0)