diff --git a/.sampo/changesets/prompt-metadata.md b/.sampo/changesets/prompt-metadata.md new file mode 100644 index 00000000..d62dd76a --- /dev/null +++ b/.sampo/changesets/prompt-metadata.md @@ -0,0 +1,5 @@ +--- +pypi/posthog: minor +--- + +`Prompts.get()` now accepts `with_metadata=True` and returns a `PromptResult` dataclass containing `source` (`api`, `cache`, `stale_cache`, or `code_fallback`), `name`, and `version` alongside the prompt text. The previous plain-string return is deprecated and will be removed in a future major version. diff --git a/posthog/ai/__init__.py b/posthog/ai/__init__.py index 19846e83..9492daca 100644 --- a/posthog/ai/__init__.py +++ b/posthog/ai/__init__.py @@ -1,3 +1,3 @@ -from posthog.ai.prompts import Prompts +from posthog.ai.prompts import PromptResult, PromptSource, Prompts -__all__ = ["Prompts"] +__all__ = ["PromptResult", "PromptSource", "Prompts"] diff --git a/posthog/ai/prompts.py b/posthog/ai/prompts.py index 35140e41..4e116cac 100644 --- a/posthog/ai/prompts.py +++ b/posthog/ai/prompts.py @@ -8,7 +8,9 @@ import re import time import urllib.parse -from typing import Any, Dict, Optional, Union +import warnings +from dataclasses import dataclass +from typing import Any, Dict, Literal, Optional, Union, overload from posthog.request import USER_AGENT, _get_session from posthog.utils import remove_trailing_slash @@ -21,13 +23,27 @@ PromptVariables = Dict[str, Union[str, int, float, bool]] PromptCacheKey = tuple[str, Optional[int]] +PromptSource = Literal["api", "cache", "stale_cache", "code_fallback"] + + +@dataclass(frozen=True) +class PromptResult: + """Result of a prompt fetch with metadata about its source.""" + + source: PromptSource + prompt: str + name: Optional[str] = None + version: Optional[int] = None + class CachedPrompt: """Cached prompt with metadata.""" - def __init__(self, prompt: str, fetched_at: float): + def __init__(self, prompt: str, fetched_at: float, name: str, version: int): self.prompt = prompt self.fetched_at = fetched_at + self.name = name + self.version = version def _cache_key(name: str, version: Optional[int]) -> PromptCacheKey: @@ -50,8 +66,9 @@ def _is_prompt_api_response(data: Any) -> bool: """Check if the response is a valid prompt API response.""" return ( isinstance(data, dict) - and "prompt" in data and isinstance(data.get("prompt"), str) + and isinstance(data.get("name"), str) + and type(data.get("version")) is int ) @@ -114,6 +131,7 @@ def __init__( default_cache_ttl_seconds or DEFAULT_CACHE_TTL_SECONDS ) self._cache: Dict[PromptCacheKey, CachedPrompt] = {} + self._has_warned_deprecation = False if posthog is not None: self._personal_api_key = getattr(posthog, "personal_api_key", None) or "" @@ -126,36 +144,116 @@ def __init__( self._project_api_key = project_api_key or "" self._host = remove_trailing_slash(host or APP_ENDPOINT) + @overload + def get( + self, + name: str, + *, + with_metadata: Literal[True], + cache_ttl_seconds: Optional[int] = ..., + fallback: Optional[str] = ..., + version: Optional[int] = ..., + ) -> PromptResult: ... + + @overload + def get( + self, + name: str, + *, + with_metadata: Literal[False], + cache_ttl_seconds: Optional[int] = ..., + fallback: Optional[str] = ..., + version: Optional[int] = ..., + ) -> str: ... + + @overload def get( self, name: str, *, + cache_ttl_seconds: Optional[int] = ..., + fallback: Optional[str] = ..., + version: Optional[int] = ..., + ) -> str: ... + + def get( + self, + name: str, + *, + with_metadata: Optional[bool] = None, cache_ttl_seconds: Optional[int] = None, fallback: Optional[str] = None, version: Optional[int] = None, - ) -> str: + ) -> Union[str, PromptResult]: """ Fetch a prompt by name from the PostHog API. - Caching behavior: - 1. If cache is fresh, return cached value - 2. If fetch fails and cache exists (stale), return stale cache with warning - 3. If fetch fails and fallback provided, return fallback with warning - 4. If fetch fails with no cache/fallback, raise exception + When ``with_metadata`` is ``True``, returns a :class:`PromptResult` + with ``source``, ``name``, and ``version`` metadata. When omitted or + ``False``, returns a plain string (deprecated -- will be removed in a + future major version). Args: name: The name of the prompt to fetch + with_metadata: If True, returns a PromptResult with source info. + Omitting this parameter is deprecated. cache_ttl_seconds: Cache TTL in seconds (defaults to instance default) fallback: Fallback prompt to use if fetch fails and no cache available version: Specific prompt version to fetch. If None, fetches the latest version Returns: - The prompt string + str if with_metadata is False/omitted, PromptResult if True Raises: Exception: If the prompt cannot be fetched and no fallback is available """ + if with_metadata is None and not self._has_warned_deprecation: + self._has_warned_deprecation = True + warnings.warn( + "[PostHog Prompts] Calling get() without with_metadata=True is " + "deprecated and will be removed in a future major version. " + "Pass with_metadata=True to receive a PromptResult object with " + "source, name, and version metadata. You can pass " + "with_metadata=False to silence this warning, but the " + "plain-string return will still be removed in the next major " + "version.", + DeprecationWarning, + stacklevel=2, + ) + + try: + result = self._get_internal( + name, cache_ttl_seconds=cache_ttl_seconds, version=version + ) + if with_metadata is True: + return result + return result.prompt + except Exception as error: + prompt_reference = _prompt_reference(name, version) + if fallback is not None: + log.warning( + "[PostHog Prompts] Failed to fetch %s, using fallback: %s", + prompt_reference, + error, + ) + if with_metadata is True: + return PromptResult(source="code_fallback", prompt=fallback) + return fallback + raise + + def _get_internal( + self, + name: str, + *, + cache_ttl_seconds: Optional[int] = None, + version: Optional[int] = None, + ) -> PromptResult: + """ + Internal method that handles cache + fetch logic, returning full metadata. + + Does NOT handle the string ``fallback`` option -- the caller handles that. + """ ttl = ( cache_ttl_seconds if cache_ttl_seconds is not None @@ -171,40 +269,48 @@ def get( is_fresh = (now - cached.fetched_at) < ttl if is_fresh: - return cached.prompt + return PromptResult( + source="cache", + prompt=cached.prompt, + name=cached.name, + version=cached.version, + ) # Try to fetch from API try: - prompt = self._fetch_prompt_from_api(name, version) - fetched_at = time.time() + data = self._fetch_prompt_from_api(name, version) # Update cache - self._cache[cache_key] = CachedPrompt(prompt=prompt, fetched_at=fetched_at) + self._cache[cache_key] = CachedPrompt( + prompt=data["prompt"], + fetched_at=time.time(), + name=data["name"], + version=data["version"], + ) - return prompt + return PromptResult( + source="api", + prompt=data["prompt"], + name=data["name"], + version=data["version"], + ) except Exception as error: prompt_reference = _prompt_reference(name, version) - # Fallback order: - # 1. Return stale cache (with warning) + # Return stale cache (with warning) if cached is not None: log.warning( "[PostHog Prompts] Failed to fetch %s, using stale cache: %s", prompt_reference, error, ) - return cached.prompt - - # 2. Return fallback (with warning) - if fallback is not None: - log.warning( - "[PostHog Prompts] Failed to fetch %s, using fallback: %s", - prompt_reference, - error, + return PromptResult( + source="stale_cache", + prompt=cached.prompt, + name=cached.name, + version=cached.version, ) - return fallback - # 3. Raise error raise def compile(self, prompt: str, variables: PromptVariables) -> str: @@ -257,7 +363,9 @@ def clear_cache( for key in keys_to_clear: self._cache.pop(key, None) - def _fetch_prompt_from_api(self, name: str, version: Optional[int] = None) -> str: + def _fetch_prompt_from_api( + self, name: str, version: Optional[int] = None + ) -> Dict[str, Any]: """ Fetch prompt from PostHog API. @@ -271,7 +379,7 @@ def _fetch_prompt_from_api(self, name: str, version: Optional[int] = None) -> st version: Specific prompt version to fetch. If None, fetches the latest Returns: - The prompt string + The validated API response dict containing prompt, name, and version Raises: Exception: If the prompt cannot be fetched @@ -329,4 +437,4 @@ def _fetch_prompt_from_api(self, name: str, version: Optional[int] = None) -> st f"[PostHog Prompts] Invalid response format for {prompt_label}" ) - return data["prompt"] + return data diff --git a/posthog/test/ai/test_prompts.py b/posthog/test/ai/test_prompts.py index 4a50a2a7..a77cdd5f 100644 --- a/posthog/test/ai/test_prompts.py +++ b/posthog/test/ai/test_prompts.py @@ -1,9 +1,10 @@ import unittest +import warnings from unittest.mock import MagicMock, patch from parameterized import parameterized -from posthog.ai.prompts import Prompts +from posthog.ai.prompts import PromptResult, Prompts class MockResponse: @@ -497,6 +498,241 @@ def test_use_custom_default_cache_ttl_from_direct_options( self.assertEqual(mock_get.call_count, 2) +class TestPromptsGetWithMetadata(TestPrompts): + """Tests for Prompts.get() with with_metadata=True.""" + + @patch("posthog.ai.prompts._get_session") + def test_return_prompt_result_with_source_api_on_fresh_fetch( + self, mock_get_session + ): + """Should return a PromptResult with source='api' on a fresh fetch.""" + mock_get = mock_get_session.return_value.get + mock_get.return_value = MockResponse(json_data=self.mock_prompt_response) + + posthog = self.create_mock_posthog() + prompts = Prompts(posthog) + + result = prompts.get("test-prompt", with_metadata=True) + + self.assertEqual( + result, + PromptResult( + source="api", + prompt=self.mock_prompt_response["prompt"], + name="test-prompt", + version=1, + ), + ) + + @patch("posthog.ai.prompts._get_session") + def test_return_source_cache_on_fresh_cache_hit(self, mock_get_session): + """Should return source='cache' on a fresh cache hit.""" + mock_get = mock_get_session.return_value.get + mock_get.return_value = MockResponse(json_data=self.mock_prompt_response) + + posthog = self.create_mock_posthog() + prompts = Prompts(posthog) + + # First call populates cache + prompts.get("test-prompt", with_metadata=True) + + # Second call should hit cache + result = prompts.get("test-prompt", with_metadata=True, cache_ttl_seconds=300) + + self.assertEqual(result.source, "cache") + self.assertEqual(result.prompt, self.mock_prompt_response["prompt"]) + self.assertEqual(result.name, "test-prompt") + self.assertEqual(result.version, 1) + self.assertEqual(mock_get.call_count, 1) + + @patch("posthog.ai.prompts._get_session") + @patch("posthog.ai.prompts.time.time") + def test_return_source_stale_cache_on_fetch_failure( + self, mock_time, mock_get_session + ): + """Should return source='stale_cache' on fetch failure with stale cache.""" + mock_get = mock_get_session.return_value.get + mock_get.side_effect = [ + MockResponse(json_data=self.mock_prompt_response), + Exception("Network error"), + ] + mock_time.return_value = 1000.0 + + posthog = self.create_mock_posthog() + prompts = Prompts(posthog) + + # First call populates cache + prompts.get("test-prompt", with_metadata=True, cache_ttl_seconds=60) + + # Advance past TTL + mock_time.return_value = 1061.0 + + # Second call should use stale cache + result = prompts.get("test-prompt", with_metadata=True, cache_ttl_seconds=60) + + self.assertEqual(result.source, "stale_cache") + self.assertEqual(result.prompt, self.mock_prompt_response["prompt"]) + self.assertEqual(result.name, "test-prompt") + self.assertEqual(result.version, 1) + + @patch("posthog.ai.prompts._get_session") + def test_return_source_code_fallback_with_none_metadata(self, mock_get_session): + """Should return source='code_fallback' with name=None, version=None.""" + mock_get = mock_get_session.return_value.get + mock_get.side_effect = Exception("Network error") + + posthog = self.create_mock_posthog() + prompts = Prompts(posthog) + + result = prompts.get( + "test-prompt", with_metadata=True, fallback="Default system prompt." + ) + + self.assertEqual( + result, + PromptResult( + source="code_fallback", + prompt="Default system prompt.", + name=None, + version=None, + ), + ) + + @patch("posthog.ai.prompts._get_session") + def test_throw_when_no_cache_and_no_fallback(self, mock_get_session): + """Should throw when no cache and no fallback.""" + mock_get = mock_get_session.return_value.get + mock_get.side_effect = Exception("Network error") + + posthog = self.create_mock_posthog() + prompts = Prompts(posthog) + + with self.assertRaises(Exception) as context: + prompts.get("test-prompt", with_metadata=True) + + self.assertIn("Network error", str(context.exception)) + + @patch("posthog.ai.prompts._get_session") + def test_return_correct_version_metadata_for_versioned_fetch( + self, mock_get_session + ): + """Should return correct version metadata for versioned fetches.""" + mock_get = mock_get_session.return_value.get + mock_get.return_value = MockResponse( + json_data={ + **self.mock_prompt_response, + "version": 3, + "prompt": "Version 3 prompt", + } + ) + + posthog = self.create_mock_posthog() + prompts = Prompts(posthog) + + result = prompts.get("test-prompt", with_metadata=True, version=3) + + self.assertEqual( + result, + PromptResult( + source="api", + prompt="Version 3 prompt", + name="test-prompt", + version=3, + ), + ) + + @patch("posthog.ai.prompts._get_session") + def test_share_cache_with_non_metadata_calls(self, mock_get_session): + """Should share cache between with_metadata=True and with_metadata=False.""" + mock_get = mock_get_session.return_value.get + mock_get.return_value = MockResponse(json_data=self.mock_prompt_response) + + posthog = self.create_mock_posthog() + prompts = Prompts(posthog) + + # First call without metadata populates cache + string_result = prompts.get("test-prompt", with_metadata=False) + self.assertEqual(string_result, self.mock_prompt_response["prompt"]) + + # Second call with metadata should use cache + metadata_result = prompts.get("test-prompt", with_metadata=True) + self.assertEqual( + metadata_result, + PromptResult( + source="cache", + prompt=self.mock_prompt_response["prompt"], + name="test-prompt", + version=1, + ), + ) + self.assertEqual(mock_get.call_count, 1) + + +class TestPromptsGetDeprecationWarning(TestPrompts): + """Tests for the deprecation warning when with_metadata is not passed.""" + + @parameterized.expand( + [ + ("not_passed", None, 1), + ("explicit_false", False, 0), + ("explicit_true", True, 0), + ] + ) + @patch("posthog.ai.prompts._get_session") + def test_deprecation_warning_count( + self, _scenario, with_metadata, expected_warnings, mock_get_session + ): + """Should emit the correct number of deprecation warnings.""" + mock_get = mock_get_session.return_value.get + mock_get.return_value = MockResponse(json_data=self.mock_prompt_response) + + posthog = self.create_mock_posthog() + prompts = Prompts(posthog) + + kwargs = {} + if with_metadata is not None: + kwargs["with_metadata"] = with_metadata + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + prompts.get("test-prompt", **kwargs) + # Second call — should never warn again + prompts.get("test-prompt", **kwargs) + + deprecation_warnings = [ + w for w in caught if issubclass(w.category, DeprecationWarning) + ] + self.assertEqual(len(deprecation_warnings), expected_warnings) + + +class TestPromptsApiResponseValidation(TestPrompts): + """Tests for strengthened API response validation.""" + + @parameterized.expand( + [ + ("missing_name", {"prompt": "hello", "version": 1}), + ("missing_version", {"prompt": "hello", "name": "test"}), + ("name_not_string", {"prompt": "hello", "name": 123, "version": 1}), + ("version_not_int", {"prompt": "hello", "name": "test", "version": "1"}), + ] + ) + @patch("posthog.ai.prompts._get_session") + def test_reject_api_response_with_invalid_metadata( + self, _scenario, json_data, mock_get_session + ): + """Should reject API responses with missing or invalid name/version.""" + mock_get = mock_get_session.return_value.get + mock_get.return_value = MockResponse(json_data=json_data) + + posthog = self.create_mock_posthog() + prompts = Prompts(posthog) + + with self.assertRaises(Exception) as context: + prompts.get("test-prompt", with_metadata=True) + + self.assertIn("Invalid response format", str(context.exception)) + + class TestPromptsCompile(TestPrompts): """Tests for the Prompts.compile() method."""