diff --git a/astrbot/core/provider/sources/openai_embedding_source.py b/astrbot/core/provider/sources/openai_embedding_source.py index 2b62d865c2..42abddc929 100644 --- a/astrbot/core/provider/sources/openai_embedding_source.py +++ b/astrbot/core/provider/sources/openai_embedding_source.py @@ -1,3 +1,6 @@ +import json +from typing import Any + import httpx from openai import AsyncOpenAI @@ -44,7 +47,7 @@ async def get_embedding(self, text: str) -> list[float]: model=self.model, **kwargs, ) - return embedding.data[0].embedding + return self._normalize_embedding_response(embedding)[0] async def get_embeddings(self, text: list[str]) -> list[list[float]]: """批量获取文本的嵌入""" @@ -54,7 +57,39 @@ async def get_embeddings(self, text: list[str]) -> list[list[float]]: model=self.model, **kwargs, ) - return [item.embedding for item in embeddings.data] + return self._normalize_embedding_response(embeddings) + + @staticmethod + def _normalize_embedding_response(response: Any) -> list[list[float]]: + if isinstance(response, str): + response = json.loads(response) + + data = response and ( + response.get("data") + if isinstance(response, dict) + else getattr(response, "data", None) + ) + if not isinstance(data, list): + raise TypeError( + f"Unexpected embedding response type: {type(response).__name__}" + ) + + vectors: list[list[float]] = [] + for item in data: + embedding = item and ( + item.get("embedding") + if isinstance(item, dict) + else getattr(item, "embedding", None) + ) + if not isinstance(embedding, list): + raise TypeError( + f"Unexpected embedding item type: {type(item).__name__}" + ) + vectors.append([float(value) for value in embedding]) + + if not vectors: + raise ValueError("Embedding response did not include any vectors") + return vectors def _embedding_kwargs(self) -> dict: """构建嵌入请求的可选参数""" diff --git a/tests/test_openai_embedding_source.py b/tests/test_openai_embedding_source.py new file mode 100644 index 0000000000..2c46117bed --- /dev/null +++ b/tests/test_openai_embedding_source.py @@ -0,0 +1,99 @@ +import json +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest + +from astrbot.core.provider.sources.openai_embedding_source import ( + OpenAIEmbeddingProvider, +) + + +def _make_provider() -> OpenAIEmbeddingProvider: + provider = OpenAIEmbeddingProvider( + provider_config={ + "id": "test-openai-embedding", + "type": "openai_embedding", + "embedding_api_key": "test-key", + "embedding_api_base": "https://api.openai.com/v1", + "embedding_model": "text-embedding-3-large", + "embedding_dimensions": 3, + }, + provider_settings={}, + ) + provider.client = SimpleNamespace( + embeddings=SimpleNamespace(create=AsyncMock()), + close=AsyncMock(), + ) + return provider + + +@pytest.mark.asyncio +async def test_get_embedding_accepts_sdk_object_response(): + provider = _make_provider() + provider.client.embeddings.create.return_value = SimpleNamespace( + data=[SimpleNamespace(embedding=[0.1, 0.2, 0.3])] + ) + + try: + result = await provider.get_embedding("astrbot") + assert result == [0.1, 0.2, 0.3] + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_get_embeddings_accepts_json_string_response(): + provider = _make_provider() + provider.client.embeddings.create.return_value = json.dumps( + { + "data": [ + {"embedding": [0.1, 0.2, 0.3]}, + {"embedding": [1, 2, 3]}, + ] + } + ) + + try: + result = await provider.get_embeddings(["a", "b"]) + assert result == [[0.1, 0.2, 0.3], [1.0, 2.0, 3.0]] + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_get_embedding_rejects_empty_vectors(): + provider = _make_provider() + provider.client.embeddings.create.return_value = {"data": []} + + try: + with pytest.raises( + ValueError, match="Embedding response did not include any vectors" + ): + await provider.get_embedding("astrbot") + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_get_embedding_rejects_none_response(): + provider = _make_provider() + provider.client.embeddings.create.return_value = None + + try: + with pytest.raises(TypeError, match="Unexpected embedding response type"): + await provider.get_embedding("astrbot") + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_get_embeddings_rejects_none_item(): + provider = _make_provider() + provider.client.embeddings.create.return_value = {"data": [None]} + + try: + with pytest.raises(TypeError, match="Unexpected embedding item type"): + await provider.get_embeddings(["astrbot"]) + finally: + await provider.terminate()