@@ -172,7 +172,7 @@ func (mc *mysqlConn) close() {
172
172
}
173
173
174
174
// Closes the network connection and unsets internal variables. Do not call this
175
- // function after successfully authentication, call Close instead. This function
175
+ // function after successful authentication, call Close instead. This function
176
176
// is called before auth or on auth failure because MySQL will have already
177
177
// closed the network connection.
178
178
func (mc * mysqlConn ) cleanup () {
@@ -245,9 +245,106 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
245
245
return stmt , err
246
246
}
247
247
248
+ // findParamPositions returns the positions of real parameter holders ('?') in the query, ignoring those in comments, strings, or backticks.
249
+ func findParamPositions (query string ) []int {
250
+ const (
251
+ stateNormal = iota
252
+ stateString
253
+ stateEscape
254
+ stateEOLComment
255
+ stateSlashStarComment
256
+ stateBacktick
257
+ )
258
+
259
+ var (
260
+ QUOTE_BYTE = byte ('\'' )
261
+ DBL_QUOTE_BYTE = byte ('"' )
262
+ BACKSLASH_BYTE = byte ('\\' )
263
+ QUESTION_MARK_BYTE = byte ('?' )
264
+ SLASH_BYTE = byte ('/' )
265
+ STAR_BYTE = byte ('*' )
266
+ HASH_BYTE = byte ('#' )
267
+ MINUS_BYTE = byte ('-' )
268
+ LINE_FEED_BYTE = byte ('\n' )
269
+ RADICAL_BYTE = byte ('`' )
270
+ )
271
+
272
+ paramPositions := make ([]int , 0 )
273
+ state := stateNormal
274
+ singleQuotes := false
275
+ lastChar := byte (0 )
276
+ lenq := len (query )
277
+ for i := 0 ; i < lenq ; i ++ {
278
+ currentChar := query [i ]
279
+ if state == stateEscape && ! ((currentChar == QUOTE_BYTE && singleQuotes ) || (currentChar == DBL_QUOTE_BYTE && ! singleQuotes )) {
280
+ state = stateString
281
+ lastChar = currentChar
282
+ continue
283
+ }
284
+ switch currentChar {
285
+ case STAR_BYTE :
286
+ if state == stateNormal && lastChar == SLASH_BYTE {
287
+ state = stateSlashStarComment
288
+ }
289
+ case SLASH_BYTE :
290
+ if state == stateSlashStarComment && lastChar == STAR_BYTE {
291
+ state = stateNormal
292
+ } else if state == stateNormal && lastChar == SLASH_BYTE {
293
+ state = stateEOLComment
294
+ }
295
+ case HASH_BYTE :
296
+ if state == stateNormal {
297
+ state = stateEOLComment
298
+ }
299
+ case MINUS_BYTE :
300
+ if state == stateNormal && lastChar == MINUS_BYTE {
301
+ state = stateEOLComment
302
+ }
303
+ case LINE_FEED_BYTE :
304
+ if state == stateEOLComment {
305
+ state = stateNormal
306
+ }
307
+ case DBL_QUOTE_BYTE :
308
+ if state == stateNormal {
309
+ state = stateString
310
+ singleQuotes = false
311
+ } else if state == stateString && ! singleQuotes {
312
+ state = stateNormal
313
+ } else if state == stateEscape {
314
+ state = stateString
315
+ }
316
+ case QUOTE_BYTE :
317
+ if state == stateNormal {
318
+ state = stateString
319
+ singleQuotes = true
320
+ } else if state == stateString && singleQuotes {
321
+ state = stateNormal
322
+ } else if state == stateEscape {
323
+ state = stateString
324
+ }
325
+ case BACKSLASH_BYTE :
326
+ if state == stateString {
327
+ state = stateEscape
328
+ }
329
+ case QUESTION_MARK_BYTE :
330
+ if state == stateNormal {
331
+ paramPositions = append (paramPositions , i )
332
+ }
333
+ case RADICAL_BYTE :
334
+ if state == stateBacktick {
335
+ state = stateNormal
336
+ } else if state == stateNormal {
337
+ state = stateBacktick
338
+ }
339
+ }
340
+ lastChar = currentChar
341
+ }
342
+ return paramPositions
343
+ }
344
+
248
345
func (mc * mysqlConn ) interpolateParams (query string , args []driver.Value ) (string , error ) {
249
- // Number of ? should be same to len(args )
250
- if strings . Count ( query , "?" ) != len (args ) {
346
+ paramPositions := findParamPositions ( query )
347
+ if len ( paramPositions ) != len (args ) {
251
348
return "" , driver .ErrSkip
252
349
}
253
350
@@ -261,21 +358,16 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
261
358
}
262
359
buf = buf [:0 ]
263
360
argPos := 0
361
+ lastIdx := 0
264
362
265
- for i := 0 ; i < len (query ); i ++ {
266
- q := strings .IndexByte (query [i :], '?' )
267
- if q == - 1 {
268
- buf = append (buf , query [i :]... )
269
- break
270
- }
271
- buf = append (buf , query [i :i + q ]... )
272
- i += q
273
-
363
+ for _ , qmIdx := range paramPositions {
364
+ buf = append (buf , query [lastIdx :qmIdx ]... )
274
365
arg := args [argPos ]
275
366
argPos ++
276
367
277
368
if arg == nil {
278
369
buf = append (buf , "NULL" ... )
370
+ lastIdx = qmIdx + 1
279
371
continue
280
372
}
281
373
@@ -339,7 +431,9 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
339
431
if len (buf )+ 4 > mc .maxAllowedPacket {
340
432
return "" , driver .ErrSkip
341
433
}
434
+ lastIdx = qmIdx + 1
342
435
}
436
+ buf = append (buf , query [lastIdx :]... )
343
437
if argPos != len (args ) {
344
438
return "" , driver .ErrSkip
345
439
}
0 commit comments