Skip to content

Commit 9e33954

Browse files
ueshincloud-fan
authored andcommitted
[SPARK-21745][SQL] Refactor ColumnVector hierarchy to make ColumnVector read-only and to introduce WritableColumnVector.
## What changes were proposed in this pull request? This is a refactoring of `ColumnVector` hierarchy and related classes. 1. make `ColumnVector` read-only 2. introduce `WritableColumnVector` with write interface 3. remove `ReadOnlyColumnVector` ## How was this patch tested? Existing tests. Author: Takuya UESHIN <[email protected]> Closes apache#18958 from ueshin/issues/SPARK-21745.
1 parent dc5d34d commit 9e33954

File tree

18 files changed

+1078
-1099
lines changed

18 files changed

+1078
-1099
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -464,14 +464,13 @@ class CodegenContext {
464464
/**
465465
* Returns the specialized code to set a given value in a column vector for a given `DataType`.
466466
*/
467-
def setValue(batch: String, row: String, dataType: DataType, ordinal: Int,
468-
value: String): String = {
467+
def setValue(vector: String, rowId: String, dataType: DataType, value: String): String = {
469468
val jt = javaType(dataType)
470469
dataType match {
471470
case _ if isPrimitiveType(jt) =>
472-
s"$batch.column($ordinal).put${primitiveTypeName(jt)}($row, $value);"
473-
case t: DecimalType => s"$batch.column($ordinal).putDecimal($row, $value, ${t.precision});"
474-
case t: StringType => s"$batch.column($ordinal).putByteArray($row, $value.getBytes());"
471+
s"$vector.put${primitiveTypeName(jt)}($rowId, $value);"
472+
case t: DecimalType => s"$vector.putDecimal($rowId, $value, ${t.precision});"
473+
case t: StringType => s"$vector.putByteArray($rowId, $value.getBytes());"
475474
case _ =>
476475
throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType")
477476
}
@@ -482,37 +481,36 @@ class CodegenContext {
482481
* that could potentially be nullable.
483482
*/
484483
def updateColumn(
485-
batch: String,
486-
row: String,
484+
vector: String,
485+
rowId: String,
487486
dataType: DataType,
488-
ordinal: Int,
489487
ev: ExprCode,
490488
nullable: Boolean): String = {
491489
if (nullable) {
492490
s"""
493491
if (!${ev.isNull}) {
494-
${setValue(batch, row, dataType, ordinal, ev.value)}
492+
${setValue(vector, rowId, dataType, ev.value)}
495493
} else {
496-
$batch.column($ordinal).putNull($row);
494+
$vector.putNull($rowId);
497495
}
498496
"""
499497
} else {
500-
s"""${setValue(batch, row, dataType, ordinal, ev.value)};"""
498+
s"""${setValue(vector, rowId, dataType, ev.value)};"""
501499
}
502500
}
503501

504502
/**
505503
* Returns the specialized code to access a value from a column vector for a given `DataType`.
506504
*/
507-
def getValue(batch: String, row: String, dataType: DataType, ordinal: Int): String = {
505+
def getValue(vector: String, rowId: String, dataType: DataType): String = {
508506
val jt = javaType(dataType)
509507
dataType match {
510508
case _ if isPrimitiveType(jt) =>
511-
s"$batch.column($ordinal).get${primitiveTypeName(jt)}($row)"
509+
s"$vector.get${primitiveTypeName(jt)}($rowId)"
512510
case t: DecimalType =>
513-
s"$batch.column($ordinal).getDecimal($row, ${t.precision}, ${t.scale})"
511+
s"$vector.getDecimal($rowId, ${t.precision}, ${t.scale})"
514512
case StringType =>
515-
s"$batch.column($ordinal).getUTF8String($row)"
513+
s"$vector.getUTF8String($rowId)"
516514
case _ =>
517515
throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType")
518516
}

sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
import org.apache.spark.sql.catalyst.util.DateTimeUtils;
3232
import org.apache.spark.sql.execution.vectorized.ColumnVector;
33+
import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
3334
import org.apache.spark.sql.types.DataTypes;
3435
import org.apache.spark.sql.types.DecimalType;
3536

@@ -135,9 +136,9 @@ private boolean next() throws IOException {
135136
/**
136137
* Reads `total` values from this columnReader into column.
137138
*/
138-
void readBatch(int total, ColumnVector column) throws IOException {
139+
void readBatch(int total, WritableColumnVector column) throws IOException {
139140
int rowId = 0;
140-
ColumnVector dictionaryIds = null;
141+
WritableColumnVector dictionaryIds = null;
141142
if (dictionary != null) {
142143
// SPARK-16334: We only maintain a single dictionary per row batch, so that it can be used to
143144
// decode all previous dictionary encoded pages if we ever encounter a non-dictionary encoded
@@ -219,8 +220,11 @@ void readBatch(int total, ColumnVector column) throws IOException {
219220
/**
220221
* Reads `num` values into column, decoding the values from `dictionaryIds` and `dictionary`.
221222
*/
222-
private void decodeDictionaryIds(int rowId, int num, ColumnVector column,
223-
ColumnVector dictionaryIds) {
223+
private void decodeDictionaryIds(
224+
int rowId,
225+
int num,
226+
WritableColumnVector column,
227+
ColumnVector dictionaryIds) {
224228
switch (descriptor.getType()) {
225229
case INT32:
226230
if (column.dataType() == DataTypes.IntegerType ||
@@ -346,13 +350,13 @@ private void decodeDictionaryIds(int rowId, int num, ColumnVector column,
346350
* is guaranteed that num is smaller than the number of values left in the current page.
347351
*/
348352

349-
private void readBooleanBatch(int rowId, int num, ColumnVector column) throws IOException {
353+
private void readBooleanBatch(int rowId, int num, WritableColumnVector column) throws IOException {
350354
assert(column.dataType() == DataTypes.BooleanType);
351355
defColumn.readBooleans(
352356
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
353357
}
354358

355-
private void readIntBatch(int rowId, int num, ColumnVector column) throws IOException {
359+
private void readIntBatch(int rowId, int num, WritableColumnVector column) throws IOException {
356360
// This is where we implement support for the valid type conversions.
357361
// TODO: implement remaining type conversions
358362
if (column.dataType() == DataTypes.IntegerType || column.dataType() == DataTypes.DateType ||
@@ -370,7 +374,7 @@ private void readIntBatch(int rowId, int num, ColumnVector column) throws IOExce
370374
}
371375
}
372376

373-
private void readLongBatch(int rowId, int num, ColumnVector column) throws IOException {
377+
private void readLongBatch(int rowId, int num, WritableColumnVector column) throws IOException {
374378
// This is where we implement support for the valid type conversions.
375379
if (column.dataType() == DataTypes.LongType ||
376380
DecimalType.is64BitDecimalType(column.dataType())) {
@@ -389,7 +393,7 @@ private void readLongBatch(int rowId, int num, ColumnVector column) throws IOExc
389393
}
390394
}
391395

392-
private void readFloatBatch(int rowId, int num, ColumnVector column) throws IOException {
396+
private void readFloatBatch(int rowId, int num, WritableColumnVector column) throws IOException {
393397
// This is where we implement support for the valid type conversions.
394398
// TODO: support implicit cast to double?
395399
if (column.dataType() == DataTypes.FloatType) {
@@ -400,7 +404,7 @@ private void readFloatBatch(int rowId, int num, ColumnVector column) throws IOEx
400404
}
401405
}
402406

403-
private void readDoubleBatch(int rowId, int num, ColumnVector column) throws IOException {
407+
private void readDoubleBatch(int rowId, int num, WritableColumnVector column) throws IOException {
404408
// This is where we implement support for the valid type conversions.
405409
// TODO: implement remaining type conversions
406410
if (column.dataType() == DataTypes.DoubleType) {
@@ -411,7 +415,7 @@ private void readDoubleBatch(int rowId, int num, ColumnVector column) throws IOE
411415
}
412416
}
413417

414-
private void readBinaryBatch(int rowId, int num, ColumnVector column) throws IOException {
418+
private void readBinaryBatch(int rowId, int num, WritableColumnVector column) throws IOException {
415419
// This is where we implement support for the valid type conversions.
416420
// TODO: implement remaining type conversions
417421
VectorizedValuesReader data = (VectorizedValuesReader) dataColumn;
@@ -432,8 +436,11 @@ private void readBinaryBatch(int rowId, int num, ColumnVector column) throws IOE
432436
}
433437
}
434438

435-
private void readFixedLenByteArrayBatch(int rowId, int num,
436-
ColumnVector column, int arrayLen) throws IOException {
439+
private void readFixedLenByteArrayBatch(
440+
int rowId,
441+
int num,
442+
WritableColumnVector column,
443+
int arrayLen) throws IOException {
437444
VectorizedValuesReader data = (VectorizedValuesReader) dataColumn;
438445
// This is where we implement support for the valid type conversions.
439446
// TODO: implement remaining type conversions

sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
import org.apache.spark.sql.catalyst.InternalRow;
3232
import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils;
3333
import org.apache.spark.sql.execution.vectorized.ColumnarBatch;
34+
import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
35+
import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector;
36+
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector;
3437
import org.apache.spark.sql.types.StructField;
3538
import org.apache.spark.sql.types.StructType;
3639

@@ -90,6 +93,8 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa
9093
*/
9194
private ColumnarBatch columnarBatch;
9295

96+
private WritableColumnVector[] columnVectors;
97+
9398
/**
9499
* If true, this class returns batches instead of rows.
95100
*/
@@ -172,20 +177,26 @@ public void initBatch(MemoryMode memMode, StructType partitionColumns,
172177
}
173178
}
174179

175-
columnarBatch = ColumnarBatch.allocate(batchSchema, memMode);
180+
int capacity = ColumnarBatch.DEFAULT_BATCH_SIZE;
181+
if (memMode == MemoryMode.OFF_HEAP) {
182+
columnVectors = OffHeapColumnVector.allocateColumns(capacity, batchSchema);
183+
} else {
184+
columnVectors = OnHeapColumnVector.allocateColumns(capacity, batchSchema);
185+
}
186+
columnarBatch = new ColumnarBatch(batchSchema, columnVectors, capacity);
176187
if (partitionColumns != null) {
177188
int partitionIdx = sparkSchema.fields().length;
178189
for (int i = 0; i < partitionColumns.fields().length; i++) {
179-
ColumnVectorUtils.populate(columnarBatch.column(i + partitionIdx), partitionValues, i);
180-
columnarBatch.column(i + partitionIdx).setIsConstant();
190+
ColumnVectorUtils.populate(columnVectors[i + partitionIdx], partitionValues, i);
191+
columnVectors[i + partitionIdx].setIsConstant();
181192
}
182193
}
183194

184195
// Initialize missing columns with nulls.
185196
for (int i = 0; i < missingColumns.length; i++) {
186197
if (missingColumns[i]) {
187-
columnarBatch.column(i).putNulls(0, columnarBatch.capacity());
188-
columnarBatch.column(i).setIsConstant();
198+
columnVectors[i].putNulls(0, columnarBatch.capacity());
199+
columnVectors[i].setIsConstant();
189200
}
190201
}
191202
}
@@ -226,7 +237,7 @@ public boolean nextBatch() throws IOException {
226237
int num = (int) Math.min((long) columnarBatch.capacity(), totalCountLoadedSoFar - rowsReturned);
227238
for (int i = 0; i < columnReaders.length; ++i) {
228239
if (columnReaders[i] == null) continue;
229-
columnReaders[i].readBatch(num, columnarBatch.column(i));
240+
columnReaders[i].readBatch(num, columnVectors[i]);
230241
}
231242
rowsReturned += num;
232243
columnarBatch.setNumRows(num);

sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import java.nio.ByteBuffer;
2121
import java.nio.ByteOrder;
2222

23-
import org.apache.spark.sql.execution.vectorized.ColumnVector;
23+
import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
2424
import org.apache.spark.unsafe.Platform;
2525

2626
import org.apache.parquet.column.values.ValuesReader;
@@ -56,39 +56,39 @@ public void skip() {
5656
}
5757

5858
@Override
59-
public final void readBooleans(int total, ColumnVector c, int rowId) {
59+
public final void readBooleans(int total, WritableColumnVector c, int rowId) {
6060
// TODO: properly vectorize this
6161
for (int i = 0; i < total; i++) {
6262
c.putBoolean(rowId + i, readBoolean());
6363
}
6464
}
6565

6666
@Override
67-
public final void readIntegers(int total, ColumnVector c, int rowId) {
67+
public final void readIntegers(int total, WritableColumnVector c, int rowId) {
6868
c.putIntsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET);
6969
offset += 4 * total;
7070
}
7171

7272
@Override
73-
public final void readLongs(int total, ColumnVector c, int rowId) {
73+
public final void readLongs(int total, WritableColumnVector c, int rowId) {
7474
c.putLongsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET);
7575
offset += 8 * total;
7676
}
7777

7878
@Override
79-
public final void readFloats(int total, ColumnVector c, int rowId) {
79+
public final void readFloats(int total, WritableColumnVector c, int rowId) {
8080
c.putFloats(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET);
8181
offset += 4 * total;
8282
}
8383

8484
@Override
85-
public final void readDoubles(int total, ColumnVector c, int rowId) {
85+
public final void readDoubles(int total, WritableColumnVector c, int rowId) {
8686
c.putDoubles(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET);
8787
offset += 8 * total;
8888
}
8989

9090
@Override
91-
public final void readBytes(int total, ColumnVector c, int rowId) {
91+
public final void readBytes(int total, WritableColumnVector c, int rowId) {
9292
for (int i = 0; i < total; i++) {
9393
// Bytes are stored as a 4-byte little endian int. Just read the first byte.
9494
// TODO: consider pushing this in ColumnVector by adding a readBytes with a stride.
@@ -159,7 +159,7 @@ public final double readDouble() {
159159
}
160160

161161
@Override
162-
public final void readBinary(int total, ColumnVector v, int rowId) {
162+
public final void readBinary(int total, WritableColumnVector v, int rowId) {
163163
for (int i = 0; i < total; i++) {
164164
int len = readInteger();
165165
int start = offset;

0 commit comments

Comments
 (0)