Skip to content

Enabling configuring IP Address for Dotnet Backend #537

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/csharp/Microsoft.Spark.UnitTest/SparkFixture.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Net;
using Microsoft.Spark.Interop;
using Microsoft.Spark.Interop.Ipc;
using Moq;
Expand All @@ -27,7 +28,7 @@ public SparkFixture()

var mockJvmBridgeFactory = new Mock<IJvmBridgeFactory>();
mockJvmBridgeFactory
.Setup(m => m.Create(It.IsAny<int>()))
.Setup(m => m.Create(It.IsAny<IPAddress>(), It.IsAny<int>()))
.Returns(MockJvm.Object);

SparkEnvironment.JvmBridgeFactory = mockJvmBridgeFactory.Object;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,11 @@ public void TestClosedStreamWithSocket()
PayloadWriter payloadWriter = new PayloadWriterFactory().Create();
Payload payload = TestData.GetDefaultPayload();

using var serverListener = new DefaultSocketWrapper();
using var serverListener = new DefaultSocketWrapper(IPAddress.Loopback);
serverListener.Listen();

var port = (serverListener.LocalEndPoint as IPEndPoint).Port;
using var clientSocket = new DefaultSocketWrapper();
using var clientSocket = new DefaultSocketWrapper(IPAddress.Loopback);
clientSocket.Connect(IPAddress.Loopback, port, null);

