diff --git a/graphistry/PlotterBase.py b/graphistry/PlotterBase.py index 28453952ae..5c755a098f 100644 --- a/graphistry/PlotterBase.py +++ b/graphistry/PlotterBase.py @@ -6,7 +6,7 @@ from graphistry.plugins_types.hypergraph import HypergraphResult from graphistry.render.resolve_render_mode import resolve_render_mode from graphistry.Engine import EngineAbstractType -import copy, hashlib, numpy as np, pandas as pd, pyarrow as pa, sys, uuid +import copy, hashlib, numpy as np, pandas as pd, pyarrow as pa, requests, sys, uuid from functools import lru_cache from weakref import WeakValueDictionary @@ -31,7 +31,7 @@ error, hash_pdf, in_ipython, in_databricks, make_iframe, random_string, warn, cache_coercion, cache_coercion_helper, WeakValueWrapper ) -from graphistry.otel import otel_traced, otel_detail_enabled +from graphistry.otel import otel_traced, otel_detail_enabled, inject_trace_headers from .bolt_util import ( bolt_graph_to_edges_dataframe, @@ -2238,8 +2238,22 @@ def plot( 'type': 'arrow', 'viztoken': str(uuid.uuid4()) } + url_params = dict(self._url_params) + token = self.session.api_token + if token: + try: + server_base = '%s://%s' % (self.session.protocol, self.session.hostname) + resp = requests.post( + '%s/api/v1/auth/jwt/ott/' % server_base, + headers=inject_trace_headers({'Authorization': 'Bearer %s' % token}), + verify=self.session.certificate_validation, + ) + resp.raise_for_status() + url_params['token'] = resp.json()['ott'] + except Exception as e: + logger.warning("Failed to exchange JWT for OTT: %s", e) - viz_url = self._pygraphistry._viz_url(info, self._url_params) + viz_url = self._pygraphistry._viz_url(info, url_params) cfg_client_protocol_hostname = self.session.client_protocol_hostname full_url = ('%s:%s' % (self.session.protocol, viz_url)) if cfg_client_protocol_hostname is None else viz_url diff --git a/graphistry/tests/test_trace_headers_behavior.py b/graphistry/tests/test_trace_headers_behavior.py index 96014e2c0a..e0633f0d46 100644 --- a/graphistry/tests/test_trace_headers_behavior.py +++ b/graphistry/tests/test_trace_headers_behavior.py @@ -4,9 +4,10 @@ import pandas as pd -# Import the ArrowFileUploader MODULE before graphistry shadows it with the class -# This ensures sys.modules has the module, allowing proper mock patching +# Import modules before graphistry shadows them with classes/symbols. +# This ensures sys.modules has the modules, allowing proper mock patching. import graphistry.ArrowFileUploader as _arrow_file_uploader_module # noqa: F401 +import graphistry.PlotterBase as _plotter_base_module # noqa: F401 import graphistry from graphistry.compute.ast import n, e_forward @@ -55,13 +56,13 @@ def _post_response_for_plot(url: str): return _mock_response({"is_valid": True, "is_uploaded": True}) if "/api/v2/share/link/" in url: return _mock_response({"success": True}) + if "/api/v1/auth/jwt/ott/" in url: + return _mock_response({"ott": "test-ott-token"}) raise AssertionError(f"Unexpected POST url: {url}") -@mock.patch("graphistry.arrow_uploader.inject_trace_headers") @mock.patch("requests.post") -def test_plot_injects_traceparent(mock_post, mock_inject): - mock_inject.side_effect = _inject_trace +def test_plot_injects_traceparent(mock_post): headers_seen = [] def _fake_post(url, **kwargs): @@ -70,22 +71,20 @@ def _fake_post(url, **kwargs): mock_post.side_effect = _fake_post - g = _make_graph() - g.plot(render="g", as_files=False, validate=False, warn=False, memoize=False) + plotter_base_module = sys.modules["graphistry.PlotterBase"] + arrow_uploader_module = sys.modules["graphistry.arrow_uploader"] + + with mock.patch.object(arrow_uploader_module, "inject_trace_headers", side_effect=_inject_trace), \ + mock.patch.object(plotter_base_module, "inject_trace_headers", side_effect=_inject_trace): + g = _make_graph() + g.plot(render="g", as_files=False, validate=False, warn=False, memoize=False) assert headers_seen assert all(h.get("traceparent") == TRACEPARENT for h in headers_seen) -@mock.patch("graphistry.arrow_uploader.inject_trace_headers") @mock.patch("requests.post") -def test_upload_injects_traceparent(mock_post, mock_inject_uploader): - # Patch ArrowFileUploader module's inject_trace_headers via sys.modules - # This is needed because graphistry.ArrowFileUploader resolves to the class, - # not the module (due to re-exports in graphistry/__init__.py) - arrow_file_uploader_module = sys.modules["graphistry.ArrowFileUploader"] - - mock_inject_uploader.side_effect = _inject_trace +def test_upload_injects_traceparent(mock_post): headers_seen = [] def _fake_post(url, **kwargs): @@ -94,7 +93,17 @@ def _fake_post(url, **kwargs): mock_post.side_effect = _fake_post - with mock.patch.object(arrow_file_uploader_module, "inject_trace_headers", side_effect=_inject_trace): + # Patch inject_trace_headers in all three modules that make POST requests: + # arrow_uploader.py, ArrowFileUploader.py, and PlotterBase.py (OTT exchange). + # Use sys.modules because graphistry/__init__.py re-exports some names as classes, + # shadowing the module attributes on the graphistry package. + arrow_uploader_module = sys.modules["graphistry.arrow_uploader"] + arrow_file_uploader_module = sys.modules["graphistry.ArrowFileUploader"] + plotter_base_module = sys.modules["graphistry.PlotterBase"] + + with mock.patch.object(arrow_uploader_module, "inject_trace_headers", side_effect=_inject_trace), \ + mock.patch.object(arrow_file_uploader_module, "inject_trace_headers", side_effect=_inject_trace), \ + mock.patch.object(plotter_base_module, "inject_trace_headers", side_effect=_inject_trace): g = _make_graph() g.upload(validate=False, warn=False, memoize=False, erase_files_on_fail=False)