Skip to content

Commit 50f682d

Browse files
committed
Enhance interpolateParams to correctly handle placeholders in queries with comments, strings, and backticks.
* Add `findParamPositions` to identify real parameter positions * Update and expand related tests.
1 parent 76c00e3 commit 50f682d

File tree

2 files changed

+162
-39
lines changed

2 files changed

+162
-39
lines changed

connection.go

Lines changed: 114 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ func (mc *mysqlConn) close() {
172172
}
173173

174174
// Closes the network connection and unsets internal variables. Do not call this
175-
// function after successfully authentication, call Close instead. This function
175+
// function after successful authentication, call Close instead. This function
176176
// is called before auth or on auth failure because MySQL will have already
177177
// closed the network connection.
178178
func (mc *mysqlConn) cleanup() {
@@ -245,9 +245,105 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
245245
return stmt, err
246246
}
247247

248+
// findParamPositions returns the positions of real parameter holders ('?') in the query, ignoring those in comments, strings, or backticks.
249+
func findParamPositions(query string, noBackslashEscapes bool) []int {
250+
const (
251+
stateNormal = iota
252+
stateString
253+
stateEscape
254+
stateEOLComment
255+
stateSlashStarComment
256+
stateBacktick
257+
)
258+
259+
var (
260+
QUOTE_BYTE = byte('\'')
261+
DBL_QUOTE_BYTE = byte('"')
262+
BACKSLASH_BYTE = byte('\\')
263+
QUESTION_MARK_BYTE = byte('?')
264+
SLASH_BYTE = byte('/')
265+
STAR_BYTE = byte('*')
266+
HASH_BYTE = byte('#')
267+
MINUS_BYTE = byte('-')
268+
LINE_FEED_BYTE = byte('\n')
269+
RADICAL_BYTE = byte('`')
270+
)
271+
272+
paramPositions := make([]int, 0)
273+
state := stateNormal
274+
singleQuotes := false
275+
lastChar := byte(0)
276+
lenq := len(query)
277+
for i := 0; i < lenq; i++ {
278+
currentChar := query[i]
279+
if state == stateEscape && !((currentChar == QUOTE_BYTE && singleQuotes) || (currentChar == DBL_QUOTE_BYTE && !singleQuotes)) {
280+
state = stateString
281+
lastChar = currentChar
282+
continue
283+
}
284+
switch currentChar {
285+
case STAR_BYTE:
286+
if state == stateNormal && lastChar == SLASH_BYTE {
287+
state = stateSlashStarComment
288+
}
289+
case SLASH_BYTE:
290+
if state == stateSlashStarComment && lastChar == STAR_BYTE {
291+
state = stateNormal
292+
}
293+
case HASH_BYTE:
294+
if state == stateNormal {
295+
state = stateEOLComment
296+
}
297+
case MINUS_BYTE:
298+
if state == stateNormal && lastChar == MINUS_BYTE {
299+
state = stateEOLComment
300+
}
301+
case LINE_FEED_BYTE:
302+
if state == stateEOLComment {
303+
state = stateNormal
304+
}
305+
case DBL_QUOTE_BYTE:
306+
if state == stateNormal {
307+
state = stateString
308+
singleQuotes = false
309+
} else if state == stateString && !singleQuotes {
310+
state = stateNormal
311+
} else if state == stateEscape {
312+
state = stateString
313+
}
314+
case QUOTE_BYTE:
315+
if state == stateNormal {
316+
state = stateString
317+
singleQuotes = true
318+
} else if state == stateString && singleQuotes {
319+
state = stateNormal
320+
} else if state == stateEscape {
321+
state = stateString
322+
}
323+
case BACKSLASH_BYTE:
324+
if state == stateString && !noBackslashEscapes {
325+
state = stateEscape
326+
}
327+
case QUESTION_MARK_BYTE:
328+
if state == stateNormal {
329+
paramPositions = append(paramPositions, i)
330+
}
331+
case RADICAL_BYTE:
332+
if state == stateBacktick {
333+
state = stateNormal
334+
} else if state == stateNormal {
335+
state = stateBacktick
336+
}
337+
}
338+
lastChar = currentChar
339+
}
340+
return paramPositions
341+
}
342+
248343
func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
249-
// Number of ? should be same to len(args)
250-
if strings.Count(query, "?") != len(args) {
344+
noBackslashEscapes := (mc.status & statusNoBackslashEscapes) != 0
345+
paramPositions := findParamPositions(query, noBackslashEscapes)
346+
if len(paramPositions) != len(args) {
251347
return "", driver.ErrSkip
252348
}
253349

@@ -261,21 +357,16 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
261357
}
262358
buf = buf[:0]
263359
argPos := 0
360+
lastIdx := 0
264361

