Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion astrbot/core/provider/sources/openai_embedding_source.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from urllib.parse import urlsplit, urlunsplit

import httpx
from openai import AsyncOpenAI

Expand All @@ -14,6 +16,29 @@
provider_type=ProviderType.EMBEDDING,
)
class OpenAIEmbeddingProvider(EmbeddingProvider):
DEFAULT_EMBEDDING_API_BASE = "https://api.openai.com/v1"

@staticmethod
def _normalize_embedding_api_base(api_base: str) -> str:
"""Normalize root-style embedding base URLs while avoiding path-specific ones.

Auto-append ``/v1`` only for host roots or single-segment paths such as
``https://example.com`` or ``https://example.com/openai``. More specific
paths (for example ``/v1-beta`` or ``/v1/embeddings``) are preserved as-is.
"""
parsed = urlsplit(api_base)
normalized_path = parsed.path.rstrip("/") if parsed.path else ""
path_segments = [segment for segment in normalized_path.split("/") if segment]
has_version_segment = any(
len(segment) > 1 and segment.startswith("v") and segment[1].isdigit()
for segment in path_segments
)
if has_version_segment or len(path_segments) > 1:
return urlunsplit(parsed._replace(path=normalized_path))

normalized_path = f"{normalized_path}/v1" if normalized_path else "/v1"
return urlunsplit(parsed._replace(path=normalized_path))

def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config, provider_settings)
self.provider_config = provider_config
Expand All @@ -25,8 +50,12 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None:
logger.info(f"[OpenAI Embedding] {provider_id} Using proxy: {proxy}")
http_client = httpx.AsyncClient(proxy=proxy)
api_base = provider_config.get(
"embedding_api_base", "https://api.openai.com/v1"
"embedding_api_base", self.DEFAULT_EMBEDDING_API_BASE
).strip()
if api_base:
api_base = self._normalize_embedding_api_base(api_base)
Comment on lines +55 to +56
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

While the current logic correctly normalizes a non-empty api_base, it doesn't handle cases where api_base becomes an empty string (e.g., if the configuration provides an empty or whitespace-only value). This will cause the AsyncOpenAI client to fail with an InvalidURL error. To make this more robust, we should ensure we fall back to the default URL if api_base is empty before normalizing.

api_base = self._normalize_embedding_api_base(api_base or "https://api.openai.com/v1")

else:
api_base = self.DEFAULT_EMBEDDING_API_BASE
logger.info(f"[OpenAI Embedding] {provider_id} Using API Base: {api_base}")
self.client = AsyncOpenAI(
api_key=provider_config.get("embedding_api_key"),
Expand Down
96 changes: 94 additions & 2 deletions tests/test_openai_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from openai.types.chat.chat_completion import ChatCompletion

from astrbot.core.provider.sources.groq_source import ProviderGroq
from astrbot.core.provider.sources.openai_embedding_source import (
OpenAIEmbeddingProvider,
)
from astrbot.core.provider.sources.openai_source import ProviderOpenAIOfficial


Expand Down Expand Up @@ -49,6 +52,20 @@ def _make_groq_provider(overrides: dict | None = None) -> ProviderGroq:
)


def _make_embedding_provider(overrides: dict | None = None) -> OpenAIEmbeddingProvider:
provider_config = {
"id": "test-openai-embedding",
"type": "openai_embedding",
"embedding_api_key": "test-key",
}
if overrides:
provider_config.update(overrides)
return OpenAIEmbeddingProvider(
provider_config=provider_config,
provider_settings={},
)


