Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdks/python/apache_beam/ml/inference/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,7 +1047,7 @@ def test_timing_metrics(self):
def test_forwards_batch_args(self):
examples = list(range(100))
with TestPipeline('FnApiRunner') as pipeline:
pcoll = pipeline | 'start' >> beam.Create(examples)
pcoll = pipeline | 'start' >> beam.Create(examples, reshuffle=False)
actual = pcoll | base.RunInference(FakeModelHandlerNeedsBigBatch())
assert_that(actual, equal_to(examples), label='assert:inferences')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,8 @@ def batch_validator_keyed_tensor_inference_fn(
min_batch_size=2,
max_batch_size=2)

pcoll = pipeline | 'start' >> beam.Create(KEYED_TORCH_EXAMPLES)
pcoll = pipeline | 'start' >> beam.Create(
KEYED_TORCH_EXAMPLES, reshuffle=False)
inference_args_side_input = (
pipeline | 'create side' >> beam.Create(inference_args))
predictions = pcoll | RunInference(
Expand Down Expand Up @@ -709,7 +710,7 @@ def batch_validator_tensor_inference_fn(
min_batch_size=2,
max_batch_size=2)

pcoll = pipeline | 'start' >> beam.Create(examples)
pcoll = pipeline | 'start' >> beam.Create(examples, reshuffle=False)
predictions = pcoll | RunInference(model_handler)
assert_that(
predictions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def batch_validator_numpy_inference_fn(
with TestPipeline() as pipeline:
examples = [numpy.array([0, 0]), numpy.array([1, 1])]

pcoll = pipeline | 'start' >> beam.Create(examples)
pcoll = pipeline | 'start' >> beam.Create(examples, reshuffle=False)
actual = pcoll | RunInference(
SklearnModelHandlerNumpy(
model_uri=temp_file_name,
Expand Down Expand Up @@ -457,7 +457,7 @@ def batch_validator_pandas_inference_fn(
with TestPipeline() as pipeline:
dataframe = pandas_dataframe()
splits = [dataframe.loc[[i]] for i in dataframe.index]
pcoll = pipeline | 'start' >> beam.Create(splits)
pcoll = pipeline | 'start' >> beam.Create(splits, reshuffle=False)
actual = pcoll | RunInference(
SklearnModelHandlerPandas(
model_uri=temp_file_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def fake_batching_inference_fn(
examples, [tf.math.multiply(n, 2) for n in examples])
]

pcoll = pipeline | 'start' >> beam.Create(examples)
pcoll = pipeline | 'start' >> beam.Create(examples, reshuffle=False)
predictions = pcoll | RunInference(model_handler)
assert_that(
predictions,
Expand Down Expand Up @@ -258,7 +258,7 @@ def fake_batching_inference_fn(
examples, [numpy.multiply(n, 2) for n in examples])
]

pcoll = pipeline | 'start' >> beam.Create(examples)
pcoll = pipeline | 'start' >> beam.Create(examples, reshuffle=False)
predictions = pcoll | RunInference(model_handler)
assert_that(
predictions,
Expand Down
11 changes: 7 additions & 4 deletions sdks/python/apache_beam/ml/inference/tensorrt_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,8 @@ def test_pipeline_single_tensor_feature_built_engine(self):
max_batch_size=4,
engine_path=
'gs://apache-beam-ml/models/single_tensor_features_engine.trt')
pcoll = pipeline | 'start' >> beam.Create(SINGLE_FEATURE_EXAMPLES)
pcoll = pipeline | 'start' >> beam.Create(
SINGLE_FEATURE_EXAMPLES, reshuffle=False)
predictions = pcoll | RunInference(engine_handler)
assert_that(
predictions,
Expand Down Expand Up @@ -423,7 +424,8 @@ def fake_inference_fn(batch, engine, inference_args=None):
'gs://apache-beam-ml/models/single_tensor_features_engine.trt',
inference_fn=fake_inference_fn,
large_model=True)
pcoll = pipeline | 'start' >> beam.Create(SINGLE_FEATURE_EXAMPLES)
pcoll = pipeline | 'start' >> beam.Create(
SINGLE_FEATURE_EXAMPLES, reshuffle=False)
predictions = pcoll | RunInference(engine_handler)
assert_that(
predictions,
Expand All @@ -443,7 +445,7 @@ def test_pipeline_sets_env_vars_correctly(self):
self.assertFalse('FOO' in os.environ)
_ = (
pipeline
| 'start' >> beam.Create(SINGLE_FEATURE_EXAMPLES)
| 'start' >> beam.Create(SINGLE_FEATURE_EXAMPLES, reshuffle=False)
| RunInference(engine_handler))
pipeline.run()
self.assertTrue('FOO' in os.environ)
Expand All @@ -457,7 +459,8 @@ def test_pipeline_multiple_tensor_feature_built_engine(self):
max_batch_size=4,
engine_path=
'gs://apache-beam-ml/models/multiple_tensor_features_engine.trt')
pcoll = pipeline | 'start' >> beam.Create(TWO_FEATURES_EXAMPLES)
pcoll = pipeline | 'start' >> beam.Create(
TWO_FEATURES_EXAMPLES, reshuffle=False)
predictions = pcoll | RunInference(engine_handler)
assert_that(
predictions,
Expand Down
49 changes: 49 additions & 0 deletions sdks/python/apache_beam/transforms/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@
from apache_beam.transforms.util import GcpHsmGeneratedSecret
from apache_beam.transforms.util import GcpSecret
from apache_beam.transforms.util import Secret
from apache_beam.transforms.util import _BatchSizeEstimator
from apache_beam.transforms.util import _GlobalWindowsBatchingDoFn
from apache_beam.transforms.window import FixedWindows
from apache_beam.transforms.window import GlobalWindow
from apache_beam.transforms.window import GlobalWindows
Expand Down Expand Up @@ -1258,6 +1260,53 @@ def check_batch_homogeneity(batch):
checks = batches | beam.Map(check_batch_homogeneity)
assert_that(checks, is_not_empty())

def test_global_batching_dofn_single_vs_multiple_bundles(self):
# This test directly verifies how bundling affects the batch sizes produced by
# the internal _GlobalWindowsBatchingDoFn of BatchElements.

# 1. Single Bundle Scenario:
# Four elements processed within the same start_bundle / finish_bundle lifecycle.
# min_batch_size = 2, max_batch_size = 2.
estimator = _BatchSizeEstimator(min_batch_size=2, max_batch_size=2)
dofn = _GlobalWindowsBatchingDoFn(estimator, element_size_fn=lambda x: 1)

dofn.start_bundle()
outputs = []
for elem in [1, 2, 3, 4]:
outputs.extend(dofn.process(elem))
outputs.extend(dofn.finish_bundle() or [])

# We should get exactly two batches of size 2.
batch_sizes = [len(wv.value) for wv in outputs]
self.assertEqual(batch_sizes, [2, 2])

# 2. Multiple Bundles Scenario (simulating elements split due to Reshuffle/GroupByKey):
# The runner splits elements into multiple bundles:
# Bundle 1 gets elements 1, 2, 3.
# Bundle 2 gets element 4.
estimator = _BatchSizeEstimator(min_batch_size=2, max_batch_size=2)
dofn = _GlobalWindowsBatchingDoFn(estimator, element_size_fn=lambda x: 1)

outputs = []
# Bundle 1
dofn.start_bundle()
for elem in [1, 2, 3]:
outputs.extend(dofn.process(elem))
outputs.extend(dofn.finish_bundle() or [])

# Bundle 2
dofn.start_bundle()
for elem in [4]:
outputs.extend(dofn.process(elem))
outputs.extend(dofn.finish_bundle() or [])

# The batch sizes will be [2, 1, 1] instead of [2, 2] because of bundle flushes.
# Specifically:
# - Bundle 1 emits a batch of 2, and then the remaining 1 element is flushed at finish_bundle (batch size 1).
# - Bundle 2 emits its 1 element at finish_bundle (batch size 1).
batch_sizes = [len(wv.value) for wv in outputs]
self.assertEqual(batch_sizes, [2, 1, 1])


class SortAndBatchElementsTest(unittest.TestCase):
"""Tests for SortAndBatchElements transform."""
Expand Down
Loading