@@ -172,7 +172,7 @@ func (mc *mysqlConn) close() {
172
172
}
173
173
174
174
// Closes the network connection and unsets internal variables. Do not call this
175
- // function after successfully authentication, call Close instead. This function
175
+ // function after successful authentication, call Close instead. This function
176
176
// is called before auth or on auth failure because MySQL will have already
177
177
// closed the network connection.
178
178
func (mc * mysqlConn ) cleanup () {
@@ -245,9 +245,105 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
245
245
return stmt , err
246
246
}
247
247
248
+ // findParamPositions returns the positions of real parameter holders ('?') in the query, ignoring those in comments, strings, or backticks.
249
+ func findParamPositions (query string , noBackslashEscapes bool ) []int {
250
+ const (
251
+ stateNormal = iota
252
+ stateString
253
+ stateEscape
254
+ stateEOLComment
255
+ stateSlashStarComment
256
+ stateBacktick
257
+ )
258
+
259
+ var (
260
+ QUOTE_BYTE = byte ('\'' )
261
+ DBL_QUOTE_BYTE = byte ('"' )
262
+ BACKSLASH_BYTE = byte ('\\' )
263
+ QUESTION_MARK_BYTE = byte ('?' )
264
+ SLASH_BYTE = byte ('/' )
265
+ STAR_BYTE = byte ('*' )
266
+ HASH_BYTE = byte ('#' )
267
+ MINUS_BYTE = byte ('-' )
268
+ LINE_FEED_BYTE = byte ('\n' )
269
+ RADICAL_BYTE = byte ('`' )
270
+ )
271
+
272
+ paramPositions := make ([]int , 0 )
273
+ state := stateNormal
274
+ singleQuotes := false
275
+ lastChar := byte (0 )
276
+ lenq := len (query )
277
+ for i := 0 ; i < lenq ; i ++ {
278
+ currentChar := query [i ]
279
+ if state == stateEscape && ! ((currentChar == QUOTE_BYTE && singleQuotes ) || (currentChar == DBL_QUOTE_BYTE && ! singleQuotes )) {
280
+ state = stateString
281
+ lastChar = currentChar
282
+ continue
283
+ }
284
+ switch currentChar {
285
+ case STAR_BYTE :
286
+ if state == stateNormal && lastChar == SLASH_BYTE {
287
+ state = stateSlashStarComment
288
+ }
289
+ case SLASH_BYTE :
290
+ if state == stateSlashStarComment && lastChar == STAR_BYTE {
291
+ state = stateNormal
292
+ }
293
+ case HASH_BYTE :
294
+ if state == stateNormal {
295
+ state = stateEOLComment
296
+ }
297
+ case MINUS_BYTE :
298
+ if state == stateNormal && lastChar == MINUS_BYTE {
299
+ state = stateEOLComment
300
+ }
301
+ case LINE_FEED_BYTE :
302
+ if state == stateEOLComment {
303
+ state = stateNormal
304
+ }
305
+ case DBL_QUOTE_BYTE :
306
+ if state == stateNormal {
307
+ state = stateString
308
+ singleQuotes = false
309
+ } else if state == stateString && ! singleQuotes {
310
+ state = stateNormal
311
+ } else if state == stateEscape {
312
+ state = stateString
313
+ }
314
+ case QUOTE_BYTE :
315
+ if state == stateNormal {
316
+ state = stateString
317
+ singleQuotes = true
318
+ } else if state == stateString && singleQuotes {
319
+ state = stateNormal
320
+ } else if state == stateEscape {
321
+ state = stateString
322
+ }
323
+ case BACKSLASH_BYTE :
324
+ if state == stateString && ! noBackslashEscapes {
325
+ state = stateEscape
326
+ }
327
+ case QUESTION_MARK_BYTE :
328
+ if state == stateNormal {
329
+ paramPositions = append (paramPositions , i )
330
+ }
331
+ case RADICAL_BYTE :
332
+ if state == stateBacktick {
333
+ state = stateNormal
334
+ } else if state == stateNormal {
335
+ state = stateBacktick
336
+ }
337
+ }
338
+ lastChar = currentChar
339
+ }
340
+ return paramPositions
341
+ }
342
+
248
343
func (mc * mysqlConn ) interpolateParams (query string , args []driver.Value ) (string , error ) {
249
- // Number of ? should be same to len(args)
250
- if strings .Count (query , "?" ) != len (args ) {
344
+ noBackslashEscapes := (mc .status & statusNoBackslashEscapes ) != 0
345
+ paramPositions := findParamPositions (query , noBackslashEscapes )
346
+ if len (paramPositions ) != len (args ) {
251
347
return "" , driver .ErrSkip
252
348
}
253
349
@@ -261,21 +357,16 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
261
357
}
262
358
buf = buf [:0 ]
263
359
argPos := 0
360
+ lastIdx := 0
264
361
265
- for i := 0 ; i < len (query ); i ++ {
266
- q := strings .IndexByte (query [i :], '?' )
267
- if q == - 1 {
268
- buf = append (buf , query [i :]... )
269
- break
270
- }
271
- buf = append (buf , query [i :i + q ]... )
272
- i += q
273
-
362
+ for _ , qmIdx := range paramPositions {
363
+ buf = append (buf , query [lastIdx :qmIdx ]... )
274
364
arg := args [argPos ]
275
365
argPos ++
276
366
277
367
if arg == nil {
278
368
buf = append (buf , "NULL" ... )
369
+ lastIdx = qmIdx + 1
279
370
continue
280
371
}
281
372
@@ -306,30 +397,30 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
306
397
}
307
398
case json.RawMessage :
308
399
buf = append (buf , '\'' )
309
- if mc .status & statusNoBackslashEscapes == 0 {
310
- buf = escapeBytesBackslash (buf , v )
311
- } else {
400
+ if noBackslashEscapes {
312
401
buf = escapeBytesQuotes (buf , v )
402
+ } else {
403
+ buf = escapeBytesBackslash (buf , v )
313
404
}
314
405
buf = append (buf , '\'' )
315
406
case []byte :
316
407
if v == nil {
317
408
buf = append (buf , "NULL" ... )
318
409
} else {
319
410
buf = append (buf , "_binary'" ... )
320
- if mc .status & statusNoBackslashEscapes == 0 {
321
- buf = escapeBytesBackslash (buf , v )
322
- } else {
411
+ if noBackslashEscapes {
323
412
buf = escapeBytesQuotes (buf , v )
413
+ } else {
414
+ buf = escapeBytesBackslash (buf , v )
324
415
}
325
416
buf = append (buf , '\'' )
326
417
}
327
418
case string :
328
419
buf = append (buf , '\'' )
329
- if mc .status & statusNoBackslashEscapes == 0 {
330
- buf = escapeStringBackslash (buf , v )
331
- } else {
420
+ if noBackslashEscapes {
332
421
buf = escapeStringQuotes (buf , v )
422
+ } else {
423
+ buf = escapeStringBackslash (buf , v )
333
424
}
334
425
buf = append (buf , '\'' )
335
426
default :
@@ -339,7 +430,9 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
339
430
if len (buf )+ 4 > mc .maxAllowedPacket {
340
431
return "" , driver .ErrSkip
341
432
}
433
+ lastIdx = qmIdx + 1
342
434
}
435
+ buf = append (buf , query [lastIdx :]... )
343
436
if argPos != len (args ) {
344
437
return "" , driver .ErrSkip
345
438
}
0 commit comments