Skip to content

Commit 378be39

Browse files
committed
NH-3791 - Fix incorrect multiple batched command adjustments
1 parent f7980ee commit 378be39

File tree

3 files changed

+86
-49
lines changed

3 files changed

+86
-49
lines changed

src/NHibernate.Test/DialectTest/FirebirdDialectFixture.cs

Lines changed: 10 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,72 +7,57 @@ namespace NHibernate.Test.DialectTest
77
[TestFixture]
88
public class FirebirdDialectFixture
99
{
10-
#region Tests
10+
readonly FirebirdDialect _dialect = new FirebirdDialect();
1111

1212
[Test]
1313
public void GetLimitString()
1414
{
15-
FirebirdDialect d = MakeDialect();
16-
17-
SqlString str = d.GetLimitString(new SqlString("SELECT * FROM fish"), null, new SqlString("10"));
15+
var str = _dialect.GetLimitString(new SqlString("SELECT * FROM fish"), null, new SqlString("10"));
1816
Assert.AreEqual("SELECT first 10 * FROM fish", str.ToString());
1917

20-
str = d.GetLimitString(new SqlString("SELECT * FROM fish ORDER BY name"), new SqlString("5"), new SqlString("15"));
18+
str = _dialect.GetLimitString(new SqlString("SELECT * FROM fish ORDER BY name"), new SqlString("5"), new SqlString("15"));
2119
Assert.AreEqual("SELECT first 15 skip 5 * FROM fish ORDER BY name", str.ToString());
2220

23-
str = d.GetLimitString(new SqlString("SELECT * FROM fish ORDER BY name DESC"), new SqlString("7"),
21+
str = _dialect.GetLimitString(new SqlString("SELECT * FROM fish ORDER BY name DESC"), new SqlString("7"),
2422
new SqlString("28"));
2523
Assert.AreEqual("SELECT first 28 skip 7 * FROM fish ORDER BY name DESC", str.ToString());
2624

27-
str = d.GetLimitString(new SqlString("SELECT DISTINCT fish.family FROM fish ORDER BY name DESC"), null,
25+
str = _dialect.GetLimitString(new SqlString("SELECT DISTINCT fish.family FROM fish ORDER BY name DESC"), null,
2826
new SqlString("28"));
2927
Assert.AreEqual("SELECT first 28 DISTINCT fish.family FROM fish ORDER BY name DESC", str.ToString());
3028

31-
str = d.GetLimitString(new SqlString("SELECT DISTINCT fish.family FROM fish ORDER BY name DESC"), new SqlString("7"),
29+
str = _dialect.GetLimitString(new SqlString("SELECT DISTINCT fish.family FROM fish ORDER BY name DESC"), new SqlString("7"),
3230
new SqlString("28"));
3331
Assert.AreEqual("SELECT first 28 skip 7 DISTINCT fish.family FROM fish ORDER BY name DESC", str.ToString());
3432
}
3533

3634
[Test]
3735
public void GetTypeName_DecimalWithoutPrecisionAndScale_ReturnsDecimalWithPrecisionOf18AndScaleOf5()
3836
{
39-
var fbDialect = MakeDialect();
40-
41-
var result = fbDialect.GetTypeName(NHibernateUtil.Decimal.SqlType);
37+
var result = _dialect.GetTypeName(NHibernateUtil.Decimal.SqlType);
4238

4339
Assert.AreEqual("DECIMAL(18, 5)", result);
4440
}
4541

4642
[Test]
4743
public void GetTypeName_DecimalWithPrecisionAndScale_ReturnsPrecisedAndScaledDecimal()
4844
{
49-
var fbDialect = MakeDialect();
50-
51-
var result = fbDialect.GetTypeName(NHibernateUtil.Decimal.SqlType, 0, 17, 2);
45+
var result = _dialect.GetTypeName(NHibernateUtil.Decimal.SqlType, 0, 17, 2);
5246

5347
Assert.AreEqual("DECIMAL(17, 2)", result);
5448
}
5549

5650
[Test]
5751
public void GetTypeName_DecimalWithPrecisionGreaterThanFbMaxPrecision_ReturnsDecimalWithFbMaxPrecision()
5852
{
59-
var fbDialect = MakeDialect();
60-
61-
var result = fbDialect.GetTypeName(NHibernateUtil.Decimal.SqlType, 0, 19, 2);
53+
var result = _dialect.GetTypeName(NHibernateUtil.Decimal.SqlType, 0, 19, 2);
6254
//Firebird allows a maximum precision of 18
6355

6456
Assert.AreEqual("DECIMAL(18, 2)", result);
6557
}
6658

67-
#endregion
68-
6959
#region Private Members
7060

71-
private static FirebirdDialect MakeDialect()
72-
{
73-
return new FirebirdDialect();
74-
}
75-
7661
#endregion
7762
}
78-
}
63+
}

src/NHibernate.Test/DriverTest/FirebirdClientDriverFixture.cs

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,30 @@ public void AdjustCommand_InsertWithParamsInSelect_ParameterIsCasted()
120120
Assert.That(cmd.CommandText, Is.EqualTo(expected));
121121
}
122122

