diff --git a/benchmarks/bench_execute_write_params.py b/benchmarks/bench_execute_write_params.py new file mode 100644 index 0000000000..cf02a438a0 --- /dev/null +++ b/benchmarks/bench_execute_write_params.py @@ -0,0 +1,538 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Benchmark: ExecuteMessage._write_query_params() and send_body() for vector +INSERT workloads. + +Compares five approaches for the parameter serialization hot loop: + + 1. baseline – current code (calling write_value() per param) + 2. pr788_inline – PR #788 inlining (local aliases, inline write_value) + 3. buf_accum – buffer accumulation (collect parts in list, single join) + 4. combined – inlining + buffer accumulation + 5. module_current – whatever the loaded module provides (.so or .py) + +Variants 1-4 are standalone pure-Python functions that call into +Cython-compiled helpers (write_value, write_string, etc.) when the .so is +loaded. Variant 5 calls the actual module method directly. + +NOTE: To compare Cython vs pure-Python for variant 5, move the .so aside: + mv cassandra/protocol.cpython-*-linux-gnu.so{,.bak} + +Usage: + python benchmarks/bench_execute_write_params.py +""" + +import io +import struct +import time +import timeit +import sys +import os + +# Ensure the repo root is importable +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import cassandra.protocol +from cassandra.protocol import ( + ExecuteMessage, + _QueryMessage, + ProtocolHandler, + write_consistency_level, + write_byte, + write_uint, + write_short, + write_int, + write_long, + write_string, + write_value, + _UNSET_VALUE, + _VALUES_FLAG, + _WITH_SERIAL_CONSISTENCY_FLAG, + _PAGE_SIZE_FLAG, + _WITH_PAGING_STATE_FLAG, + _PROTOCOL_TIMESTAMP_FLAG, + _WITH_KEYSPACE_FLAG, +) +from cassandra import ProtocolVersion +from cassandra.marshal import int32_pack, uint16_pack, uint8_pack, uint64_pack + +# --------------------------------------------------------------------------- +# Pre-computed constants (as in PR #788) +# --------------------------------------------------------------------------- +_INT32_NEG1 = int32_pack(-1) # NULL marker +_INT32_NEG2 = int32_pack(-2) # UNSET marker + + +# =================================================================== +# Variant 1: baseline – mirrors current _write_query_params exactly +# =================================================================== + + +def baseline_write_query_params(msg, f, protocol_version): + write_consistency_level(f, msg.consistency_level) + flags = 0x00 + if msg.query_params is not None: + flags |= _VALUES_FLAG + if msg.serial_consistency_level: + flags |= _WITH_SERIAL_CONSISTENCY_FLAG + if msg.fetch_size: + flags |= _PAGE_SIZE_FLAG + if msg.paging_state: + flags |= _WITH_PAGING_STATE_FLAG + if msg.timestamp is not None: + flags |= _PROTOCOL_TIMESTAMP_FLAG + if getattr(msg, "keyspace", None) is not None: + if ProtocolVersion.uses_keyspace_flag(protocol_version): + flags |= _WITH_KEYSPACE_FLAG + if ProtocolVersion.uses_int_query_flags(protocol_version): + write_uint(f, flags) + else: + write_byte(f, flags) + if msg.query_params is not None: + write_short(f, len(msg.query_params)) + for param in msg.query_params: + write_value(f, param) + if msg.fetch_size: + write_int(f, msg.fetch_size) + if msg.paging_state: + write_string(f, msg.paging_state) + if msg.serial_consistency_level: + write_consistency_level(f, msg.serial_consistency_level) + if msg.timestamp is not None: + write_long(f, msg.timestamp) + + +def baseline_send_body(msg, f, protocol_version): + write_string(f, msg.query_id) + if ProtocolVersion.uses_prepared_metadata(protocol_version): + write_string(f, msg.result_metadata_id) + baseline_write_query_params(msg, f, protocol_version) + + +# =================================================================== +# Variant 2: pr788_inline – inline write_value with local aliases +# =================================================================== + + +def pr788_write_query_params(msg, f, protocol_version): + write_consistency_level(f, msg.consistency_level) + flags = 0x00 + if msg.query_params is not None: + flags |= _VALUES_FLAG + if msg.serial_consistency_level: + flags |= _WITH_SERIAL_CONSISTENCY_FLAG + if msg.fetch_size: + flags |= _PAGE_SIZE_FLAG + if msg.paging_state: + flags |= _WITH_PAGING_STATE_FLAG + if msg.timestamp is not None: + flags |= _PROTOCOL_TIMESTAMP_FLAG + if getattr(msg, "keyspace", None) is not None: + if ProtocolVersion.uses_keyspace_flag(protocol_version): + flags |= _WITH_KEYSPACE_FLAG + if ProtocolVersion.uses_int_query_flags(protocol_version): + write_uint(f, flags) + else: + write_byte(f, flags) + if msg.query_params is not None: + write_short(f, len(msg.query_params)) + _fw = f.write + _i32 = int32_pack + for param in msg.query_params: + if param is None: + _fw(_INT32_NEG1) + elif param is _UNSET_VALUE: + _fw(_INT32_NEG2) + else: + _fw(_i32(len(param))) + _fw(param) + if msg.fetch_size: + write_int(f, msg.fetch_size) + if msg.paging_state: + write_string(f, msg.paging_state) + if msg.serial_consistency_level: + write_consistency_level(f, msg.serial_consistency_level) + if msg.timestamp is not None: + write_long(f, msg.timestamp) + + +def pr788_send_body(msg, f, protocol_version): + write_string(f, msg.query_id) + if ProtocolVersion.uses_prepared_metadata(protocol_version): + write_string(f, msg.result_metadata_id) + pr788_write_query_params(msg, f, protocol_version) + + +# =================================================================== +# Variant 3: buf_accum – collect all writes in a list, single join +# =================================================================== + + +def bufaccum_write_query_params(msg, f, protocol_version): + parts = [] + _p = parts.append + _i32 = int32_pack + _u16 = uint16_pack + _u8 = uint8_pack + _u64 = uint64_pack + + _p(_u16(msg.consistency_level)) + + flags = 0x00 + if msg.query_params is not None: + flags |= _VALUES_FLAG + if msg.serial_consistency_level: + flags |= _WITH_SERIAL_CONSISTENCY_FLAG + if msg.fetch_size: + flags |= _PAGE_SIZE_FLAG + if msg.paging_state: + flags |= _WITH_PAGING_STATE_FLAG + if msg.timestamp is not None: + flags |= _PROTOCOL_TIMESTAMP_FLAG + if getattr(msg, "keyspace", None) is not None: + if ProtocolVersion.uses_keyspace_flag(protocol_version): + flags |= _WITH_KEYSPACE_FLAG + + if ProtocolVersion.uses_int_query_flags(protocol_version): + from cassandra.marshal import uint32_pack + + _p(uint32_pack(flags)) + else: + _p(_u8(flags)) + + if msg.query_params is not None: + _p(_u16(len(msg.query_params))) + for param in msg.query_params: + if param is None: + _p(_INT32_NEG1) + elif param is _UNSET_VALUE: + _p(_INT32_NEG2) + else: + _p(_i32(len(param))) + _p(param) + + if msg.fetch_size: + _p(_i32(msg.fetch_size)) + if msg.paging_state: + ps = msg.paging_state + if isinstance(ps, str): + ps = ps.encode("utf8") + _p(_u16(len(ps))) + _p(ps) + if msg.serial_consistency_level: + _p(_u16(msg.serial_consistency_level)) + if msg.timestamp is not None: + _p(_u64(msg.timestamp)) + + f.write(b"".join(parts)) + + +def bufaccum_send_body(msg, f, protocol_version): + write_string(f, msg.query_id) + if ProtocolVersion.uses_prepared_metadata(protocol_version): + write_string(f, msg.result_metadata_id) + bufaccum_write_query_params(msg, f, protocol_version) + + +# =================================================================== +# Variant 4: combined – inline write_value + buffer accumulation +# (single len+data concat per param, then single join) +# =================================================================== + + +def combined_write_query_params(msg, f, protocol_version): + parts = [] + _p = parts.append + _i32 = int32_pack + _u16 = uint16_pack + _u8 = uint8_pack + _u64 = uint64_pack + + _p(_u16(msg.consistency_level)) + + flags = 0x00 + if msg.query_params is not None: + flags |= _VALUES_FLAG + if msg.serial_consistency_level: + flags |= _WITH_SERIAL_CONSISTENCY_FLAG + if msg.fetch_size: + flags |= _PAGE_SIZE_FLAG + if msg.paging_state: + flags |= _WITH_PAGING_STATE_FLAG + if msg.timestamp is not None: + flags |= _PROTOCOL_TIMESTAMP_FLAG + if getattr(msg, "keyspace", None) is not None: + if ProtocolVersion.uses_keyspace_flag(protocol_version): + flags |= _WITH_KEYSPACE_FLAG + + if ProtocolVersion.uses_int_query_flags(protocol_version): + from cassandra.marshal import uint32_pack + + _p(uint32_pack(flags)) + else: + _p(_u8(flags)) + + if msg.query_params is not None: + _p(_u16(len(msg.query_params))) + for param in msg.query_params: + if param is None: + _p(_INT32_NEG1) + elif param is _UNSET_VALUE: + _p(_INT32_NEG2) + else: + _p(_i32(len(param)) + param) # single concat per param + + if msg.fetch_size: + _p(_i32(msg.fetch_size)) + if msg.paging_state: + ps = msg.paging_state + if isinstance(ps, str): + ps = ps.encode("utf8") + _p(_u16(len(ps))) + _p(ps) + if msg.serial_consistency_level: + _p(_u16(msg.serial_consistency_level)) + if msg.timestamp is not None: + _p(_u64(msg.timestamp)) + + f.write(b"".join(parts)) + + +def combined_send_body(msg, f, protocol_version): + write_string(f, msg.query_id) + if ProtocolVersion.uses_prepared_metadata(protocol_version): + write_string(f, msg.result_metadata_id) + combined_write_query_params(msg, f, protocol_version) + + +# =================================================================== +# Test scenarios +# =================================================================== + + +def make_vector_params(dim): + """Simulate a prepared INSERT with (int32_key, float_vector) params. + + Returns a list of pre-serialized bytes, as BoundStatement.bind() would + produce *after* calling col_type.serialize() on each value. + """ + int_key = int32_pack(42) # 4 bytes + vector_bytes = struct.pack(f">{dim}f", *([0.1] * dim)) # dim * 4 bytes + return [int_key, vector_bytes] + + +def make_scalar_params(n, size=20): + """Simulate n text columns of `size` bytes each.""" + return [b"\x41" * size for _ in range(n)] + + +PROTO_VERSION = 4 +ITERATIONS = 500_000 +REPEATS = 5 + +SCENARIOS = [ + ("128D vector INSERT (2 params)", make_vector_params(128)), + ("768D vector INSERT (2 params)", make_vector_params(768)), + ("1536D vector INSERT (2 params)", make_vector_params(1536)), + ("scalar 10 text cols (10 params)", make_scalar_params(10, 20)), +] + +# _write_query_params variants (the core hot path) +WQP_VARIANTS = [ + ("1_baseline", baseline_write_query_params), + ("2_pr788_inline", pr788_write_query_params), + ("3_buf_accum", bufaccum_write_query_params), + ("4_combined", combined_write_query_params), +] + +# send_body variants (includes query_id framing) +SB_VARIANTS = [ + ("1_baseline", baseline_send_body), + ("2_pr788_inline", pr788_send_body), + ("3_buf_accum", bufaccum_send_body), + ("4_combined", combined_send_body), +] + + +# =================================================================== +# Benchmark helpers +# =================================================================== + + +def verify_output(ref_fn, test_fn, msg, pv): + """Verify two functions produce byte-identical output.""" + f1 = io.BytesIO() + ref_fn(msg, f1, pv) + ref_bytes = f1.getvalue() + + f2 = io.BytesIO() + test_fn(msg, f2, pv) + test_bytes = f2.getvalue() + + if ref_bytes != test_bytes: + for i, (a, b) in enumerate(zip(ref_bytes, test_bytes)): + if a != b: + return False, f"diff at byte {i}: ref=0x{a:02x}, test=0x{b:02x}" + if len(ref_bytes) != len(test_bytes): + return False, f"len diff: ref={len(ref_bytes)}, test={len(test_bytes)}" + return True, "" + + +def bench_fn(fn, msg, pv, iterations, repeats): + """Benchmark a single function, return best ns/call.""" + f = io.BytesIO() + + def run(): + f.seek(0) + f.truncate() + fn(msg, f, pv) + + t = timeit.repeat(run, number=iterations, repeat=repeats, timer=time.process_time) + return min(t) / iterations * 1e9 + + +def make_execute_msg(params): + """Create a realistic ExecuteMessage for a prepared INSERT.""" + return ExecuteMessage( + query_id=b"\x01\x02\x03\x04\x05\x06\x07\x08", # 8-byte prepared query ID + query_params=params, + consistency_level=1, # ONE + timestamp=1234567890123456, # typical microsecond timestamp + # No serial CL, no fetch_size, no paging — typical INSERT + ) + + +# =================================================================== +# Main +# =================================================================== + + +def main(): + is_cython = cassandra.protocol.__file__.endswith(".so") + print(f"Python: {sys.version.split()[0]}") + print(f"Module: {cassandra.protocol.__file__}") + print(f"Cython: {'YES (.so loaded)' if is_cython else 'NO (pure Python .py)'}") + print(f"Config: proto v{PROTO_VERSION}, {ITERATIONS:,} iters, best of {REPEATS}") + print() + print("NOTE: Variants 1-4 are standalone pure-Python functions.") + print( + " They call Cython-compiled helpers (write_value, etc.) when .so is loaded." + ) + print(" 'module' calls the actual loaded module method directly.") + print() + + # Grab the base-class _write_query_params to bypass ExecuteMessage's + # super() overhead — gives a fair comparison with standalone functions. + _module_wqp = _QueryMessage._write_query_params + + for scenario_label, params in SCENARIOS: + msg = make_execute_msg(params) + total_param_bytes = sum(len(p) for p in params) + print(f"=== {scenario_label} (payload: {total_param_bytes:,} bytes) ===") + print() + + # ---- _write_query_params benchmarks ---- + print(" _write_query_params() [core hot path]:") + print(f" {'variant':20s} {'ns/call':>10s} {'vs baseline':>11s}") + print(f" {'-------':20s} {'-------':>10s} {'-----------':>11s}") + + baseline_wqp_ns = None + for var_label, var_fn in WQP_VARIANTS: + ok, err = verify_output( + baseline_write_query_params, var_fn, msg, PROTO_VERSION + ) + if not ok: + print(f" {var_label:20s} MISMATCH: {err}") + continue + ns = bench_fn(var_fn, msg, PROTO_VERSION, ITERATIONS, REPEATS) + if baseline_wqp_ns is None: + baseline_wqp_ns = ns + speedup = baseline_wqp_ns / ns + print(f" {var_label:20s} {ns:8.1f} {speedup:5.2f}x") + + # Module variant (bypass super() for fair comparison) + def module_wqp(m, f, pv): + _module_wqp(m, f, pv) + + ok, err = verify_output( + baseline_write_query_params, module_wqp, msg, PROTO_VERSION + ) + if ok: + ns = bench_fn(module_wqp, msg, PROTO_VERSION, ITERATIONS, REPEATS) + speedup = baseline_wqp_ns / ns if baseline_wqp_ns else 0 + label = "5_module" + (" (cython)" if is_cython else " (py)") + print(f" {label:20s} {ns:8.1f} {speedup:5.2f}x") + else: + print(f" 5_module MISMATCH: {err}") + + print() + + # ---- send_body benchmarks ---- + print(" send_body() [includes query_id framing]:") + print(f" {'variant':20s} {'ns/call':>10s} {'vs baseline':>11s}") + print(f" {'-------':20s} {'-------':>10s} {'-----------':>11s}") + + baseline_sb_ns = None + for var_label, var_fn in SB_VARIANTS: + ok, err = verify_output(baseline_send_body, var_fn, msg, PROTO_VERSION) + if not ok: + print(f" {var_label:20s} MISMATCH: {err}") + continue + ns = bench_fn(var_fn, msg, PROTO_VERSION, ITERATIONS, REPEATS) + if baseline_sb_ns is None: + baseline_sb_ns = ns + speedup = baseline_sb_ns / ns + print(f" {var_label:20s} {ns:8.1f} {speedup:5.2f}x") + + # Module send_body (direct method call, no lambda) + def module_sb(m, f, pv): + m.send_body(f, pv) + + ok, err = verify_output(baseline_send_body, module_sb, msg, PROTO_VERSION) + if ok: + ns = bench_fn(module_sb, msg, PROTO_VERSION, ITERATIONS, REPEATS) + speedup = baseline_sb_ns / ns if baseline_sb_ns else 0 + label = "5_module" + (" (cython)" if is_cython else " (py)") + print(f" {label:20s} {ns:8.1f} {speedup:5.2f}x") + else: + print(f" 5_module MISMATCH: {err}") + + print() + + # ---- encode_message benchmark (full wire frame) ---- + print(" encode_message() [full wire frame]:") + + def run_encode(): + return ProtocolHandler.encode_message( + msg, + stream_id=1, + protocol_version=PROTO_VERSION, + compressor=None, + allow_beta_protocol_version=False, + ) + + ref_frame = run_encode() + t = timeit.repeat( + run_encode, number=ITERATIONS, repeat=REPEATS, timer=time.process_time + ) + enc_ns = min(t) / ITERATIONS * 1e9 + print(f" {'current':20s} {enc_ns:8.1f} (frame: {len(ref_frame)} bytes)") + print() + print() + + +if __name__ == "__main__": + main() diff --git a/cassandra/protocol.py b/cassandra/protocol.py index 4628c7ee0e..ab27c89ead 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -587,9 +587,20 @@ def _write_query_params(self, f, protocol_version): write_byte(f, flags) if self.query_params is not None: - write_short(f, len(self.query_params)) + # Accumulate param bytes in a list and write once instead of + # 2*N+1 separate f.write() calls via write_value(). + _int32_pack = int32_pack + parts = [uint16_pack(len(self.query_params))] + _parts_append = parts.append for param in self.query_params: - write_value(f, param) + if param is None: + _parts_append(_int32_pack(-1)) + elif param is _UNSET_VALUE: + _parts_append(_int32_pack(-2)) + else: + _parts_append(_int32_pack(len(param))) + _parts_append(param) + f.write(b"".join(parts)) if self.fetch_size: write_int(f, self.fetch_size) if self.paging_state: @@ -635,8 +646,8 @@ def __init__(self, query_id, query_params, consistency_level, super(ExecuteMessage, self).__init__(query_params, consistency_level, serial_consistency_level, fetch_size, paging_state, timestamp, skip_meta, continuous_paging_options) - def _write_query_params(self, f, protocol_version): - super(ExecuteMessage, self)._write_query_params(f, protocol_version) + # _write_query_params inherited from _QueryMessage; removed redundant + # pass-through override to avoid extra MRO lookup per call. def send_body(self, f, protocol_version): write_string(f, self.query_id) diff --git a/tests/unit/test_protocol.py b/tests/unit/test_protocol.py index 9704811239..b165153b3b 100644 --- a/tests/unit/test_protocol.py +++ b/tests/unit/test_protocol.py @@ -15,16 +15,19 @@ import unittest from unittest.mock import Mock +import io +import struct from cassandra import ProtocolVersion, UnsupportedOperation from cassandra.protocol import ( PrepareMessage, QueryMessage, ExecuteMessage, UnsupportedOperation, _PAGING_OPTIONS_FLAG, _WITH_SERIAL_CONSISTENCY_FLAG, _PAGE_SIZE_FLAG, _WITH_PAGING_STATE_FLAG, - BatchMessage + BatchMessage, + _UNSET_VALUE, write_value, ProtocolHandler ) from cassandra.query import BatchType -from cassandra.marshal import uint32_unpack +from cassandra.marshal import uint32_unpack, int32_pack, uint16_pack from cassandra.cluster import ContinuousPagingOptions import pytest @@ -189,3 +192,184 @@ def test_batch_message_with_keyspace(self): (b'\x00\x03',), (b'\x00\x00\x00\x80',), (b'\x00\x02',), (b'ks',)) ) + +class WriteQueryParamsBufferAccumulationTest(unittest.TestCase): + """ + Tests for the buffer accumulation optimization in + _QueryMessage._write_query_params(). + + The optimization replaces per-parameter write_value(f, param) calls with + list.append + b"".join + single f.write(). These tests verify the + serialized bytes are identical to the original write_value() behaviour. + """ + + # -- helpers ---------------------------------------------------------- + + @staticmethod + def _reference_write_value_bytes(params): + """Build expected bytes using the original write_value() function.""" + buf = io.BytesIO() + buf.write(uint16_pack(len(params))) + for p in params: + write_value(buf, p) + return buf.getvalue() + + @staticmethod + def _execute_msg_bytes(msg, protocol_version): + """Serialize an ExecuteMessage and return the raw bytes.""" + buf = io.BytesIO() + msg.send_body(buf, protocol_version) + return buf.getvalue() + + # -- basic write_value parity ----------------------------------------- + + def test_normal_params(self): + """Normal (non-NULL, non-UNSET) byte-string parameters.""" + params = [b'hello', b'world', b'\x00\x01\x02'] + expected = self._reference_write_value_bytes(params) + msg = ExecuteMessage(query_id=b'qid', query_params=params, + consistency_level=1) + raw = self._execute_msg_bytes(msg, protocol_version=4) + self.assertIn(expected, raw) + + def test_null_params(self): + """NULL parameters must serialize as int32(-1).""" + params = [None, None] + expected = self._reference_write_value_bytes(params) + msg = ExecuteMessage(query_id=b'qid', query_params=params, + consistency_level=1) + raw = self._execute_msg_bytes(msg, protocol_version=4) + self.assertIn(expected, raw) + + def test_unset_params(self): + """UNSET parameters must serialize as int32(-2).""" + params = [_UNSET_VALUE, _UNSET_VALUE] + expected = self._reference_write_value_bytes(params) + msg = ExecuteMessage(query_id=b'qid', query_params=params, + consistency_level=4) + raw = self._execute_msg_bytes(msg, protocol_version=4) + self.assertIn(expected, raw) + + def test_mixed_params(self): + """Mix of normal, NULL and UNSET params in one message.""" + params = [b'data', None, _UNSET_VALUE, b'more', None] + expected = self._reference_write_value_bytes(params) + msg = ExecuteMessage(query_id=b'qid', query_params=params, + consistency_level=1) + raw = self._execute_msg_bytes(msg, protocol_version=4) + self.assertIn(expected, raw) + + def test_empty_bytes_param(self): + """An empty bytes value (length 0) must differ from NULL (length -1).""" + params = [b''] + expected = self._reference_write_value_bytes(params) + msg = ExecuteMessage(query_id=b'qid', query_params=params, + consistency_level=1) + raw = self._execute_msg_bytes(msg, protocol_version=4) + self.assertIn(expected, raw) + # Verify it's NOT serialized as NULL + null_bytes = int32_pack(-1) + param_section_start = raw.find(expected) + param_section = raw[param_section_start:param_section_start + len(expected)] + self.assertNotIn(null_bytes, param_section[2:]) # skip the uint16 count + + def test_empty_query_params_list(self): + """An empty params list should write count=0 and nothing else.""" + params = [] + expected = self._reference_write_value_bytes(params) + self.assertEqual(expected, uint16_pack(0)) + msg = ExecuteMessage(query_id=b'qid', query_params=params, + consistency_level=1) + raw = self._execute_msg_bytes(msg, protocol_version=4) + self.assertIn(expected, raw) + + def test_none_query_params(self): + """When query_params is None, no param block should be written.""" + msg1 = ExecuteMessage(query_id=b'qid', query_params=None, + consistency_level=1) + msg2 = ExecuteMessage(query_id=b'qid', query_params=[b'x'], + consistency_level=1) + raw1 = self._execute_msg_bytes(msg1, protocol_version=4) + raw2 = self._execute_msg_bytes(msg2, protocol_version=4) + # raw1 should be shorter (no param section) + self.assertLess(len(raw1), len(raw2)) + + def test_large_vector_param(self): + """Large parameter simulating a high-dimensional vector embedding.""" + # 768-dimensional float32 vector = 3072 bytes + vector_bytes = struct.pack('768f', *([0.123456] * 768)) + params = [vector_bytes] + expected = self._reference_write_value_bytes(params) + msg = ExecuteMessage(query_id=b'qid', query_params=params, + consistency_level=1) + raw = self._execute_msg_bytes(msg, protocol_version=4) + self.assertIn(expected, raw) + + def test_query_message_with_params(self): + """QueryMessage (not just ExecuteMessage) uses the same code path.""" + params = [b'val1', None, b'val2'] + expected = self._reference_write_value_bytes(params) + msg = QueryMessage(query='SELECT * FROM t WHERE k=? AND v=? AND w=?', + consistency_level=1, + query_params=params) + raw = io.BytesIO() + msg.send_body(raw, protocol_version=4) + self.assertIn(expected, raw.getvalue()) + + def test_proto_v3_vs_v4_params(self): + """The param encoding should be identical across protocol versions.""" + params = [b'abc', None, b'xyz'] + msg_v3 = ExecuteMessage(query_id=b'qid', query_params=params, + consistency_level=1) + msg_v4 = ExecuteMessage(query_id=b'qid', query_params=params, + consistency_level=1) + raw_v3 = self._execute_msg_bytes(msg_v3, protocol_version=3) + raw_v4 = self._execute_msg_bytes(msg_v4, protocol_version=4) + expected = self._reference_write_value_bytes(params) + self.assertIn(expected, raw_v3) + self.assertIn(expected, raw_v4) + + def test_encode_message_roundtrip(self): + """Full encode_message path exercises header + body framing.""" + params = [b'roundtrip'] + msg = QueryMessage(query='SELECT 1', + consistency_level=1, + query_params=params) + # encode_message returns the full on-wire frame + frame = ProtocolHandler.encode_message(msg, stream_id=1, + protocol_version=4, + compressor=None, + allow_beta_protocol_version=False) + # The frame should contain the param bytes somewhere inside + expected_param_bytes = self._reference_write_value_bytes(params) + # frame may be memoryview/bytearray; convert to bytes for assertIn + frame_bytes = bytes(frame) + self.assertIn(expected_param_bytes, frame_bytes) + + def test_many_params(self): + """50 parameters to exercise the accumulation loop at scale.""" + params = [b'param_%03d' % i for i in range(50)] + expected = self._reference_write_value_bytes(params) + msg = ExecuteMessage(query_id=b'qid', query_params=params, + consistency_level=1) + raw = self._execute_msg_bytes(msg, protocol_version=4) + self.assertIn(expected, raw) + + def test_single_null_param(self): + """Regression: a single NULL param should serialize correctly.""" + params = [None] + expected = self._reference_write_value_bytes(params) + msg = ExecuteMessage(query_id=b'qid', query_params=params, + consistency_level=1) + raw = self._execute_msg_bytes(msg, protocol_version=4) + self.assertIn(expected, raw) + + def test_single_unset_param(self): + """Regression: a single UNSET param should serialize correctly.""" + params = [_UNSET_VALUE] + expected = self._reference_write_value_bytes(params) + msg = ExecuteMessage(query_id=b'qid', query_params=params, + consistency_level=4) + raw = self._execute_msg_bytes(msg, protocol_version=4) + self.assertIn(expected, raw) +