Skip to content

Commit 0adbfa3

Browse files
timsaucerclaude
andcommitted
test: fold dot_product alias check into parametrized test
Generalize test_array_function_aliases to accept multi-column data so the dot_product/inner_product alias case fits, dropping the standalone test_dot_product_alias_matches_inner_product. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent 40b788e commit 0adbfa3

1 file changed

Lines changed: 12 additions & 21 deletions

File tree

python/tests/test_functions.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -720,31 +720,22 @@ def test_array_function_obj_tests(stmt, py_expr):
720720
@pytest.mark.parametrize(
721721
("alias_fn", "primary_fn", "data"),
722722
[
723-
(f.list_compact, f.array_compact, [[1.0, None, 2.0, None, 3.0]]),
724-
(f.list_normalize, f.array_normalize, [[3.0, 4.0]]),
723+
(f.list_compact, f.array_compact, {"a": [[1.0, None, 2.0, None, 3.0]]}),
724+
(f.list_normalize, f.array_normalize, {"a": [[3.0, 4.0]]}),
725+
(
726+
f.dot_product,
727+
f.inner_product,
728+
{"a": [[1.0, 2.0, 3.0]], "b": [[4.0, 5.0, 6.0]]},
729+
),
725730
],
726731
)
727732
def test_array_function_aliases(alias_fn, primary_fn, data):
728-
"""list_* helpers should be exact aliases for their array_* counterparts."""
729-
ctx = SessionContext()
730-
df = ctx.from_pydict({"a": data})
731-
alias_result = df.select(alias_fn(column("a")).alias("r")).collect()
732-
primary_result = df.select(primary_fn(column("a")).alias("r")).collect()
733-
assert (
734-
alias_result[0].column(0).to_pylist() == primary_result[0].column(0).to_pylist()
735-
)
736-
737-
738-
def test_dot_product_alias_matches_inner_product():
739-
"""dot_product should be an exact alias for inner_product."""
733+
"""Alias helpers should be exact aliases for their primary counterparts."""
740734
ctx = SessionContext()
741-
df = ctx.from_pydict({"a": [[1.0, 2.0, 3.0]], "b": [[4.0, 5.0, 6.0]]})
742-
alias_result = df.select(
743-
f.dot_product(column("a"), column("b")).alias("r")
744-
).collect()
745-
primary_result = df.select(
746-
f.inner_product(column("a"), column("b")).alias("r")
747-
).collect()
735+
df = ctx.from_pydict(data)
736+
cols = [column(name) for name in data]
737+
alias_result = df.select(alias_fn(*cols).alias("r")).collect()
738+
primary_result = df.select(primary_fn(*cols).alias("r")).collect()
748739
assert (
749740
alias_result[0].column(0).to_pylist() == primary_result[0].column(0).to_pylist()
750741
)

0 commit comments

Comments
 (0)