265-
for i := 0; i < len(query); i++ {
266-
q := strings.IndexByte(query[i:], '?')
267-
if q == -1 {
268-
buf = append(buf, query[i:]...)
269-
break
270-
}
271-
buf = append(buf, query[i:i+q]...)
272-
i += q
273-
362+
for _, qmIdx := range paramPositions {
363+
buf = append(buf, query[lastIdx:qmIdx]...)
274364
arg := args[argPos]
275365
argPos++
276366

277367
if arg == nil {
278368
buf = append(buf, "NULL"...)
369+
lastIdx = qmIdx + 1
279370
continue
280371
}
281372

@@ -306,30 +397,30 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
306397
}
307398
case json.RawMessage:
308399
buf = append(buf, '\'')
309-
if mc.status&statusNoBackslashEscapes == 0 {
310-
buf = escapeBytesBackslash(buf, v)
311-
} else {
400+
if noBackslashEscapes {
312401
buf = escapeBytesQuotes(buf, v)
402+
} else {
403+
buf = escapeBytesBackslash(buf, v)
313404
}
314405
buf = append(buf, '\'')
315406
case []byte:
316407
if v == nil {
317408
buf = append(buf, "NULL"...)
318409
} else {
319410
buf = append(buf, "_binary'"...)
320-
if mc.status&statusNoBackslashEscapes == 0 {
321-
buf = escapeBytesBackslash(buf, v)
322-
} else {
411+
if noBackslashEscapes {
323412
buf = escapeBytesQuotes(buf, v)
413+
} else {
414+
buf = escapeBytesBackslash(buf, v)
324415
}
325416
buf = append(buf, '\'')
326417
}
327418
case string:
328419
buf = append(buf, '\'')
329-
if mc.status&statusNoBackslashEscapes == 0 {
330-
buf = escapeStringBackslash(buf, v)
331-
} else {
420+
if noBackslashEscapes {
332421
buf = escapeStringQuotes(buf, v)
422+
} else {
423+
buf = escapeStringBackslash(buf, v)
333424
}
334425
buf = append(buf, '\'')
335426
default:
@@ -339,7 +430,9 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
339430
if len(buf)+4 > mc.maxAllowedPacket {
340431
return "", driver.ErrSkip
341432
}
433+
lastIdx = qmIdx + 1
342434
}
435+
buf = append(buf, query[lastIdx:]...)
343436
if argPos != len(args) {
344437
return "", driver.ErrSkip
345438
}

connection_test.go

Lines changed: 48 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -79,24 +79,6 @@ func TestInterpolateParamsTooManyPlaceholders(t *testing.T) {
7979
}
8080
}
8181

82-
// We don't support placeholder in string literal for now.
83-
// https://github.com/go-sql-driver/mysql/pull/490
84-
func TestInterpolateParamsPlaceholderInString(t *testing.T) {
85-
mc := &mysqlConn{
86-
buf: newBuffer(),
87-
maxAllowedPacket: maxPacketSize,
88-
cfg: &Config{
89-
InterpolateParams: true,
90-
},
91-
}
92-
93-
q, err := mc.interpolateParams("SELECT 'abc?xyz',?", []driver.Value{int64(42)})
94-
// When InterpolateParams support string literal, this should return `"SELECT 'abc?xyz', 42`
95-
if err != driver.ErrSkip {
96-
t.Errorf("Expected err=driver.ErrSkip, got err=%#v, q=%#v", err, q)
97-
}
98-
}
99-
10082
func TestInterpolateParamsUint64(t *testing.T) {
10183
mc := &mysqlConn{
10284
buf: newBuffer(),
@@ -204,3 +186,51 @@ func (bc badConnection) Write(b []byte) (n int, err error) {
204186
func (bc badConnection) Close() error {
205187
return nil
206188
}
189+
190+
func TestInterpolateParamsWithComments(t *testing.T) {
191+
mc := &mysqlConn{
192+
buf: newBuffer(),
193+
maxAllowedPacket: maxPacketSize,
194+
cfg: &Config{
195+
InterpolateParams: true,
196+
},
197+
}
198+
199+
tests := []struct {
200+
query string
201+
args []driver.Value
202+
expected string
203+
shouldSkip bool
204+
}{
205+
// ? in single-line comment (--) should not be replaced
206+
{"SELECT 1 -- ?\n, ?", []driver.Value{int64(42)}, "SELECT 1 -- ?\n, 42", false},
207+
// ? in single-line comment (#) should not be replaced
208+
{"SELECT 1 # ?\n, ?", []driver.Value{int64(42)}, "SELECT 1 # ?\n, 42", false},
209+
// ? in multi-line comment should not be replaced
210+
{"SELECT /* ? */ ?", []driver.Value{int64(42)}, "SELECT /* ? */ 42", false},
211+
// ? in string literal should not be replaced
212+
{"SELECT '?', ?", []driver.Value{int64(42)}, "SELECT '?', 42", false},
213+
// ? in backtick identifier should not be replaced
214+
{"SELECT `?`, ?", []driver.Value{int64(42)}, "SELECT `?`, 42", false},
215+
// Multiple comments and real placeholders
216+
{"SELECT ? -- comment ?\n, ? /* ? */ , ? # ?\n, ?", []driver.Value{int64(1), int64(2), int64(3)}, "SELECT 1 -- comment ?\n, 2 /* ? */ , 3 # ?\n, ?", true},
217+
}
218+
219+
for i, test := range tests {
220+
221+
q, err := mc.interpolateParams(test.query, test.args)
222+
if test.shouldSkip {
223+
if err != driver.ErrSkip {
224+
t.Errorf("Test %d: Expected driver.ErrSkip, got err=%#v, q=%#v", i, err, q)
225+
}
226+
continue
227+
}
228+
if err != nil {
229+
t.Errorf("Test %d: Expected err=nil, got %#v", i, err)
230+
continue
231+
}
232+
if q != test.expected {
233+
t.Errorf("Test %d: Expected: %q\nGot: %q", i, test.expected, q)
234+
}
235+
}
236+
}

0 commit comments

Comments
 (0)