@pytest.mark.asyncio
async def test_handle_api_error_content_moderated_removes_images():
provider = _make_provider(
Comment on lines 69 to 71
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Cover edge case where embedding_api_base ends with a trailing slash but no /v1

Given the current normalization (rstrip('/') then conditionally appending /v1), please add a test for embedding_api_base="https://example.com/openai/" to assert it becomes https://example.com/openai/v1/ and not https://example.com/openai//v1/, so the trailing-slash handling is locked in.

Suggested implementation:

    finally:
        await provider.terminate()


def test_embedding_api_base_trailing_slash_normalized():
    provider = _make_provider(
        overrides={"embedding_api_base": "https://example.com/openai/"}
    )

    # The provider should normalize the embedding API base by removing any
    # trailing slash and then appending `/v1`, resulting in a single slash.
    # This asserts we do *not* end up with `https://example.com/openai//v1/`.
    base_url = str(provider.client._client.base_url)
    assert base_url == "https://example.com/openai/v1/"

Depending on how OpenAIEmbeddingProvider exposes its underlying OpenAI client, you may need to adjust the attribute chain used to read the base URL:

  • If the provider exposes the client as provider._client instead of provider.client, change provider.client._client.base_url to provider._client.base_url or provider._client._client.base_url.
  • If the base URL is stored on a different attribute (e.g. provider._client.base_url or provider.client.base_url), update the test accordingly while keeping the assertion value https://example.com/openai/v1/.

The key behavior to lock in is that "https://example.com/openai/" normalizes to "https://example.com/openai/v1/" without a double slash.

Expand Down Expand Up @@ -234,7 +251,9 @@ async def test_openai_payload_keeps_reasoning_content_in_assistant_history():
provider._finally_convert_payload(payloads)

assistant_message = payloads["messages"][0]
assert assistant_message["content"] == [{"type": "text", "text": "final answer"}]
assert assistant_message["content"] == [
{"type": "text", "text": "final answer"}
]
assert assistant_message["reasoning_content"] == "step 1"
finally:
await provider.terminate()
Expand All @@ -259,7 +278,9 @@ async def test_groq_payload_drops_reasoning_content_from_assistant_history():
provider._finally_convert_payload(payloads)

assistant_message = payloads["messages"][0]
assert assistant_message["content"] == [{"type": "text", "text": "final answer"}]
assert assistant_message["content"] == [
{"type": "text", "text": "final answer"}
]
assert "reasoning_content" not in assistant_message
assert "reasoning" not in assistant_message
finally:
Expand Down Expand Up @@ -533,3 +554,74 @@ async def fake_create(**kwargs):
assert extra_body["temperature"] == 0.1
finally:
await provider.terminate()


@pytest.mark.asyncio
async def test_openai_embedding_provider_appends_v1_to_base_url_when_missing():
provider = _make_embedding_provider(
{"embedding_api_base": "https://example.com/openai"}
)
try:
assert str(provider.client.base_url) == "https://example.com/openai/v1/"
finally:
await provider.terminate()


@pytest.mark.asyncio
async def test_openai_embedding_provider_preserves_existing_v1_suffix():
provider = _make_embedding_provider(
{"embedding_api_base": "https://example.com/openai/v1/"}
)
try:
assert str(provider.client.base_url) == "https://example.com/openai/v1/"
finally:
await provider.terminate()


@pytest.mark.asyncio
async def test_openai_embedding_provider_normalizes_trailing_slash_without_double_slash():
provider = _make_embedding_provider(
{"embedding_api_base": "https://example.com/openai/"}
)
try:
assert str(provider.client.base_url) == "https://example.com/openai/v1/"
finally:
await provider.terminate()


@pytest.mark.asyncio
async def test_openai_embedding_provider_falls_back_to_default_base_for_blank_config():
provider = _make_embedding_provider({"embedding_api_base": " "})
try:
assert str(provider.client.base_url) == "https://api.openai.com/v1/"
finally:
await provider.terminate()


@pytest.mark.asyncio
async def test_openai_embedding_provider_preserves_versioned_or_specific_paths():
base_urls = [
"https://example.com/v1-beta",
"https://example.com/v1/embeddings",
]

for base_url in base_urls:
provider = _make_embedding_provider({"embedding_api_base": base_url})
try:
assert str(provider.client.base_url) == f"{base_url.rstrip('/')}/"
finally:
await provider.terminate()


@pytest.mark.asyncio
async def test_openai_embedding_provider_preserves_query_and_fragment_when_normalizing_path():
provider = _make_embedding_provider(
{"embedding_api_base": "https://example.com/openai/?next=/foo/#frag/"}
)
try:
assert (
str(provider.client.base_url)
== "https://example.com/openai/v1?next=/foo/#frag/"
)
finally:
await provider.terminate()
Loading