Skip to content

Commit 0d7618a

Browse files
MaxGekkHyukjinKwon
authored andcommitted
[SPARK-42585][CONNECT] Streaming of local relations
### What changes were proposed in this pull request? In the PR, I propose to transfer a local relation to the server in streaming way when it exceeds some size which is defined by the SQL config `spark.sql.session.localRelationCacheThreshold`. The config value is 64MB by default. In particular: 1. The client applies the `sha256` function over the arrow form of the local relation; 2. It checks presents of the relation at the server side by sending the relation hash to the server; 3. If the server doesn't have the local relation, the client transfers the local relation as an artefact with the name `cache/<sha256>`; 4. As soon as the relation has presented at the server already, or transferred recently, the client transform the logical plan by replacing the `LocalRelation` node by `CachedLocalRelation` with the hash. 5. On another hand, the server converts `CachedLocalRelation` back to `LocalRelation` by retrieving the relation body from the local cache. #### Details of the implementation The client sends new command `ArtifactStatusesRequest` to check either the local relation is cached at the server or not. New command comes via new RPC endpoint `ArtifactStatus`. And the server answers by new message `ArtifactStatusesResponse`, see **base.proto**. The client transfers serialized (in avro) body of local relation and its schema via the RPC endpoint `AddArtifacts`. On another hand, the server stores the received artifact in the block manager using the id `CacheId`. The last one has 3 parts: - `userId` - the identifier of the user that created the local relation, - `sessionId` - the identifier of the session which the relation belongs to, - `hash` - a `sha-256` hash over relation body. See **SparkConnectArtifactManager.addArtifact()**. The current query is blocked till the local relation is cached at the server side. When the server receives the query, it retrieves `userId`, `sessionId` and `hash` from `CachedLocalRelation`, and gets the local relation data from the block manager. See **SparkConnectPlanner.transformCachedLocalRelation()**. The occupied blocks at the block manager are removed when an user session is invalidated in `userSessionMapping`. See **SparkConnectService.RemoveSessionListener** and **BlockManager.removeCache()`**. ### Why are the changes needed? To allow creating a dataframe from a large local collection. `spark.createDataFrame(...)` fails with the following error w/o the changes: ```java 23/04/21 20:32:20 WARN NettyServerStream: Exception processing message org.sparkproject.connect.grpc.StatusRuntimeException: RESOURCE_EXHAUSTED: gRPC message exceeds maximum size 134217728: 268435456 at org.sparkproject.connect.grpc.Status.asRuntimeException(Status.java:526) ``` ### Does this PR introduce _any_ user-facing change? No. The changes extend the existing proto API. ### How was this patch tested? By running the new tests: ``` $ build/sbt "test:testOnly *.ArtifactManagerSuite" $ build/sbt "test:testOnly *.ClientE2ETestSuite" $ build/sbt "test:testOnly *.ArtifactStatusesHandlerSuite" ``` Closes apache#40827 from MaxGekk/streaming-createDataFrame-2. Authored-by: Max Gekk <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent d26292c commit 0d7618a

File tree

23 files changed

+922
-208
lines changed

23 files changed

