diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index 3f0bea7fbe1a..7bca1fc63385 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -1047,7 +1047,7 @@ def test_timing_metrics(self): def test_forwards_batch_args(self): examples = list(range(100)) with TestPipeline('FnApiRunner') as pipeline: - pcoll = pipeline | 'start' >> beam.Create(examples) + pcoll = pipeline | 'start' >> beam.Create(examples, reshuffle=False) actual = pcoll | base.RunInference(FakeModelHandlerNeedsBigBatch()) assert_that(actual, equal_to(examples), label='assert:inferences') diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py index 50279820b267..8efec14c865f 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py @@ -635,7 +635,8 @@ def batch_validator_keyed_tensor_inference_fn( min_batch_size=2, max_batch_size=2) - pcoll = pipeline | 'start' >> beam.Create(KEYED_TORCH_EXAMPLES) + pcoll = pipeline | 'start' >> beam.Create( + KEYED_TORCH_EXAMPLES, reshuffle=False) inference_args_side_input = ( pipeline | 'create side' >> beam.Create(inference_args)) predictions = pcoll | RunInference( @@ -709,7 +710,7 @@ def batch_validator_tensor_inference_fn( min_batch_size=2, max_batch_size=2) - pcoll = pipeline | 'start' >> beam.Create(examples) + pcoll = pipeline | 'start' >> beam.Create(examples, reshuffle=False) predictions = pcoll | RunInference(model_handler) assert_that( predictions, diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py b/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py index 400ac77cf498..76d6bc65729c 100644 --- a/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py @@ -299,7 +299,7 @@ def batch_validator_numpy_inference_fn( with TestPipeline() as pipeline: examples = [numpy.array([0, 0]), numpy.array([1, 1])] - pcoll = pipeline | 'start' >> beam.Create(examples) + pcoll = pipeline | 'start' >> beam.Create(examples, reshuffle=False) actual = pcoll | RunInference( SklearnModelHandlerNumpy( model_uri=temp_file_name, @@ -457,7 +457,7 @@ def batch_validator_pandas_inference_fn( with TestPipeline() as pipeline: dataframe = pandas_dataframe() splits = [dataframe.loc[[i]] for i in dataframe.index] - pcoll = pipeline | 'start' >> beam.Create(splits) + pcoll = pipeline | 'start' >> beam.Create(splits, reshuffle=False) actual = pcoll | RunInference( SklearnModelHandlerPandas( model_uri=temp_file_name, diff --git a/sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py b/sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py index c884ee58b0a0..3a2e58e378eb 100644 --- a/sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py @@ -165,7 +165,7 @@ def fake_batching_inference_fn( examples, [tf.math.multiply(n, 2) for n in examples]) ] - pcoll = pipeline | 'start' >> beam.Create(examples) + pcoll = pipeline | 'start' >> beam.Create(examples, reshuffle=False) predictions = pcoll | RunInference(model_handler) assert_that( predictions, @@ -258,7 +258,7 @@ def fake_batching_inference_fn( examples, [numpy.multiply(n, 2) for n in examples]) ] - pcoll = pipeline | 'start' >> beam.Create(examples) + pcoll = pipeline | 'start' >> beam.Create(examples, reshuffle=False) predictions = pcoll | RunInference(model_handler) assert_that( predictions, diff --git a/sdks/python/apache_beam/ml/inference/tensorrt_inference_test.py b/sdks/python/apache_beam/ml/inference/tensorrt_inference_test.py index 39e46c7f7c0d..80a01b8f4d4c 100644 --- a/sdks/python/apache_beam/ml/inference/tensorrt_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/tensorrt_inference_test.py @@ -362,7 +362,8 @@ def test_pipeline_single_tensor_feature_built_engine(self): max_batch_size=4, engine_path= 'gs://apache-beam-ml/models/single_tensor_features_engine.trt') - pcoll = pipeline | 'start' >> beam.Create(SINGLE_FEATURE_EXAMPLES) + pcoll = pipeline | 'start' >> beam.Create( + SINGLE_FEATURE_EXAMPLES, reshuffle=False) predictions = pcoll | RunInference(engine_handler) assert_that( predictions, @@ -423,7 +424,8 @@ def fake_inference_fn(batch, engine, inference_args=None): 'gs://apache-beam-ml/models/single_tensor_features_engine.trt', inference_fn=fake_inference_fn, large_model=True) - pcoll = pipeline | 'start' >> beam.Create(SINGLE_FEATURE_EXAMPLES) + pcoll = pipeline | 'start' >> beam.Create( + SINGLE_FEATURE_EXAMPLES, reshuffle=False) predictions = pcoll | RunInference(engine_handler) assert_that( predictions, @@ -443,7 +445,7 @@ def test_pipeline_sets_env_vars_correctly(self): self.assertFalse('FOO' in os.environ) _ = ( pipeline - | 'start' >> beam.Create(SINGLE_FEATURE_EXAMPLES) + | 'start' >> beam.Create(SINGLE_FEATURE_EXAMPLES, reshuffle=False) | RunInference(engine_handler)) pipeline.run() self.assertTrue('FOO' in os.environ) @@ -457,7 +459,8 @@ def test_pipeline_multiple_tensor_feature_built_engine(self): max_batch_size=4, engine_path= 'gs://apache-beam-ml/models/multiple_tensor_features_engine.trt') - pcoll = pipeline | 'start' >> beam.Create(TWO_FEATURES_EXAMPLES) + pcoll = pipeline | 'start' >> beam.Create( + TWO_FEATURES_EXAMPLES, reshuffle=False) predictions = pcoll | RunInference(engine_handler) assert_that( predictions, diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py index a965ff33d829..63ce42726c1f 100644 --- a/sdks/python/apache_beam/transforms/util_test.py +++ b/sdks/python/apache_beam/transforms/util_test.py @@ -75,6 +75,8 @@ from apache_beam.transforms.util import GcpHsmGeneratedSecret from apache_beam.transforms.util import GcpSecret from apache_beam.transforms.util import Secret +from apache_beam.transforms.util import _BatchSizeEstimator +from apache_beam.transforms.util import _GlobalWindowsBatchingDoFn from apache_beam.transforms.window import FixedWindows from apache_beam.transforms.window import GlobalWindow from apache_beam.transforms.window import GlobalWindows @@ -1258,6 +1260,53 @@ def check_batch_homogeneity(batch): checks = batches | beam.Map(check_batch_homogeneity) assert_that(checks, is_not_empty()) + def test_global_batching_dofn_single_vs_multiple_bundles(self): + # This test directly verifies how bundling affects the batch sizes produced by + # the internal _GlobalWindowsBatchingDoFn of BatchElements. + + # 1. Single Bundle Scenario: + # Four elements processed within the same start_bundle / finish_bundle lifecycle. + # min_batch_size = 2, max_batch_size = 2. + estimator = _BatchSizeEstimator(min_batch_size=2, max_batch_size=2) + dofn = _GlobalWindowsBatchingDoFn(estimator, element_size_fn=lambda x: 1) + + dofn.start_bundle() + outputs = [] + for elem in [1, 2, 3, 4]: + outputs.extend(dofn.process(elem)) + outputs.extend(dofn.finish_bundle() or []) + + # We should get exactly two batches of size 2. + batch_sizes = [len(wv.value) for wv in outputs] + self.assertEqual(batch_sizes, [2, 2]) + + # 2. Multiple Bundles Scenario (simulating elements split due to Reshuffle/GroupByKey): + # The runner splits elements into multiple bundles: + # Bundle 1 gets elements 1, 2, 3. + # Bundle 2 gets element 4. + estimator = _BatchSizeEstimator(min_batch_size=2, max_batch_size=2) + dofn = _GlobalWindowsBatchingDoFn(estimator, element_size_fn=lambda x: 1) + + outputs = [] + # Bundle 1 + dofn.start_bundle() + for elem in [1, 2, 3]: + outputs.extend(dofn.process(elem)) + outputs.extend(dofn.finish_bundle() or []) + + # Bundle 2 + dofn.start_bundle() + for elem in [4]: + outputs.extend(dofn.process(elem)) + outputs.extend(dofn.finish_bundle() or []) + + # The batch sizes will be [2, 1, 1] instead of [2, 2] because of bundle flushes. + # Specifically: + # - Bundle 1 emits a batch of 2, and then the remaining 1 element is flushed at finish_bundle (batch size 1). + # - Bundle 2 emits its 1 element at finish_bundle (batch size 1). + batch_sizes = [len(wv.value) for wv in outputs] + self.assertEqual(batch_sizes, [2, 1, 1]) + class SortAndBatchElementsTest(unittest.TestCase): """Tests for SortAndBatchElements transform."""