Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions encodings/fastlanes/src/for/vtable/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ impl VTable for FoRVTable {
}

fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
// Note that we **only** serialize the optional scalar value (not including the dtype).
Ok(Some(ScalarValue::to_proto_bytes(metadata.value())))
}

Expand Down
84 changes: 61 additions & 23 deletions vortex-array/src/arrays/constant/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,6 @@ use vortex_scalar::Scalar;

use crate::stats::ArrayStats;

/// Protobuf-encoded metadata for [`ConstantArray`].
///
/// When the serialized scalar value is small enough (see `CONSTANT_INLINE_THRESHOLD`),
/// it is inlined directly in the metadata to avoid a device-to-host copy on GPU.
#[derive(Clone, prost::Message)]
pub struct ConstantMetadata {
#[prost(optional, bytes, tag = "1")]
pub(super) scalar_value: Option<Vec<u8>>,
}

#[derive(Clone, Debug)]
pub struct ConstantArray {
pub(super) scalar: Scalar,
Expand Down Expand Up @@ -47,21 +37,69 @@ impl ConstantArray {

#[cfg(test)]
mod tests {
use vortex_scalar::ScalarValue;
use rstest::rstest;
use vortex_dtype::Nullability;
use vortex_error::VortexResult;
use vortex_scalar::Scalar;
use vortex_session::VortexSession;

use super::ConstantMetadata;
use crate::ProstMetadata;
use crate::test_harness::check_metadata;
use crate::arrays::ConstantArray;
use crate::arrays::constant::vtable::CONSTANT_INLINE_THRESHOLD;
use crate::arrays::constant::vtable::ConstantVTable;
use crate::vtable::VTable;

#[cfg_attr(miri, ignore)]
#[test]
fn test_constant_metadata() {
let scalar_bytes: Vec<u8> = ScalarValue::to_proto_bytes(Some(&ScalarValue::from(i32::MAX)));
check_metadata(
"constant.metadata",
ProstMetadata(ConstantMetadata {
scalar_value: Some(scalar_bytes),
}),
#[rstest]
#[case::below_threshold(CONSTANT_INLINE_THRESHOLD - 1, true)]
#[case::at_threshold(CONSTANT_INLINE_THRESHOLD, true)]
#[case::above_threshold(CONSTANT_INLINE_THRESHOLD + 1, false)]
fn test_metadata_inlining(
#[case] nbytes: usize,
#[case] should_inline: bool,
) -> VortexResult<()> {
// UTF-8 scalar `nbytes` equals the string length.
let string = "x".repeat(nbytes);
let array = ConstantArray::new(Scalar::from(string.as_str()), 10);
let metadata = ConstantVTable::metadata(&array)?;

assert_eq!(
metadata.is_some(),
should_inline,
"scalar of {nbytes} bytes: expected inlined={should_inline}"
);
Ok(())
}

#[test]
fn test_metadata_round_trips() -> VortexResult<()> {
let scalar = Scalar::from(42i64);
let array = ConstantArray::new(scalar.clone(), 5);
let metadata = ConstantVTable::metadata(&array)?;

// Serialize and deserialize the metadata.
let bytes =
ConstantVTable::serialize(metadata)?.expect("serialize should produce Some bytes");
let session = VortexSession::empty();
let deserialized = ConstantVTable::deserialize(
&bytes,
&vortex_dtype::DType::Primitive(vortex_dtype::PType::I64, Nullability::NonNullable),
5,
&session,
)?;

assert_eq!(deserialized.unwrap(), scalar);
Ok(())
}

#[test]
fn test_empty_bytes_deserializes_to_none() -> VortexResult<()> {
let session = VortexSession::empty();
let metadata = ConstantVTable::deserialize(
&[],
&vortex_dtype::DType::Primitive(vortex_dtype::PType::I32, Nullability::NonNullable),
10,
&session,
)?;
assert!(metadata.is_none(), "empty bytes should deserialize to None");
Ok(())
}
}
1 change: 0 additions & 1 deletion vortex-array/src/arrays/constant/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ pub use arbitrary::ArbitraryConstantArray;

mod array;
pub use array::ConstantArray;
pub(crate) use array::ConstantMetadata;
pub(crate) use vtable::canonical::constant_canonicalize;

pub(crate) mod compute;
Expand Down
67 changes: 38 additions & 29 deletions vortex-array/src/arrays/constant/vtable/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,15 @@ use std::fmt::Debug;

use vortex_dtype::DType;
use vortex_error::VortexResult;
use vortex_error::vortex_bail;
use vortex_error::vortex_ensure;
use vortex_scalar::Scalar;
use vortex_scalar::ScalarValue;
use vortex_session::VortexSession;

use crate::ArrayRef;
use crate::DeserializeMetadata;
use crate::ExecutionCtx;
use crate::IntoArray;
use crate::ProstMetadata;
use crate::SerializeMetadata;
use crate::arrays::ConstantArray;
use crate::arrays::constant::ConstantMetadata;
use crate::arrays::constant::compute::rules::PARENT_RULES;
use crate::arrays::constant::vtable::canonical::constant_canonicalize;
use crate::buffer::BufferHandle;
Expand All @@ -44,12 +39,20 @@ impl ConstantVTable {

/// Maximum size (in bytes) of a protobuf-encoded scalar value that will be inlined
/// into the array metadata. Values larger than this are stored only in the buffer.
const CONSTANT_INLINE_THRESHOLD: usize = 1024;
pub(crate) const CONSTANT_INLINE_THRESHOLD: usize = 1024;

impl VTable for ConstantVTable {
type Array = ConstantArray;

type Metadata = ProstMetadata<ConstantMetadata>;
/// Optional inlined scalar constant.
///
/// When the scalar value is small enough (<= `CONSTANT_INLINE_THRESHOLD` bytes), it is stored
/// directly in the metadata to avoid an extra buffer allocation and potential
/// device-to-host copy during deserialization.
///
/// Currently, scalars are **always** stored in a separate buffer, regardless of if we inline a
/// small scalar into the metadata.
Comment on lines +53 to +54
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See PR description above also.

We should be able to modify the visitor to just check if the scalar size is too big? But that could potentially be error-prone?

type Metadata = Option<Scalar>;

type ArrayVTable = Self;
type OperationsVTable = Self;
Expand All @@ -61,28 +64,34 @@ impl VTable for ConstantVTable {
}

fn metadata(array: &ConstantArray) -> VortexResult<Self::Metadata> {
let constant = &array.scalar();
let proto_bytes: Vec<u8> = ScalarValue::to_proto_bytes(constant.value());
let scalar_value = (proto_bytes.len() <= CONSTANT_INLINE_THRESHOLD).then_some(proto_bytes);
Ok(ProstMetadata(ConstantMetadata { scalar_value }))
let constant = array.scalar();

// If the scalar is small enough, we can simply carry it around as metadata.
Ok((constant.nbytes() <= CONSTANT_INLINE_THRESHOLD).then_some(constant.clone()))
}

fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
Ok(Some(metadata.serialize()))
// If we do not have a scalar to serialize, just return empty bytes.
Ok(Some(metadata.map_or_else(Vec::new, |c| {
// Note that we **only** serialize the optional scalar value (not including the dtype).
ScalarValue::to_proto_bytes(c.value())
})))
}

fn deserialize(
bytes: &[u8],
_dtype: &DType,
dtype: &DType,
_len: usize,
_session: &VortexSession,
) -> VortexResult<Self::Metadata> {
// Empty bytes indicates an old writer that didn't produce metadata.
if bytes.is_empty() {
return Ok(ProstMetadata(ConstantMetadata { scalar_value: None }));
return Ok(None);
}
let metadata = <Self::Metadata as DeserializeMetadata>::deserialize(bytes)?;
Ok(ProstMetadata(metadata))

// Otherwise, deserialize the constant scalar from the metadata.
let scalar_value = ScalarValue::from_proto_bytes(bytes, dtype)?;
Some(Scalar::try_new(dtype.clone(), scalar_value)).transpose()
}

fn build(
Expand All @@ -93,22 +102,22 @@ impl VTable for ConstantVTable {
_children: &dyn ArrayChildren,
) -> VortexResult<ConstantArray> {
// Prefer reading the scalar from inlined metadata to avoid device-to-host copies.
let scalar = if let Some(proto_bytes) = &metadata.scalar_value {
let scalar_value = ScalarValue::from_proto_bytes(proto_bytes, dtype)?;

Scalar::try_new(dtype.clone(), scalar_value)
} else {
if buffers.len() != 1 {
vortex_bail!("Expected 1 buffer, got {}", buffers.len());
}
if let Some(constant) = metadata {
return Ok(ConstantArray::new(constant.clone(), len));
}

let buffer = buffers[0].clone().try_to_host_sync()?;
let bytes: &[u8] = buffer.as_ref();
// Otherwise, get the constant scalar from the buffers.
vortex_ensure!(
buffers.len() == 1,
"Expected 1 buffer, got {}",
buffers.len()
);

let scalar_value = ScalarValue::from_proto_bytes(bytes, dtype)?;
let buffer = buffers[0].clone().try_to_host_sync()?;
let bytes: &[u8] = buffer.as_ref();

Scalar::try_new(dtype.clone(), scalar_value)
}?;
let scalar_value = ScalarValue::from_proto_bytes(bytes, dtype)?;
let scalar = Scalar::try_new(dtype.clone(), scalar_value)?;

Ok(ConstantArray::new(scalar, len))
}
Expand Down
6 changes: 4 additions & 2 deletions vortex-scalar/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,10 @@ impl Scalar {
)
}

/// Returns the size of the scalar in bytes, uncompressed.
#[cfg(test)]
/// Returns an **ESTIMATE** of the size of the scalar in bytes, uncompressed.
///
/// Note that the protobuf serialization of scalars will likely have a different (but roughly
/// similar) length.
pub fn nbytes(&self) -> usize {
use vortex_dtype::NativeDecimalType;
use vortex_dtype::i256;
Expand Down
36 changes: 36 additions & 0 deletions vortex-scalar/src/tests/round_trip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
mod tests {
use std::sync::Arc;

use rstest::rstest;
use vortex_buffer::ByteBuffer;
use vortex_dtype::DType;
use vortex_dtype::DecimalDType;
Expand All @@ -21,6 +22,7 @@ mod tests {

use crate::DecimalValue;
use crate::Scalar;
use crate::ScalarValue;
use crate::tests::SESSION;

// Test that primitive scalars round-trip through ScalarValue
Expand Down Expand Up @@ -292,4 +294,38 @@ mod tests {
let bool_scalar = Scalar::bool(true, Nullability::NonNullable);
assert!(bool_scalar.as_decimal_opt().is_none());
}

/// Verifies that [`Scalar::nbytes`] matches the length of the proto-serialized scalar value.
#[rstest]
#[case::null_i32(Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)))]
#[case::bool_true(Scalar::from(true))]
#[case::bool_false(Scalar::from(false))]
#[case::i8(Scalar::from(i8::MAX))]
#[case::i16(Scalar::from(i16::MAX))]
#[case::i32(Scalar::from(i32::MAX))]
#[case::i64(Scalar::from(i64::MAX))]
#[case::u8(Scalar::from(u8::MAX))]
#[case::u16(Scalar::from(u16::MAX))]
#[case::u32(Scalar::from(u32::MAX))]
#[case::u64(Scalar::from(u64::MAX))]
#[case::f32(Scalar::from(f32::MAX))]
#[case::f64(Scalar::from(f64::MAX))]
#[case::utf8_empty(Scalar::from(""))]
#[case::utf8_short(Scalar::from("hello"))]
#[case::utf8_long(Scalar::from("x".repeat(2048).as_str()))]
#[case::binary_empty(Scalar::binary(Vec::<u8>::new(), Nullability::NonNullable))]
#[case::binary_short(Scalar::binary(vec![1u8, 2, 3], Nullability::NonNullable))]
fn test_nbytes_approx_eq_to_proto_bytes(#[case] scalar: Scalar) {
let proto_bytes: Vec<u8> = ScalarValue::to_proto_bytes(scalar.value());
let diff = (scalar.nbytes() as isize - proto_bytes.len() as isize).abs();

// NOTE: THE 4 HERE IS COMPLETELY ARBITRARY!!!
assert!(
diff <= 4,
"nbytes() should be within 4 of proto-serialized length for {:?}, got {} vs {}",
scalar,
scalar.nbytes(),
proto_bytes.len(),
);
}
}
Loading