using (ISocketWrapper serverSocket = serverListener.Accept())
Expand Down
4 changes: 2 additions & 2 deletions src/csharp/Microsoft.Spark.Worker.UnitTest/TaskRunnerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ public class TaskRunnerTests
[Fact]
public void TestTaskRunner()
{
using var serverListener = new DefaultSocketWrapper();
using var serverListener = new DefaultSocketWrapper(IPAddress.Loopback);
serverListener.Listen();

var port = (serverListener.LocalEndPoint as IPEndPoint).Port;
var clientSocket = new DefaultSocketWrapper();
var clientSocket = new DefaultSocketWrapper(IPAddress.Loopback);
clientSocket.Connect(IPAddress.Loopback, port, null);

PayloadWriter payloadWriter = new PayloadWriterFactory().Create();
Expand Down
3 changes: 2 additions & 1 deletion src/csharp/Microsoft.Spark/Interop/Ipc/CallbackServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ internal void Run(ISocketWrapper listener)
/// </summary>
private void Run()
{
Run(SocketFactory.CreateSocket());
IPAddress dotnetCallbackServerIpAddress = SparkEnvironment.ConfigurationService.GetCallbackServerIPAddress();
Run(SocketFactory.CreateSocket(dotnetCallbackServerIpAddress));
}

/// <summary>
Expand Down
4 changes: 4 additions & 0 deletions src/csharp/Microsoft.Spark/Interop/Ipc/IJvmBridgeFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Net;

namespace Microsoft.Spark.Interop.Ipc
{
internal interface IJvmBridgeFactory
{
IJvmBridge Create(int portNumber);

IJvmBridge Create(IPAddress ip, int portNumber);
}
}
16 changes: 11 additions & 5 deletions src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridge.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,24 @@ internal sealed class JvmBridge : IJvmBridge
new ConcurrentQueue<ISocketWrapper>();
private readonly ILoggerService _logger =
LoggerServiceFactory.GetLogger(typeof(JvmBridge));
private readonly IPAddress _ipAddress;
private readonly int _portNumber;
private readonly JvmThreadPoolGC _jvmThreadPoolGC;
private readonly bool _isRunningRepl;

internal JvmBridge(int portNumber)
internal JvmBridge(int portNumber): this(IPAddress.Loopback, portNumber)
{
}

internal JvmBridge(IPAddress ipAddress, int portNumber)
{
if (portNumber == 0)
{
throw new Exception("Port number is not set.");
}

_ipAddress = ipAddress;
_portNumber = portNumber;
_logger.LogInfo($"JvMBridge port is {portNumber}");
_logger.LogInfo($"JvMBridge IP is {_ipAddress} port is {_portNumber}");

_jvmThreadPoolGC = new JvmThreadPoolGC(
_logger, this, SparkEnvironment.ConfigurationService.JvmThreadGCInterval, _processId);
Expand Down Expand Up @@ -83,8 +88,9 @@ private ISocketWrapper GetConnection()
_socketSemaphore.Wait();
if (!_sockets.TryDequeue(out ISocketWrapper socket))
{
socket = SocketFactory.CreateSocket();
socket.Connect(IPAddress.Loopback, _portNumber);
IPEndPoint dotnetBackendIPEndpoint = SparkEnvironment.ConfigurationService.GetBackendIPEndpoint();
socket = SocketFactory.CreateSocket(dotnetBackendIPEndpoint.Address);
socket.Connect(_ipAddress, _portNumber);
}

return socket;
Expand Down
7 changes: 7 additions & 0 deletions src/csharp/Microsoft.Spark/Interop/Ipc/JvmBridgeFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Net;

namespace Microsoft.Spark.Interop.Ipc
{
internal class JvmBridgeFactory : IJvmBridgeFactory
Expand All @@ -10,5 +12,10 @@ public IJvmBridge Create(int portNumber)
{
return new JvmBridge(portNumber);
}

public IJvmBridge Create(IPAddress ipAddress, int portNumber)
{
return new JvmBridge(ipAddress, portNumber);
}
}
}
4 changes: 3 additions & 1 deletion src/csharp/Microsoft.Spark/Interop/SparkEnvironment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Net;
using Microsoft.Spark.Interop.Ipc;
using Microsoft.Spark.Services;

Expand Down Expand Up @@ -70,8 +71,9 @@ public static IJvmBridge JvmBridge
{
get
{
IPEndPoint jvmBackendEndPoint = ConfigurationService.GetBackendIPEndpoint();
return s_jvmBridge ??=
JvmBridgeFactory.Create(ConfigurationService.GetBackendPortNumber());
JvmBridgeFactory.Create(jvmBackendEndPoint.Address, jvmBackendEndPoint.Port);
}
set
{
Expand Down
9 changes: 5 additions & 4 deletions src/csharp/Microsoft.Spark/Network/DefaultSocketWrapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.IO;
using System.Net;
using System.Net.Sockets;
using Microsoft.Spark.Interop;
using Microsoft.Spark.Services;
using Microsoft.Spark.Utils;

Expand All @@ -24,12 +25,12 @@ internal sealed class DefaultSocketWrapper : ISocketWrapper
/// Default constructor that creates a new instance of DefaultSocket class which represents
/// a traditional socket (System.Net.Socket.Socket).
///
/// This socket is bound to Loopback with port 0.
/// This socket is bound to provided IP address with port 0.
/// </summary>
public DefaultSocketWrapper() :
this(new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
public DefaultSocketWrapper(IPAddress ipAddress) :
this(new Socket(ipAddress.AddressFamily, SocketType.Stream, ProtocolType.Tcp))
{
_innerSocket.Bind(new IPEndPoint(IPAddress.Loopback, 0));
_innerSocket.Bind(new IPEndPoint(ipAddress, 0));
}

/// <summary>
Expand Down
9 changes: 8 additions & 1 deletion src/csharp/Microsoft.Spark/Network/SocketFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Net;

namespace Microsoft.Spark.Network
{
/// <summary>
Expand All @@ -17,7 +19,12 @@ internal static class SocketFactory
/// </returns>
public static ISocketWrapper CreateSocket()
{
return new DefaultSocketWrapper();
return new DefaultSocketWrapper(IPAddress.Loopback);
}

public static ISocketWrapper CreateSocket(IPAddress ip)
{
return new DefaultSocketWrapper(ip);
}
}
}
4 changes: 3 additions & 1 deletion src/csharp/Microsoft.Spark/RDD.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Net;
using Microsoft.Spark.Interop;
using Microsoft.Spark.Interop.Ipc;
using Microsoft.Spark.Network;
using Microsoft.Spark.Utils;
Expand Down Expand Up @@ -262,7 +263,8 @@ public IEnumerable<T> Collect()
{
(int port, string secret) = CollectAndServe();
using ISocketWrapper socket = SocketFactory.CreateSocket();
socket.Connect(IPAddress.Loopback, port, secret);
IPEndPoint dotnetBackendIPEndpoint = SparkEnvironment.ConfigurationService.GetBackendIPEndpoint();
socket.Connect(dotnetBackendIPEndpoint.Address, port, secret);

var collector = new RDD.Collector();
System.IO.Stream stream = socket.InputStream;
Expand Down
38 changes: 31 additions & 7 deletions src/csharp/Microsoft.Spark/Services/ConfigurationService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System;
using System.IO;
using System.Net;
using System.Runtime.InteropServices;
using static System.Environment;
using Microsoft.Spark.Utils;
Expand All @@ -24,6 +25,8 @@ internal sealed class ConfigurationService : IConfigurationService
internal const string WorkerVerDirEnvVarNameFormat = "DOTNET_WORKER_{0}_DIR";

private const string DotnetBackendPortEnvVarName = "DOTNETBACKEND_PORT";
private const string DotnetBackendIPAddressEnvVarName = "DOTNET_SPARK_BACKEND_IP_ADDRESS";
private const string DotnetCallbackServerIPAddressEnvVarName = "DOTNET_SPARK_CALLBACK_SERVER_IP_ADDRESS";
private const int DotnetBackendDebugPort = 5567;

private const string DotnetNumBackendThreadsEnvVarName = "DOTNET_SPARK_NUM_BACKEND_THREADS";
Expand Down Expand Up @@ -99,21 +102,26 @@ public TimeSpan JvmThreadGCInterval
!string.IsNullOrEmpty(GetEnvironmentVariable("DATABRICKS_RUNTIME_VERSION"));

/// <summary>
/// Returns the port number for socket communication between JVM and CLR.
/// Returns the IP Endpoint for socket communication between JVM and CLR.
/// </summary>
public int GetBackendPortNumber()
public IPEndPoint GetBackendIPEndpoint()
{
if (!int.TryParse(
GetEnvironmentVariable(DotnetBackendPortEnvVarName),
Environment.GetEnvironmentVariable(DotnetBackendPortEnvVarName),
out int portNumber))
{
_logger.LogInfo($"'{DotnetBackendPortEnvVarName}' environment variable is not set.");
portNumber = DotnetBackendDebugPort;
}

_logger.LogInfo($"Using port {portNumber} for connection.");

return portNumber;
string ipAddress = Environment.GetEnvironmentVariable(DotnetBackendIPAddressEnvVarName);
if (ipAddress == null)
{
_logger.LogInfo($"'{DotnetBackendIPAddressEnvVarName}' environment variable is not set.");
ipAddress = "127.0.0.1";
}
_logger.LogInfo($"Using IP address {ipAddress} and port {portNumber} for connection.");

return new IPEndPoint(IPAddress.Parse(ipAddress), portNumber);
}

/// <summary>
Expand All @@ -131,6 +139,22 @@ public int GetNumBackendThreads()
return numThreads;
}

/// <summary>
/// Returns the IP address for socket communication between JVM and CallBack Server.
/// </summary>
public IPAddress GetCallbackServerIPAddress()
{
string ipAddress = Environment.GetEnvironmentVariable(DotnetCallbackServerIPAddressEnvVarName);
if (ipAddress == null)
{
_logger.LogInfo($"'{DotnetCallbackServerIPAddressEnvVarName}' environment variable is not set.");
ipAddress = "127.0.0.1";
}
_logger.LogInfo($"Using IP address {ipAddress} for connection with Callback Server.");

return IPAddress.Parse(ipAddress);
}

/// <summary>
/// Returns the worker executable path.
/// </summary>
Expand Down
14 changes: 10 additions & 4 deletions src/csharp/Microsoft.Spark/Services/IConfigurationService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Net;

namespace Microsoft.Spark.Services
{
Expand All @@ -17,14 +18,19 @@ internal interface IConfigurationService
TimeSpan JvmThreadGCInterval { get; }

/// <summary>
/// The port number used for communicating with the .NET backend process.
/// Returns the max number of threads for socket communication between JVM and CLR.
/// </summary>
int GetBackendPortNumber();
int GetNumBackendThreads();

/// <summary>
/// Returns the max number of threads for socket communication between JVM and CLR.
/// The IP Endpoint used for communicating with the .NET backend process.
/// </summary>
int GetNumBackendThreads();
IPEndPoint GetBackendIPEndpoint();

/// <summary>
/// The IP address used for communicating with CallBack server.
/// </summary>
IPAddress GetCallbackServerIPAddress();

/// <summary>
/// The full path to the .NET worker executable.
Expand Down
16 changes: 10 additions & 6 deletions src/csharp/Microsoft.Spark/Sql/DataFrame.cs
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,8 @@ public IEnumerable<Row> ToLocalIterator(bool prefetchPartitions)
Reference.Invoke("toPythonIterator", prefetchPartitions),
true);
using ISocketWrapper socket = SocketFactory.CreateSocket();
socket.Connect(IPAddress.Loopback, port, secret);
IPEndPoint dotnetBackendIPEndpoint = SparkEnvironment.ConfigurationService.GetBackendIPEndpoint();
socket.Connect(dotnetBackendIPEndpoint.Address, port, secret);
foreach (Row row in new RowCollector().Collect(socket, server))
{
yield return row;
Expand Down Expand Up @@ -1077,15 +1078,18 @@ public int SemanticHash() =>
/// </summary>
/// <param name="funcName">String name of function to call</param>
/// <param name="args">Arguments to the function</param>
/// <returns>IEnumerable of Rows from Spark</returns>
/// <returns></returns>
private IEnumerable<Row> GetRows(string funcName, params object[] args)
{
(int port, string secret, _) = GetConnectionInfo(funcName, args);
using ISocketWrapper socket = SocketFactory.CreateSocket();
socket.Connect(IPAddress.Loopback, port, secret);
foreach (Row row in new RowCollector().Collect(socket))
IPEndPoint dotnetBackendIPEndpoint = SparkEnvironment.ConfigurationService.GetBackendIPEndpoint();
using (ISocketWrapper socket = SocketFactory.CreateSocket())
{
yield return row;
socket.Connect(dotnetBackendIPEndpoint.Address, port, secret);
foreach (Row row in new RowCollector().Collect(socket))
{
yield return row;
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class DotnetBackend extends Logging {
@volatile
private[dotnet] var callbackClient: Option[CallbackClient] = None

def init(portNumber: Int): Int = {
def init(ipAddress: String, portNumber: Int): Int = {
val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
val numBackendThreads = conf.get(DOTNET_NUM_BACKEND_THREADS)
logInfo(s"The number of DotnetBackend threads is set to $numBackendThreads.")
Expand Down Expand Up @@ -63,7 +63,7 @@ class DotnetBackend extends Logging {
}
})

channelFuture = bootstrap.bind(new InetSocketAddress("localhost", portNumber))
channelFuture = bootstrap.bind(new InetSocketAddress(ipAddress, portNumber))
channelFuture.syncUninterruptibly()
channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ object DotnetRunner extends Logging {
// In debug mode this runner will not launch a .NET process.
val runInDebugMode = settings._1
@volatile var dotnetBackendPortNumber = settings._2
val dotnetBackendIPAddress = sys.env.getOrElse("DOTNET_SPARK_BACKEND_IP_ADDRESS", "127.0.0.1")
var dotnetExecutable = ""
var otherArgs: Array[String] = null

Expand Down Expand Up @@ -110,8 +111,9 @@ object DotnetRunner extends Logging {
override def run() {
// need to get back dotnetBackendPortNumber because if the value passed to init is 0
// the port number is dynamically assigned in the backend
dotnetBackendPortNumber = dotnetBackend.init(dotnetBackendPortNumber)
logInfo(s"Port number used by DotnetBackend is $dotnetBackendPortNumber")
dotnetBackendPortNumber = dotnetBackend.init(dotnetBackendIPAddress, dotnetBackendPortNumber)
logInfo(s"IP address used by DotnetBackend is $dotnetBackendIPAddress and " +
s"Port number used is $dotnetBackendPortNumber")
initialized.release()
dotnetBackend.run()
}
Expand Down
Loading