123+
[Test]
124+
public void AdjustCommand_InsertWithParamsInSelect_ParameterIsNotCasted_WhenColumnNameContainsSelect()
125+
{
126+
MakeDriver();
127+
var cmd = BuildInsertWithParamsInSelectCommandWithSelectInColumnName(SqlTypeFactory.Int32);
128+
129+
_driver.AdjustCommand(cmd);
130+
131+
var expected = "insert into table1 (col1_select_aaa) values(@p0) from table2";
132+
Assert.That(cmd.CommandText, Is.EqualTo(expected));
133+
}
134+
135+
[Test]
136+
public void AdjustCommand_InsertWithParamsInSelect_ParameterIsNotCasted_WhenColumnNameContainsWhere()
137+
{
138+
MakeDriver();
139+
var cmd = BuildInsertWithParamsInSelectCommandWithWhereInColumnName(SqlTypeFactory.Int32);
140+
141+
_driver.AdjustCommand(cmd);
142+
143+
var expected = "insert into table1 (col1_where_aaa) values(@p0) from table2";
144+
Assert.That(cmd.CommandText, Is.EqualTo(expected));
145+
}
146+
123147
private void MakeDriver()
124148
{
125149
var cfg = TestConfigurationHelper.GetDefaultConfiguration();
@@ -170,7 +194,7 @@ private IDbCommand BuildSelectCaseCommand(SqlType paramType)
170194
.Add(" end) from table")
171195
.ToSqlString();
172196

173-
return _driver.GenerateCommand(CommandType.Text, sqlString, new SqlType[] { paramType, paramType, paramType });
197+
return _driver.GenerateCommand(CommandType.Text, sqlString, new[] { paramType, paramType, paramType });
174198
}
175199

176200
private IDbCommand BuildSelectConcatCommand(SqlType paramType)
@@ -183,7 +207,7 @@ private IDbCommand BuildSelectConcatCommand(SqlType paramType)
183207
.Add("from table")
184208
.ToSqlString();
185209

186-
return _driver.GenerateCommand(CommandType.Text, sqlString, new SqlType[] { paramType });
210+
return _driver.GenerateCommand(CommandType.Text, sqlString, new[] { paramType });
187211
}
188212

189213
private IDbCommand BuildSelectAddCommand(SqlType paramType)
@@ -194,7 +218,7 @@ private IDbCommand BuildSelectAddCommand(SqlType paramType)
194218
.Add(" from table")
195219
.ToSqlString();
196220

197-
return _driver.GenerateCommand(CommandType.Text, sqlString, new SqlType[] { paramType });
221+
return _driver.GenerateCommand(CommandType.Text, sqlString, new[] { paramType });
198222
}
199223

200224
private IDbCommand BuildInsertWithParamsInSelectCommand(SqlType paramType)
@@ -206,8 +230,31 @@ private IDbCommand BuildInsertWithParamsInSelectCommand(SqlType paramType)
206230
.Add(" from table2")
207231
.ToSqlString();
208232

209-
return _driver.GenerateCommand(CommandType.Text, sqlString, new SqlType[] { paramType });
233+
return _driver.GenerateCommand(CommandType.Text, sqlString, new[] { paramType });
210234
}
211235

