Skip to content

[SPARK-52495][SQL] Allow including partition columns in the single variant column #51206

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -6797,6 +6797,13 @@
],
"sqlState" : "22023"
},
"VARIANT_DATA_SCHEMA_CONFLICT_WITH_PARTITION_SCHEMA" : {
"message" : [
"The data schema <dataSchema> conflicts with the partition schema <partitionSchema> in the variant column.",
"Please enable \"spark.sql.variant.allowDuplicateKeys\" to resolve the conflicts."
],
"sqlState" : "0A000"
},
"VIEW_ALREADY_EXISTS" : {
"message" : [
"Cannot create view <relationName> because it already exists.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,108 @@ private void buildJson(JsonParser parser) throws IOException {
}
}

/**
* Parse a JSON parser as a Variant value and merge partition columns into the
* result.
*
* @param parser The JSON parser to parse
* @param allowDuplicateKeys Whether to allow duplicate keys
* @param partitionColumns Array of partition column names, can be null
* @param partitionValues Array of partition values corresponding to
* partitionColumns, can be null
* @return A Variant containing the merged JSON and partition data
* @throws IOException if any JSON parsing error happens
*/
public static Variant parseJsonWithPartitionValues(
JsonParser parser,
boolean allowDuplicateKeys,
String[] partitionColumns,
Object[] partitionValues) throws IOException {
VariantBuilder builder = new VariantBuilder(allowDuplicateKeys);

// Start building an object
ArrayList<FieldEntry> fields = new ArrayList<>();
int start = builder.getWritePos();

// First, parse the JSON content and collect its fields
while (parser.nextToken() != JsonToken.END_OBJECT) {
String key = parser.currentName();
parser.nextToken();
int id = builder.addKey(key);
fields.add(new FieldEntry(key, id, builder.getWritePos() - start));
builder.buildJson(parser);
}

// Then add partition columns if they exist
if (partitionColumns != null && partitionValues != null) {
appendPartitionValues(builder, start, fields, partitionColumns, partitionValues);
}

// Finish writing the object with all fields
builder.finishWritingObject(start, fields);
return builder.result();
}

/**
* Append partition values to the top-level variant object.
* @param builder The variant builder
* @param start The starting position of the top-level variant object in the builder
* @param fields The current data fields of the top-level variant object
* @param partitionColumns Array of partition column names, can be null
* @param partitionValues Array of partition values corresponding to the partitionColumns
*/
public static void appendPartitionValues(
VariantBuilder builder,
int start,
ArrayList<FieldEntry> fields,
String[] partitionColumns,
Object[] partitionValues) {
for (int i = 0; i < partitionColumns.length; i++) {
if (partitionValues[i] != null) {
int id = builder.addKey(partitionColumns[i]);
fields.add(new FieldEntry(partitionColumns[i], id, builder.getWritePos() - start));
builder.appendPartitionValue(partitionValues[i]);
}
}
}

/**
* Append a partition value to the variant builder. This method handles
* different data types that partition values might have.
*/
public void appendPartitionValue(Object value) {
if (value == null) {
appendNull();
} else if (value instanceof String) {
appendString((String) value);
} else if (value instanceof Integer) {
appendLong((Integer) value);
} else if (value instanceof Long) {
appendLong((Long) value);
} else if (value instanceof Double) {
appendDouble((Double) value);
} else if (value instanceof Float) {
appendFloat((Float) value);
} else if (value instanceof Boolean) {
appendBoolean((Boolean) value);
} else if (value instanceof java.math.BigDecimal) {
appendDecimal((java.math.BigDecimal) value);
} else if (value instanceof java.sql.Date) {
// Convert java.sql.Date to days since epoch
long millis = ((java.sql.Date) value).getTime();
int days = (int) (millis / (24 * 60 * 60 * 1000L));
appendDate(days);
} else if (value instanceof java.sql.Timestamp) {
// Convert java.sql.Timestamp to microseconds since epoch
long millis = ((java.sql.Timestamp) value).getTime();
long micros = millis * 1000L;
appendTimestamp(micros);
} else {
// For any other type, convert to string
appendString(value.toString());
}
}

