diff --git a/benchmarks/micro/bench_col_names_cache.py b/benchmarks/micro/bench_col_names_cache.py new file mode 100644 index 0000000000..0ce6ce0edd --- /dev/null +++ b/benchmarks/micro/bench_col_names_cache.py @@ -0,0 +1,66 @@ +# Copyright ScyllaDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Micro-benchmark: column_names / column_types extraction from metadata. + +Measures the cost of building [c[2] for c in metadata] and [c[3] for c in metadata] +vs using pre-cached lists (as done for prepared statements with result_metadata). + +Run: + python benchmarks/bench_col_names_cache.py +""" + +import sys +import timeit + + +def make_column_metadata(ncols): + """Create fake column_metadata tuples like recv_results_metadata produces.""" + class FakeType: + pass + return [(f"ks_{i}", f"tbl_{i}", f"col_{i}", FakeType) for i in range(ncols)] + + +def bench(): + for ncols in (5, 10, 20, 50): + metadata = make_column_metadata(ncols) + + # Pre-cached (done once at prepare time) + cached_names = [c[2] for c in metadata] + cached_types = [c[3] for c in metadata] + + def extract_uncached(): + names = [c[2] for c in metadata] + types = [c[3] for c in metadata] + return names, types + + def extract_cached(): + return cached_names, cached_types + + n = 500_000 + t_uncached = timeit.timeit(extract_uncached, number=n) + t_cached = timeit.timeit(extract_cached, number=n) + + saving_ns = (t_uncached - t_cached) / n * 1e9 + speedup = t_uncached / t_cached if t_cached > 0 else float('inf') + print(f" {ncols} cols: uncached={t_uncached / n * 1e9:.1f} ns, " + f"cached={t_cached / n * 1e9:.1f} ns, " + f"saving={saving_ns:.1f} ns ({speedup:.1f}x)") + + +if __name__ == "__main__": + print(f"Python {sys.version}") + print("\n=== column_names / column_types extraction ===") + bench() diff --git a/benchmarks/micro/bench_result_kind_dispatch.py b/benchmarks/micro/bench_result_kind_dispatch.py new file mode 100644 index 0000000000..4ff6b790f6 --- /dev/null +++ b/benchmarks/micro/bench_result_kind_dispatch.py @@ -0,0 +1,115 @@ +# Copyright ScyllaDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Micro-benchmark: RESULT_KIND dispatch ordering and getattr vs direct access. + +Measures the cost difference between: +1. Checking RESULT_KIND_ROWS first vs third in the if/elif chain +2. getattr(msg, 'continuous_paging_options', None) vs msg.continuous_paging_options + +Run: + python benchmarks/bench_result_kind_dispatch.py +""" + +import sys +import timeit + + +def bench(): + n = 2_000_000 + + # Simulate the result kind values + RESULT_KIND_SET_KEYSPACE = 0x0003 + RESULT_KIND_SCHEMA_CHANGE = 0x0005 + RESULT_KIND_ROWS = 0x0002 + RESULT_KIND_VOID = 0x0001 + + kind = RESULT_KIND_ROWS # the common case + + # Old order: SET_KEYSPACE, SCHEMA_CHANGE, ROWS, VOID + def old_dispatch(): + if kind == RESULT_KIND_SET_KEYSPACE: + return 'set_keyspace' + elif kind == RESULT_KIND_SCHEMA_CHANGE: + return 'schema_change' + elif kind == RESULT_KIND_ROWS: + return 'rows' + elif kind == RESULT_KIND_VOID: + return 'void' + + # New order: ROWS, VOID, SET_KEYSPACE, SCHEMA_CHANGE + def new_dispatch(): + if kind == RESULT_KIND_ROWS: + return 'rows' + elif kind == RESULT_KIND_VOID: + return 'void' + elif kind == RESULT_KIND_SET_KEYSPACE: + return 'set_keyspace' + elif kind == RESULT_KIND_SCHEMA_CHANGE: + return 'schema_change' + + print(f"=== RESULT_KIND dispatch order ({n:,} iters) ===\n") + + # Warmup + for _ in range(10000): + old_dispatch() + new_dispatch() + + t_old = timeit.timeit(old_dispatch, number=n) + t_new = timeit.timeit(new_dispatch, number=n) + ns_old = t_old / n * 1e9 + ns_new = t_new / n * 1e9 + saving = ns_old - ns_new + speedup = ns_old / ns_new if ns_new > 0 else float('inf') + print(f" Old (ROWS=3rd): {ns_old:.1f} ns") + print(f" New (ROWS=1st): {ns_new:.1f} ns") + print(f" Saving: {saving:.1f} ns ({speedup:.2f}x)") + + # getattr vs direct attribute access + print(f"\n=== getattr vs direct attribute access ({n:,} iters) ===\n") + + class OldMsg: + pass + + class NewMsg: + continuous_paging_options = None + + old_msg = OldMsg() + new_msg = NewMsg() + + def old_getattr(): + return getattr(old_msg, 'continuous_paging_options', None) + + def new_direct(): + return new_msg.continuous_paging_options + + for _ in range(10000): + old_getattr() + new_direct() + + t_old = timeit.timeit(old_getattr, number=n) + t_new = timeit.timeit(new_direct, number=n) + ns_old = t_old / n * 1e9 + ns_new = t_new / n * 1e9 + saving = ns_old - ns_new + speedup = ns_old / ns_new if ns_new > 0 else float('inf') + print(f" getattr(msg, 'continuous_paging_options', None): {ns_old:.1f} ns") + print(f" msg.continuous_paging_options: {ns_new:.1f} ns") + print(f" Saving: {saving:.1f} ns ({speedup:.2f}x)") + + +if __name__ == "__main__": + print(f"Python {sys.version}\n") + bench() diff --git a/benchmarks/micro/bench_session_cluster_cache.py b/benchmarks/micro/bench_session_cluster_cache.py new file mode 100644 index 0000000000..121341f8f4 --- /dev/null +++ b/benchmarks/micro/bench_session_cluster_cache.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python +""" +Benchmark: caching self.session.cluster as a local variable. + +Measures the cost of repeated self.session.cluster double-lookups +vs. a single local assignment. +""" + +import sys +import time + +ITERS = 5_000_000 + + +class FakeCluster: + class control_connection: + _tablets_routing_v1 = True + protocol_version = 5 + class metadata: + class _tablets: + @staticmethod + def add_tablet(ks, tbl, tablet): + pass + + +class FakeSession: + cluster = FakeCluster() + + +class FakeResponseFuture: + def __init__(self): + self.session = FakeSession() + + +def bench_double_lookup(rf, n): + """Simulates 3 accesses to self.session.cluster (tablet routing block).""" + t0 = time.perf_counter_ns() + for _ in range(n): + _ = rf.session.cluster.control_connection + _ = rf.session.cluster.protocol_version + _ = rf.session.cluster.metadata + return (time.perf_counter_ns() - t0) / n + + +def bench_cached_local(rf, n): + """Simulates caching session.cluster in a local.""" + t0 = time.perf_counter_ns() + for _ in range(n): + cluster = rf.session.cluster + _ = cluster.control_connection + _ = cluster.protocol_version + _ = cluster.metadata + return (time.perf_counter_ns() - t0) / n + + +def main(): + print(f"Python {sys.version}\n") + rf = FakeResponseFuture() + + ns_old = bench_double_lookup(rf, ITERS) + ns_new = bench_cached_local(rf, ITERS) + saving = ns_old - ns_new + speedup = ns_old / ns_new if ns_new else float('inf') + + print(f"=== self.session.cluster caching ({ITERS:,} iters) ===\n") + print(f" 3x self.session.cluster (old): {ns_old:.1f} ns") + print(f" 1x local + 3x local (new): {ns_new:.1f} ns") + print(f" Saving: {saving:.1f} ns ({speedup:.2f}x)") + + +if __name__ == "__main__": + main() diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 9eace8810d..af0cedbdcf 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -4725,34 +4725,56 @@ def _set_result(self, host, connection, pool, response): if pool and not pool.is_shutdown: pool.return_connection(connection) - trace_id = getattr(response, 'trace_id', None) - if trace_id: - if not self._query_traces: - self._query_traces = [] - self._query_traces.append(QueryTrace(trace_id, self.session)) - - self._warnings = getattr(response, 'warnings', None) - self._custom_payload = getattr(response, 'custom_payload', None) - - if self._custom_payload and self.session.cluster.control_connection._tablets_routing_v1 and 'tablets-routing-v1' in self._custom_payload: - protocol = self.session.cluster.protocol_version - info = self._custom_payload.get('tablets-routing-v1') - ctype = ResponseFuture._TABLET_ROUTING_CTYPE - if ctype is None: - ctype = types.lookup_casstype('TupleType(LongType, LongType, ListType(TupleType(UUIDType, Int32Type)))') - ResponseFuture._TABLET_ROUTING_CTYPE = ctype - tablet_routing_info = ctype.from_binary(info, protocol) - first_token = tablet_routing_info[0] - last_token = tablet_routing_info[1] - tablet_replicas = tablet_routing_info[2] - tablet = Tablet.from_row(first_token, last_token, tablet_replicas) - keyspace = self.query.keyspace - table = self.query.table - self.session.cluster.metadata._tablets.add_tablet(keyspace, table, tablet) - if isinstance(response, ResultMessage): - if response.kind == RESULT_KIND_SET_KEYSPACE: - session = getattr(self, 'session', None) + # Hot path: ResultMessage has trace_id, warnings, and + # custom_payload in __slots__, always initialised in __init__, + # so direct attribute access is safe and faster than getattr(). + trace_id = response.trace_id + session = self.session + if trace_id: + if not self._query_traces: + self._query_traces = [] + self._query_traces.append(QueryTrace(trace_id, session)) + + self._warnings = response.warnings + custom_payload = response.custom_payload + self._custom_payload = custom_payload + + # Cache session.cluster to avoid repeated double-lookup in the + # tablet routing block (3 accesses) and schema-change path. + cluster = session.cluster + if custom_payload and cluster.control_connection._tablets_routing_v1: + info = custom_payload.get('tablets-routing-v1') + if info is not None: + ctype = ResponseFuture._TABLET_ROUTING_CTYPE + if ctype is None: + ctype = types.lookup_casstype('TupleType(LongType, LongType, ListType(TupleType(UUIDType, Int32Type)))') + ResponseFuture._TABLET_ROUTING_CTYPE = ctype + first_token, last_token, tablet_replicas = ctype.from_binary(info, cluster.protocol_version) + tablet = Tablet.from_row(first_token, last_token, tablet_replicas) + if tablet is not None: + cluster.metadata._tablets.add_tablet(self.query.keyspace, self.query.table, tablet) + + if response.kind == RESULT_KIND_ROWS: + self._paging_state = response.paging_state + # Use pre-cached column names/types from PreparedStatement + # when available to avoid rebuilding lists from metadata. + ps = self.prepared_statement + if ps is not None and ps._result_col_names is not None: + col_names = ps._result_col_names + col_types = ps._result_col_types + else: + col_names = response.column_names + col_types = response.column_types + self._col_names = col_names + self._col_types = col_types + if self.message.continuous_paging_options: + self._handle_continuous_paging_first_response(connection, response) + else: + self._set_final_result(self.row_factory(col_names, response.parsed_rows)) + elif response.kind == RESULT_KIND_VOID: + self._set_final_result(None) + elif response.kind == RESULT_KIND_SET_KEYSPACE: # since we're running on the event loop thread, we need to # use a non-blocking method for setting the keyspace on # all connections in this session, otherwise the event @@ -4767,23 +4789,25 @@ def _set_result(self, host, connection, pool, response): # refresh the schema before responding, but do it in another # thread instead of the event loop thread self.is_schema_agreed = False - self.session.submit( + session.submit( refresh_schema_and_set_result, - self.session.cluster.control_connection, + cluster.control_connection, self, connection, **response.schema_change_event) - elif response.kind == RESULT_KIND_ROWS: - self._paging_state = response.paging_state - self._col_names = response.column_names - self._col_types = response.column_types - if getattr(self.message, 'continuous_paging_options', None): - self._handle_continuous_paging_first_response(connection, response) - else: - self._set_final_result(self.row_factory(response.column_names, response.parsed_rows)) - elif response.kind == RESULT_KIND_VOID: - self._set_final_result(None) else: self._set_final_result(response) elif isinstance(response, ErrorMessage): + # Cold path: ErrorMessage inherits from _MessageType which + # defines warnings/custom_payload as class-level defaults but + # does NOT have trace_id -- getattr is required here. + trace_id = getattr(response, 'trace_id', None) + if trace_id: + if not self._query_traces: + self._query_traces = [] + self._query_traces.append(QueryTrace(trace_id, self.session)) + + self._warnings = getattr(response, 'warnings', None) + self._custom_payload = getattr(response, 'custom_payload', None) + retry_policy = self._retry_policy if isinstance(response, ReadTimeoutErrorMessage): @@ -4862,6 +4886,10 @@ def _set_result(self, host, connection, pool, response): self._handle_retry_decision(retry, response, host) elif isinstance(response, ConnectionException): + # ConnectionException has no trace_id/warnings/custom_payload; + # clear any stale values from a previous retry attempt. + self._warnings = None + self._custom_payload = None if self._metrics is not None: self._metrics.on_connection_error() if not isinstance(response, ConnectionShutdown): @@ -4871,6 +4899,8 @@ def _set_result(self, host, connection, pool, response): self.query, cl, error=response, retry_num=self._query_retries) self._handle_retry_decision(retry, response, host) elif isinstance(response, Exception): + self._warnings = None + self._custom_payload = None if hasattr(response, 'to_exception'): self._set_final_exception(response.to_exception()) else: @@ -4926,6 +4956,7 @@ def _execute_after_prepare(self, host, connection, pool, response): ) )) self.prepared_statement.result_metadata = response.column_metadata + self.prepared_statement._cache_result_metadata_columns(response.column_metadata) new_metadata_id = response.result_metadata_id if new_metadata_id is not None: self.prepared_statement.result_metadata_id = new_metadata_id @@ -5076,9 +5107,12 @@ def result(self): ... log.exception("Operation failed:") """ + return ResultSet(self, self._wait_for_result()) + + def _wait_for_result(self): self._event.wait() if self._final_result is not _NOT_SET: - return ResultSet(self, self._final_result) + return self._final_result else: raise self._final_exception @@ -5264,6 +5298,9 @@ class ResultSet(object): like you might see on a normal call to ``session.execute()``. """ + __slots__ = ('response_future', 'column_names', 'column_types', + '_current_rows', '_page_iter', '_list_mode') + def __init__(self, response_future, initial_response): self.response_future = response_future self.column_names = response_future._col_names @@ -5349,8 +5386,7 @@ def fetch_next_page(self): """ if self.response_future.has_more_pages: self.response_future.start_fetching_next_page() - result = self.response_future.result() - self._current_rows = result._current_rows # ResultSet has already _set_current_rows to the appropriate form + self._set_current_rows(self.response_future._wait_for_result()) else: self._current_rows = [] diff --git a/cassandra/protocol.py b/cassandra/protocol.py index 4628c7ee0e..b063f87baf 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -86,6 +86,8 @@ def __init__(cls, name, bases, dct): class _MessageType(object, metaclass=_RegisterMessageType): + __slots__ = () + tracing = False custom_payload = None warnings = None @@ -105,7 +107,7 @@ def __repr__(self): def _get_params(message_obj): base_attrs = dir(_MessageType) return ( - (n, a) for n, a in message_obj.__dict__.items() + (n, a) for n, a in getattr(message_obj, '__dict__', {}).items() if n not in base_attrs and not n.startswith('_') and not callable(a) ) @@ -542,6 +544,11 @@ def recv_body(cls, f, *args): class _QueryMessage(_MessageType): + # DSE continuous paging: stored when the feature is active, otherwise None. + # Declared as a class attribute so that callers can use direct attribute + # access instead of getattr(msg, 'continuous_paging_options', None). + continuous_paging_options = None + def __init__(self, query_params, consistency_level, serial_consistency_level=None, fetch_size=None, paging_state=None, timestamp=None, skip_meta=False, @@ -658,9 +665,12 @@ class ResultMessage(_MessageType): opcode = 0x08 name = 'RESULT' - kind = None - results = None - paging_state = None + __slots__ = ('kind', 'column_names', 'column_types', 'parsed_rows', + 'paging_state', 'continuous_paging_seq', 'continuous_paging_last', + 'new_keyspace', 'column_metadata', 'query_id', 'bind_metadata', + 'pk_indexes', 'schema_change_event', 'is_lwt', + 'result_metadata_id', 'stream_id', 'trace_id', + 'custom_payload', 'warnings') # Names match type name in module scope. Most are imported from cassandra.cqltypes (except CUSTOM_TYPE) type_codes = _cqltypes_by_code = dict((v, globals()[k]) for k, v in type_codes.__dict__.items() if not k.startswith('_')) @@ -672,25 +682,26 @@ class ResultMessage(_MessageType): _CONTINUOUS_PAGING_LAST_FLAG = 0x80000000 _METADATA_ID_FLAG = 0x0008 - kind = None - - # These are all the things a result message might contain. They are populated according to 'kind' - column_names = None - column_types = None - parsed_rows = None - paging_state = None - continuous_paging_seq = None - continuous_paging_last = None - new_keyspace = None - column_metadata = None - query_id = None - bind_metadata = None - pk_indexes = None - schema_change_event = None - is_lwt = False - def __init__(self, kind): self.kind = kind + self.column_names = None + self.column_types = None + self.parsed_rows = None + self.paging_state = None + self.continuous_paging_seq = None + self.continuous_paging_last = None + self.new_keyspace = None + self.column_metadata = None + self.query_id = None + self.bind_metadata = None + self.pk_indexes = None + self.schema_change_event = None + self.is_lwt = False + self.result_metadata_id = None + self.stream_id = None + self.trace_id = None + self.custom_payload = None + self.warnings = None def recv(self, f, protocol_version, protocol_features, user_type_map, result_metadata, column_encryption_policy): if self.kind == RESULT_KIND_VOID: @@ -901,6 +912,11 @@ class BatchMessage(_MessageType): opcode = 0x0D name = 'BATCH' + # Batch messages never use continuous paging, but callers access this + # attribute directly (instead of getattr) for speed. Declare it here so + # that BatchMessage matches the same interface as _QueryMessage. + continuous_paging_options = None + def __init__(self, batch_type, queries, consistency_level, serial_consistency_level=None, timestamp=None, keyspace=None): @@ -1182,11 +1198,14 @@ def decode_message(cls, protocol_version, protocol_features, user_type_map, stre msg_class = cls.message_types_by_opcode[opcode] msg = msg_class.recv_body(body, protocol_version, protocol_features, user_type_map, result_metadata, cls.column_encryption_policy) msg.stream_id = stream_id - msg.trace_id = trace_id - msg.custom_payload = custom_payload - msg.warnings = warnings - - if msg.warnings: + if trace_id is not None: + msg.trace_id = trace_id + if custom_payload is not None: + msg.custom_payload = custom_payload + if warnings is not None: + msg.warnings = warnings + + if warnings: for w in msg.warnings: log.warning("Server warning: %s", w) @@ -1218,6 +1237,7 @@ class FastResultMessage(ResultMessage): Cython version of Result Message that has a faster implementation of recv_results_row. """ + __slots__ = () # type_codes = ResultMessage.type_codes.copy() code_to_type = dict((v, k) for k, v in ResultMessage.type_codes.items()) recv_results_rows = make_recv_results_rows(colparser) diff --git a/cassandra/query.py b/cassandra/query.py index 6c6878fdb4..cfb8e9a2e8 100644 --- a/cassandra/query.py +++ b/cassandra/query.py @@ -459,6 +459,11 @@ class PreparedStatement(object): serial_consistency_level = None # TODO never used? _is_lwt = False + # Cached column names/types derived from result_metadata, to avoid + # rebuilding [c[2] for c in result_metadata] on every result set. + _result_col_names = None + _result_col_types = None + def __init__(self, column_metadata, query_id, routing_key_indexes, query, keyspace, protocol_version, result_metadata, result_metadata_id, is_lwt=False, column_encryption_policy=None): @@ -469,11 +474,21 @@ def __init__(self, column_metadata, query_id, routing_key_indexes, query, self.keyspace = keyspace self.protocol_version = protocol_version self.result_metadata = result_metadata + self._cache_result_metadata_columns(result_metadata) self.result_metadata_id = result_metadata_id self.column_encryption_policy = column_encryption_policy self.is_idempotent = False self._is_lwt = is_lwt + def _cache_result_metadata_columns(self, result_metadata): + """Pre-compute column names and types from result_metadata.""" + if result_metadata: + self._result_col_names = [c[2] for c in result_metadata] + self._result_col_types = [c[3] for c in result_metadata] + else: + self._result_col_names = None + self._result_col_types = None + @classmethod def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata, query, prepared_keyspace, protocol_version, result_metadata, diff --git a/tests/unit/test_response_future.py b/tests/unit/test_response_future.py index 7168ad2940..461e5ebdef 100644 --- a/tests/unit/test_response_future.py +++ b/tests/unit/test_response_future.py @@ -61,7 +61,8 @@ def make_response_future(self, session): return ResponseFuture(session, message, query, 1) def make_mock_response(self, col_names, rows): - return Mock(spec=ResultMessage, kind=RESULT_KIND_ROWS, column_names=col_names, parsed_rows=rows, paging_state=None, col_types=None) + return Mock(spec=ResultMessage, kind=RESULT_KIND_ROWS, column_names=col_names, parsed_rows=rows, + paging_state=None, col_types=None, trace_id=None, warnings=None, custom_payload=None) def test_result_message(self): session = self.make_basic_session() @@ -104,7 +105,8 @@ def test_set_keyspace_result(self): result = Mock(spec=ResultMessage, kind=RESULT_KIND_SET_KEYSPACE, - results="keyspace1") + results="keyspace1", + trace_id=None, warnings=None, custom_payload=None) rf._set_result(None, None, None, result) rf._set_keyspace_completed({}) assert not rf.result() @@ -118,7 +120,8 @@ def test_schema_change_result(self): 'keyspace': "keyspace1", "table": "table1"} result = Mock(spec=ResultMessage, kind=RESULT_KIND_SCHEMA_CHANGE, - schema_change_event=event_results) + schema_change_event=event_results, + trace_id=None, warnings=None, custom_payload=None) connection = Mock() rf._set_result(None, connection, None, result) session.submit.assert_called_once_with(ANY, ANY, rf, connection, **event_results) @@ -127,7 +130,8 @@ def test_other_result_message_kind(self): session = self.make_session() rf = self.make_response_future(session) rf.send_request() - result = Mock(spec=ResultMessage, kind=999, results=[1, 2, 3]) + result = Mock(spec=ResultMessage, kind=999, results=[1, 2, 3], + trace_id=None, warnings=None, custom_payload=None) rf._set_result(None, None, None, result) assert rf.result()[0] == result @@ -629,7 +633,6 @@ def test_repeat_orig_query_after_succesful_reprepare(self): response = Mock(spec=ResultMessage, kind=RESULT_KIND_PREPARED, result_metadata_id='foo') - response.results = (None, None, None, None, None) response.query_id = query_id rf._query = Mock(return_value=True) diff --git a/tests/unit/test_resultset.py b/tests/unit/test_resultset.py index 80e9c21ff9..6d87fa229d 100644 --- a/tests/unit/test_resultset.py +++ b/tests/unit/test_resultset.py @@ -33,7 +33,7 @@ def test_iter_non_paged(self): def test_iter_paged(self): expected = list(range(10)) response_future = Mock(has_more_pages=True, _continuous_paging_session=None) - response_future.result.side_effect = (ResultSet(Mock(), expected[-5:]), ) # ResultSet is iterable, so it must be protected in order to be returned whole by the Mock + response_future._wait_for_result.side_effect = (expected[-5:], ) rs = ResultSet(response_future, expected[:5]) itr = iter(rs) # this is brittle, depends on internal impl details. Would like to find a better way @@ -43,11 +43,11 @@ def test_iter_paged(self): def test_iter_paged_with_empty_pages(self): expected = list(range(10)) response_future = Mock(has_more_pages=True, _continuous_paging_session=None) - response_future.result.side_effect = [ - ResultSet(Mock(), []), - ResultSet(Mock(), [0, 1, 2, 3, 4]), - ResultSet(Mock(), []), - ResultSet(Mock(), [5, 6, 7, 8, 9]), + response_future._wait_for_result.side_effect = [ + [], + [0, 1, 2, 3, 4], + [], + [5, 6, 7, 8, 9], ] rs = ResultSet(response_future, []) itr = iter(rs) @@ -65,7 +65,7 @@ def test_list_paged(self): # list access on RS for backwards-compatibility expected = list(range(10)) response_future = Mock(has_more_pages=True, _continuous_paging_session=None) - response_future.result.side_effect = (ResultSet(Mock(), expected[-5:]), ) # ResultSet is iterable, so it must be protected in order to be returned whole by the Mock + response_future._wait_for_result.side_effect = (expected[-5:], ) rs = ResultSet(response_future, expected[:5]) # this is brittle, depends on internal impl details. Would like to find a better way type(response_future).has_more_pages = PropertyMock(side_effect=(True, True, True, False)) # First two True are consumed on check entering list mode @@ -98,7 +98,7 @@ def test_iterate_then_index(self): # RuntimeError if indexing during or after pages response_future = Mock(has_more_pages=True, _continuous_paging_session=None) - response_future.result.side_effect = (ResultSet(Mock(), expected[-5:]), ) # ResultSet is iterable, so it must be protected in order to be returned whole by the Mock + response_future._wait_for_result.side_effect = (expected[-5:], ) rs = ResultSet(response_future, expected[:5]) type(response_future).has_more_pages = PropertyMock(side_effect=(True, False)) itr = iter(rs) @@ -131,7 +131,7 @@ def test_index_list_mode(self): # pages response_future = Mock(has_more_pages=True, _continuous_paging_session=None) - response_future.result.side_effect = (ResultSet(Mock(), expected[-5:]), ) # ResultSet is iterable, so it must be protected in order to be returned whole by the Mock + response_future._wait_for_result.side_effect = (expected[-5:], ) rs = ResultSet(response_future, expected[:5]) # this is brittle, depends on internal impl details. Would like to find a better way type(response_future).has_more_pages = PropertyMock(side_effect=(True, True, True, False)) # First two True are consumed on check entering list mode @@ -159,7 +159,7 @@ def test_eq(self): # pages response_future = Mock(has_more_pages=True, _continuous_paging_session=None) - response_future.result.side_effect = (ResultSet(Mock(), expected[-5:]), ) # ResultSet is iterable, so it must be protected in order to be returned whole by the Mock + response_future._wait_for_result.side_effect = (expected[-5:], ) rs = ResultSet(response_future, expected[:5]) type(response_future).has_more_pages = PropertyMock(side_effect=(True, True, True, False)) # eq before iteration causes list to be materialized