Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -143,4 +143,5 @@ class SchemaPruningSuite extends SparkFunSuite with SQLHelper {
val prunedSchema = SchemaPruning.pruneSchema(schema, rootFields)
assert(prunedSchema.head.metadata.getString("foo") == "bar")
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,11 @@ trait FileFormat {
*
* NOTE: Extractors are lazy, invoked only if the query actually selects their column at runtime.
*
* Return types: extractors may return either a raw value (which is converted to the column's
* catalyst form via [[Literal.create]]) or an already-built [[Literal]] (whose `.value` is
* used directly). For complex types ([[ArrayType]] / [[MapType]] / [[StructType]]), return the
* value in catalyst form ([[ArrayData]] / [[MapData]] / [[InternalRow]]).
*
* See also [[FileFormat.getFileConstantMetadataColumnValue]].
*/
def fileConstantMetadataExtractors: Map[String, PartitionedFile => Any] =
Expand Down Expand Up @@ -273,6 +278,9 @@ object FileFormat {
FileSourceConstantMetadataStructField(FILE_BLOCK_LENGTH, LongType, nullable = false),
FileSourceConstantMetadataStructField(FILE_MODIFICATION_TIME, TimestampType, nullable = false))

private val BASE_METADATA_NAME_TO_TYPE: Map[String, DataType] =
BASE_METADATA_FIELDS.map(f => f.name -> f.dataType).toMap

/**
* All [[BASE_METADATA_FIELDS]] require custom extractors because they are derived directly from
* fields of the [[PartitionedFile]], and do have entries in the file's metadata map.
Expand All @@ -299,16 +307,26 @@ object FileFormat {
* If an extractor is available, apply it. Otherwise, look up the column's name in the file's
* column value map and return the result (or null, if not found).
*
* Raw values (including null) are automatically converted to literals as a courtesy.
* Raw values (including null) are converted via [[Literal.create]], which accepts catalyst-form
* values directly. This lets a complex constant metadata column return an [[ArrayData]] /
* [[MapData]] / [[InternalRow]] whose element types only the caller knows. If the extractor
* returns an already-built [[Literal]] (allowed by the extractor contract), its value is
* unwrapped before delegating to [[Literal.create]] so the dataType validation in the
* case-class constructor is checked against the raw value.
*/
def getFileConstantMetadataColumnValue(
name: String,
file: PartitionedFile,
metadataExtractors: Map[String, PartitionedFile => Any]): Literal = {
metadataExtractors: Map[String, PartitionedFile => Any],
dataType: DataType): Literal = {
val extractor = metadataExtractors.getOrElse(name,
{ pf: PartitionedFile => pf.otherConstantMetadataColumnValues.get(name).orNull }
)
Literal(extractor.apply(file))
val rawValue = extractor.apply(file) match {
case lit: Literal => lit.value
case other => other
}
Literal.create(rawValue, dataType)
}

// create an internal row given required metadata fields and file information
Expand All @@ -334,17 +352,22 @@ object FileFormat {
modificationTime = fileModificationTime,
fileSize = fileSize,
otherConstantMetadataColumnValues = Map.empty)
updateMetadataInternalRow(new GenericInternalRow(fieldNames.length), fieldNames, pf, extractors)
val fieldDataTypes = fieldNames.map(BASE_METADATA_NAME_TO_TYPE)
updateMetadataInternalRow(
new GenericInternalRow(fieldNames.length), fieldNames, pf, extractors, fieldDataTypes)
}

// update an internal row given required metadata fields and file information
def updateMetadataInternalRow(
row: InternalRow,
fieldNames: Seq[String],
file: PartitionedFile,
metadataExtractors: Map[String, PartitionedFile => Any]): InternalRow = {
metadataExtractors: Map[String, PartitionedFile => Any],
fieldDataTypes: Seq[DataType]): InternalRow = {
require(fieldDataTypes.length == fieldNames.length,
s"fieldDataTypes length ${fieldDataTypes.length} != fieldNames length ${fieldNames.length}")
fieldNames.zipWithIndex.foreach { case (name, i) =>
getFileConstantMetadataColumnValue(name, file, metadataExtractors) match {
getFileConstantMetadataColumnValue(name, file, metadataExtractors, fieldDataTypes(i)) match {
case Literal(null, _) => row.setNullAt(i)
case literal => row.update(i, literal.value)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ import org.apache.spark.rdd.{InputFileBlockHolder, RDD}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.{FileSourceOptions, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericInternalRow, JoinedRow, Literal, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.types.PhysicalDataType
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.execution.datasources.FileFormat._
import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2
Expand Down Expand Up @@ -173,7 +172,8 @@ class FileScanRDD(
private def updateMetadataRow(): Unit =
if (metadataColumns.nonEmpty && currentFile != null) {
updateMetadataInternalRow(
metadataRow, metadataColumns.map(_.name), currentFile, metadataExtractors)
metadataRow, metadataColumns.map(_.name), currentFile, metadataExtractors,
metadataColumns.map(_.dataType))
}

/**
Expand All @@ -183,11 +183,11 @@ class FileScanRDD(
val tmpRow = new GenericInternalRow(1)
metadataColumns.map { attr =>
// Populate each metadata column by passing the resulting value through `tmpRow`.
getFileConstantMetadataColumnValue(attr.name, currentFile, metadataExtractors) match {
getFileConstantMetadataColumnValue(
attr.name, currentFile, metadataExtractors, attr.dataType) match {
case Literal(null, _) =>
tmpRow.setNullAt(0)
case literal =>
require(PhysicalDataType(attr.dataType) == PhysicalDataType(literal.dataType))
tmpRow.update(0, literal.value)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ object SchemaPruning extends Rule[LogicalPlan] {
val metadataSchema =
relation.output.collect { case FileSourceMetadataAttribute(attr) => attr }.toStructType
val prunedMetadataSchema = if (metadataSchema.nonEmpty) {
pruneSchema(metadataSchema, requestedRootFields)
pruneMetadataSchema(metadataSchema, requestedRootFields)
} else {
metadataSchema
}
Expand Down Expand Up @@ -114,6 +114,44 @@ object SchemaPruning extends Rule[LogicalPlan] {
fsRelation.fileFormat.isInstanceOf[ParquetFileFormat] ||
fsRelation.fileFormat.isInstanceOf[OrcFileFormat])

/**
* Prunes a file-source metadata schema (one `StructType` containing each
* `FileSourceMetadataAttribute`). Unlike pruning a data file schema, this only prunes
* unused sibling sub-attributes (each is its own per-field extractor); kept sub-attributes'
* data types are preserved verbatim because the extractor produces a complete catalyst
* value, and shaving fields out would shift positions in that value.
*/
private def pruneMetadataSchema(
metadataSchema: StructType,
requestedRootFields: Seq[RootField]): StructType = {
val resolver = conf.resolver
StructType(metadataSchema.map { topField =>
topField.dataType match {
case innerStruct: StructType =>
// Collect the requested sub-attribute names for this metadata attribute from the
// root field tree. Anything below those sub-attributes (e.g. nested struct/array
// element fields) is ignored, since extractor outputs aren't pruned.
val requestedSubNames: Set[String] = requestedRootFields.collect {
case rf if resolver(rf.field.name, topField.name) =>
rf.field.dataType match {
case rs: StructType => rs.fieldNames.toSet
case _ => Set.empty[String]
}
}.flatten.toSet
val keptSubFields = innerStruct.fields.filter { sub =>
requestedSubNames.exists(name => resolver(name, sub.name))
}
if (keptSubFields.length == innerStruct.fields.length) {
// Nothing to prune for this attribute; keep the original.
topField
} else {
topField.copy(dataType = StructType(keptSubFields))
}
case _ => topField
}
})
}

/**
* Normalizes the names of the attribute references in the given expressions to reflect
* the names in the given logical relation. This makes it possible to compare attributes and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@ import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, FileSourceConstantMetadataStructField, FileSourceGeneratedMetadataStructField, Literal}
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.classic.{DataFrame, Dataset}
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.functions.{col, lit, when}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructField, StructType}
import org.apache.spark.sql.types.{ArrayType, IntegerType, LongType, StringType, StructField, StructType}
import org.apache.spark.unsafe.types.UTF8String

/** Verifies the ability for a FileFormat to define custom metadata types */
Expand Down Expand Up @@ -336,6 +338,59 @@ class FileSourceCustomMetadataStructSuite extends SharedSparkSession {
}
}

test("[SPARK-56931] complex constant metadata fields (array<struct>, struct) on row path") {
withTempData("parquet", FILE_SCHEMA) { (_, f0, f1) =>
val permElement = StructType(Seq(
StructField("email", StringType),
StructField("role", StringType)))
val locationStruct = StructType(Seq(
StructField("country", StringType),
StructField("city", StringType)))
val complexFields = Seq(
FileSourceConstantMetadataStructField("perms", ArrayType(permElement, containsNull = true)),
FileSourceConstantMetadataStructField("location", locationStruct))
val format = new TestFileFormat(complexFields)

// Build per-file values in catalyst form.
def perms(email: String, role: String): InternalRow =
InternalRow(UTF8String.fromString(email), UTF8String.fromString(role))
def loc(country: String, city: String): InternalRow =
InternalRow(UTF8String.fromString(country), UTF8String.fromString(city))

val files = Seq(
FileStatusWithMetadata(f0, Map(
"perms" -> new GenericArrayData(Array[Any](perms("a@x", "r"), perms("b@x", "w"))),
"location" -> loc("US", "SFO"))),
FileStatusWithMetadata(f1, Map(
"perms" -> new GenericArrayData(Array[Any](perms("c@x", "r"), perms("d@x", "o"))),
"location" -> loc("CA", "YYZ"))))
val df = createDF(format, files)

// Force the row materialization path (Batched=false) so we exercise the
// updateMetadataInternalRow -> getFileConstantMetadataColumnValue -> Literal.create
// change end-to-end. The query touches a subset of each subfield to also exercise
// the metadata-schema pruning preservation rule.
withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
// Query only the non-first sub-fields of each complex column. A buggy implementation
// that pruned the kept sub-attribute's inner schema down to only the queried fields
// would surface here: the extractor still produces `InternalRow("US", "SFO")` /
// `InternalRow(email, role)`, and reading the kept field at the pruned (now zero)
// ordinal would yield the index-0 value instead of the index-1 value.
checkAnswer(
df.selectExpr(
"fileNum",
"_metadata.perms[1].role AS second_role",
"_metadata.location.city AS city",
"size(_metadata.perms) AS perms_count"),
Seq(
Row(0, "w", "SFO", 2),
Row(0, "w", "SFO", 2),
Row(1, "o", "YYZ", 2),
Row(1, "o", "YYZ", 2)))
}
}
}

test("generated columns and extractors take precedence over metadata map values") {
withTempData("parquet", FILE_SCHEMA) { (_, f0, f1) =>
import FileFormat.{FILE_NAME, FILE_SIZE}
Expand Down