diff --git a/encodings/fastlanes/src/for/vtable/mod.rs b/encodings/fastlanes/src/for/vtable/mod.rs index 1f9824fd43d..4b730c96d48 100644 --- a/encodings/fastlanes/src/for/vtable/mod.rs +++ b/encodings/fastlanes/src/for/vtable/mod.rs @@ -69,6 +69,7 @@ impl VTable for FoRVTable { } fn serialize(metadata: Self::Metadata) -> VortexResult>> { + // Note that we **only** serialize the optional scalar value (not including the dtype). Ok(Some(ScalarValue::to_proto_bytes(metadata.value()))) } diff --git a/vortex-array/src/arrays/constant/array.rs b/vortex-array/src/arrays/constant/array.rs index 5722baa7ab8..9193956e9df 100644 --- a/vortex-array/src/arrays/constant/array.rs +++ b/vortex-array/src/arrays/constant/array.rs @@ -5,16 +5,6 @@ use vortex_scalar::Scalar; use crate::stats::ArrayStats; -/// Protobuf-encoded metadata for [`ConstantArray`]. -/// -/// When the serialized scalar value is small enough (see `CONSTANT_INLINE_THRESHOLD`), -/// it is inlined directly in the metadata to avoid a device-to-host copy on GPU. -#[derive(Clone, prost::Message)] -pub struct ConstantMetadata { - #[prost(optional, bytes, tag = "1")] - pub(super) scalar_value: Option>, -} - #[derive(Clone, Debug)] pub struct ConstantArray { pub(super) scalar: Scalar, @@ -47,21 +37,69 @@ impl ConstantArray { #[cfg(test)] mod tests { - use vortex_scalar::ScalarValue; + use rstest::rstest; + use vortex_dtype::Nullability; + use vortex_error::VortexResult; + use vortex_scalar::Scalar; + use vortex_session::VortexSession; - use super::ConstantMetadata; - use crate::ProstMetadata; - use crate::test_harness::check_metadata; + use crate::arrays::ConstantArray; + use crate::arrays::constant::vtable::CONSTANT_INLINE_THRESHOLD; + use crate::arrays::constant::vtable::ConstantVTable; + use crate::vtable::VTable; - #[cfg_attr(miri, ignore)] - #[test] - fn test_constant_metadata() { - let scalar_bytes: Vec = ScalarValue::to_proto_bytes(Some(&ScalarValue::from(i32::MAX))); - check_metadata( - "constant.metadata", - ProstMetadata(ConstantMetadata { - scalar_value: Some(scalar_bytes), - }), + #[rstest] + #[case::below_threshold(CONSTANT_INLINE_THRESHOLD - 1, true)] + #[case::at_threshold(CONSTANT_INLINE_THRESHOLD, true)] + #[case::above_threshold(CONSTANT_INLINE_THRESHOLD + 1, false)] + fn test_metadata_inlining( + #[case] nbytes: usize, + #[case] should_inline: bool, + ) -> VortexResult<()> { + // UTF-8 scalar `nbytes` equals the string length. + let string = "x".repeat(nbytes); + let array = ConstantArray::new(Scalar::from(string.as_str()), 10); + let metadata = ConstantVTable::metadata(&array)?; + + assert_eq!( + metadata.is_some(), + should_inline, + "scalar of {nbytes} bytes: expected inlined={should_inline}" ); + Ok(()) + } + + #[test] + fn test_metadata_round_trips() -> VortexResult<()> { + let scalar = Scalar::from(42i64); + let array = ConstantArray::new(scalar.clone(), 5); + let metadata = ConstantVTable::metadata(&array)?; + + // Serialize and deserialize the metadata. + let bytes = + ConstantVTable::serialize(metadata)?.expect("serialize should produce Some bytes"); + let session = VortexSession::empty(); + let deserialized = ConstantVTable::deserialize( + &bytes, + &vortex_dtype::DType::Primitive(vortex_dtype::PType::I64, Nullability::NonNullable), + 5, + &session, + )?; + + assert_eq!(deserialized.unwrap(), scalar); + Ok(()) + } + + #[test] + fn test_empty_bytes_deserializes_to_none() -> VortexResult<()> { + let session = VortexSession::empty(); + let metadata = ConstantVTable::deserialize( + &[], + &vortex_dtype::DType::Primitive(vortex_dtype::PType::I32, Nullability::NonNullable), + 10, + &session, + )?; + assert!(metadata.is_none(), "empty bytes should deserialize to None"); + Ok(()) } } diff --git a/vortex-array/src/arrays/constant/mod.rs b/vortex-array/src/arrays/constant/mod.rs index 6b0448b54ac..bb5bb519401 100644 --- a/vortex-array/src/arrays/constant/mod.rs +++ b/vortex-array/src/arrays/constant/mod.rs @@ -8,7 +8,6 @@ pub use arbitrary::ArbitraryConstantArray; mod array; pub use array::ConstantArray; -pub(crate) use array::ConstantMetadata; pub(crate) use vtable::canonical::constant_canonicalize; pub(crate) mod compute; diff --git a/vortex-array/src/arrays/constant/vtable/mod.rs b/vortex-array/src/arrays/constant/vtable/mod.rs index ca4a36d1fcb..66cbad88afc 100644 --- a/vortex-array/src/arrays/constant/vtable/mod.rs +++ b/vortex-array/src/arrays/constant/vtable/mod.rs @@ -5,20 +5,15 @@ use std::fmt::Debug; use vortex_dtype::DType; use vortex_error::VortexResult; -use vortex_error::vortex_bail; use vortex_error::vortex_ensure; use vortex_scalar::Scalar; use vortex_scalar::ScalarValue; use vortex_session::VortexSession; use crate::ArrayRef; -use crate::DeserializeMetadata; use crate::ExecutionCtx; use crate::IntoArray; -use crate::ProstMetadata; -use crate::SerializeMetadata; use crate::arrays::ConstantArray; -use crate::arrays::constant::ConstantMetadata; use crate::arrays::constant::compute::rules::PARENT_RULES; use crate::arrays::constant::vtable::canonical::constant_canonicalize; use crate::buffer::BufferHandle; @@ -44,12 +39,20 @@ impl ConstantVTable { /// Maximum size (in bytes) of a protobuf-encoded scalar value that will be inlined /// into the array metadata. Values larger than this are stored only in the buffer. -const CONSTANT_INLINE_THRESHOLD: usize = 1024; +pub(crate) const CONSTANT_INLINE_THRESHOLD: usize = 1024; impl VTable for ConstantVTable { type Array = ConstantArray; - type Metadata = ProstMetadata; + /// Optional inlined scalar constant. + /// + /// When the scalar value is small enough (<= `CONSTANT_INLINE_THRESHOLD` bytes), it is stored + /// directly in the metadata to avoid an extra buffer allocation and potential + /// device-to-host copy during deserialization. + /// + /// Currently, scalars are **always** stored in a separate buffer, regardless of if we inline a + /// small scalar into the metadata. + type Metadata = Option; type ArrayVTable = Self; type OperationsVTable = Self; @@ -61,28 +64,34 @@ impl VTable for ConstantVTable { } fn metadata(array: &ConstantArray) -> VortexResult { - let constant = &array.scalar(); - let proto_bytes: Vec = ScalarValue::to_proto_bytes(constant.value()); - let scalar_value = (proto_bytes.len() <= CONSTANT_INLINE_THRESHOLD).then_some(proto_bytes); - Ok(ProstMetadata(ConstantMetadata { scalar_value })) + let constant = array.scalar(); + + // If the scalar is small enough, we can simply carry it around as metadata. + Ok((constant.nbytes() <= CONSTANT_INLINE_THRESHOLD).then_some(constant.clone())) } fn serialize(metadata: Self::Metadata) -> VortexResult>> { - Ok(Some(metadata.serialize())) + // If we do not have a scalar to serialize, just return empty bytes. + Ok(Some(metadata.map_or_else(Vec::new, |c| { + // Note that we **only** serialize the optional scalar value (not including the dtype). + ScalarValue::to_proto_bytes(c.value()) + }))) } fn deserialize( bytes: &[u8], - _dtype: &DType, + dtype: &DType, _len: usize, _session: &VortexSession, ) -> VortexResult { // Empty bytes indicates an old writer that didn't produce metadata. if bytes.is_empty() { - return Ok(ProstMetadata(ConstantMetadata { scalar_value: None })); + return Ok(None); } - let metadata = ::deserialize(bytes)?; - Ok(ProstMetadata(metadata)) + + // Otherwise, deserialize the constant scalar from the metadata. + let scalar_value = ScalarValue::from_proto_bytes(bytes, dtype)?; + Some(Scalar::try_new(dtype.clone(), scalar_value)).transpose() } fn build( @@ -93,22 +102,22 @@ impl VTable for ConstantVTable { _children: &dyn ArrayChildren, ) -> VortexResult { // Prefer reading the scalar from inlined metadata to avoid device-to-host copies. - let scalar = if let Some(proto_bytes) = &metadata.scalar_value { - let scalar_value = ScalarValue::from_proto_bytes(proto_bytes, dtype)?; - - Scalar::try_new(dtype.clone(), scalar_value) - } else { - if buffers.len() != 1 { - vortex_bail!("Expected 1 buffer, got {}", buffers.len()); - } + if let Some(constant) = metadata { + return Ok(ConstantArray::new(constant.clone(), len)); + } - let buffer = buffers[0].clone().try_to_host_sync()?; - let bytes: &[u8] = buffer.as_ref(); + // Otherwise, get the constant scalar from the buffers. + vortex_ensure!( + buffers.len() == 1, + "Expected 1 buffer, got {}", + buffers.len() + ); - let scalar_value = ScalarValue::from_proto_bytes(bytes, dtype)?; + let buffer = buffers[0].clone().try_to_host_sync()?; + let bytes: &[u8] = buffer.as_ref(); - Scalar::try_new(dtype.clone(), scalar_value) - }?; + let scalar_value = ScalarValue::from_proto_bytes(bytes, dtype)?; + let scalar = Scalar::try_new(dtype.clone(), scalar_value)?; Ok(ConstantArray::new(scalar, len)) } diff --git a/vortex-scalar/src/scalar.rs b/vortex-scalar/src/scalar.rs index 85575600b43..448007e9a1c 100644 --- a/vortex-scalar/src/scalar.rs +++ b/vortex-scalar/src/scalar.rs @@ -316,8 +316,10 @@ impl Scalar { ) } - /// Returns the size of the scalar in bytes, uncompressed. - #[cfg(test)] + /// Returns an **ESTIMATE** of the size of the scalar in bytes, uncompressed. + /// + /// Note that the protobuf serialization of scalars will likely have a different (but roughly + /// similar) length. pub fn nbytes(&self) -> usize { use vortex_dtype::NativeDecimalType; use vortex_dtype::i256; diff --git a/vortex-scalar/src/tests/round_trip.rs b/vortex-scalar/src/tests/round_trip.rs index 69c28ba0220..ce2fe873bc5 100644 --- a/vortex-scalar/src/tests/round_trip.rs +++ b/vortex-scalar/src/tests/round_trip.rs @@ -11,6 +11,7 @@ mod tests { use std::sync::Arc; + use rstest::rstest; use vortex_buffer::ByteBuffer; use vortex_dtype::DType; use vortex_dtype::DecimalDType; @@ -21,6 +22,7 @@ mod tests { use crate::DecimalValue; use crate::Scalar; + use crate::ScalarValue; use crate::tests::SESSION; // Test that primitive scalars round-trip through ScalarValue @@ -292,4 +294,38 @@ mod tests { let bool_scalar = Scalar::bool(true, Nullability::NonNullable); assert!(bool_scalar.as_decimal_opt().is_none()); } + + /// Verifies that [`Scalar::nbytes`] matches the length of the proto-serialized scalar value. + #[rstest] + #[case::null_i32(Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)))] + #[case::bool_true(Scalar::from(true))] + #[case::bool_false(Scalar::from(false))] + #[case::i8(Scalar::from(i8::MAX))] + #[case::i16(Scalar::from(i16::MAX))] + #[case::i32(Scalar::from(i32::MAX))] + #[case::i64(Scalar::from(i64::MAX))] + #[case::u8(Scalar::from(u8::MAX))] + #[case::u16(Scalar::from(u16::MAX))] + #[case::u32(Scalar::from(u32::MAX))] + #[case::u64(Scalar::from(u64::MAX))] + #[case::f32(Scalar::from(f32::MAX))] + #[case::f64(Scalar::from(f64::MAX))] + #[case::utf8_empty(Scalar::from(""))] + #[case::utf8_short(Scalar::from("hello"))] + #[case::utf8_long(Scalar::from("x".repeat(2048).as_str()))] + #[case::binary_empty(Scalar::binary(Vec::::new(), Nullability::NonNullable))] + #[case::binary_short(Scalar::binary(vec![1u8, 2, 3], Nullability::NonNullable))] + fn test_nbytes_approx_eq_to_proto_bytes(#[case] scalar: Scalar) { + let proto_bytes: Vec = ScalarValue::to_proto_bytes(scalar.value()); + let diff = (scalar.nbytes() as isize - proto_bytes.len() as isize).abs(); + + // NOTE: THE 4 HERE IS COMPLETELY ARBITRARY!!! + assert!( + diff <= 4, + "nbytes() should be within 4 of proto-serialized length for {:?}, got {} vs {}", + scalar, + scalar.nbytes(), + proto_bytes.len(), + ); + } }