Skip to content

Commit 883c78c

Browse files
committed
Enhance interpolateParams to correctly handle placeholders in queries with comments, strings, and backticks.
* Add `findParamPositions` to identify real parameter positions * Update and expand related tests.
1 parent 76c00e3 commit 883c78c

File tree

2 files changed

+154
-30
lines changed

2 files changed

+154
-30
lines changed

connection.go

Lines changed: 106 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ func (mc *mysqlConn) close() {
172172
}
173173

174174
// Closes the network connection and unsets internal variables. Do not call this
175-
// function after successfully authentication, call Close instead. This function
175+
// function after successful authentication, call Close instead. This function
176176
// is called before auth or on auth failure because MySQL will have already
177177
// closed the network connection.
178178
func (mc *mysqlConn) cleanup() {
@@ -245,9 +245,106 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
245245
return stmt, err
246246
}
247247

248+
// findParamPositions returns the positions of real parameter holders ('?') in the query, ignoring those in comments, strings, or backticks.
249+
func findParamPositions(query string) []int {
250+
const (
251+
stateNormal = iota
252+
stateString
253+
stateEscape
254+
stateEOLComment
255+
stateSlashStarComment
256+
stateBacktick
257+
)
258+
259+
var (
260+
QUOTE_BYTE = byte('\'')
261+
DBL_QUOTE_BYTE = byte('"')
262+
BACKSLASH_BYTE = byte('\\')
263+
QUESTION_MARK_BYTE = byte('?')
264+
SLASH_BYTE = byte('/')
265+
STAR_BYTE = byte('*')
266+
HASH_BYTE = byte('#')
267+
MINUS_BYTE = byte('-')
268+
LINE_FEED_BYTE = byte('\n')
269+
RADICAL_BYTE = byte('`')
270+
)
271+
272+
paramPositions := make([]int, 0)
273+
state := stateNormal
274+
singleQuotes := false
275+
lastChar := byte(0)
276+
lenq := len(query)
277+
for i := 0; i < lenq; i++ {
278+
currentChar := query[i]
279+
if state == stateEscape && !((currentChar == QUOTE_BYTE && singleQuotes) || (currentChar == DBL_QUOTE_BYTE && !singleQuotes)) {
280+
state = stateString
281+
lastChar = currentChar
282+
continue
283+
}
284+
switch currentChar {
285+
case STAR_BYTE:
286+
if state == stateNormal && lastChar == SLASH_BYTE {
287+
state = stateSlashStarComment
288+
}
289+
case SLASH_BYTE:
290+
if state == stateSlashStarComment && lastChar == STAR_BYTE {
291+
state = stateNormal
292+
} else if state == stateNormal && lastChar == SLASH_BYTE {
293+
state = stateEOLComment
294+
}
295+
case HASH_BYTE:
296+
if state == stateNormal {
297+
state = stateEOLComment
298+
}
299+
case MINUS_BYTE:
300+
if state == stateNormal && lastChar == MINUS_BYTE {
301+
state = stateEOLComment
302+
}
303+
case LINE_FEED_BYTE:
304+
if state == stateEOLComment {
305+
state = stateNormal
306+
}
307+
case DBL_QUOTE_BYTE:
308+
if state == stateNormal {
309+
state = stateString
310+
singleQuotes = false
311+
} else if state == stateString && !singleQuotes {
312+
state = stateNormal
313+
} else if state == stateEscape {
314+
state = stateString
315+
}
316+
case QUOTE_BYTE:
317+
if state == stateNormal {
318+
state = stateString
319+
singleQuotes = true
320+
} else if state == stateString && singleQuotes {
321+
state = stateNormal
322+
} else if state == stateEscape {
323+
state = stateString
324+
}
325+
case BACKSLASH_BYTE:
326+
if state == stateString {
327+
state = stateEscape
328+
}
329+
case QUESTION_MARK_BYTE:
330+
if state == stateNormal {
331+
paramPositions = append(paramPositions, i)
332+
}
333+
case RADICAL_BYTE:
334+
if state == stateBacktick {
335+
state = stateNormal
336+
} else if state == stateNormal {
337+
state = stateBacktick
338+
}
339+
}
340+
lastChar = currentChar
341+
}
342+
return paramPositions
343+
}
344+
248345
func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
249-
// Number of ? should be same to len(args)
250-
if strings.Count(query, "?") != len(args) {
346+
paramPositions := findParamPositions(query)
347+
if len(paramPositions) != len(args) {
251348
return "", driver.ErrSkip
252349
}
253350

@@ -261,21 +358,16 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
261358
}
262359
buf = buf[:0]
263360
argPos := 0
361+
lastIdx := 0
264362

