Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/trigger_files/beam_PostCommit_SQL.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to run ",
"modification": 1
"modification": 2
}
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ public abstract class RowCoderGenerator {

private static final String CODERS_FIELD_NAME = "FIELD_CODERS";
private static final String POSITIONS_FIELD_NAME = "FIELD_ENCODING_POSITIONS";
private static final String SCHEMA_OPTION_STATIC_ENCODING = "beam:option:row:static_encoding";

static class WithStackTrace<T> {
private final T value;
Expand Down Expand Up @@ -407,8 +408,13 @@ static void encodeDelegate(
checkState(value.getFieldCount() == value.getSchema().getFieldCount());
checkState(encodingPosToIndex.length == value.getFieldCount());

boolean staticEncoding =
value.getSchema().getOptions().getValueOrDefault(SCHEMA_OPTION_STATIC_ENCODING, false);

// Encode the field count. This allows us to handle compatible schema changes.
VAR_INT_CODER.encode(value.getFieldCount(), outputStream);
if (!staticEncoding) {
VAR_INT_CODER.encode(value.getFieldCount(), outputStream);
}

if (hasNullableFields) {
// If the row has null fields, extract the values out once so that both scanNullFields and
Expand All @@ -420,7 +426,9 @@ static void encodeDelegate(
}

// Encode a bitmap for the null fields to save having to encode a bunch of nulls.
NULL_LIST_CODER.encode(scanNullFields(fieldValues, encodingPosToIndex), outputStream);
if (!staticEncoding) {
NULL_LIST_CODER.encode(scanNullFields(fieldValues, encodingPosToIndex), outputStream);
}
for (int encodingPos = 0; encodingPos < fieldValues.length; ++encodingPos) {
@Nullable Object fieldValue = fieldValues[encodingPosToIndex[encodingPos]];
if (fieldValue != null) {
Expand All @@ -430,7 +438,9 @@ static void encodeDelegate(
} else {
// Otherwise, we know all fields are non-null, so the null list is always empty.

NULL_LIST_CODER.encode(EMPTY_BIT_SET, outputStream);
if (!staticEncoding) {
NULL_LIST_CODER.encode(EMPTY_BIT_SET, outputStream);
}
for (int encodingPos = 0; encodingPos < value.getFieldCount(); ++encodingPos) {
@Nullable Object fieldValue = value.getValue(encodingPosToIndex[encodingPos]);
if (fieldValue != null) {
Expand Down Expand Up @@ -511,9 +521,15 @@ public InstrumentedType prepare(InstrumentedType instrumentedType) {
static Row decodeDelegate(
Schema schema, Coder[] coders, int[] encodingPosToIndex, InputStream inputStream)
throws IOException {
int fieldCount = VAR_INT_CODER.decode(inputStream);

BitSet nullFields = NULL_LIST_CODER.decode(inputStream);
int fieldCount;
BitSet nullFields;
if (schema.getOptions().getValueOrDefault(SCHEMA_OPTION_STATIC_ENCODING, false)) {
fieldCount = schema.getFieldCount();
nullFields = new BitSet();
} else {
fieldCount = VAR_INT_CODER.decode(inputStream);
nullFields = NULL_LIST_CODER.decode(inputStream);
}
Object[] fieldValues = new Object[coders.length];
for (int encodingPos = 0; encodingPos < fieldCount; ++encodingPos) {
// In the case of a schema change going backwards, fieldCount might be > coders.length,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -626,4 +626,23 @@ public void testEncodingPositionRemoveFields() throws Exception {
Row decoded = RowCoder.of(schema2).decode(new ByteArrayInputStream(os.toByteArray()));
assertEquals(expected, decoded);
}

@Test
public void testStaticEncoding() throws Exception {
Schema schema =
Schema.builder()
.addInt32Field("f_int32")
.addStringField("f_string")
.setOptions(
Schema.Options.builder()
.setOption("beam:option:row:static_encoding", FieldType.BOOLEAN, true)
.build())
.build();
Row row = Row.withSchema(schema).addValues(42, "hello world!").build();
ByteArrayOutputStream bos = new ByteArrayOutputStream();
RowCoder.of(schema).encode(row, bos);
assertEquals(14, bos.toByteArray().length);

CoderProperties.coderDecodeEncodeEqual(RowCoder.of(schema), row);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,10 @@ public PCollection<Row> expand(PCollectionList<Row> pinput) {
BeamSqlPipelineOptions options =
pinput.getPipeline().getOptions().as(BeamSqlPipelineOptions.class);

String builderString = builder.toBlock().toString();
CalcFn calcFn =
new CalcFn(
builder.toBlock().toString(),
builderString,
outputSchema,
options.getVerifyRowValues(),
getJarPaths(program),
Expand Down Expand Up @@ -502,120 +503,109 @@ FieldAccessDescriptor getFieldAccess() {
@Override
public Expression field(BlockBuilder list, int index, Type storageType) {
this.referencedColumns.add(index);
return getBeamField(list, index, input, inputSchema);
return getBeamField(list, index, input, inputSchema, true);
}

// Read field from Beam Row
private static Expression getBeamField(
BlockBuilder list, int index, Expression input, Schema schema) {
BlockBuilder list, int index, Expression input, Schema schema, boolean useByteString) {
if (index >= schema.getFieldCount() || index < 0) {
throw new IllegalArgumentException("Unable to find value #" + index);
}

final Expression expression = list.append(list.newName("current"), input);

final Field field = schema.getField(index);
final FieldType fieldType = field.getType();
final Expression fieldName = Expressions.constant(field.getName());
Expression value = getBeamField(list, expression, fieldName, fieldType);

return toCalciteValue(value, fieldType, useByteString);
}

private static Expression getBeamField(
BlockBuilder list, Expression expression, Expression fieldName, FieldType fieldType) {
final Expression value;
switch (fieldType.getTypeName()) {
case BYTE:
value = Expressions.call(expression, "getByte", fieldName);
break;
return Expressions.call(expression, "getByte", fieldName);
case INT16:
value = Expressions.call(expression, "getInt16", fieldName);
break;
return Expressions.call(expression, "getInt16", fieldName);
case INT32:
value = Expressions.call(expression, "getInt32", fieldName);
break;
return Expressions.call(expression, "getInt32", fieldName);
case INT64:
value = Expressions.call(expression, "getInt64", fieldName);
break;
return Expressions.call(expression, "getInt64", fieldName);
case DECIMAL:
value = Expressions.call(expression, "getDecimal", fieldName);
break;
return Expressions.call(expression, "getDecimal", fieldName);
case FLOAT:
value = Expressions.call(expression, "getFloat", fieldName);
break;
return Expressions.call(expression, "getFloat", fieldName);
case DOUBLE:
value = Expressions.call(expression, "getDouble", fieldName);
break;
return Expressions.call(expression, "getDouble", fieldName);
case STRING:
value = Expressions.call(expression, "getString", fieldName);
break;
return Expressions.call(expression, "getString", fieldName);
case DATETIME:
value = Expressions.call(expression, "getDateTime", fieldName);
break;
return Expressions.call(expression, "getDateTime", fieldName);
case BOOLEAN:
value = Expressions.call(expression, "getBoolean", fieldName);
break;
return Expressions.call(expression, "getBoolean", fieldName);
case BYTES:
value = Expressions.call(expression, "getBytes", fieldName);
break;
return Expressions.call(expression, "getBytes", fieldName);
case ARRAY:
value = Expressions.call(expression, "getArray", fieldName);
break;
return Expressions.call(expression, "getArray", fieldName);
case MAP:
value = Expressions.call(expression, "getMap", fieldName);
break;
return Expressions.call(expression, "getMap", fieldName);
case ROW:
value = Expressions.call(expression, "getRow", fieldName);
break;
return Expressions.call(expression, "getRow", fieldName);
case ITERABLE:
value = Expressions.call(expression, "getIterable", fieldName);
break;
return Expressions.call(expression, "getIterable", fieldName);
case LOGICAL_TYPE:
String identifier = fieldType.getLogicalType().getIdentifier();
LogicalType logicalType = fieldType.getLogicalType();
String identifier = logicalType.getIdentifier();
if (FixedString.IDENTIFIER.equals(identifier)
|| VariableString.IDENTIFIER.equals(identifier)) {
value = Expressions.call(expression, "getString", fieldName);
return Expressions.call(expression, "getString", fieldName);
} else if (FixedBytes.IDENTIFIER.equals(identifier)
|| VariableBytes.IDENTIFIER.equals(identifier)) {
value = Expressions.call(expression, "getBytes", fieldName);
return Expressions.call(expression, "getBytes", fieldName);
} else if (TimeWithLocalTzType.IDENTIFIER.equals(identifier)) {
value = Expressions.call(expression, "getDateTime", fieldName);
return Expressions.call(expression, "getDateTime", fieldName);
} else if (SqlTypes.DATE.getIdentifier().equals(identifier)) {
value =
Expressions.convert_(
Expressions.call(
expression,
"getLogicalTypeValue",
fieldName,
Expressions.constant(LocalDate.class)),
LocalDate.class);
return Expressions.convert_(
Expressions.call(
expression,
"getLogicalTypeValue",
fieldName,
Expressions.constant(LocalDate.class)),
LocalDate.class);
} else if (SqlTypes.TIME.getIdentifier().equals(identifier)) {
value =
Expressions.convert_(
Expressions.call(
expression,
"getLogicalTypeValue",
fieldName,
Expressions.constant(LocalTime.class)),
LocalTime.class);
return Expressions.convert_(
Expressions.call(
expression,
"getLogicalTypeValue",
fieldName,
Expressions.constant(LocalTime.class)),
LocalTime.class);
} else if (SqlTypes.DATETIME.getIdentifier().equals(identifier)) {
value =
Expressions.convert_(
Expressions.call(
expression,
"getLogicalTypeValue",
fieldName,
Expressions.constant(LocalDateTime.class)),
LocalDateTime.class);
return Expressions.convert_(
Expressions.call(
expression,
"getLogicalTypeValue",
fieldName,
Expressions.constant(LocalDateTime.class)),
LocalDateTime.class);
} else if (FixedPrecisionNumeric.IDENTIFIER.equals(identifier)) {
value = Expressions.call(expression, "getDecimal", fieldName);
return Expressions.call(expression, "getDecimal", fieldName);
} else if (logicalType instanceof PassThroughLogicalType) {
return getBeamField(list, expression, fieldName, logicalType.getBaseType());
} else {
throw new UnsupportedOperationException("Unable to get logical type " + identifier);
}
break;
default:
throw new UnsupportedOperationException("Unable to get " + fieldType.getTypeName());
}

return toCalciteValue(value, fieldType);
}

// Value conversion: Beam => Calcite
private static Expression toCalciteValue(Expression value, FieldType fieldType) {
private static Expression toCalciteValue(
Expression value, FieldType fieldType, boolean useByteString) {
switch (fieldType.getTypeName()) {
case BYTE:
return Expressions.convert_(value, Byte.class);
Expand All @@ -642,7 +632,10 @@ private static Expression toCalciteValue(Expression value, FieldType fieldType)
Expressions.call(Expressions.convert_(value, AbstractInstant.class), "getMillis"));
case BYTES:
return nullOr(
value, Expressions.new_(ByteString.class, Expressions.convert_(value, byte[].class)));
value,
useByteString
? Expressions.new_(ByteString.class, Expressions.convert_(value, byte[].class))
: Expressions.convert_(value, byte[].class));
case ARRAY:
case ITERABLE:
return nullOr(value, toCalciteList(value, fieldType.getCollectionElementType()));
Expand All @@ -651,7 +644,8 @@ private static Expression toCalciteValue(Expression value, FieldType fieldType)
case ROW:
return nullOr(value, toCalciteRow(value, fieldType.getRowSchema()));
case LOGICAL_TYPE:
String identifier = fieldType.getLogicalType().getIdentifier();
LogicalType logicalType = fieldType.getLogicalType();
String identifier = logicalType.getIdentifier();
if (FixedString.IDENTIFIER.equals(identifier)
|| VariableString.IDENTIFIER.equals(identifier)) {
return Expressions.convert_(value, String.class);
Expand Down Expand Up @@ -692,6 +686,8 @@ private static Expression toCalciteValue(Expression value, FieldType fieldType)
return nullOr(value, returnValue);
} else if (FixedPrecisionNumeric.IDENTIFIER.equals(identifier)) {
return Expressions.convert_(value, BigDecimal.class);
} else if (logicalType instanceof PassThroughLogicalType) {
return toCalciteValue(value, logicalType.getBaseType(), useByteString);
} else {
throw new UnsupportedOperationException("Unable to convert logical type " + identifier);
}
Expand All @@ -704,7 +700,7 @@ private static Expression toCalciteList(Expression input, FieldType elementType)
ParameterExpression value = Expressions.parameter(Object.class);

BlockBuilder block = new BlockBuilder();
block.add(toCalciteValue(value, elementType));
block.add(toCalciteValue(value, elementType, false));

return Expressions.new_(
WrappedList.class,
Expand All @@ -722,7 +718,7 @@ private static Expression toCalciteMap(Expression input, FieldType mapValueType)
ParameterExpression value = Expressions.parameter(Object.class);

BlockBuilder block = new BlockBuilder();
block.add(toCalciteValue(value, mapValueType));
block.add(toCalciteValue(value, mapValueType, false));

return Expressions.new_(
WrappedMap.class,
Expand All @@ -745,7 +741,8 @@ private static Expression toCalciteRow(Expression input, Schema schema) {

for (int i = 0; i < schema.getFieldCount(); i++) {
BlockBuilder list = new BlockBuilder(/* optimizing= */ false, body);
Expression returnValue = getBeamField(list, i, row, schema);
// instruct conversion of BYTES to byte[], required by BeamJavaTypeFactory
Expression returnValue = getBeamField(list, i, row, schema, false);

list.append(returnValue);

Expand Down
Loading
Loading