Skip to content

Commit 178ec41

Browse files
committed
CSHARP-1977: Cache SCRAM ClientKey
1 parent d2eda40 commit 178ec41

File tree

8 files changed

+421
-28
lines changed

8 files changed

+421
-28
lines changed
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
/* Copyright 2019–present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System;
17+
using System.Linq;
18+
using System.Security;
19+
using MongoDB.Driver.Core.Misc;
20+
using MongoDB.Shared;
21+
22+
namespace MongoDB.Driver.Core.Authentication
23+
{
24+
/// <summary>
25+
/// A cache for Client and Server keys, to be used during authentication.
26+
/// </summary>
27+
internal class ScramCache
28+
{
29+
private ScramCacheKey _cacheKey;
30+
private ScramCacheEntry _cachedEntry;
31+
32+
/// <summary>
33+
/// Try to get a cached entry.
34+
/// </summary>
35+
/// <param name="key"></param>
36+
/// <param name="entry"></param>
37+
/// <returns></returns>
38+
public bool TryGet(ScramCacheKey key, out ScramCacheEntry entry)
39+
{
40+
if (key.Equals(_cacheKey))
41+
{
42+
entry = _cachedEntry;
43+
return true;
44+
}
45+
else
46+
{
47+
entry = null;
48+
return false;
49+
}
50+
}
51+
52+
/// <summary>
53+
/// Add a cached entry.
54+
/// </summary>
55+
/// <param name="key"></param>
56+
/// <param name="entry"></param>
57+
public void Add(ScramCacheKey key, ScramCacheEntry entry)
58+
{
59+
_cacheKey = key;
60+
_cachedEntry = entry;
61+
}
62+
}
63+
64+
internal class ScramCacheKey
65+
{
66+
private int _iterationCount;
67+
private SecureString _password;
68+
private byte[] _salt;
69+
70+
internal ScramCacheKey(SecureString password, byte[] salt, int iterationCount)
71+
{
72+
_iterationCount = iterationCount;
73+
_password = password;
74+
_salt = salt;
75+
}
76+
77+
public override bool Equals(object obj)
78+
{
79+
if (this == obj)
80+
{
81+
return true;
82+
}
83+
84+
if (obj == null || obj.GetType() != obj.GetType())
85+
{
86+
return false;
87+
}
88+
89+
ScramCacheKey other = (ScramCacheKey) obj;
90+
91+
return
92+
Equals(_password,other._password) &&
93+
_iterationCount == other._iterationCount &&
94+
_salt.SequenceEqual(other._salt);
95+
}
96+
97+
public override int GetHashCode()
98+
{
99+
// ignore _password when computing the hash code
100+
return new Hasher()
101+
.Hash(_iterationCount)
102+
.Hash(_salt)
103+
.GetHashCode();
104+
}
105+
106+
// private methods
107+
private bool Equals(SecureString x, SecureString y)
108+
{
109+
if (object.ReferenceEquals(x, y))
110+
{
111+
return true;
112+
}
113+
114+
if (object.ReferenceEquals(x, null) || object.ReferenceEquals(y, null))
115+
{
116+
return false;
117+
118+
}
119+
using (var dx = new DecryptedSecureString(x))
120+
using (var dy = new DecryptedSecureString(y))
121+
{
122+
var xchars = dx.GetChars();
123+
var ychars = dy.GetChars();
124+
return Equals(xchars, ychars);
125+
}
126+
}
127+
128+
private bool Equals(char[] x, char[] y)
129+
{
130+
if (x.Length != y.Length)
131+
{
132+
return false;
133+
}
134+
135+
for (var i = 0; i < x.Length; i++)
136+
{
137+
if (x[i] != y[i])
138+
{
139+
return false;
140+
}
141+
}
142+
143+
return true;
144+
}
145+
146+
}
147+
148+
internal class ScramCacheEntry
149+
{
150+
private byte[] _clientKey;
151+
private byte[] _serverKey;
152+
153+
public ScramCacheEntry(byte[] clientKey, byte[] serverKey)
154+
{
155+
_clientKey = clientKey;
156+
_serverKey = serverKey;
157+
}
158+
159+
public byte[] ClientKey => _clientKey;
160+
161+
public byte[] ServerKey => _serverKey;
162+
}
163+
}

src/MongoDB.Driver.Core/Core/Authentication/ScramSha1Authenticator.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ public ScramSha1Authenticator(UsernamePasswordCredential credential)
5555
}
5656

5757
internal ScramSha1Authenticator(UsernamePasswordCredential credential, IRandomStringGenerator randomStringGenerator)
58-
: base(credential, HashAlgorithmName.SHA1, randomStringGenerator, H1, Hi1, Hmac1)
58+
: base(credential, HashAlgorithmName.SHA1, randomStringGenerator, H1, Hi1, Hmac1, new ScramCache())
5959
{
6060
}
6161

src/MongoDB.Driver.Core/Core/Authentication/ScramSha256Authenticator.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ public ScramSha256Authenticator(UsernamePasswordCredential credential)
6060
}
6161

6262
internal ScramSha256Authenticator(UsernamePasswordCredential credential, IRandomStringGenerator randomStringGenerator)
63-
: base(credential, HashAlgorithmName.SHA256, randomStringGenerator, H256, Hi256, Hmac256)
63+
: base(credential, HashAlgorithmName.SHA256, randomStringGenerator, H256, Hi256, Hmac256, new ScramCache())
6464
{
6565
}
6666

src/MongoDB.Driver.Core/Core/Authentication/ScramShaAuthenticator.cs

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2018–present MongoDB Inc.
1+
/* Copyright 2018–present MongoDB Inc.
22
*
33
* Licensed under the Apache License, Version 2.0 (the "License");
44
* you may not use this file except in compliance with the License.
@@ -72,8 +72,8 @@ protected ScramShaAuthenticator(UsernamePasswordCredential credential,
7272
H h,
7373
Hi hi,
7474
Hmac hmac)
75-
: this(credential, hashAlgorithmName, new DefaultRandomStringGenerator(), h, hi, hmac) { }
76-
75+
: this(credential, hashAlgorithmName, new DefaultRandomStringGenerator(), h, hi, hmac, new ScramCache()) { }
76+
7777
/// <summary>
7878
/// Initializes a new instance of the <see cref="ScramShaAuthenticator"/> class.
7979
/// </summary>
@@ -83,14 +83,16 @@ protected ScramShaAuthenticator(UsernamePasswordCredential credential,
8383
/// <param name="h">The H function to use.</param>
8484
/// <param name="hi">The Hi function to use.</param>
8585
/// <param name="hmac">The Hmac function to use.</param>
86+
/// <param name="cache">The cache to use.</param>
8687
internal ScramShaAuthenticator(
8788
UsernamePasswordCredential credential,
8889
HashAlgorithmName hashAlgorithName,
8990
IRandomStringGenerator randomStringGenerator,
9091
H h,
9192
Hi hi,
92-
Hmac hmac)
93-
: base(new ScramShaMechanism(credential, hashAlgorithName, randomStringGenerator, h, hi, hmac))
93+
Hmac hmac,
94+
ScramCache cache)
95+
: base(new ScramShaMechanism(credential, hashAlgorithName, randomStringGenerator, h, hi, hmac, cache))
9496
{
9597
_databaseName = credential.Source;
9698
}
@@ -108,14 +110,16 @@ private class ScramShaMechanism : ISaslMechanism
108110
private readonly Hi _hi;
109111
private readonly Hmac _hmac;
110112
private readonly string _name;
113+
private ScramCache _cache;
111114

112115
public ScramShaMechanism(
113116
UsernamePasswordCredential credential,
114117
HashAlgorithmName hashAlgorithmName,
115118
IRandomStringGenerator randomStringGenerator,
116119
H h,
117120
Hi hi,
118-
Hmac hmac)
121+
Hmac hmac,
122+
ScramCache cache)
119123
{
120124
_credential = Ensure.IsNotNull(credential, nameof(credential));
121125
_h = h;
@@ -127,6 +131,7 @@ public ScramShaMechanism(
127131
}
128132
_name = $"SCRAM-SHA-{hashAlgorithmName.ToString().Substring(3)}";
129133
_randomStringGenerator = Ensure.IsNotNull(randomStringGenerator, nameof(randomStringGenerator));
134+
_cache = cache;
130135
}
131136

132137
public string Name => _name;
@@ -143,9 +148,9 @@ public ISaslStep Initialize(IConnection connection, SaslConversation conversatio
143148

144149
var clientFirstMessageBare = username + "," + nonce;
145150
var clientFirstMessage = gs2Header + clientFirstMessageBare;
146-
var clientFirstMessageBytes = Utf8Encodings.Strict.GetBytes(clientFirstMessage);
151+
var clientFirstMessageBytes = Utf8Encodings.Strict.GetBytes(clientFirstMessage);
147152

148-
return new ClientFirst(clientFirstMessageBytes, clientFirstMessageBare, _credential, r, _h, _hi, _hmac);
153+
return new ClientFirst(clientFirstMessageBytes, clientFirstMessageBare, _credential, r, _h, _hi, _hmac, _cache);
149154
}
150155

151156
private string GenerateRandomString()
@@ -163,7 +168,7 @@ private string PrepUsername(string username)
163168

164169
private class ClientFirst : ISaslStep
165170
{
166-
171+
167172
private readonly byte[] _bytesToSendToServer;
168173
private readonly string _clientFirstMessageBare;
169174
private readonly UsernamePasswordCredential _credential;
@@ -173,14 +178,17 @@ private class ClientFirst : ISaslStep
173178
private readonly Hi _hi;
174179
private readonly Hmac _hmac;
175180

181+
private ScramCache _cache;
182+
176183
public ClientFirst(
177184
byte[] bytesToSendToServer,
178185
string clientFirstMessageBare,
179186
UsernamePasswordCredential credential,
180187
string rPrefix,
181188
H h,
182189
Hi hi,
183-
Hmac hmac)
190+
Hmac hmac,
191+
ScramCache cache)
184192
{
185193
_bytesToSendToServer = bytesToSendToServer;
186194
_clientFirstMessageBare = clientFirstMessageBare;
@@ -189,6 +197,7 @@ public ClientFirst(
189197
_hi = hi;
190198
_hmac = hmac;
191199
_rPrefix = rPrefix;
200+
_cache = cache;
192201
}
193202

194203
public byte[] BytesToSendToServer => _bytesToSendToServer;
@@ -214,19 +223,31 @@ public ISaslStep Transition(SaslConversation conversation, byte[] bytesReceivedF
214223
var nonce = "r=" + r;
215224
var clientFinalMessageWithoutProof = channelBinding + "," + nonce;
216225

217-
var saltedPassword = _hi(
218-
_credential,
219-
Convert.FromBase64String(s),
220-
int.Parse(i));
226+
var salt = Convert.FromBase64String(map['s']);
227+
var iterations = int.Parse(map['i']);
228+
229+
byte[] clientKey;
230+
byte[] serverKey;
231+
232+
var cacheKey = new ScramCacheKey(_credential.SaslPreppedPassword, salt, iterations);
233+
if (_cache.TryGet(cacheKey, out var cacheEntry))
234+
{
235+
clientKey = cacheEntry.ClientKey;
236+
serverKey = cacheEntry.ServerKey;
237+
}
238+
else
239+
{
240+
var saltedPassword = _hi( _credential, salt, iterations);
241+
clientKey = _hmac(encoding, saltedPassword, "Client Key");
242+
serverKey = _hmac(encoding, saltedPassword, "Server Key");
243+
_cache.Add(cacheKey, new ScramCacheEntry(clientKey, serverKey));
244+
}
221245

222-
var clientKey = _hmac(encoding, saltedPassword, "Client Key");
223246
var storedKey = _h(clientKey);
224247
var authMessage = _clientFirstMessageBare + "," + serverFirstMessage + "," + clientFinalMessageWithoutProof;
225248
var clientSignature = _hmac(encoding, storedKey, authMessage);
226249
var clientProof = XOR(clientKey, clientSignature);
227-
var serverKey = _hmac(encoding, saltedPassword, "Server Key");
228250
var serverSignature = _hmac(encoding, serverKey, authMessage);
229-
230251
var proof = "p=" + Convert.ToBase64String(clientProof);
231252
var clientFinalMessage = clientFinalMessageWithoutProof + "," + proof;
232253

@@ -243,7 +264,6 @@ private byte[] XOR(byte[] a, byte[] b)
243264

244265
return result;
245266
}
246-
247267
}
248268

249269
private class ClientLast : ISaslStep
@@ -287,4 +307,4 @@ private bool ConstantTimeEquals(byte[] a, byte[] b)
287307
}
288308
}
289309
}
290-
}
310+
}

src/MongoDB.Driver.Core/Core/Connections/CommandEventHelper.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,6 +1038,7 @@ private BsonDocument BuildFindCommandFromQuery(QueryMessage message)
10381038
default:
10391039
if (element.Name.StartsWith("$"))
10401040
{
1041+
// should we actually remove the $ or not?
10411042
command[element.Name.Substring(1)] = element.Value;
10421043
}
10431044
else

0 commit comments

Comments
 (0)