Skip to content

Commit 8499a62

Browse files
jerrypengviirya
authored andcommitted
[SPARK-53785][SS] Memory Source for RTM
### What changes were proposed in this pull request? Add a memory source implementation that support Real-time Mode. This source is going to be used to test Real-time Mode. In this implementation of the memory source, a RPC server is set up on the driver and source tasks will constantly pull this RPC server for new data. This differs from the existing memory source implementation as data is sent once to tasks as part of the Partition/Split metadata at the beginning of a batch. ### Why are the changes needed? To test Real-time Mode queries. ### Does this PR introduce _any_ user-facing change? No, this source is purely to help testing. ### How was this patch tested? Actual unit tests to test RTM will be added in the future once the engine supports actually running queries in RTM. ### Was this patch authored or co-authored using generative AI tooling? No Closes apache#52502 from jerrypeng/SPARK-53785. Authored-by: Jerry Peng <jerry.peng@databricks.com> Signed-off-by: Liang-Chi Hsieh <viirya@gmail.com>
1 parent c041671 commit 8499a62

File tree

3 files changed

+346
-0
lines changed

3 files changed

+346
-0
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.datasources.v2
19+
20+
import org.apache.spark.util.{Clock, SystemClock}
21+
22+
/* The singleton object to control the time in testing */
23+
object LowLatencyClock {
24+
private var clock: Clock = new SystemClock
25+
26+
def getClock: Clock = clock
27+
28+
def getTimeMillis(): Long = {
29+
clock.getTimeMillis()
30+
}
31+
32+
def waitTillTime(targetTime: Long): Unit = {
33+
clock.waitTillTime(targetTime)
34+
}
35+
36+
def setClock(inputClock: Clock): Unit = {
37+
clock = inputClock
38+
}
39+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/memory.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,13 @@ case class MemoryStream[A : Encoder](
158158
id: Int,
159159
sqlContext: SQLContext,
160160
numPartitions: Option[Int] = None)
161+
extends MemoryStreamBaseClass[A](
162+
id, sqlContext, numPartitions = numPartitions)
163+
164+
abstract class MemoryStreamBaseClass[A: Encoder](
165+
id: Int,
166+
sqlContext: SQLContext,
167+
numPartitions: Option[Int] = None)
161168
extends MemoryStreamBase[A](sqlContext)
162169
with MicroBatchStream
163170
with SupportsTriggerAvailableNow
Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.streaming
19+
20+
import java.util.concurrent.atomic.AtomicInteger
21+
import javax.annotation.concurrent.GuardedBy
22+
23+
import scala.collection.mutable.ListBuffer
24+
25+
import org.json4s.{Formats, NoTypeHints}
26+
import org.json4s.jackson.Serialization
27+
28+
import org.apache.spark.{SparkEnv, TaskContext}
29+
import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef}
30+
import org.apache.spark.sql.{Encoder, SQLContext}
31+
import org.apache.spark.sql.catalyst.InternalRow
32+
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
33+
import org.apache.spark.sql.connector.read.InputPartition
34+
import org.apache.spark.sql.connector.read.PartitionReader
35+
import org.apache.spark.sql.connector.read.PartitionReaderFactory
36+
import org.apache.spark.sql.connector.read.streaming.{
37+
Offset => OffsetV2,
38+
PartitionOffset,
39+
ReadLimit,
40+
SupportsRealTimeMode,
41+
SupportsRealTimeRead
42+
}
43+
import org.apache.spark.sql.connector.read.streaming.SupportsRealTimeRead.RecordStatus
44+
import org.apache.spark.sql.execution.datasources.v2.LowLatencyClock
45+
import org.apache.spark.sql.execution.streaming.runtime._
46+
import org.apache.spark.util.{Clock, RpcUtils}
47+
48+
/**
49+
* A low latency memory source from memory, only for unit test purpose.
50+
* This class is very similar to ContinuousMemoryStream, except that it implements the
51+
* interface of SupportsRealTimeMode, rather than ContinuousStream
52+
* The overall strategy here is:
53+
* * LowLatencyMemoryStream maintains a list of records for each partition. addData() will
54+
* distribute records evenly-ish across partitions.
55+
* * RecordEndpoint is set up as an endpoint for executor-side
56+
* LowLatencyMemoryStreamInputPartitionReader instances to poll. It returns the record at
57+
* the specified offset within the list, or null if that offset doesn't yet have a record.
58+
* This differs from the existing memory source implementation as data is sent once to
59+
* tasks as part of the Partition/Split metadata at the beginning of a batch.
60+
*/
61+
class LowLatencyMemoryStream[A: Encoder](
62+
id: Int,
63+
sqlContext: SQLContext,
64+
numPartitions: Int = 2,
65+
clock: Clock = LowLatencyClock.getClock)
66+
extends MemoryStreamBaseClass[A](0, sqlContext)
67+
with SupportsRealTimeMode {
68+
private implicit val formats: Formats = Serialization.formats(NoTypeHints)
69+
70+
@GuardedBy("this")
71+
private val records = Seq.fill(numPartitions)(new ListBuffer[UnsafeRow])
72+
73+
private val recordEndpoint = new ContinuousRecordEndpoint(records, this)
74+
@volatile private var endpointRef: RpcEndpointRef = _
75+
76+
override def addData(data: IterableOnce[A]): Offset = synchronized {
77+
// Distribute data evenly among partition lists.
78+
data.iterator.to(Seq).zipWithIndex.map {
79+
case (item, index) =>
80+
records(index % numPartitions) += toRow(item).copy().asInstanceOf[UnsafeRow]
81+
}
82+
83+
// The new target offset is the offset where all records in all partitions have been processed.
84+
LowLatencyMemoryStreamOffset((0 until numPartitions).map(i => (i, records(i).size)).toMap)
85+
}
86+
87+
def addData(partitionId: Int, data: IterableOnce[A]): Offset = synchronized {
88+
require(
89+
partitionId >= 0 && partitionId < numPartitions,
90+
s"Partition ID $partitionId is out of bounds for $numPartitions partitions."
91+
)
92+
93+
// Add data to the specified partition.
94+
records(partitionId) ++= data.iterator.map(item => toRow(item).copy().asInstanceOf[UnsafeRow])
95+
96+
// The new target offset is the offset where all records in all partitions have been processed.
97+
LowLatencyMemoryStreamOffset((0 until numPartitions).map(i => (i, records(i).size)).toMap)
98+
}
99+
100+
override def initialOffset(): OffsetV2 = {
101+
LowLatencyMemoryStreamOffset((0 until numPartitions).map(i => (i, 0)).toMap)
102+
}
103+
104+
override def latestOffset(startOffset: OffsetV2, limit: ReadLimit): OffsetV2 = synchronized {
105+
LowLatencyMemoryStreamOffset((0 until numPartitions).map(i => (i, records(i).size)).toMap)
106+
}
107+
108+
override def deserializeOffset(json: String): LowLatencyMemoryStreamOffset = {
109+
LowLatencyMemoryStreamOffset(Serialization.read[Map[Int, Int]](json))
110+
}
111+
112+
override def mergeOffsets(offsets: Array[PartitionOffset]): LowLatencyMemoryStreamOffset = {
113+
LowLatencyMemoryStreamOffset(
114+
offsets.map {
115+
case ContinuousRecordPartitionOffset(part, num) => (part, num)
116+
}.toMap
117+
)
118+
}
119+
120+
override def planInputPartitions(start: OffsetV2): Array[InputPartition] = {
121+
val startOffset = start.asInstanceOf[LowLatencyMemoryStreamOffset]
122+
synchronized {
123+
val endpointName = s"LowLatencyRecordEndpoint-${java.util.UUID.randomUUID()}-$id"
124+
endpointRef = recordEndpoint.rpcEnv.setupEndpoint(endpointName, recordEndpoint)
125+
126+
startOffset.partitionNums.map {
127+
case (part, index) =>
128+
LowLatencyMemoryStreamInputPartition(
129+
endpointName,
130+
endpointRef.address,
131+
part,
132+
index,
133+
Int.MaxValue
134+
)
135+
}.toArray
136+
}
137+
}
138+
139+
override def planInputPartitions(start: OffsetV2, end: OffsetV2): Array[InputPartition] = {
140+
val startOffset = start.asInstanceOf[LowLatencyMemoryStreamOffset]
141+
val endOffset = end.asInstanceOf[LowLatencyMemoryStreamOffset]
142+
synchronized {
143+
val endpointName = s"LowLatencyRecordEndpoint-${java.util.UUID.randomUUID()}-$id"
144+
endpointRef = recordEndpoint.rpcEnv.setupEndpoint(endpointName, recordEndpoint)
145+
146+
startOffset.partitionNums.map {
147+
case (part, index) =>
148+
LowLatencyMemoryStreamInputPartition(
149+
endpointName,
150+
endpointRef.address,
151+
part,
152+
index,
153+
endOffset.partitionNums(part)
154+
)
155+
}.toArray
156+
}
157+
}
158+
159+
override def createReaderFactory(): PartitionReaderFactory = {
160+
new LowLatencyMemoryStreamReaderFactory(clock)
161+
}
162+
163+
override def stop(): Unit = {
164+
if (endpointRef != null) recordEndpoint.rpcEnv.stop(endpointRef)
165+
}
166+
167+
override def commit(end: OffsetV2): Unit = {}
168+
169+
override def reset(): Unit = synchronized {
170+
super.reset()
171+
records.foreach(_.clear())
172+
}
173+
}
174+
175+
object LowLatencyMemoryStream {
176+
protected val memoryStreamId = new AtomicInteger(0)
177+
178+
def apply[A: Encoder](implicit sqlContext: SQLContext): LowLatencyMemoryStream[A] =
179+
new LowLatencyMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext)
180+
181+
def apply[A: Encoder](numPartitions: Int)(
182+
implicit
183+
sqlContext: SQLContext): LowLatencyMemoryStream[A] =
184+
new LowLatencyMemoryStream[A](
185+
memoryStreamId.getAndIncrement(),
186+
sqlContext,
187+
numPartitions = numPartitions
188+
)
189+
190+
def singlePartition[A: Encoder](implicit sqlContext: SQLContext): LowLatencyMemoryStream[A] =
191+
new LowLatencyMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext, 1)
192+
}
193+
194+
/**
195+
* An input partition for LowLatency memory stream.
196+
*/
197+
case class LowLatencyMemoryStreamInputPartition(
198+
driverEndpointName: String,
199+
driverEndpointAddress: RpcAddress,
200+
partition: Int,
201+
startOffset: Int,
202+
endOffset: Int)
203+
extends InputPartition
204+
205+
class LowLatencyMemoryStreamReaderFactory(clock: Clock) extends PartitionReaderFactory {
206+
override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
207+
val p = partition.asInstanceOf[LowLatencyMemoryStreamInputPartition]
208+
new LowLatencyMemoryStreamPartitionReader(
209+
p.driverEndpointName,
210+
p.driverEndpointAddress,
211+
p.partition,
212+
p.startOffset,
213+
p.endOffset,
214+
clock
215+
)
216+
}
217+
}
218+
219+
/**
220+
* An input partition reader for LowLatency memory stream.
221+
*
222+
* Polls the driver endpoint for new records.
223+
*/
224+
class LowLatencyMemoryStreamPartitionReader(
225+
driverEndpointName: String,
226+
driverEndpointAddress: RpcAddress,
227+
partition: Int,
228+
startOffset: Int,
229+
endOffset: Int,
230+
clock: Clock)
231+
extends SupportsRealTimeRead[InternalRow] {
232+
// Avoid tracking the ref, given that we create a new one for each partition reader
233+
// because a new driver endpoint is created for each LowLatencyMemoryStream. If we track the ref,
234+
// we can end up with a lot of refs (1000s) if a test suite has so many test cases and can lead to
235+
// issues with the tracking array. Causing the test suite to be flaky.
236+
private val endpoint = RpcUtils.makeDriverRef(
237+
driverEndpointName,
238+
driverEndpointAddress.host,
239+
driverEndpointAddress.port,
240+
SparkEnv.get.rpcEnv
241+
)
242+
243+
private var currentOffset = startOffset
244+
private var current: Option[InternalRow] = None
245+
246+
// Defense-in-depth against failing to propagate the task context. Since it's not inheritable,
247+
// we have to do a bit of error prone work to get it into every thread used by LowLatency
248+
// processing. We hope that some unit test will end up instantiating a LowLatency memory stream
249+
// in such cases.
250+
if (TaskContext.get() == null) {
251+
throw new IllegalStateException("Task context was not set!")
252+
}
253+
override def nextWithTimeout(timeout: java.lang.Long): RecordStatus = {
254+
val startReadTime = clock.nanoTime()
255+
var elapsedTimeMs = 0L
256+
current = getRecord
257+
while (current.isEmpty) {
258+
val POLL_TIME = 10L
259+
if (elapsedTimeMs >= timeout) {
260+
return RecordStatus.newStatusWithoutArrivalTime(false)
261+
}
262+
Thread.sleep(POLL_TIME)
263+
current = getRecord
264+
elapsedTimeMs = (clock.nanoTime() - startReadTime) / 1000 / 1000
265+
}
266+
currentOffset += 1
267+
RecordStatus.newStatusWithoutArrivalTime(true)
268+
}
269+
270+
override def next(): Boolean = {
271+
current = getRecord
272+
if (current.isDefined) {
273+
currentOffset += 1
274+
true
275+
} else {
276+
false
277+
}
278+
}
279+
280+
override def get(): InternalRow = current.get
281+
282+
override def close(): Unit = {}
283+
284+
override def getOffset: ContinuousRecordPartitionOffset =
285+
ContinuousRecordPartitionOffset(partition, currentOffset)
286+
287+
private def getRecord: Option[InternalRow] = {
288+
if (currentOffset >= endOffset) {
289+
return None
290+
}
291+
endpoint.askSync[Option[InternalRow]](
292+
GetRecord(ContinuousRecordPartitionOffset(partition, currentOffset))
293+
)
294+
}
295+
}
296+
297+
case class LowLatencyMemoryStreamOffset(partitionNums: Map[Int, Int]) extends Offset {
298+
private implicit val formats: Formats = Serialization.formats(NoTypeHints)
299+
override def json(): String = Serialization.write(partitionNums)
300+
}

0 commit comments

Comments
 (0)