Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ test = [
"pytest-cov>=7.1.0",
"pytest-asyncio>=1.3.0",
"python-dotenv>=1.2.2",
"vcrpy>=8.1.1",
]
release = ["build>=1.4.3", "jinja2>=3.1.6", "sphinx>=9.1.0", "twine>=6.2.0"]
lint = ["basedpyright>=1.39.0", "ruff>=0.15.10"]
Expand Down
10 changes: 7 additions & 3 deletions splunklib/ai/engines/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,10 @@
_testing_force_tool_strategy = False


def _thread_id_new_uuid() -> str:
return str(uuid.uuid4())


def _supports_provider_strategy(model: BaseChatModel) -> bool:
return (
model.profile is not None
Expand Down Expand Up @@ -365,16 +369,16 @@ async def awrap_model_call(
# LLM halucinated a thread_id, start a new conversation instead.
# This should not happen, since we provide an enum above, but just
# in case.
args.thread_id = str(uuid.uuid4())
args.thread_id = _thread_id_new_uuid()

if args.thread_id and args.thread_id in called_thread_ids:
# LLM did not listen not to issue multiple calls to the
# same thread_id, start a new conversation instead.
args.thread_id = str(uuid.uuid4())
args.thread_id = _thread_id_new_uuid()

if not args.thread_id:
# Generate thread_id for a new conversation.
args.thread_id = str(uuid.uuid4())
args.thread_id = _thread_id_new_uuid()

called_thread_ids.add(args.thread_id)
call["args"] = asdict(args)
Expand Down
2 changes: 2 additions & 0 deletions tests/ai_test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ async def _buildInternalAIModel(
auth=(client_id, client_secret),
)

response.raise_for_status()

token = _TokenResponse.model_validate_json(response.text).access_token

auth_handler = _InternalAIAuth(token)
Expand Down
160 changes: 159 additions & 1 deletion tests/ai_testlib.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,22 @@
from typing import override
import functools
import inspect
import json
import os
from collections.abc import Callable, Coroutine
from typing import Any, override
from unittest.mock import patch
from urllib import parse

import vcr
from vcr.config import RecordMode
from vcr.request import Request

from splunklib.ai.model import PredefinedModel
from tests.ai_test_model import InternalAIModel, TestLLMSettings, create_model
from tests.testlib import SDKTestCase

REDACTED_APP_KEY = "[[[--APPKEY-REDACTED-]]]"


class AITestCase(SDKTestCase):
_model: PredefinedModel | None = None
Expand Down Expand Up @@ -42,3 +56,147 @@ async def model(self) -> PredefinedModel:
model = await create_model(self.test_llm_settings)
self._model = model
return model


def ai_snapshot_test() -> Callable[
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] the name suggests it can be used only for AI. What do you think about record_snapshot name?

Copy link
Copy Markdown
Member Author

@mateusz834 mateusz834 Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the name suggests it can be used only for AI.

We do filtering here, so i think it only works for AI 😄 .

[Callable[..., Coroutine[Any, Any, None]]], Callable[..., Coroutine[Any, Any, None]]
]:
def decorator(
fn: Callable[..., Coroutine[Any, Any, None]],
) -> Callable[..., Coroutine[Any, Any, None]]:
source_file = inspect.getfile(fn)
test_dir = os.path.dirname(source_file)
test_file = os.path.splitext(os.path.basename(source_file))[0]

snapshot_dir = os.path.join(test_dir, "snapshots", test_file)
snapshot_filename = f"{fn.__qualname__}.json"

@functools.wraps(fn)
async def wrapper(self: AITestCase, *args: Any, **kwargs: Any) -> None:
settings = self.test_llm_settings
assert settings.internal_ai is not None

internal_ai_hostname = parse.urlparse(
settings.internal_ai.base_url
).hostname
assert internal_ai_hostname is not None

class _JSONFriendlySerializer:
def deserialize(self, serialized: str) -> Any:
assert settings.internal_ai is not None
serialized = serialized.replace(
REDACTED_APP_KEY, settings.internal_ai.app_key
)

data = json.loads(serialized)
for interaction in data.get("interactions", []):
interaction["request"]["uri"] = interaction["request"][
"uri"
].replace("internal-ai-host", internal_ai_hostname, 1)

interaction["request"]["body"] = json.dumps(
interaction["request"]["body"]
)
body = interaction["response"]["body"]
interaction["response"]["body"] = {}
interaction["response"]["body"]["string"] = json.dumps(body)

return data

def serialize(self, dict: Any) -> str:
for interaction in dict.get("interactions", []):
interaction["request"]["uri"] = interaction["request"][
"uri"
].replace(internal_ai_hostname, "internal-ai-host", 1)

body = interaction["request"]["body"]
interaction["request"]["body"] = json.loads(body)

resp_body = interaction["response"]["body"]["string"]
interaction["response"]["body"] = json.loads(resp_body)

out = json.dumps(dict, indent=4) + "\n"
assert settings.internal_ai is not None
out = out.replace(settings.internal_ai.app_key, REDACTED_APP_KEY)

# Assert that nothing is leaking into the public snapshots.
assert internal_ai_hostname not in out.lower()
assert settings.internal_ai.app_key.lower() not in out.lower()
assert settings.internal_ai.base_url.lower() not in out.lower()
assert settings.internal_ai.token_url.lower() not in out.lower()
assert settings.internal_ai.client_id.lower() not in out.lower()
assert settings.internal_ai.client_secret.lower() not in out.lower()

return out
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious if we should look for the other secrets for our testing LLM before serializing - to make sure we do not leak anything by mistake when creating new snapshots.

or add some ci stage for that.

WDYT?

Copy link
Copy Markdown
Member Author

@mateusz834 mateusz834 Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious if we should look for the other secrets for our testing LLM before serializing - to make sure we do not leak anything by mistake when creating new snapshots.

Do you have anything other in mind to check? We are only recording LLM http requests, and i think we are checking everything.


def _before_record_request(request: Request) -> Request | None:
url = parse.urlparse(request.uri) # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType]
if url.hostname == internal_ai_hostname:
request.headers = {}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we keep some specific headers for like Authorization and keep comparing them?

Copy link
Copy Markdown
Member Author

@mateusz834 mateusz834 Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought that it makes sense, but now i realized that that would only test our test suite since the header that is used for our internal ai does not match how our end used will use this, say with OpenAI API directly.

) -> collections.abc.Generator[Request, Response, None]:
request.headers["api-key"] = self.token
yield request

