diff --git a/.github/trigger_files/beam_PostCommit_SQL.json b/.github/trigger_files/beam_PostCommit_SQL.json index 6cc79a7a0325..833fd9b0d174 100644 --- a/.github/trigger_files/beam_PostCommit_SQL.json +++ b/.github/trigger_files/beam_PostCommit_SQL.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run ", - "modification": 1 + "modification": 2 } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoderGenerator.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoderGenerator.java index 63a796141d5f..dc6a28fdf6b8 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoderGenerator.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoderGenerator.java @@ -113,6 +113,7 @@ public abstract class RowCoderGenerator { private static final String CODERS_FIELD_NAME = "FIELD_CODERS"; private static final String POSITIONS_FIELD_NAME = "FIELD_ENCODING_POSITIONS"; + private static final String SCHEMA_OPTION_STATIC_ENCODING = "beam:option:row:static_encoding"; static class WithStackTrace { private final T value; @@ -407,8 +408,13 @@ static void encodeDelegate( checkState(value.getFieldCount() == value.getSchema().getFieldCount()); checkState(encodingPosToIndex.length == value.getFieldCount()); + boolean staticEncoding = + value.getSchema().getOptions().getValueOrDefault(SCHEMA_OPTION_STATIC_ENCODING, false); + // Encode the field count. This allows us to handle compatible schema changes. - VAR_INT_CODER.encode(value.getFieldCount(), outputStream); + if (!staticEncoding) { + VAR_INT_CODER.encode(value.getFieldCount(), outputStream); + } if (hasNullableFields) { // If the row has null fields, extract the values out once so that both scanNullFields and @@ -420,7 +426,9 @@ static void encodeDelegate( } // Encode a bitmap for the null fields to save having to encode a bunch of nulls. - NULL_LIST_CODER.encode(scanNullFields(fieldValues, encodingPosToIndex), outputStream); + if (!staticEncoding) { + NULL_LIST_CODER.encode(scanNullFields(fieldValues, encodingPosToIndex), outputStream); + } for (int encodingPos = 0; encodingPos < fieldValues.length; ++encodingPos) { @Nullable Object fieldValue = fieldValues[encodingPosToIndex[encodingPos]]; if (fieldValue != null) { @@ -430,7 +438,9 @@ static void encodeDelegate( } else { // Otherwise, we know all fields are non-null, so the null list is always empty. - NULL_LIST_CODER.encode(EMPTY_BIT_SET, outputStream); + if (!staticEncoding) { + NULL_LIST_CODER.encode(EMPTY_BIT_SET, outputStream); + } for (int encodingPos = 0; encodingPos < value.getFieldCount(); ++encodingPos) { @Nullable Object fieldValue = value.getValue(encodingPosToIndex[encodingPos]); if (fieldValue != null) { @@ -511,9 +521,15 @@ public InstrumentedType prepare(InstrumentedType instrumentedType) { static Row decodeDelegate( Schema schema, Coder[] coders, int[] encodingPosToIndex, InputStream inputStream) throws IOException { - int fieldCount = VAR_INT_CODER.decode(inputStream); - - BitSet nullFields = NULL_LIST_CODER.decode(inputStream); + int fieldCount; + BitSet nullFields; + if (schema.getOptions().getValueOrDefault(SCHEMA_OPTION_STATIC_ENCODING, false)) { + fieldCount = schema.getFieldCount(); + nullFields = new BitSet(); + } else { + fieldCount = VAR_INT_CODER.decode(inputStream); + nullFields = NULL_LIST_CODER.decode(inputStream); + } Object[] fieldValues = new Object[coders.length]; for (int encodingPos = 0; encodingPos < fieldCount; ++encodingPos) { // In the case of a schema change going backwards, fieldCount might be > coders.length, diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/RowCoderTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/RowCoderTest.java index 885ff8f1491a..7818cba9a818 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/RowCoderTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/RowCoderTest.java @@ -626,4 +626,23 @@ public void testEncodingPositionRemoveFields() throws Exception { Row decoded = RowCoder.of(schema2).decode(new ByteArrayInputStream(os.toByteArray())); assertEquals(expected, decoded); } + + @Test + public void testStaticEncoding() throws Exception { + Schema schema = + Schema.builder() + .addInt32Field("f_int32") + .addStringField("f_string") + .setOptions( + Schema.Options.builder() + .setOption("beam:option:row:static_encoding", FieldType.BOOLEAN, true) + .build()) + .build(); + Row row = Row.withSchema(schema).addValues(42, "hello world!").build(); + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + RowCoder.of(schema).encode(row, bos); + assertEquals(14, bos.toByteArray().length); + + CoderProperties.coderDecodeEncodeEqual(RowCoder.of(schema), row); + } } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRel.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRel.java index 3c1c2579dedf..fad96abb29a5 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRel.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRel.java @@ -221,9 +221,10 @@ public PCollection expand(PCollectionList pinput) { BeamSqlPipelineOptions options = pinput.getPipeline().getOptions().as(BeamSqlPipelineOptions.class); + String builderString = builder.toBlock().toString(); CalcFn calcFn = new CalcFn( - builder.toBlock().toString(), + builderString, outputSchema, options.getVerifyRowValues(), getJarPaths(program), @@ -502,120 +503,109 @@ FieldAccessDescriptor getFieldAccess() { @Override public Expression field(BlockBuilder list, int index, Type storageType) { this.referencedColumns.add(index); - return getBeamField(list, index, input, inputSchema); + return getBeamField(list, index, input, inputSchema, true); } // Read field from Beam Row private static Expression getBeamField( - BlockBuilder list, int index, Expression input, Schema schema) { + BlockBuilder list, int index, Expression input, Schema schema, boolean useByteString) { if (index >= schema.getFieldCount() || index < 0) { throw new IllegalArgumentException("Unable to find value #" + index); } final Expression expression = list.append(list.newName("current"), input); - final Field field = schema.getField(index); final FieldType fieldType = field.getType(); final Expression fieldName = Expressions.constant(field.getName()); + Expression value = getBeamField(list, expression, fieldName, fieldType); + + return toCalciteValue(value, fieldType, useByteString); + } + + private static Expression getBeamField( + BlockBuilder list, Expression expression, Expression fieldName, FieldType fieldType) { final Expression value; switch (fieldType.getTypeName()) { case BYTE: - value = Expressions.call(expression, "getByte", fieldName); - break; + return Expressions.call(expression, "getByte", fieldName); case INT16: - value = Expressions.call(expression, "getInt16", fieldName); - break; + return Expressions.call(expression, "getInt16", fieldName); case INT32: - value = Expressions.call(expression, "getInt32", fieldName); - break; + return Expressions.call(expression, "getInt32", fieldName); case INT64: - value = Expressions.call(expression, "getInt64", fieldName); - break; + return Expressions.call(expression, "getInt64", fieldName); case DECIMAL: - value = Expressions.call(expression, "getDecimal", fieldName); - break; + return Expressions.call(expression, "getDecimal", fieldName); case FLOAT: - value = Expressions.call(expression, "getFloat", fieldName); - break; + return Expressions.call(expression, "getFloat", fieldName); case DOUBLE: - value = Expressions.call(expression, "getDouble", fieldName); - break; + return Expressions.call(expression, "getDouble", fieldName); case STRING: - value = Expressions.call(expression, "getString", fieldName); - break; + return Expressions.call(expression, "getString", fieldName); case DATETIME: - value = Expressions.call(expression, "getDateTime", fieldName); - break; + return Expressions.call(expression, "getDateTime", fieldName); case BOOLEAN: - value = Expressions.call(expression, "getBoolean", fieldName); - break; + return Expressions.call(expression, "getBoolean", fieldName); case BYTES: - value = Expressions.call(expression, "getBytes", fieldName); - break; + return Expressions.call(expression, "getBytes", fieldName); case ARRAY: - value = Expressions.call(expression, "getArray", fieldName); - break; + return Expressions.call(expression, "getArray", fieldName); case MAP: - value = Expressions.call(expression, "getMap", fieldName); - break; + return Expressions.call(expression, "getMap", fieldName); case ROW: - value = Expressions.call(expression, "getRow", fieldName); - break; + return Expressions.call(expression, "getRow", fieldName); case ITERABLE: - value = Expressions.call(expression, "getIterable", fieldName); - break; + return Expressions.call(expression, "getIterable", fieldName); case LOGICAL_TYPE: - String identifier = fieldType.getLogicalType().getIdentifier(); + LogicalType logicalType = fieldType.getLogicalType(); + String identifier = logicalType.getIdentifier(); if (FixedString.IDENTIFIER.equals(identifier) || VariableString.IDENTIFIER.equals(identifier)) { - value = Expressions.call(expression, "getString", fieldName); + return Expressions.call(expression, "getString", fieldName); } else if (FixedBytes.IDENTIFIER.equals(identifier) || VariableBytes.IDENTIFIER.equals(identifier)) { - value = Expressions.call(expression, "getBytes", fieldName); + return Expressions.call(expression, "getBytes", fieldName); } else if (TimeWithLocalTzType.IDENTIFIER.equals(identifier)) { - value = Expressions.call(expression, "getDateTime", fieldName); + return Expressions.call(expression, "getDateTime", fieldName); } else if (SqlTypes.DATE.getIdentifier().equals(identifier)) { - value = - Expressions.convert_( - Expressions.call( - expression, - "getLogicalTypeValue", - fieldName, - Expressions.constant(LocalDate.class)), - LocalDate.class); + return Expressions.convert_( + Expressions.call( + expression, + "getLogicalTypeValue", + fieldName, + Expressions.constant(LocalDate.class)), + LocalDate.class); } else if (SqlTypes.TIME.getIdentifier().equals(identifier)) { - value = - Expressions.convert_( - Expressions.call( - expression, - "getLogicalTypeValue", - fieldName, - Expressions.constant(LocalTime.class)), - LocalTime.class); + return Expressions.convert_( + Expressions.call( + expression, + "getLogicalTypeValue", + fieldName, + Expressions.constant(LocalTime.class)), + LocalTime.class); } else if (SqlTypes.DATETIME.getIdentifier().equals(identifier)) { - value = - Expressions.convert_( - Expressions.call( - expression, - "getLogicalTypeValue", - fieldName, - Expressions.constant(LocalDateTime.class)), - LocalDateTime.class); + return Expressions.convert_( + Expressions.call( + expression, + "getLogicalTypeValue", + fieldName, + Expressions.constant(LocalDateTime.class)), + LocalDateTime.class); } else if (FixedPrecisionNumeric.IDENTIFIER.equals(identifier)) { - value = Expressions.call(expression, "getDecimal", fieldName); + return Expressions.call(expression, "getDecimal", fieldName); + } else if (logicalType instanceof PassThroughLogicalType) { + return getBeamField(list, expression, fieldName, logicalType.getBaseType()); } else { throw new UnsupportedOperationException("Unable to get logical type " + identifier); } - break; default: throw new UnsupportedOperationException("Unable to get " + fieldType.getTypeName()); } - - return toCalciteValue(value, fieldType); } // Value conversion: Beam => Calcite - private static Expression toCalciteValue(Expression value, FieldType fieldType) { + private static Expression toCalciteValue( + Expression value, FieldType fieldType, boolean useByteString) { switch (fieldType.getTypeName()) { case BYTE: return Expressions.convert_(value, Byte.class); @@ -642,7 +632,10 @@ private static Expression toCalciteValue(Expression value, FieldType fieldType) Expressions.call(Expressions.convert_(value, AbstractInstant.class), "getMillis")); case BYTES: return nullOr( - value, Expressions.new_(ByteString.class, Expressions.convert_(value, byte[].class))); + value, + useByteString + ? Expressions.new_(ByteString.class, Expressions.convert_(value, byte[].class)) + : Expressions.convert_(value, byte[].class)); case ARRAY: case ITERABLE: return nullOr(value, toCalciteList(value, fieldType.getCollectionElementType())); @@ -651,7 +644,8 @@ private static Expression toCalciteValue(Expression value, FieldType fieldType) case ROW: return nullOr(value, toCalciteRow(value, fieldType.getRowSchema())); case LOGICAL_TYPE: - String identifier = fieldType.getLogicalType().getIdentifier(); + LogicalType logicalType = fieldType.getLogicalType(); + String identifier = logicalType.getIdentifier(); if (FixedString.IDENTIFIER.equals(identifier) || VariableString.IDENTIFIER.equals(identifier)) { return Expressions.convert_(value, String.class); @@ -692,6 +686,8 @@ private static Expression toCalciteValue(Expression value, FieldType fieldType) return nullOr(value, returnValue); } else if (FixedPrecisionNumeric.IDENTIFIER.equals(identifier)) { return Expressions.convert_(value, BigDecimal.class); + } else if (logicalType instanceof PassThroughLogicalType) { + return toCalciteValue(value, logicalType.getBaseType(), useByteString); } else { throw new UnsupportedOperationException("Unable to convert logical type " + identifier); } @@ -704,7 +700,7 @@ private static Expression toCalciteList(Expression input, FieldType elementType) ParameterExpression value = Expressions.parameter(Object.class); BlockBuilder block = new BlockBuilder(); - block.add(toCalciteValue(value, elementType)); + block.add(toCalciteValue(value, elementType, false)); return Expressions.new_( WrappedList.class, @@ -722,7 +718,7 @@ private static Expression toCalciteMap(Expression input, FieldType mapValueType) ParameterExpression value = Expressions.parameter(Object.class); BlockBuilder block = new BlockBuilder(); - block.add(toCalciteValue(value, mapValueType)); + block.add(toCalciteValue(value, mapValueType, false)); return Expressions.new_( WrappedMap.class, @@ -745,7 +741,8 @@ private static Expression toCalciteRow(Expression input, Schema schema) { for (int i = 0; i < schema.getFieldCount(); i++) { BlockBuilder list = new BlockBuilder(/* optimizing= */ false, body); - Expression returnValue = getBeamField(list, i, row, schema); + // instruct conversion of BYTES to byte[], required by BeamJavaTypeFactory + Expression returnValue = getBeamField(list, i, row, schema, false); list.append(returnValue); diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java index 5bcac6ad256f..3aaa91680999 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java @@ -22,6 +22,7 @@ import java.lang.reflect.Type; import java.util.Date; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import java.util.stream.IntStream; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.FieldType; @@ -169,6 +170,12 @@ public static boolean isStringType(FieldType fieldType) { FieldType.DATETIME, SqlTypeName.TIMESTAMP, FieldType.STRING, SqlTypeName.VARCHAR); + // Associating FieldType to generated RelDataType objects for Beam logical types. Used for + // recovering the original type in output schema after full Beam FieldType->Calcite Type->Beam + // FieldType trip + private static final Map LOGICAL_TYPE_REL_DATA_MAPPING = + new ConcurrentHashMap<>(); + /** Generate {@link Schema} from {@code RelDataType} which is used to create table. */ public static Schema toSchema(RelDataType tableInfo) { return tableInfo.getFieldList().stream().map(CalciteUtils::toField).collect(Schema.toSchema()); @@ -254,6 +261,9 @@ public static Schema.Field toField(String name, RelDataType calciteType) { } public static FieldType toFieldType(RelDataType calciteType) { + if (LOGICAL_TYPE_REL_DATA_MAPPING.containsKey(calciteType)) { + return LOGICAL_TYPE_REL_DATA_MAPPING.get(calciteType); + } switch (calciteType.getSqlTypeName()) { case ARRAY: case MULTISET: @@ -315,6 +325,29 @@ public static RelDataType toRelDataType(RelDataTypeFactory dataTypeFactory, Fiel Schema schema = fieldType.getRowSchema(); Preconditions.checkArgumentNotNull(schema); return toCalciteRowType(schema, dataTypeFactory); + case LOGICAL_TYPE: + Schema.LogicalType logicalType = fieldType.getLogicalType(); + RelDataType relDataType; + if (logicalType instanceof PassThroughLogicalType) { + relDataType = + toRelDataType( + dataTypeFactory, logicalType.getBaseType().withNullable(fieldType.getNullable())); + } else { + relDataType = dataTypeFactory.createSqlType(toSqlTypeName(fieldType)); + } + // For backward-compatibility, exclude logical types registered in + // CALCITE_TO_BEAM_TYPE_MAPPING, + // e.g., primitive types, date time types, etc. + SqlTypeName typeName = relDataType.getSqlTypeName(); + if (typeName != null && !CALCITE_TO_BEAM_TYPE_MAPPING.containsKey(typeName)) { + // register both nullable and non-nullable variants. + boolean flipNullable = !relDataType.isNullable(); + LOGICAL_TYPE_REL_DATA_MAPPING.put(relDataType, fieldType); + LOGICAL_TYPE_REL_DATA_MAPPING.put( + dataTypeFactory.createTypeWithNullability(relDataType, flipNullable), + fieldType.withNullable(flipNullable)); + } + return relDataType; default: return dataTypeFactory.createSqlType(toSqlTypeName(fieldType)); } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamComplexTypeTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamComplexTypeTest.java index a5f78f715293..5ef081b92c3f 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamComplexTypeTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamComplexTypeTest.java @@ -17,6 +17,8 @@ */ package org.apache.beam.sdk.extensions.sql; +import static org.junit.Assert.assertEquals; + import java.nio.charset.StandardCharsets; import java.time.LocalDate; import java.time.LocalDateTime; @@ -300,7 +302,6 @@ public void testSelectInnerRowOfNestedRow() { pipeline.run().waitUntilFinish(Duration.standardMinutes(2)); } - @Ignore("https://github.com/apache/beam/issues/21024") @Test public void testNestedBytes() { byte[] bytes = new byte[] {-70, -83, -54, -2}; @@ -325,7 +326,6 @@ public void testNestedBytes() { pipeline.run(); } - @Ignore("https://github.com/apache/beam/issues/21024") @Test public void testNestedArrayOfBytes() { byte[] bytes = new byte[] {-70, -83, -54, -2}; @@ -773,4 +773,28 @@ public void testMapWithNullRowFields() { PAssert.that(outputRow).containsInAnyOrder(expectedRow); pipeline.run().waitUntilFinish(Duration.standardMinutes(1)); } + + @Test + public void testUnknownLogicalType() { + Schema.FieldType rowType = Schema.FieldType.row(innerRowSchema); + + Schema.LogicalType logicalType = + new org.apache.beam.sdk.schemas.logicaltypes.UnknownLogicalType( + "RowBackedLogicalType", new byte[] {}, Schema.FieldType.STRING, "", rowType); + + Schema inputSchema = Schema.builder().addLogicalTypeField("logical_field", logicalType).build(); + + Row nestedRow = Row.withSchema(innerRowSchema).addValue("abc").addValue(42L).build(); + Row inputRow = Row.withSchema(inputSchema).addValue(nestedRow).build(); + + PCollection outputRow = + pipeline + .apply(Create.of(inputRow)) + .setRowSchema(inputSchema) + .apply(SqlTransform.query("select * from PCOLLECTION")); + + PAssert.that(outputRow).containsInAnyOrder(inputRow); + assertEquals(inputRow.getSchema(), outputRow.getSchema()); + pipeline.run().waitUntilFinish(Duration.standardMinutes(1)); + } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtilsTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtilsTest.java index 6c85c3582e95..481a700c0c99 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtilsTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtilsTest.java @@ -24,6 +24,8 @@ import java.util.Map; import java.util.stream.Collectors; import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.logicaltypes.PassThroughLogicalType; +import org.apache.beam.sdk.values.Row; import org.apache.beam.vendor.calcite.v1_40_0.org.apache.calcite.rel.type.RelDataType; import org.apache.beam.vendor.calcite.v1_40_0.org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.beam.vendor.calcite.v1_40_0.org.apache.calcite.rel.type.RelDataTypeSystem; @@ -178,4 +180,22 @@ public void testFieldTypeNotFound() { thrown.expectMessage("Cannot find a matching Beam FieldType for Calcite type: UNKNOWN"); CalciteUtils.toFieldType(relDataType); } + + @Test + public void testToRelDataTypeWithRowBackedLogicalType() { + Schema nestedSchema = Schema.builder().addField("nested_f1", Schema.FieldType.INT32).build(); + Schema.FieldType rowType = Schema.FieldType.row(nestedSchema); + + Schema.LogicalType logicalType = + new PassThroughLogicalType( + "RowBackedLogicalType", Schema.FieldType.STRING, "", rowType) {}; + + Schema.FieldType logicalFieldType = Schema.FieldType.logicalType(logicalType); + + RelDataType relDataType = CalciteUtils.toRelDataType(dataTypeFactory, logicalFieldType); + + assertEquals(SqlTypeName.ROW, relDataType.getSqlTypeName()); + assertEquals(1, relDataType.getFieldCount()); + assertEquals("nested_f1", relDataType.getFieldList().get(0).getName()); + } } diff --git a/sdks/python/apache_beam/coders/coder_impl.pxd b/sdks/python/apache_beam/coders/coder_impl.pxd index 02d3f1fe8dbf..e64177e6fd34 100644 --- a/sdks/python/apache_beam/coders/coder_impl.pxd +++ b/sdks/python/apache_beam/coders/coder_impl.pxd @@ -117,6 +117,10 @@ cdef class BigEndianShortCoderImpl(StreamCoderImpl): pass +cdef class ByteCoderImpl(StreamCoderImpl): + pass + + cdef class SinglePrecisionFloatCoderImpl(StreamCoderImpl): pass diff --git a/sdks/python/apache_beam/coders/coder_impl.py b/sdks/python/apache_beam/coders/coder_impl.py index 1270b98f9bc4..cd66a0c09e01 100644 --- a/sdks/python/apache_beam/coders/coder_impl.py +++ b/sdks/python/apache_beam/coders/coder_impl.py @@ -841,7 +841,7 @@ def encode_to_stream(self, value, out, nested): out.write_bigendian_int16(value) def decode_from_stream(self, in_stream, nested): - # type: (create_InputStream, bool) -> float + # type: (create_InputStream, bool) -> int return in_stream.read_bigendian_int16() def estimate_size(self, unused_value, nested=False): @@ -850,6 +850,22 @@ def estimate_size(self, unused_value, nested=False): return 2 +class ByteCoderImpl(StreamCoderImpl): + """For internal use only; no backwards-compatibility guarantees.""" + def encode_to_stream(self, value, out, nested): + # type: (int, create_OutputStream, bool) -> None + out.write_byte(value) + + def decode_from_stream(self, in_stream, nested): + # type: (create_InputStream, bool) -> int + return in_stream.read_byte() + + def estimate_size(self, unused_value, nested=False): + # type: (Any, bool) -> int + # A byte is encoded as 1 byte, regardless of nesting. + return 1 + + class SinglePrecisionFloatCoderImpl(StreamCoderImpl): """For internal use only; no backwards-compatibility guarantees.""" def encode_to_stream(self, value, out, nested): diff --git a/sdks/python/apache_beam/coders/coders.py b/sdks/python/apache_beam/coders/coders.py index b9bee4585688..556b18043189 100644 --- a/sdks/python/apache_beam/coders/coders.py +++ b/sdks/python/apache_beam/coders/coders.py @@ -91,6 +91,7 @@ 'Coder', 'AvroGenericCoder', 'BooleanCoder', + 'ByteCoder', 'BytesCoder', 'CloudpickleCoder', 'DillCoder', @@ -698,6 +699,25 @@ def __hash__(self): return hash(type(self)) +class ByteCoder(FastCoder): + """A coder used for single byte values""" + def _create_impl(self): + return coder_impl.ByteCoderImpl() + + def is_deterministic(self): + # type: () -> bool + return True + + def to_type_hint(self): + return int + + def __eq__(self, other): + return type(self) == type(other) + + def __hash__(self): + return hash(type(self)) + + class SinglePrecisionFloatCoder(FastCoder): """A coder used for single-precision floating-point values.""" def _create_impl(self): diff --git a/sdks/python/apache_beam/coders/coders_test_common.py b/sdks/python/apache_beam/coders/coders_test_common.py index 5b7f5f65a560..5e5cfc8a5b62 100644 --- a/sdks/python/apache_beam/coders/coders_test_common.py +++ b/sdks/python/apache_beam/coders/coders_test_common.py @@ -191,6 +191,7 @@ def tearDownClass(cls): coders.ProtoCoder, coders.ProtoPlusCoder, coders.BigEndianShortCoder, + coders.ByteCoder, coders.SinglePrecisionFloatCoder, coders.ToBytesCoder, coders.BigIntegerCoder, # tested in DecimalCoder @@ -1076,6 +1077,16 @@ def test_decimal_coder(self): test_encodings[idx], base64.b64encode(test_coder.encode(value)).decode().rstrip("=")) + def test_byte_coder(self): + test_coder = coders.ByteCoder() + test_values = [0, 80, 127, 128, 255] + test_encodings = ("AA", "UA", "fw", "gA", "/w") + self.check_coder(test_coder, *test_values) + for idx, value in enumerate(test_values): + self.assertEqual( + test_encodings[idx], + base64.b64encode(test_coder.encode(value)).decode().rstrip("=")) + def test_OrderedUnionCoder(self): test_coder = coders._OrderedUnionCoder((str, coders.StrUtf8Coder()), (int, coders.VarIntCoder()), diff --git a/sdks/python/apache_beam/coders/row_coder.py b/sdks/python/apache_beam/coders/row_coder.py index 1becf408cfbf..29f85ba9cbf4 100644 --- a/sdks/python/apache_beam/coders/row_coder.py +++ b/sdks/python/apache_beam/coders/row_coder.py @@ -22,6 +22,7 @@ from apache_beam.coders.coder_impl import RowCoderImpl from apache_beam.coders.coders import BigEndianShortCoder from apache_beam.coders.coders import BooleanCoder +from apache_beam.coders.coders import ByteCoder from apache_beam.coders.coders import BytesCoder from apache_beam.coders.coders import Coder from apache_beam.coders.coders import DecimalCoder @@ -147,8 +148,10 @@ def _nonnull_coder_from_type(field_type): return VarIntCoder() elif field_type.atomic_type == schema_pb2.INT32: return VarInt32Coder() - if field_type.atomic_type == schema_pb2.INT16: + elif field_type.atomic_type == schema_pb2.INT16: return BigEndianShortCoder() + elif field_type.atomic_type == schema_pb2.BYTE: + return ByteCoder() elif field_type.atomic_type == schema_pb2.FLOAT: return SinglePrecisionFloatCoder() elif field_type.atomic_type == schema_pb2.DOUBLE: diff --git a/sdks/python/apache_beam/transforms/sql_test.py b/sdks/python/apache_beam/transforms/sql_test.py index fc55320ba699..6649e210685a 100644 --- a/sdks/python/apache_beam/transforms/sql_test.py +++ b/sdks/python/apache_beam/transforms/sql_test.py @@ -45,6 +45,19 @@ coders.registry.register_coder(Shopper, coders.RowCoder) +class Aribitrary: + def __init__(self, obj): + self.obj = obj + + def __eq__(self, other): + return self.obj == other.obj + + +UserTypeRow = typing.NamedTuple( + "UserTypeRow", [("id", int), ("arb", Aribitrary), ("complex", complex)]) +coders.registry.register_coder(UserTypeRow, coders.RowCoder) + + @pytest.mark.xlang_sql_expansion_service @unittest.skipIf( TestPipeline().get_pipeline_options().view_as(StandardOptions).runner @@ -149,6 +162,19 @@ def test_row(self): | SqlTransform("SELECT a*a as s, LENGTH(b) AS c FROM PCOLLECTION")) assert_that(out, equal_to([(1, 1), (4, 1), (100, 2)])) + def test_row_user_type(self): + with TestPipeline() as p: + out = ( + p | beam.Create([ + UserTypeRow(1, Aribitrary(1.0), 1 + 2.5j), + UserTypeRow(1, Aribitrary("abc"), -1j), + ]) + | SqlTransform("SELECT arb, complex FROM PCOLLECTION") + | beam.Map(tuple)) + assert_that( + out, + equal_to([(Aribitrary(1.0), 1 + 2.5j), (Aribitrary("abc"), -1j)])) + def test_windowing_before_sql(self): with TestPipeline() as p: out = ( diff --git a/sdks/python/apache_beam/typehints/schemas.py b/sdks/python/apache_beam/typehints/schemas.py index 9e337f080fbf..2e028bb37e17 100644 --- a/sdks/python/apache_beam/typehints/schemas.py +++ b/sdks/python/apache_beam/typehints/schemas.py @@ -106,6 +106,9 @@ from apache_beam.utils.timestamp import Timestamp PYTHON_ANY_URN = "beam:logical:pythonsdk_any:v1" +_PYTHON_ANY_FIELD_TYPE_BYTE = "_pythonsdk_any_type_byte" +_PYTHON_ANY_FIELD_PAYLOAD = "payload" +_SCHEMA_OPTION_STATIC_ENCODING = "beam:option:row:static_encoding" # Bi-directional mappings _PRIMITIVES = ( @@ -255,6 +258,37 @@ def schema_field( description=description) +def _python_any_schema_pb2(): + # A portable schema matches FastPrimitivesCoder encoded values + return schema_pb2.FieldType( + logical_type=schema_pb2.LogicalType( + urn=PYTHON_ANY_URN, + representation=schema_pb2.FieldType( + nullable=False, + row_type=schema_pb2.RowType( + schema=schema_pb2.Schema( + fields=[ + schema_pb2.Field( + name=_PYTHON_ANY_FIELD_TYPE_BYTE, + type=schema_pb2.FieldType( + atomic_type=schema_pb2.BYTE, nullable=False)), + schema_pb2.Field( + name=_PYTHON_ANY_FIELD_PAYLOAD, + type=schema_pb2.FieldType( + atomic_type=schema_pb2.BYTES, nullable=False)) + ], + options=[ + schema_pb2.Option( + name=_SCHEMA_OPTION_STATIC_ENCODING, + type=schema_pb2.FieldType( + atomic_type=schema_pb2.BOOLEAN), + value=schema_pb2.FieldValue( + atomic_value=schema_pb2.AtomicTypeValue( + boolean=True))) + ])))), + nullable=True) + + class SchemaTranslation(object): def __init__(self, schema_registry: SchemaTypeRegistry = SCHEMA_REGISTRY): self.schema_registry = schema_registry @@ -361,9 +395,7 @@ def typing_to_runner_api(self, type_: type) -> schema_pb2.FieldType: logical_type = LogicalType.from_typing(type_) except ValueError: # Unknown type, just treat it like Any - return schema_pb2.FieldType( - logical_type=schema_pb2.LogicalType(urn=PYTHON_ANY_URN), - nullable=True) + return _python_any_schema_pb2() else: argument_type = None argument = None diff --git a/sdks/python/apache_beam/typehints/schemas_test.py b/sdks/python/apache_beam/typehints/schemas_test.py index d70bf0c47d33..8032e4701c25 100644 --- a/sdks/python/apache_beam/typehints/schemas_test.py +++ b/sdks/python/apache_beam/typehints/schemas_test.py @@ -43,6 +43,7 @@ from apache_beam.portability import common_urns from apache_beam.portability.api import schema_pb2 from apache_beam.typehints import row_type +from apache_beam.typehints import schemas from apache_beam.typehints import typehints from apache_beam.typehints.native_type_compatibility import match_is_named_tuple from apache_beam.typehints.schemas import SchemaTypeRegistry @@ -583,11 +584,7 @@ def test_proto_survives_typing_roundtrip(self, fieldtype_proto): def test_unknown_primitive_maps_to_any(self): self.assertEqual( - typing_to_runner_api(np.uint32), - schema_pb2.FieldType( - logical_type=schema_pb2.LogicalType( - urn="beam:logical:pythonsdk_any:v1"), - nullable=True)) + typing_to_runner_api(np.uint32), schemas._python_any_schema_pb2()) def test_unknown_atomic_raise_valueerror(self): self.assertRaises( diff --git a/sdks/python/setup.py b/sdks/python/setup.py index 5d4f86ae4d97..e3f6a1786038 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -171,7 +171,7 @@ def cythonize(*args, **kwargs): # ml_test is pinned to versions that require protobuf<5 on Python 3.10. Those # cannot be installed together, so ADK deps stay out of ml_test (use ml_base). ml_base_core = [ - 'embeddings>=0.0.4', # 0.0.3 crashes setuptools + 'embeddings>=0.0.4', # 0.0.3 crashes setuptools 'onnxruntime', # onnx 1.12–1.13 cap protobuf in ways that trigger huge backtracking with # Beam[gcp]+ml_test; pip can fall back to onnx 1.11 sdist which needs cmake. @@ -303,8 +303,7 @@ def generate_external_transform_wrappers(): except subprocess.CalledProcessError as err: raise RuntimeError( 'Could not generate external transform wrappers due to ' - 'error: %s', - err.stderr) + 'error: {}'.format(err.stderr)) def get_portability_package_data():