265-
for i := 0; i < len(query); i++ {
266-
q := strings.IndexByte(query[i:], '?')
267-
if q == -1 {
268-
buf = append(buf, query[i:]...)
269-
break
270-
}
271-
buf = append(buf, query[i:i+q]...)
272-
i += q
273-
363+
for _, qmIdx := range paramPositions {
364+
buf = append(buf, query[lastIdx:qmIdx]...)
274365
arg := args[argPos]
275366
argPos++
276367

277368
if arg == nil {
278369
buf = append(buf, "NULL"...)
370+
lastIdx = qmIdx + 1
279371
continue
280372
}
281373

@@ -339,7 +431,9 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
339431
if len(buf)+4 > mc.maxAllowedPacket {
340432
return "", driver.ErrSkip
341433
}
434+
lastIdx = qmIdx + 1
342435
}
436+
buf = append(buf, query[lastIdx:]...)
343437
if argPos != len(args) {
344438
return "", driver.ErrSkip
345439
}

connection_test.go

Lines changed: 48 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -79,24 +79,6 @@ func TestInterpolateParamsTooManyPlaceholders(t *testing.T) {
7979
}
8080
}
8181

82-
// We don't support placeholder in string literal for now.
83-
// https://github.com/go-sql-driver/mysql/pull/490
84-
func TestInterpolateParamsPlaceholderInString(t *testing.T) {
85-
mc := &mysqlConn{
86-
buf: newBuffer(),
87-
maxAllowedPacket: maxPacketSize,
88-
cfg: &Config{
89-
InterpolateParams: true,
90-
},
91-
}
92-
93-
q, err := mc.interpolateParams("SELECT 'abc?xyz',?", []driver.Value{int64(42)})
94-
// When InterpolateParams support string literal, this should return `"SELECT 'abc?xyz', 42`
95-
if err != driver.ErrSkip {
96-
t.Errorf("Expected err=driver.ErrSkip, got err=%#v, q=%#v", err, q)
97-
}
98-
}
99-
10082
func TestInterpolateParamsUint64(t *testing.T) {
10183
mc := &mysqlConn{
10284
buf: newBuffer(),
@@ -204,3 +186,51 @@ func (bc badConnection) Write(b []byte) (n int, err error) {
204186
func (bc badConnection) Close() error {
205187
return nil
206188
}
189+
190+
func TestInterpolateParamsWithComments(t *testing.T) {
191+
mc := &mysqlConn{
192+
buf: newBuffer(),
193+
maxAllowedPacket: maxPacketSize,
194+
cfg: &Config{
195+
InterpolateParams: true,
196+
},
197+
}
198+
199+
tests := []struct {
200+
query string
201+
args []driver.Value
202+
expected string
203+
shouldSkip bool
204+
}{
205+
// ? in single-line comment (--) should not be replaced
206+
{"SELECT 1 -- ?\n, ?", []driver.Value{int64(42)}, "SELECT 1 -- ?\n, 42", false},
207+
// ? in single-line comment (#) should not be replaced
208+
{"SELECT 1 # ?\n, ?", []driver.Value{int64(42)}, "SELECT 1 # ?\n, 42", false},
209+
// ? in multi-line comment should not be replaced
210+
{"SELECT /* ? */ ?", []driver.Value{int64(42)}, "SELECT /* ? */ 42", false},
211+
// ? in string literal should not be replaced
212+
{"SELECT '?', ?", []driver.Value{int64(42)}, "SELECT '?', 42", false},
213+
// ? in backtick identifier should not be replaced
214+
{"SELECT `?`, ?", []driver.Value{int64(42)}, "SELECT `?`, 42", false},
215+
// Multiple comments and real placeholders
216+
{"SELECT ? -- comment ?\n, ? /* ? */ , ? # ?\n, ?", []driver.Value{int64(1), int64(2), int64(3)}, "SELECT 1 -- comment ?\n, 2 /* ? */ , 3 # ?\n, ?", true},
217+
}
218+
219+
for i, test := range tests {
220+
221+
q, err := mc.interpolateParams(test.query, test.args)
222+
if test.shouldSkip {
223+
if err != driver.ErrSkip {
224+
t.Errorf("Test %d: Expected driver.ErrSkip, got err=%#v, q=%#v", i, err, q)
225+
}
226+
continue
227+
}
228+
if err != nil {
229+
t.Errorf("Test %d: Expected err=nil, got %#v", i, err)
230+
continue
231+
}
232+
if q != test.expected {
233+
t.Errorf("Test %d: Expected: %q\nGot: %q", i, test.expected, q)
234+
}
235+
}
236+
}

0 commit comments

Comments
 (0)