diff --git a/sentry_sdk/integrations/ray.py b/sentry_sdk/integrations/ray.py index 92a35546ab..4399fd5a1b 100644 --- a/sentry_sdk/integrations/ray.py +++ b/sentry_sdk/integrations/ray.py @@ -36,6 +36,26 @@ def _check_sentry_initialized() -> None: ) +def _insert_sentry_tracing_in_signature(func: "Callable[..., Any]") -> None: + # Patching new_func signature to add the _sentry_tracing parameter to it + # Ray later inspects the signature and finds the unexpected parameter otherwise + signature = inspect.signature(func) + params = list(signature.parameters.values()) + sentry_tracing_param = inspect.Parameter( + "_sentry_tracing", + kind=inspect.Parameter.KEYWORD_ONLY, + default=None, + ) + + # Keyword only arguments are penultimate if function has variadic keyword arguments + if params and params[-1].kind is inspect.Parameter.VAR_KEYWORD: + params.insert(-1, sentry_tracing_param) + else: + params.append(sentry_tracing_param) + + func.__signature__ = signature.replace(parameters=params) # type: ignore[attr-defined] + + def _patch_ray_remote() -> None: old_remote = ray.remote @@ -86,18 +106,7 @@ def new_func( return result - # Patching new_func signature to add the _sentry_tracing parameter to it - # Ray later inspects the signature and finds the unexpected parameter otherwise - signature = inspect.signature(new_func) - params = list(signature.parameters.values()) - params.append( - inspect.Parameter( - "_sentry_tracing", - kind=inspect.Parameter.KEYWORD_ONLY, - default=None, - ) - ) - new_func.__signature__ = signature.replace(parameters=params) # type: ignore[attr-defined] + _insert_sentry_tracing_in_signature(new_func) if f: rv = old_remote(new_func) diff --git a/tests/integrations/ray/test_ray.py b/tests/integrations/ray/test_ray.py index dcbf8f456b..be7ebc9d05 100644 --- a/tests/integrations/ray/test_ray.py +++ b/tests/integrations/ray/test_ray.py @@ -74,10 +74,29 @@ def read_error_from_log(job_id, ray_temp_dir): return error +def example_task(): + with sentry_sdk.start_span(op="task", name="example task step"): + ... + + return sentry_sdk.get_client().transport.envelopes + + +# RayIntegration must leave variadic keyword arguments at the end +def example_task_with_kwargs(**kwargs): + with sentry_sdk.start_span(op="task", name="example task step"): + ... + + return sentry_sdk.get_client().transport.envelopes + + @pytest.mark.parametrize( "task_options", [{}, {"num_cpus": 0, "memory": 1024 * 1024 * 10}] ) -def test_tracing_in_ray_tasks(task_options): +@pytest.mark.parametrize( + "task", + [example_task, example_task_with_kwargs], +) +def test_tracing_in_ray_tasks(task_options, task): setup_sentry() ray.init( @@ -87,21 +106,18 @@ def test_tracing_in_ray_tasks(task_options): } ) - def example_task(): - with sentry_sdk.start_span(op="task", name="example task step"): - ... - - return sentry_sdk.get_client().transport.envelopes - # Setup ray task, calling decorator directly instead of @, # to accommodate for test parametrization if task_options: - example_task = ray.remote(**task_options)(example_task) + example_task = ray.remote(**task_options)(task) else: - example_task = ray.remote(example_task) + example_task = ray.remote(task) # Function name shouldn't be overwritten by Sentry wrapper - assert example_task._function_name == "tests.integrations.ray.test_ray.example_task" + assert ( + example_task._function_name + == f"tests.integrations.ray.test_ray.{task.__name__}" + ) with sentry_sdk.start_transaction(op="task", name="ray test transaction"): worker_envelopes = ray.get(example_task.remote()) @@ -115,17 +131,14 @@ def example_task(): worker_transaction = worker_envelope.get_transaction_event() assert ( worker_transaction["transaction"] - == "tests.integrations.ray.test_ray.test_tracing_in_ray_tasks..example_task" + == f"tests.integrations.ray.test_ray.{task.__name__}" ) assert worker_transaction["transaction_info"] == {"source": "task"} (span,) = client_transaction["spans"] assert span["op"] == "queue.submit.ray" assert span["origin"] == "auto.queue.ray" - assert ( - span["description"] - == "tests.integrations.ray.test_ray.test_tracing_in_ray_tasks..example_task" - ) + assert span["description"] == f"tests.integrations.ray.test_ray.{task.__name__}" assert span["parent_span_id"] == client_transaction["contexts"]["trace"]["span_id"] assert span["trace_id"] == client_transaction["contexts"]["trace"]["trace_id"]