Skip to content

Commit 9c2097f

Browse files
committed
CSHARP-1736: Fix issue with .NET Core not supporting DnsEndPoint in Socket.Connect due to bsd socket differences.
1 parent ad81ac6 commit 9c2097f

File tree

1 file changed

+110
-5
lines changed

1 file changed

+110
-5
lines changed

src/MongoDB.Driver.Core/Core/Connections/TcpStreamFactory.cs

Lines changed: 110 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
*/
1515

1616
using System;
17+
using System.Collections.Generic;
1718
using System.IO;
19+
using System.Linq;
1820
using System.Net;
1921
using System.Net.Sockets;
2022
using System.Threading;
@@ -46,16 +48,68 @@ public TcpStreamFactory(TcpStreamSettings settings)
4648
// methods
4749
public Stream CreateStream(EndPoint endPoint, CancellationToken cancellationToken)
4850
{
51+
#if NET45
4952
var socket = CreateSocket(endPoint);
5053
Connect(socket, endPoint, cancellationToken);
5154
return CreateNetworkStream(socket);
55+
#else
56+
// ugh... I know... but there isn't a non-async version of dns resolution
57+
// in .NET Core
58+
var resolved = ResolveEndPointsAsync(endPoint).GetAwaiter().GetResult();
59+
for (int i = 0; i < resolved.Length; i++)
60+
{
61+
try
62+
{
63+
var socket = CreateSocket(resolved[i]);
64+
Connect(socket, resolved[i], cancellationToken);
65+
return CreateNetworkStream(socket);
66+
}
67+
catch
68+
{
69+
// if we have tried all of them and still failed,
70+
// then blow up.
71+
if (i == resolved.Length - 1)
72+
{
73+
throw;
74+
}
75+
}
76+
}
77+
78+
// we should never get here...
79+
throw new InvalidOperationException("Unabled to resolve endpoint.");
80+
#endif
5281
}
5382

5483
public async Task<Stream> CreateStreamAsync(EndPoint endPoint, CancellationToken cancellationToken)
5584
{
85+
#if NET45
5686
var socket = CreateSocket(endPoint);
5787
await ConnectAsync(socket, endPoint, cancellationToken).ConfigureAwait(false);
5888
return CreateNetworkStream(socket);
89+
#else
90+
var resolved = await ResolveEndPointsAsync(endPoint).ConfigureAwait(false);
91+
for (int i = 0; i < resolved.Length; i++)
92+
{
93+
try
94+
{
95+
var socket = CreateSocket(resolved[i]);
96+
await ConnectAsync(socket, resolved[i], cancellationToken).ConfigureAwait(false);
97+
return CreateNetworkStream(socket);
98+
}
99+
catch
100+
{
101+
// if we have tried all of them and still failed,
102+
// then blow up.
103+
if (i == resolved.Length - 1)
104+
{
105+
throw;
106+
}
107+
}
108+
}
109+
110+
// we should never get here...
111+
throw new InvalidOperationException("Unabled to resolve endpoint.");
112+
#endif
59113
}
60114

61115
// non-public methods
@@ -65,11 +119,7 @@ private void ConfigureConnectedSocket(Socket socket)
65119
socket.ReceiveBufferSize = _settings.ReceiveBufferSize;
66120
socket.SendBufferSize = _settings.SendBufferSize;
67121

68-
var socketConfigurator = _settings.SocketConfigurator;
69-
if (socketConfigurator != null)
70-
{
71-
socketConfigurator(socket);
72-
}
122+
_settings.SocketConfigurator?.Invoke(socket);
73123
}
74124

75125
private void Connect(Socket socket, EndPoint endPoint, CancellationToken cancellationToken)
@@ -202,7 +252,62 @@ private Socket CreateSocket(EndPoint endPoint)
202252
{
203253
addressFamily = _settings.AddressFamily;
204254
}
255+
205256
return new Socket(addressFamily, SocketType.Stream, ProtocolType.Tcp);
206257
}
258+
259+
private async Task<EndPoint[]> ResolveEndPointsAsync(EndPoint initial)
260+
{
261+
var dnsInitial = initial as DnsEndPoint;
262+
if (dnsInitial == null)
263+
{
264+
return new[] { initial };
265+
}
266+
267+
IPAddress address;
268+
if (IPAddress.TryParse(dnsInitial.Host, out address))
269+
{
270+
return new[] { new IPEndPoint(address, dnsInitial.Port) };
271+
}
272+
273+
var preferred = initial.AddressFamily;
274+
if (preferred == AddressFamily.Unspecified || preferred == AddressFamily.Unknown)
275+
{
276+
preferred = _settings.AddressFamily;
277+
}
278+
279+
return (await Dns.GetHostAddressesAsync(dnsInitial.Host).ConfigureAwait(false))
280+
.Select(x => new IPEndPoint(x, dnsInitial.Port))
281+
.OrderBy(x => x, new PreferredAddressFamilyComparer(preferred))
282+
.ToArray();
283+
}
284+
285+
private class PreferredAddressFamilyComparer : IComparer<EndPoint>
286+
{
287+
private readonly AddressFamily _preferred;
288+
289+
public PreferredAddressFamilyComparer(AddressFamily preferred)
290+
{
291+
_preferred = preferred;
292+
}
293+
294+
public int Compare(EndPoint x, EndPoint y)
295+
{
296+
if (x.AddressFamily == y.AddressFamily)
297+
{
298+
return 0;
299+
}
300+
if (x.AddressFamily == _preferred)
301+
{
302+
return -1;
303+
}
304+
else if (y.AddressFamily == _preferred)
305+
{
306+
return 1;
307+
}
308+
309+
return 0;
310+
}
311+
}
207312
}
208313
}

0 commit comments

Comments
 (0)