-
-
Notifications
You must be signed in to change notification settings - Fork 1.9k
fix: append /v1 for OpenAI embedding api base #6910
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,9 @@ | |
| from openai.types.chat.chat_completion import ChatCompletion | ||
|
|
||
| from astrbot.core.provider.sources.groq_source import ProviderGroq | ||
| from astrbot.core.provider.sources.openai_embedding_source import ( | ||
| OpenAIEmbeddingProvider, | ||
| ) | ||
| from astrbot.core.provider.sources.openai_source import ProviderOpenAIOfficial | ||
|
|
||
|
|
||
|
|
@@ -49,6 +52,20 @@ def _make_groq_provider(overrides: dict | None = None) -> ProviderGroq: | |
| ) | ||
|
|
||
|
|
||
| def _make_embedding_provider(overrides: dict | None = None) -> OpenAIEmbeddingProvider: | ||
| provider_config = { | ||
| "id": "test-openai-embedding", | ||
| "type": "openai_embedding", | ||
| "embedding_api_key": "test-key", | ||
| } | ||
| if overrides: | ||
| provider_config.update(overrides) | ||
| return OpenAIEmbeddingProvider( | ||
| provider_config=provider_config, | ||
| provider_settings={}, | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_handle_api_error_content_moderated_removes_images(): | ||
| provider = _make_provider( | ||
|
Comment on lines
69
to
71
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion (testing): Cover edge case where Given the current normalization ( Suggested implementation: finally:
await provider.terminate()
def test_embedding_api_base_trailing_slash_normalized():
provider = _make_provider(
overrides={"embedding_api_base": "https://example.com/openai/"}
)
# The provider should normalize the embedding API base by removing any
# trailing slash and then appending `/v1`, resulting in a single slash.
# This asserts we do *not* end up with `https://example.com/openai//v1/`.
base_url = str(provider.client._client.base_url)
assert base_url == "https://example.com/openai/v1/"Depending on how
The key behavior to lock in is that |
||
|
|
@@ -234,7 +251,9 @@ async def test_openai_payload_keeps_reasoning_content_in_assistant_history(): | |
| provider._finally_convert_payload(payloads) | ||
|
|
||
| assistant_message = payloads["messages"][0] | ||
| assert assistant_message["content"] == [{"type": "text", "text": "final answer"}] | ||
| assert assistant_message["content"] == [ | ||
| {"type": "text", "text": "final answer"} | ||
| ] | ||
| assert assistant_message["reasoning_content"] == "step 1" | ||
| finally: | ||
| await provider.terminate() | ||
|
|
@@ -259,7 +278,9 @@ async def test_groq_payload_drops_reasoning_content_from_assistant_history(): | |
| provider._finally_convert_payload(payloads) | ||
|
|
||
| assistant_message = payloads["messages"][0] | ||
| assert assistant_message["content"] == [{"type": "text", "text": "final answer"}] | ||
| assert assistant_message["content"] == [ | ||
| {"type": "text", "text": "final answer"} | ||
| ] | ||
| assert "reasoning_content" not in assistant_message | ||
| assert "reasoning" not in assistant_message | ||
| finally: | ||
|
|
@@ -533,3 +554,74 @@ async def fake_create(**kwargs): | |
| assert extra_body["temperature"] == 0.1 | ||
| finally: | ||
| await provider.terminate() | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_openai_embedding_provider_appends_v1_to_base_url_when_missing(): | ||
| provider = _make_embedding_provider( | ||
| {"embedding_api_base": "https://example.com/openai"} | ||
| ) | ||
| try: | ||
| assert str(provider.client.base_url) == "https://example.com/openai/v1/" | ||
| finally: | ||
| await provider.terminate() | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_openai_embedding_provider_preserves_existing_v1_suffix(): | ||
| provider = _make_embedding_provider( | ||
| {"embedding_api_base": "https://example.com/openai/v1/"} | ||
| ) | ||
| try: | ||
| assert str(provider.client.base_url) == "https://example.com/openai/v1/" | ||
| finally: | ||
| await provider.terminate() | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_openai_embedding_provider_normalizes_trailing_slash_without_double_slash(): | ||
| provider = _make_embedding_provider( | ||
| {"embedding_api_base": "https://example.com/openai/"} | ||
| ) | ||
| try: | ||
| assert str(provider.client.base_url) == "https://example.com/openai/v1/" | ||
| finally: | ||
| await provider.terminate() | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_openai_embedding_provider_falls_back_to_default_base_for_blank_config(): | ||
| provider = _make_embedding_provider({"embedding_api_base": " "}) | ||
| try: | ||
| assert str(provider.client.base_url) == "https://api.openai.com/v1/" | ||
| finally: | ||
| await provider.terminate() | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_openai_embedding_provider_preserves_versioned_or_specific_paths(): | ||
| base_urls = [ | ||
| "https://example.com/v1-beta", | ||
| "https://example.com/v1/embeddings", | ||
| ] | ||
|
|
||
| for base_url in base_urls: | ||
| provider = _make_embedding_provider({"embedding_api_base": base_url}) | ||
| try: | ||
| assert str(provider.client.base_url) == f"{base_url.rstrip('/')}/" | ||
| finally: | ||
| await provider.terminate() | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_openai_embedding_provider_preserves_query_and_fragment_when_normalizing_path(): | ||
| provider = _make_embedding_provider( | ||
| {"embedding_api_base": "https://example.com/openai/?next=/foo/#frag/"} | ||
| ) | ||
| try: | ||
| assert ( | ||
| str(provider.client.base_url) | ||
| == "https://example.com/openai/v1?next=/foo/#frag/" | ||
| ) | ||
| finally: | ||
| await provider.terminate() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While the current logic correctly normalizes a non-empty
api_base, it doesn't handle cases whereapi_basebecomes an empty string (e.g., if the configuration provides an empty or whitespace-only value). This will cause theAsyncOpenAIclient to fail with anInvalidURLerror. To make this more robust, we should ensure we fall back to the default URL ifapi_baseis empty before normalizing.