// Choose the smallest unsigned integer type that can store `value`. It must be within
// `[0, SIZE_LIMIT]`.
private int getIntegerSize(int value) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ class UnivocityParser(
dataSchema: StructType,
requiredSchema: StructType,
val options: CSVOptions,
filters: Seq[Filter]) extends Logging {
filters: Seq[Filter],
partitionSchema: StructType = StructType(Seq.empty),
partitionValues: InternalRow = InternalRow.empty) extends Logging {
require(requiredSchema.toSet.subsetOf(dataSchema.toSet),
s"requiredSchema (${requiredSchema.catalogString}) should be the subset of " +
s"dataSchema (${dataSchema.catalogString}).")
Expand Down Expand Up @@ -344,6 +346,8 @@ class UnivocityParser(
(tokens: Array[String], index: Int) => tokens(tokenIndexArr(index))
}

private val variantAllowDuplicateKeys = SQLConf.get.getConf(SQLConf.VARIANT_ALLOW_DUPLICATE_KEYS)

/**
* The entire line of CSV data is collected into a single variant object. When `headerColumnNames`
* is defined, the field names will be extracted from it. Otherwise, the field names will have a
Expand All @@ -360,7 +364,7 @@ class UnivocityParser(
val extra = numFields - singleVariantFieldConverters.length
singleVariantFieldConverters.appendAll(Array.fill(extra)(new VariantValueConverter))
}
val builder = new VariantBuilder(false)
val builder = new VariantBuilder(variantAllowDuplicateKeys)
val start = builder.getWritePos
val fields = new java.util.ArrayList[VariantBuilder.FieldEntry](numFields)
for (i <- 0 until numFields) {
Expand All @@ -369,6 +373,26 @@ class UnivocityParser(
fields.add(new VariantBuilder.FieldEntry(key, id, builder.getWritePos - start))
singleVariantFieldConverters(i).convertInput(builder, tokens(i))
}

// Add the partition columns to the variant object
if (partitionSchema.nonEmpty && SQLConf.get.includePartitionColumnsInSingleVariantColumn) {
val partitionColumnNames = partitionSchema.fields.map(_.name)
val partitionColumnValues = (0 until partitionValues.numFields).map { i =>
if (!partitionValues.isNullAt(i)) {
partitionValues.get(i, partitionSchema.fields(i).dataType)
} else {
null
}
}.toArray
VariantBuilder.appendPartitionValues(
builder,
start,
fields,
partitionColumnNames,
partitionColumnValues
)
}

builder.finishWritingObject(start, fields)
val v = builder.result()
row(0) = new VariantVal(v.getValue, v.getMetadata)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ class JacksonParser(
schema: DataType,
val options: JSONOptions,
allowArrayAsStructs: Boolean,
filters: Seq[Filter] = Seq.empty) extends Logging {
filters: Seq[Filter] = Seq.empty,
partitionSchema: StructType = StructType(Seq.empty),
partitionValues: InternalRow = InternalRow.empty) extends Logging {

import JacksonUtils._
import com.fasterxml.jackson.core.JsonToken._
Expand Down Expand Up @@ -112,7 +114,11 @@ class JacksonParser(
Some(InternalRow(parseVariant(parser)))
}
case _: StructType if options.singleVariantColumn.isDefined => (parser: JsonParser) => {
Some(InternalRow(parseVariant(parser)))
if (partitionSchema.nonEmpty && SQLConf.get.includePartitionColumnsInSingleVariantColumn) {
Some(InternalRow(parseVariantWithPartitionValues(parser)))
} else {
Some(InternalRow(parseVariant(parser)))
}
}
case st: StructType => makeStructRootConverter(st)
case mt: MapType => makeMapRootConverter(mt)
Expand All @@ -138,6 +144,26 @@ class JacksonParser(
}
}

// When singleVariantColumn is defined and we have partition columns, build the variant
// value with the partition fields
protected final def parseVariantWithPartitionValues(parser: JsonParser): VariantVal = {
val partitionColumnNames = partitionSchema.fields.map(_.name)
val partitionColumnValues = (0 until partitionValues.numFields).map { i =>
if (!partitionValues.isNullAt(i)) {
partitionValues.get(i, partitionSchema.fields(i).dataType)
} else {
null
}
}.toArray
val v = VariantBuilder.parseJsonWithPartitionValues(
parser,
variantAllowDuplicateKeys,
partitionColumnNames,
partitionColumnValues
)
new VariantVal(v.getValue, v.getMetadata)
}

private def makeStructRootConverter(st: StructType): JsonParser => Iterable[InternalRow] = {
val elementConverter = makeConverter(st)
val fieldConverters = st.map(_.dataType).map(makeConverter).toArray
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ import org.apache.spark.unsafe.types.{UTF8String, VariantVal}

class StaxXmlParser(
schema: StructType,
val options: XmlOptions) extends Logging {
val options: XmlOptions,
partitionSchema: StructType = StructType(Seq.empty),
partitionValues: InternalRow = InternalRow.empty) extends Logging {

private lazy val timestampFormatter = TimestampFormatter(
options.timestampFormatInRead,
Expand Down Expand Up @@ -146,7 +148,7 @@ class StaxXmlParser(
options.singleVariantColumn match {
case Some(_) =>
// If the singleVariantColumn is specified, parse the entire xml string as a Variant
val v = StaxXmlParser.parseVariant(xml, options)
val v = StaxXmlParser.parseVariant(xml, options, partitionSchema, partitionValues)
Some(InternalRow(v))
case _ =>
// Otherwise, parse the xml string as Structs
Expand Down Expand Up @@ -928,10 +930,14 @@ object StaxXmlParser {
/**
* Parse the input XML string as a Variant value
*/
def parseVariant(xml: String, options: XmlOptions): VariantVal = {
def parseVariant(
xml: String,
options: XmlOptions,
partitionSchema: StructType = StructType(Seq.empty),
partitionValues: InternalRow = InternalRow.empty): VariantVal = {
val parser = StaxXmlParserUtils.filteredReader(xml)
val rootAttributes = StaxXmlParserUtils.gatherRootAttributes(parser)
val v = convertVariant(parser, rootAttributes, options)
val v = convertVariant(parser, rootAttributes, options, partitionSchema, partitionValues)
parser.close()
v
}
Expand All @@ -944,22 +950,30 @@ object StaxXmlParser {
* @param parser The XML event stream reader positioned after the start element
* @param attributes The attributes of the current XML element to be included in the Variant
* @param options Configuration options that control how XML is parsed into Variants
* @param partitionSchema The schema of the partition columns, if any
* @param partitionValues The values of the partition columns, if any
* @return A Variant representing the XML element with its attributes and child content
*/
def convertVariant(
parser: XMLEventReader,
attributes: Array[Attribute],
options: XmlOptions): VariantVal = {
val v = convertVariantInternal(parser, attributes, options)
options: XmlOptions,
partitionSchema: StructType = StructType(Seq.empty),
partitionValues: InternalRow = InternalRow.empty): VariantVal = {
val v = convertVariantInternal(parser, attributes, options, partitionSchema, partitionValues)
new VariantVal(v.getValue, v.getMetadata)
}

private def convertVariantInternal(
parser: XMLEventReader,
attributes: Array[Attribute],
options: XmlOptions): Variant = {
options: XmlOptions,
partitionSchema: StructType = StructType(Seq.empty),
partitionValues: InternalRow = InternalRow.empty): Variant = {
val variantAllowDuplicateKeys = SQLConf.get.getConf(SQLConf.VARIANT_ALLOW_DUPLICATE_KEYS)

// The variant builder for the root startElement
val rootBuilder = new VariantBuilder(false)
val rootBuilder = new VariantBuilder(variantAllowDuplicateKeys)
val start = rootBuilder.getWritePos

// Map to store the variant values of all child fields
Expand All @@ -977,7 +991,7 @@ object StaxXmlParser {
// Handle attributes first
StaxXmlParserUtils.convertAttributesToValuesMap(attributes, options).foreach {
case (f, v) =>
val builder = new VariantBuilder(false)
val builder = new VariantBuilder(variantAllowDuplicateKeys)
appendXMLCharacterToVariant(builder, v, options)
val variants = fieldToVariants.getOrElseUpdate(f, new java.util.ArrayList[Variant]())
variants.add(builder.result())
Expand All @@ -997,7 +1011,7 @@ object StaxXmlParser {
case c: Characters if !c.isWhiteSpace =>
// Treat the character as a value tag field, where we use the [[XMLOptions.valueTag]] as
// the field key
val builder = new VariantBuilder(false)
val builder = new VariantBuilder(variantAllowDuplicateKeys)
appendXMLCharacterToVariant(builder, c.getData, options)
val variants = fieldToVariants.getOrElseUpdate(
options.valueTag,
Expand All @@ -1006,6 +1020,27 @@ object StaxXmlParser {
variants.add(builder.result())

case _: EndElement =>
// In the end, add partition values if they exist
if (partitionSchema.nonEmpty &&
SQLConf.get.includePartitionColumnsInSingleVariantColumn) {
partitionSchema.fields.zipWithIndex.foreach {
case (field, i) =>
val value = partitionValues.get(i, field.dataType)
if (value != null) {
val builder = new VariantBuilder(variantAllowDuplicateKeys)
appendXMLCharacterToVariant(builder, value.toString, options)
val variants = fieldToVariants.getOrElseUpdate(
field.name,
new java.util.ArrayList[Variant]()
)
// If the partition schema overlaps with the data schema, we **OVERRIDE** the
// data with the partition values.
Comment on lines +1036 to +1037
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the behavior during overlap without singleVariantColumn? Can the variant field be converted to an array in this case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned above, the partition values will overwrite the data values

variants.clear()
variants.add(builder.result())
}
}
}

if (fieldToVariants.nonEmpty) {
val onlyValueTagField = fieldToVariants.keySet.forall(_ == options.valueTag)
if (onlyValueTagField) {
Expand Down Expand Up @@ -1036,6 +1071,8 @@ object StaxXmlParser {
*
* @param builder The variant builder to write to
* @param fieldToVariants A map of field names to their corresponding variant values of the object
* The map is sorted by field names, and the ordering is based on the case
* sensitivity.
*/
private def writeVariantObject(
builder: VariantBuilder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ object XmlOptions extends DataSourceOptions {
val INDENT = newOption("indent")
val PREFERS_DECIMAL = newOption("prefersDecimal")
val VALIDATE_NAME = newOption("validateName")
val SINGLE_VARIANT_COLUMN = newOption("singleVariantColumn")
val SINGLE_VARIANT_COLUMN = newOption(DataSourceOptions.SINGLE_VARIANT_COLUMN)
// Options with alternative
val ENCODING = "encoding"
val CHARSET = "charset"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3071,4 +3071,16 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE
)
)
}

def variantDataSchemaConflictWithPartitionSchema(
dataSchema: StructType,
partitionSchema: StructType): SparkRuntimeException = {
new SparkRuntimeException(
errorClass = "VARIANT_DATA_SCHEMA_CONFLICT_WITH_PARTITION_SCHEMA",
messageParameters = Map(
"dataSchema" -> toSQLType(dataSchema),
"partitionSchema" -> toSQLType(partitionSchema)
)
)
}
}
Loading