Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions astrbot/core/provider/sources/openai_embedding_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None:
api_base = provider_config.get(
"embedding_api_base", "https://api.openai.com/v1"
).strip()
if api_base and not api_base.endswith("/v1") and not api_base.endswith("/v1/"):
api_base = api_base.rstrip("/") + "/v1"
Comment on lines +30 to +31
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current logic for handling api_base does not correctly fall back to the default https://api.openai.com/v1 when embedding_api_base is explicitly provided as an empty string. If provider_config.get("embedding_api_base") returns an empty string, the if api_base condition evaluates to False, skipping the logic to append /v1. This results in self.client.base_url being set to an empty string, which is likely an invalid URL for AsyncOpenAI and would cause connection errors.

The unit test test_empty_api_base_uses_default in tests/unit/test_openai_embedding_source.py correctly identifies the desired behavior (falling back to the default), but the current implementation will cause that test to fail. The proposed change ensures that an empty api_base (after stripping whitespace) is treated as if it were not provided, and thus defaults to https://api.openai.com/v1.

        if not api_base:
            api_base = "https://api.openai.com/v1"
        elif not api_base.endswith("/v1") and not api_base.endswith("/v1/"):
            api_base = api_base.rstrip("/") + "/v1"

logger.info(f"[OpenAI Embedding] {provider_id} Using API Base: {api_base}")
self.client = AsyncOpenAI(
api_key=provider_config.get("embedding_api_key"),
Expand Down
72 changes: 72 additions & 0 deletions tests/unit/test_openai_embedding_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from unittest.mock import AsyncMock, patch

import pytest

from astrbot.core.provider.sources.openai_embedding_source import (
OpenAIEmbeddingProvider,
)


def _make_provider(overrides: dict | None = None) -> OpenAIEmbeddingProvider:
provider_config = {
"id": "test-openai-embedding",
"embedding_api_key": "test-key",
"embedding_model": "text-embedding-3-small",
}
if overrides:
provider_config.update(overrides)
return OpenAIEmbeddingProvider(
provider_config=provider_config,
provider_settings={},
)


class TestOpenAIEmbeddingProviderApiBaseV1Suffix:
"""Test that /v1 suffix is auto-appended to embedding_api_base.

Regression test for: https://github.com/AstrBotDevs/AstrBot/issues/6887
PR #6669 removed automatic /v1 suffix because some providers don't use
standard /v1/embeddings endpoint, but this broke OpenAI-compatible
providers. PR #6863 reintroduces the auto-append logic.
"""

def test_api_base_without_v1_gets_v1_appended(self) -> None:
"""api_base like 'https://api.openai.com' should become 'https://api.openai.com/v1'."""
provider = _make_provider({"embedding_api_base": "https://api.openai.com"})
# The provider should auto-append /v1
assert provider.client.base_url == "https://api.openai.com/v1"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Add a test case for embedding_api_base with surrounding whitespace

Since the constructor applies .strip() to embedding_api_base, it would be useful to add a test that passes values with leading/trailing spaces (e.g. ' https://api.openai.com ' or ' https://api.openai.com/ ') and asserts that provider.client.base_url is normalized to https://api.openai.com/v1.


def test_api_base_with_trailing_slash_gets_v1_appended(self) -> None:
"""api_base like 'https://api.openai.com/' should become 'https://api.openai.com/v1'."""
provider = _make_provider({"embedding_api_base": "https://api.openai.com/"})
assert provider.client.base_url == "https://api.openai.com/v1"

def test_api_base_already_with_v1_is_unchanged(self) -> None:
"""api_base already ending with /v1 should not double-append."""
provider = _make_provider({"embedding_api_base": "https://api.openai.com/v1"})
assert provider.client.base_url == "https://api.openai.com/v1"

def test_api_base_with_v1_trailing_slash_is_unchanged(self) -> None:
"""api_base already ending with /v1/ should not double-append."""
provider = _make_provider({"embedding_api_base": "https://api.openai.com/v1/"})
assert provider.client.base_url == "https://api.openai.com/v1/"

def test_api_base_custom_endpoint_without_v1_gets_v1_appended(self) -> None:
"""Custom API base like 'https://openai.example.com' should become 'https://openai.example.com/v1'."""
provider = _make_provider({"embedding_api_base": "https://openai.example.com"})
assert provider.client.base_url == "https://openai.example.com/v1"

def test_api_base_custom_endpoint_already_with_v1_is_unchanged(self) -> None:
"""Custom API base already with /v1 should not change."""
provider = _make_provider({"embedding_api_base": "https://openai.example.com/v1"})
assert provider.client.base_url == "https://openai.example.com/v1"

def test_empty_api_base_uses_default(self) -> None:
"""Empty api_base should use the default OpenAI endpoint."""
provider = _make_provider({"embedding_api_base": ""})
assert provider.client.base_url == "https://api.openai.com/v1"

def test_default_api_base_is_unchanged(self) -> None:
"""Default api_base (not set) should be the standard OpenAI endpoint."""
provider = _make_provider()
assert provider.client.base_url == "https://api.openai.com/v1"
Loading