+922
-208
lines changed

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,12 +119,23 @@ class SparkSession private[sql] (
119119

120120
private def createDataset[T](encoder: AgnosticEncoder[T], data: Iterator[T]): Dataset[T] = {
121121
newDataset(encoder) { builder =>
122-
val localRelationBuilder = builder.getLocalRelationBuilder
123-
.setSchema(encoder.schema.json)
124122
if (data.nonEmpty) {
125123
val timeZoneId = conf.get("spark.sql.session.timeZone")
126-
val arrowData = ConvertToArrow(encoder, data, timeZoneId, allocator)
127-
localRelationBuilder.setData(arrowData)
124+
val (arrowData, arrowDataSize) = ConvertToArrow(encoder, data, timeZoneId, allocator)
125+
if (arrowDataSize <= conf.get("spark.sql.session.localRelationCacheThreshold").toInt) {
126+
builder.getLocalRelationBuilder
127+
.setSchema(encoder.schema.json)
128+
.setData(arrowData)
129+
} else {
130+
val hash = client.cacheLocalRelation(arrowDataSize, arrowData, encoder.schema.json)
131+
builder.getCachedLocalRelationBuilder
132+
.setUserId(client.userId)
133+
.setSessionId(client.sessionId)
134+
.setHash(hash)
135+
}
136+
} else {
137+
builder.getLocalRelationBuilder
138+
.setSchema(encoder.schema.json)
128139
}
129140
}
130141
}

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.connect.client
1919
import java.io.{ByteArrayInputStream, InputStream}
2020
import java.net.URI
2121
import java.nio.file.{Files, Path, Paths}
22+
import java.util.Arrays
2223
import java.util.concurrent.CopyOnWriteArrayList
2324
import java.util.zip.{CheckedInputStream, CRC32}
2425

@@ -32,6 +33,7 @@ import Artifact._
3233
import com.google.protobuf.ByteString
3334
import io.grpc.ManagedChannel
3435
import io.grpc.stub.StreamObserver
36+
import org.apache.commons.codec.digest.DigestUtils.sha256Hex
3537

3638
import org.apache.spark.connect.proto
3739
import org.apache.spark.connect.proto.AddArtifactsResponse
@@ -42,14 +44,20 @@ import org.apache.spark.util.{ThreadUtils, Utils}
4244
* The Artifact Manager is responsible for handling and transferring artifacts from the local
4345
* client to the server (local/remote).
4446
* @param userContext
47+
* @param sessionId
48+
* An unique identifier of the session which the artifact manager belongs to.
4549
* @param channel
4650
*/
47-
class ArtifactManager(userContext: proto.UserContext, channel: ManagedChannel) {
51+
class ArtifactManager(
52+
userContext: proto.UserContext,
53+
sessionId: String,
54+
channel: ManagedChannel) {
4855
// Using the midpoint recommendation of 32KiB for chunk size as specified in
4956
// https://github.com/grpc/grpc.github.io/issues/371.
5057
private val CHUNK_SIZE: Int = 32 * 1024
5158

5259
private[this] val stub = proto.SparkConnectServiceGrpc.newStub(channel)
60+
private[this] val bstub = proto.SparkConnectServiceGrpc.newBlockingStub(channel)
5361
private[this] val classFinders = new CopyOnWriteArrayList[ClassFinder]
5462

5563
/**
@@ -100,6 +108,31 @@ class ArtifactManager(userContext: proto.UserContext, channel: ManagedChannel) {
100108
*/
101109
def addArtifacts(uris: Seq[URI]): Unit = addArtifacts(uris.flatMap(parseArtifacts))
102110

111+
private def isCachedArtifact(hash: String): Boolean = {
112+
val artifactName = CACHE_PREFIX + "/" + hash
113+
val request = proto.ArtifactStatusesRequest
114+
.newBuilder()
115+
.setUserContext(userContext)
116+
.setSessionId(sessionId)
117+
.addAllNames(Arrays.asList(artifactName))
118+
.build()
119+
val statuses = bstub.artifactStatus(request).getStatusesMap
120+
if (statuses.containsKey(artifactName)) {
121+
statuses.get(artifactName).getExists
122+
} else false
123+
}
124+
125+
/**
126+
* Cache the give blob at the session.
127+
*/
128+
def cacheArtifact(blob: Array[Byte]): String = {
129+
val hash = sha256Hex(blob)
130+
if (!isCachedArtifact(hash)) {
131+
addArtifacts(newCacheArtifact(hash, new InMemory(blob)) :: Nil)
132+
}
133+
hash
134+
}
135+
103136
/**
104137
* Upload all class file artifacts from the local REPL(s) to the server.
105138
*
@@ -182,6 +215,7 @@ class ArtifactManager(userContext: proto.UserContext, channel: ManagedChannel) {
182215
val builder = proto.AddArtifactsRequest
183216
.newBuilder()
184217
.setUserContext(userContext)
218+
.setSessionId(sessionId)
185219
artifacts.foreach { artifact =>
186220
val in = new CheckedInputStream(artifact.storage.asInstanceOf[LocalData].stream, new CRC32)
187221
try {
@@ -236,6 +270,7 @@ class ArtifactManager(userContext: proto.UserContext, channel: ManagedChannel) {
236270
val builder = proto.AddArtifactsRequest
237271
.newBuilder()
238272
.setUserContext(userContext)
273+
.setSessionId(sessionId)
239274

240275
val in = new CheckedInputStream(artifact.storage.asInstanceOf[LocalData].stream, new CRC32)
241276
try {
@@ -289,6 +324,7 @@ class Artifact private (val path: Path, val storage: LocalData) {
289324
object Artifact {
290325
val CLASS_PREFIX: Path = Paths.get("classes")
291326
val JAR_PREFIX: Path = Paths.get("jars")
327+
val CACHE_PREFIX: Path = Paths.get("cache")
292328

293329
def newJarArtifact(fileName: Path, storage: LocalData): Artifact = {
294330
newArtifact(JAR_PREFIX, ".jar", fileName, storage)
@@ -298,6 +334,10 @@ object Artifact {
298334
newArtifact(CLASS_PREFIX, ".class", fileName, storage)
299335
}
300336

337+
def newCacheArtifact(id: String, storage: LocalData): Artifact = {
338+
newArtifact(CACHE_PREFIX, "", Paths.get(id), storage)
339+
}
340+
301341
private def newArtifact(
302342
prefix: Path,
303343
requiredSuffix: String,

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@
1717

1818
package org.apache.spark.sql.connect.client
1919

20+
import com.google.protobuf.ByteString
2021
import io.grpc.{CallCredentials, CallOptions, Channel, ClientCall, ClientInterceptor, CompositeChannelCredentials, ForwardingClientCall, Grpc, InsecureChannelCredentials, ManagedChannel, ManagedChannelBuilder, Metadata, MethodDescriptor, Status, TlsChannelCredentials}
2122
import java.net.URI
23+
import java.nio.ByteBuffer
24+
import java.nio.charset.StandardCharsets
2225
import java.util.UUID
2326
import java.util.concurrent.Executor
2427
import scala.language.existentials
@@ -39,19 +42,21 @@ private[sql] class SparkConnectClient(
3942

4043
private[this] val stub = proto.SparkConnectServiceGrpc.newBlockingStub(channel)
4144

42-
private[client] val artifactManager: ArtifactManager = new ArtifactManager(userContext, channel)
43-
4445
/**
4546
* Placeholder method.
4647
* @return
4748
* User ID.
4849
*/
49-
private[client] def userId: String = userContext.getUserId()
50+
private[sql] def userId: String = userContext.getUserId()
5051

5152
// Generate a unique session ID for this client. This UUID must be unique to allow
5253
// concurrent Spark sessions of the same user. If the channel is closed, creating
5354
// a new client will create a new session ID.
54-
private[client] val sessionId: String = UUID.randomUUID.toString
55+
private[sql] val sessionId: String = UUID.randomUUID.toString
56+
57+
private[client] val artifactManager: ArtifactManager = {
58+
new ArtifactManager(userContext, sessionId, channel)
59+
}
5560

5661
/**
5762
* Dispatch the [[proto.AnalyzePlanRequest]] to the Spark Connect server.
@@ -215,6 +220,19 @@ private[sql] class SparkConnectClient(
215220
def shutdown(): Unit = {
216221
channel.shutdownNow()
217222
}
223+
224+
/**
225+
* Cache the given local relation at the server, and return its key in the remote cache.
226+
*/
227+
def cacheLocalRelation(size: Int, data: ByteString, schema: String): String = {
228+
val schemaBytes = schema.getBytes(StandardCharsets.UTF_8)
229+
val locRelData = data.toByteArray
230+
val locRel = ByteBuffer.allocate(4 + locRelData.length + schemaBytes.length)
231+
locRel.putInt(size)
232+
locRel.put(locRelData)
233+
locRel.put(schemaBytes)
234+
artifactManager.cacheArtifact(locRel.array())
235+
}
218236
}
219237

220238
object SparkConnectClient {

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/util/ConvertToArrow.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,13 @@ import org.apache.spark.sql.util.ArrowUtils
3434
private[sql] object ConvertToArrow {
3535

3636
/**
37-
* Convert an iterator of common Scala objects into a sinlge Arrow IPC Stream.
37+
* Convert an iterator of common Scala objects into a single Arrow IPC Stream.
3838
*/
3939
def apply[T](
4040
encoder: AgnosticEncoder[T],
4141
data: Iterator[T],
4242
timeZoneId: String,
43-
bufferAllocator: BufferAllocator): ByteString = {
43+
bufferAllocator: BufferAllocator): (ByteString, Int) = {
4444
val arrowSchema = ArrowUtils.toArrowSchema(encoder.schema, timeZoneId)
4545
val root = VectorSchemaRoot.create(arrowSchema, bufferAllocator)
4646
val writer: ArrowWriter = ArrowWriter.create(root)
@@ -64,7 +64,7 @@ private[sql] object ConvertToArrow {
6464
ArrowStreamWriter.writeEndOfStream(channel, IpcOption.DEFAULT)
6565

6666
// Done
67-
bytes.toByteString
67+
(bytes.toByteString, bytes.size)
6868
} finally {
6969
root.close()
7070
}

connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
3434
import org.apache.spark.sql.catalyst.parser.ParseException
3535
import org.apache.spark.sql.connect.client.util.{IntegrationTestUtils, RemoteSparkSession}
3636
import org.apache.spark.sql.functions._
37+
import org.apache.spark.sql.internal.SQLConf
3738
import org.apache.spark.sql.types._
3839

3940
class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper {
@@ -853,6 +854,19 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper {
853854
}.getMessage
854855
assert(message.contains("PARSE_SYNTAX_ERROR"))
855856
}
857+
858+
test("SparkSession.createDataFrame - large data set") {
859+
val threshold = 1024 * 1024
860+
withSQLConf(SQLConf.LOCAL_RELATION_CACHE_THRESHOLD.key -> threshold.toString) {
861+
val count = 2
862+
val suffix = "abcdef"
863+
val str = scala.util.Random.alphanumeric.take(1024 * 1024).mkString + suffix
864+
val data = Seq.tabulate(count)(i => (i, str))
865+
val df = spark.createDataFrame(data)
866+
assert(df.count() === count)
867+
assert(!df.filter(df("_2").endsWith(suffix)).isEmpty)
868+
}
869+
}
856870
}
857871

858872
private[sql] case class MyType(id: Long, a: Double, b: Double)

connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class ArtifactSuite extends ConnectFunSuite with BeforeAndAfterEach {
4949

5050
private def createArtifactManager(): Unit = {
5151
channel = InProcessChannelBuilder.forName(getClass.getName).directExecutor().build()
52-
artifactManager = new ArtifactManager(proto.UserContext.newBuilder().build(), channel)
52+
artifactManager = new ArtifactManager(proto.UserContext.newBuilder().build(), "", channel)
5353
}
5454

5555
override def beforeEach(): Unit = {

connector/connect/common/src/main/protobuf/spark/connect/base.proto

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,43 @@ message AddArtifactsResponse {
542542
repeated ArtifactSummary artifacts = 1;
543543
}
544544

545+
// Request to get current statuses of artifacts at the server side.
546+
message ArtifactStatusesRequest {
547+
// (Required)
548+
//
549+
// The session_id specifies a spark session for a user id (which is specified
550+
// by user_context.user_id). The session_id is set by the client to be able to
551+
// collate streaming responses from different queries within the dedicated session.
552+
string session_id = 1;
553+
554+
// User context
555+
UserContext user_context = 2;
556+
557+
// Provides optional information about the client sending the request. This field
558+
// can be used for language or version specific information and is only intended for
559+
// logging purposes and will not be interpreted by the server.
560+
optional string client_type = 3;
561+
562+
// The name of the artifact is expected in the form of a "Relative Path" that is made up of a
563+
// sequence of directories and the final file element.
564+
// Examples of "Relative Path"s: "jars/test.jar", "classes/xyz.class", "abc.xyz", "a/b/X.jar".
565+
// The server is expected to maintain the hierarchy of files as defined by their name. (i.e
566+
// The relative path of the file on the server's filesystem will be the same as the name of
567+
// the provided artifact)
568+
repeated string names = 4;
569+
}
570+
571+
// Response to checking artifact statuses.
572+
message ArtifactStatusesResponse {
573+
message ArtifactStatus {
574+
// Exists or not particular artifact at the server.
575+
bool exists = 1;
576+
}
577+
578+
// A map of artifact names to their statuses.
579+
map<string, ArtifactStatus> statuses = 1;
580+
}
581+
545582
// Main interface for the SparkConnect service.
546583
service SparkConnectService {
547584

@@ -559,5 +596,8 @@ service SparkConnectService {
559596
// Add artifacts to the session and returns a [[AddArtifactsResponse]] containing metadata about
560597
// the added artifacts.
561598
rpc AddArtifacts(stream AddArtifactsRequest) returns (AddArtifactsResponse) {}
599+
600+
// Check statuses of artifacts in the session and returns them in a [[ArtifactStatusesResponse]]
601+
rpc ArtifactStatus(ArtifactStatusesRequest) returns (ArtifactStatusesResponse) {}
562602
}
563603

connector/connect/common/src/main/protobuf/spark/connect/relations.proto

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ message Relation {
6868
WithWatermark with_watermark = 33;
6969
ApplyInPandasWithState apply_in_pandas_with_state = 34;
7070
HtmlString html_string = 35;
71+
CachedLocalRelation cached_local_relation = 36;
7172

7273
// NA functions
7374
NAFill fill_na = 90;
@@ -381,6 +382,18 @@ message LocalRelation {
381382
optional string schema = 2;
382383
}
383384

385+
// A local relation that has been cached already.
386+
message CachedLocalRelation {
387+
// (Required) An identifier of the user which created the local relation
388+
string userId = 1;
389+
390+
// (Required) An identifier of the Spark SQL session in which the user created the local relation.
391+
string sessionId = 2;
392+
393+
// (Required) A sha-256 hash of the serialized local relation.
394+
string hash = 3;
395+
}
396+
384397
// Relation of type [[Sample]] that samples a fraction of the dataset.
385398
message Sample {
386399
// (Required) Input relation for a Sample.

connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,11 @@ import java.nio.file.{Files, Path, Paths, StandardCopyOption}
2222
import java.util.concurrent.CopyOnWriteArrayList
2323

2424
import scala.collection.JavaConverters._
25+
import scala.reflect.ClassTag
2526

2627
import org.apache.spark.{SparkContext, SparkEnv}
27-
import org.apache.spark.sql.SparkSession
28+
import org.apache.spark.sql.connect.service.SessionHolder
29+
import org.apache.spark.storage.{CacheId, StorageLevel}
2830
import org.apache.spark.util.Utils
2931

3032
/**
@@ -87,11 +89,28 @@ class SparkConnectArtifactManager private[connect] {
8789
* @param serverLocalStagingPath
8890
*/
8991
private[connect] def addArtifact(
90-
session: SparkSession,
92+
sessionHolder: SessionHolder,
9193
remoteRelativePath: Path,
9294
serverLocalStagingPath: Path): Unit = {
9395
require(!remoteRelativePath.isAbsolute)
94-
if (remoteRelativePath.startsWith("classes/")) {
96+
if (remoteRelativePath.startsWith("cache/")) {
97+
val tmpFile = serverLocalStagingPath.toFile
98+
Utils.tryWithSafeFinallyAndFailureCallbacks {
99+
val blockManager = sessionHolder.session.sparkContext.env.blockManager
100+
val blockId = CacheId(
101+
userId = sessionHolder.userId,
102+
sessionId = sessionHolder.sessionId,
103+
hash = remoteRelativePath.toString.stripPrefix("cache/"))
104+
val updater = blockManager.TempFileBasedBlockStoreUpdater(
105+
blockId = blockId,
106+
level = StorageLevel.MEMORY_AND_DISK_SER,
107+
classTag = implicitly[ClassTag[Array[Byte]]],
108+
tmpFile = tmpFile,
109+
blockSize = tmpFile.length(),
110+
tellMaster = false)
111+
updater.save()
112+
}(catchBlock = { tmpFile.delete() })
113+
} else if (remoteRelativePath.startsWith("classes/")) {
95114
// Move class files to common location (shared among all users)
96115
val target = classArtifactDir.resolve(remoteRelativePath.toString.stripPrefix("classes/"))
97116
Files.createDirectories(target.getParent)
@@ -110,7 +129,7 @@ class SparkConnectArtifactManager private[connect] {
110129
Files.move(serverLocalStagingPath, target)
111130
if (remoteRelativePath.startsWith("jars")) {
112131
// Adding Jars to the underlying spark context (visible to all users)
113-
session.sessionState.resourceLoader.addJar(target.toString)
132+
sessionHolder.session.sessionState.resourceLoader.addJar(target.toString)
114133
jarsList.add(target)
115134
}
116135
}

0 commit comments

Comments
 (0)