Skip to content

sql reader在构建数据库表名时根据具体的数据库类型构建 #704

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 13, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions reader/sql/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -1507,7 +1507,7 @@ func (r *Reader) getValidData(connectStr, curDB, matchData, matchStr string,
continue
}

rawSql, err := getRawSqls(queryType, s)
rawSql, err := r.getRawSqls(queryType, s)
if err != nil {
return validData, sqls, err
}
Expand Down Expand Up @@ -1678,14 +1678,34 @@ func (r *Reader) getCheckAll(queryType int) (checkAll bool, err error) {

return true, nil
}
//根据数据库类型返回表名
func getWrappedTableName(dbtype string, table string) (tableName string, err error) {
switch dbtype {
case reader.ModeMySQL:
tableName = "`" + table + "`"
case reader.ModeMSSQL, reader.ModePostgreSQL:
tableName = "\"" + table + "\""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

就针对pg的加上双引号就行了,把MSSQL单独弄出来

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sqlserver 如果表名有横杆的话,也需要加双引号,否则会报语法错误,比如
rows ,err:=sql.Query("select * from log-2018-08-01") 会有问题

default:
err = fmt.Errorf("%v mode not support in sql reader", dbtype)
}
return
}

// 根据 queryType 获取表中所有记录或者表中所有数据的条数的sql语句
func getRawSqls(queryType int, table string) (sqls string, err error) {
func (r *Reader) getRawSqls(queryType int, table string) (sqls string, err error) {
switch queryType {
case TABLE:
sqls += "Select * From `" + table + "`;"
tableName, err := getWrappedTableName(r.dbtype, table)
if err != nil {
return "", err
}
sqls += "Select * From " + tableName + ";"
case COUNT:
sqls += "Select Count(*) From `" + table + "`;"
tableName, err := getWrappedTableName(r.dbtype, table)
if err != nil {
return "", err
}
sqls += "Select Count(*) From " + tableName + ";"
case DATABASE:
default:
return "", fmt.Errorf("%v queryType is not support get sql now", queryType)
Expand Down
45 changes: 42 additions & 3 deletions reader/sql/sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1017,9 +1017,24 @@ func Test_getCheckAll(t *testing.T) {
assert.EqualValues(t, test.expRes, checkHistory)
}
}
func Test_getWrappedTableName(t *testing.T) {
dbtype := reader.ModeMySQL
tname, err := getWrappedTableName(dbtype, "my_table")
expRes := "`my_table`"
assert.NoError(t, err)
assert.EqualValues(t, expRes, tname)

dbtype = reader.ModePostgreSQL
tname, err = getWrappedTableName(dbtype, "my_table")
expRes = "\"my_table\""
assert.NoError(t, err)
assert.EqualValues(t, expRes, tname)
}
func Test_getRawSQLs(t *testing.T) {
tests := []struct {
r := &Reader{
dbtype: reader.ModeMySQL,
}
mysqltests := []struct {
queryType int
expSQLs string
}{
Expand All @@ -1037,11 +1052,35 @@ func Test_getRawSQLs(t *testing.T) {
},
}

for _, test := range tests {
sqls, err := getRawSqls(test.queryType, "my_table")
for _, test := range mysqltests {
sqls, err := r.getRawSqls(test.queryType, "my_table")
assert.NoError(t, err)
assert.EqualValues(t, test.expSQLs, sqls)
}
r.dbtype = reader.ModePostgreSQL
pgtests := []struct {
queryType int
expSQLs string
}{
{
queryType: TABLE,
expSQLs: "Select * From \"my_table\";",
},
{
queryType: COUNT,
expSQLs: "Select Count(*) From \"my_table\";",
},
{
queryType: DATABASE,
expSQLs: "",
},
}
for _, test := range pgtests {
sqls, err := r.getRawSqls(test.queryType, "my_table")
assert.NoError(t, err)
assert.EqualValues(t, test.expSQLs, sqls)
}

}

func Test_getConnectStr(t *testing.T) {
Expand Down