Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions graphistry/PlotterBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from graphistry.plugins_types.hypergraph import HypergraphResult
from graphistry.render.resolve_render_mode import resolve_render_mode
from graphistry.Engine import EngineAbstractType
import copy, hashlib, numpy as np, pandas as pd, pyarrow as pa, sys, uuid
import copy, hashlib, numpy as np, pandas as pd, pyarrow as pa, requests, sys, uuid
from functools import lru_cache
from weakref import WeakValueDictionary

Expand All @@ -31,7 +31,7 @@
error, hash_pdf, in_ipython, in_databricks, make_iframe, random_string, warn,
cache_coercion, cache_coercion_helper, WeakValueWrapper
)
from graphistry.otel import otel_traced, otel_detail_enabled
from graphistry.otel import otel_traced, otel_detail_enabled, inject_trace_headers

from .bolt_util import (
bolt_graph_to_edges_dataframe,
Expand Down Expand Up @@ -2238,8 +2238,22 @@ def plot(
'type': 'arrow',
'viztoken': str(uuid.uuid4())
}
url_params = dict(self._url_params)
token = self.session.api_token
if token:
try:
server_base = '%s://%s' % (self.session.protocol, self.session.hostname)
resp = requests.post(
'%s/api/v1/auth/jwt/ott/' % server_base,
headers=inject_trace_headers({'Authorization': 'Bearer %s' % token}),
verify=self.session.certificate_validation,
)
resp.raise_for_status()
url_params['token'] = resp.json()['ott']
except Exception as e:
logger.warning("Failed to exchange JWT for OTT: %s", e)

viz_url = self._pygraphistry._viz_url(info, self._url_params)
viz_url = self._pygraphistry._viz_url(info, url_params)
cfg_client_protocol_hostname = self.session.client_protocol_hostname
full_url = ('%s:%s' % (self.session.protocol, viz_url)) if cfg_client_protocol_hostname is None else viz_url

Expand Down
41 changes: 25 additions & 16 deletions graphistry/tests/test_trace_headers_behavior.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

import pandas as pd

# Import the ArrowFileUploader MODULE before graphistry shadows it with the class
# This ensures sys.modules has the module, allowing proper mock patching
# Import modules before graphistry shadows them with classes/symbols.
# This ensures sys.modules has the modules, allowing proper mock patching.
import graphistry.ArrowFileUploader as _arrow_file_uploader_module # noqa: F401
import graphistry.PlotterBase as _plotter_base_module # noqa: F401

import graphistry
from graphistry.compute.ast import n, e_forward
Expand Down Expand Up @@ -55,13 +56,13 @@ def _post_response_for_plot(url: str):
return _mock_response({"is_valid": True, "is_uploaded": True})
if "/api/v2/share/link/" in url:
return _mock_response({"success": True})
if "/api/v1/auth/jwt/ott/" in url:
return _mock_response({"ott": "test-ott-token"})
raise AssertionError(f"Unexpected POST url: {url}")


@mock.patch("graphistry.arrow_uploader.inject_trace_headers")
@mock.patch("requests.post")
def test_plot_injects_traceparent(mock_post, mock_inject):
mock_inject.side_effect = _inject_trace
def test_plot_injects_traceparent(mock_post):
headers_seen = []

def _fake_post(url, **kwargs):
Expand All @@ -70,22 +71,20 @@ def _fake_post(url, **kwargs):

mock_post.side_effect = _fake_post

g = _make_graph()
g.plot(render="g", as_files=False, validate=False, warn=False, memoize=False)
plotter_base_module = sys.modules["graphistry.PlotterBase"]
arrow_uploader_module = sys.modules["graphistry.arrow_uploader"]

with mock.patch.object(arrow_uploader_module, "inject_trace_headers", side_effect=_inject_trace), \
mock.patch.object(plotter_base_module, "inject_trace_headers", side_effect=_inject_trace):
g = _make_graph()
g.plot(render="g", as_files=False, validate=False, warn=False, memoize=False)

assert headers_seen
assert all(h.get("traceparent") == TRACEPARENT for h in headers_seen)


@mock.patch("graphistry.arrow_uploader.inject_trace_headers")
@mock.patch("requests.post")
def test_upload_injects_traceparent(mock_post, mock_inject_uploader):
# Patch ArrowFileUploader module's inject_trace_headers via sys.modules
# This is needed because graphistry.ArrowFileUploader resolves to the class,
# not the module (due to re-exports in graphistry/__init__.py)
arrow_file_uploader_module = sys.modules["graphistry.ArrowFileUploader"]

mock_inject_uploader.side_effect = _inject_trace
def test_upload_injects_traceparent(mock_post):
headers_seen = []

def _fake_post(url, **kwargs):
Expand All @@ -94,7 +93,17 @@ def _fake_post(url, **kwargs):

mock_post.side_effect = _fake_post

with mock.patch.object(arrow_file_uploader_module, "inject_trace_headers", side_effect=_inject_trace):
# Patch inject_trace_headers in all three modules that make POST requests:
# arrow_uploader.py, ArrowFileUploader.py, and PlotterBase.py (OTT exchange).
# Use sys.modules because graphistry/__init__.py re-exports some names as classes,
# shadowing the module attributes on the graphistry package.
arrow_uploader_module = sys.modules["graphistry.arrow_uploader"]
arrow_file_uploader_module = sys.modules["graphistry.ArrowFileUploader"]
plotter_base_module = sys.modules["graphistry.PlotterBase"]

with mock.patch.object(arrow_uploader_module, "inject_trace_headers", side_effect=_inject_trace), \
mock.patch.object(arrow_file_uploader_module, "inject_trace_headers", side_effect=_inject_trace), \
mock.patch.object(plotter_base_module, "inject_trace_headers", side_effect=_inject_trace):
g = _make_graph()
g.upload(validate=False, warn=False, memoize=False, erase_files_on_fail=False)

Expand Down
Loading