236+
private IDbCommand BuildInsertWithParamsInSelectCommandWithSelectInColumnName(SqlType paramType)
237+
{
238+
var sqlString = new SqlStringBuilder()
239+
.Add("insert into table1 (col1_select_aaa) ")
240+
.Add("values(")
241+
.AddParameter()
242+
.Add(") from table2")
243+
.ToSqlString();
244+
245+
return _driver.GenerateCommand(CommandType.Text, sqlString, new[] { paramType });
246+
}
247+
248+
private IDbCommand BuildInsertWithParamsInSelectCommandWithWhereInColumnName(SqlType paramType)
249+
{
250+
var sqlString = new SqlStringBuilder()
251+
.Add("insert into table1 (col1_where_aaa) ")
252+
.Add("values(")
253+
.AddParameter()
254+
.Add(") from table2")
255+
.ToSqlString();
256+
257+
return _driver.GenerateCommand(CommandType.Text, sqlString, new[] { paramType });
258+
}
212259
}
213260
}

src/NHibernate/Driver/FirebirdClientDriver.cs

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using System.Linq;
44
using System.Text.RegularExpressions;
55
using NHibernate.Dialect;
6+
using NHibernate.SqlCommand;
67
using NHibernate.SqlTypes;
78
using NHibernate.Util;
89

@@ -14,7 +15,7 @@ namespace NHibernate.Driver
1415
/// </summary>
1516
public class FirebirdClientDriver : ReflectionBasedDriver
1617
{
17-
private const string SELECT_CLAUSE_EXP = "(?<=select|where).*";
18+
private const string SELECT_CLAUSE_EXP = @"(?<=\bselect|\bwhere).*";
1819
private const string CAST_PARAMS_EXP = @"(?<![=<>]\s?|first\s?|skip\s?|between\s|between\s@\bp\w+\b\sand\s)@\bp\w+\b(?!\s?[=<>])";
1920
private readonly Regex _statementRegEx = new Regex(SELECT_CLAUSE_EXP, RegexOptions.IgnoreCase);
2021
private readonly Regex _castCandidateRegEx = new Regex(CAST_PARAMS_EXP, RegexOptions.IgnoreCase);
@@ -59,35 +60,39 @@ protected override void InitializeParameter(IDbDataParameter dbParam, string nam
5960
base.InitializeParameter(dbParam, name, convertedSqlType);
6061
}
6162

62-
public override void AdjustCommand(IDbCommand command)
63+
public override IDbCommand GenerateCommand(CommandType type, SqlString sqlString, SqlType[] parameterTypes)
6364
{
64-
var expWithParams = GetStatementsWithCastCandidates(command.CommandText);
65-
if (string.IsNullOrWhiteSpace(expWithParams))
66-
return;
65+
var command = base.GenerateCommand(type, sqlString, parameterTypes);
6766

68-
var candidates = GetCastCandidates(expWithParams);
69-
var castParams = from IDbDataParameter p in command.Parameters
70-
where candidates.Contains(p.ParameterName)
71-
select p;
72-
foreach (IDbDataParameter param in castParams)
67+
var expWithParams = GetStatementsWithCastCandidates(command.CommandText);
68+
if (!string.IsNullOrWhiteSpace(expWithParams))
7369
{
74-
TypeCastParam(param, command);
70+
var candidates = GetCastCandidates(expWithParams);
71+
var castParams = from IDbDataParameter p in command.Parameters
72+
where candidates.Contains(p.ParameterName)
73+
select p;
74+
foreach (IDbDataParameter param in castParams)
75+
{
76+
TypeCastParam(param, command);
77+
}
7578
}
79+
80+
return command;
7681
}
7782

7883
private string GetStatementsWithCastCandidates(string commandText)
7984
{
80-
var match = _statementRegEx.Match(commandText);
81-
return match.Value;
85+
return _statementRegEx.Match(commandText).Value;
8286
}
8387

84-
private IEnumerable<string> GetCastCandidates(string statement)
88+
private HashSet<string> GetCastCandidates(string statement)
8589
{
86-
var matches = _castCandidateRegEx.Matches(statement);
87-
foreach (Match match in matches)
88-
{
89-
yield return match.Value;
90-
}
90+
var candidates =
91+
_castCandidateRegEx
92+
.Matches(statement)
93+
.Cast<Match>()
94+
.Select(match => match.Value);
95+
return new HashSet<string>(candidates);
9196
}
9297

9398
private void TypeCastParam(IDbDataParameter param, IDbCommand command)
@@ -101,4 +106,4 @@ private string GetFbTypeFromDbType(DbType dbType)
101106
return _fbDialect.GetCastTypeName(new SqlType(dbType));
102107
}
103108
}
104-
}
109+
}

0 commit comments

Comments
 (0)