I don't believe now that the complexity of adding this is worth.

return request
return None

def _before_record_response(response: Any) -> Any:
response["headers"] = {}
return response

def _json_body_matcher(r1: Any, r2: Any) -> None:
b1 = json.loads(r1.body)
b2 = json.loads(r2.body)
if b1 != b2:
raise AssertionError(f"Body mismatch:\n{b1}\n!=\n{b2}")

my_vcr = vcr.VCR(
cassette_library_dir=snapshot_dir,
serializer="json-friendly",
record_mode=RecordMode.ONCE,
match_on=[
"method",
"scheme",
"host",
"port",
"path",
"query",
"jsonbody",
],
before_record_request=_before_record_request,
before_record_response=_before_record_response,
# record_on_exception=False,
# drop_unused_requests=True,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to be removed if not needed

Copy link
Copy Markdown
Member Author

@mateusz834 mateusz834 Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to keep them since I have found them as useful options to toggle when you are adding say a new test or just experimenting with this recordings, They are not used now but i uncommented them multiple times, so might be worth, WDYT?

)
my_vcr.register_serializer("json-friendly", _JSONFriendlySerializer())
my_vcr.register_matcher("jsonbody", _json_body_matcher)

with my_vcr.use_cassette(snapshot_filename): # pyright: ignore[reportGeneralTypeIssues]
await fn(self, *args, **kwargs)

return wrapper

return decorator


def deterministic_thread_ids() -> Callable[
Comment thread
mateusz834 marked this conversation as resolved.
[Callable[..., Coroutine[Any, Any, None]]], Callable[..., Coroutine[Any, Any, None]]
]:
def decorator(
fn: Callable[..., Coroutine[Any, Any, None]],
) -> Callable[..., Coroutine[Any, Any, None]]:
@functools.wraps(fn)
async def wrapper(self: AITestCase, *args: Any, **kwargs: Any) -> None:
counter = 0

def _deterministic_uuid() -> str:
nonlocal counter
result = f"00000000-0000-0000-0000-{counter:012d}"
counter += 1
return result

with patch(
"splunklib.ai.engines.langchain._thread_id_new_uuid",
side_effect=_deterministic_uuid,
):
await fn(self, *args, **kwargs)

return wrapper

return decorator
Loading