Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions astrbot/core/provider/sources/openai_embedding_source.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import json
from typing import Any

import httpx
from openai import AsyncOpenAI

Expand Down Expand Up @@ -44,7 +47,7 @@ async def get_embedding(self, text: str) -> list[float]:
model=self.model,
**kwargs,
)
return embedding.data[0].embedding
return self._normalize_embedding_response(embedding)[0]

async def get_embeddings(self, text: list[str]) -> list[list[float]]:
"""批量获取文本的嵌入"""
Expand All @@ -54,7 +57,39 @@ async def get_embeddings(self, text: list[str]) -> list[list[float]]:
model=self.model,
**kwargs,
)
return [item.embedding for item in embeddings.data]
return self._normalize_embedding_response(embeddings)

@staticmethod
def _normalize_embedding_response(response: Any) -> list[list[float]]:
if isinstance(response, str):
response = json.loads(response)

data = response and (
response.get("data")
if isinstance(response, dict)
else getattr(response, "data", None)
)
if not isinstance(data, list):
raise TypeError(
f"Unexpected embedding response type: {type(response).__name__}"
)

vectors: list[list[float]] = []
for item in data:
embedding = item and (
item.get("embedding")
if isinstance(item, dict)
else getattr(item, "embedding", None)
)
if not isinstance(embedding, list):
raise TypeError(
f"Unexpected embedding item type: {type(item).__name__}"
)
vectors.append([float(value) for value in embedding])

if not vectors:
raise ValueError("Embedding response did not include any vectors")
return vectors

def _embedding_kwargs(self) -> dict:
"""构建嵌入请求的可选参数"""
Expand Down
99 changes: 99 additions & 0 deletions tests/test_openai_embedding_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import json
from types import SimpleNamespace
from unittest.mock import AsyncMock

import pytest

from astrbot.core.provider.sources.openai_embedding_source import (
OpenAIEmbeddingProvider,
)


def _make_provider() -> OpenAIEmbeddingProvider:
provider = OpenAIEmbeddingProvider(
provider_config={
"id": "test-openai-embedding",
"type": "openai_embedding",
"embedding_api_key": "test-key",
"embedding_api_base": "https://api.openai.com/v1",
"embedding_model": "text-embedding-3-large",
"embedding_dimensions": 3,
},
provider_settings={},
)
provider.client = SimpleNamespace(
embeddings=SimpleNamespace(create=AsyncMock()),
close=AsyncMock(),
)
return provider


@pytest.mark.asyncio
async def test_get_embedding_accepts_sdk_object_response():
provider = _make_provider()
provider.client.embeddings.create.return_value = SimpleNamespace(
data=[SimpleNamespace(embedding=[0.1, 0.2, 0.3])]
)

try:
result = await provider.get_embedding("astrbot")
assert result == [0.1, 0.2, 0.3]
finally:
await provider.terminate()


@pytest.mark.asyncio
async def test_get_embeddings_accepts_json_string_response():
provider = _make_provider()
provider.client.embeddings.create.return_value = json.dumps(
{
"data": [
{"embedding": [0.1, 0.2, 0.3]},
{"embedding": [1, 2, 3]},
]
}
)

try:
result = await provider.get_embeddings(["a", "b"])
assert result == [[0.1, 0.2, 0.3], [1.0, 2.0, 3.0]]
finally:
await provider.terminate()


@pytest.mark.asyncio
async def test_get_embedding_rejects_empty_vectors():
provider = _make_provider()
provider.client.embeddings.create.return_value = {"data": []}

try:
with pytest.raises(
ValueError, match="Embedding response did not include any vectors"
):
await provider.get_embedding("astrbot")
finally:
await provider.terminate()


@pytest.mark.asyncio
async def test_get_embedding_rejects_none_response():
provider = _make_provider()
provider.client.embeddings.create.return_value = None

try:
with pytest.raises(TypeError, match="Unexpected embedding response type"):
await provider.get_embedding("astrbot")
finally:
await provider.terminate()


@pytest.mark.asyncio
async def test_get_embeddings_rejects_none_item():
provider = _make_provider()
provider.client.embeddings.create.return_value = {"data": [None]}

try:
with pytest.raises(TypeError, match="Unexpected embedding item type"):
await provider.get_embeddings(["astrbot"])
finally:
await provider.terminate()
Loading