Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import botpy
import botpy.message
from botpy import Client
from botpy.client import ConnectionSession, Robot, Token

from astrbot import logger
from astrbot.api.event import MessageChain
Expand All @@ -34,10 +35,43 @@


# QQ 机器人官方框架
class QQOfficialGatewayUnavailableError(RuntimeError):
"""botpy returned no usable gateway metadata during startup."""


class botClient(Client):
def set_platform(self, platform: QQOfficialPlatformAdapter) -> None:
self.platform = platform

async def _bot_login(self, token: Token) -> None:
user = await self.http.login(token)

# botpy may return None here after a transient /gateway/bot timeout.
self._ws_ap = await self.api.get_ws_url()
session_limit = (
self._ws_ap.get("session_start_limit")
if isinstance(self._ws_ap, dict)
else None
)
max_concurrency = (
session_limit.get("max_concurrency")
if isinstance(session_limit, dict)
else None
)
if not isinstance(max_concurrency, int):
raise QQOfficialGatewayUnavailableError(
"gateway metadata unavailable during qq_official startup"
)

self._connection = ConnectionSession(
max_async=max_concurrency,
connect=self.bot_connect,
dispatch=self.ws_dispatch,
loop=self.loop,
api=self.api,
)
self._connection.state.robot = Robot(user)

# 收到群消息
async def on_group_at_message_create(
self, message: botpy.message.GroupMessage
Expand Down Expand Up @@ -99,6 +133,8 @@ def _commit(self, abm: AstrBotMessage) -> None:

@register_platform_adapter("qq_official", "QQ 机器人官方 API 适配器")
class QQOfficialPlatformAdapter(Platform):
STARTUP_RETRY_DELAY_SECONDS = 5

def __init__(
self,
platform_config: dict,
Expand All @@ -123,18 +159,57 @@ def __init__(
public_guild_messages=True,
direct_message=guild_dm,
)
self.client = botClient(
self._shutdown_event = asyncio.Event()
self.client = self._create_client()

self._session_last_message_id: dict[str, str] = {}
self._session_scene: dict[str, str] = {}

self.test_mode = os.environ.get("TEST_MODE", "off") == "on"

def _create_client(self) -> botClient:
client = botClient(
intents=self.intents,
bot_log=False,
timeout=20,
)
client.set_platform(self)
return client

self.client.set_platform(self)
@staticmethod
def _should_retry_startup_error(error: Exception) -> bool:
return isinstance(
error,
(
asyncio.TimeoutError,
ConnectionError,
OSError,
QQOfficialGatewayUnavailableError,
),
)

self._session_last_message_id: dict[str, str] = {}
self._session_scene: dict[str, str] = {}
async def _restart_client(self) -> None:
try:
await self.client.close()
except asyncio.CancelledError:
raise
except Exception as e:
logger.warning(
"qq_official(%s): close client failed during recovery: %s",
self.meta().id,
e,
)
self.client = self._create_client()

self.test_mode = os.environ.get("TEST_MODE", "off") == "on"
async def _sleep_until_retry_or_shutdown(self) -> bool:
try:
await asyncio.wait_for(
self._shutdown_event.wait(),
timeout=self.STARTUP_RETRY_DELAY_SECONDS,
)
return False
except asyncio.TimeoutError:
return True

async def send_by_session(
self,
Expand Down Expand Up @@ -500,12 +575,60 @@ def _parse_from_qqofficial(
abm.self_id = "qq_official"
return abm

def run(self):
return self.client.start(appid=self.appid, secret=self.secret)
async def run(self) -> None:
try:
while not self._shutdown_event.is_set():
try:
await self.client.start(appid=self.appid, secret=self.secret)
if self._shutdown_event.is_set():
break
logger.warning(
"qq_official(%s): client stopped unexpectedly, restarting in %ss",
self.meta().id,
self.STARTUP_RETRY_DELAY_SECONDS,
)
except asyncio.CancelledError:
raise
except Exception as e:
if not self._should_retry_startup_error(e):
raise
if self._shutdown_event.is_set():
break
logger.warning(
"qq_official(%s): startup failed, retrying in %ss: %s",
self.meta().id,
self.STARTUP_RETRY_DELAY_SECONDS,
e,
)

await self._restart_client()
if not await self._sleep_until_retry_or_shutdown():
break
finally:
try:
await self.client.close()
except asyncio.CancelledError:
raise
except Exception as e:
logger.warning(
"qq_official(%s): close client failed during shutdown: %s",
self.meta().id,
e,
)

def get_client(self) -> botClient:
return self.client

async def terminate(self) -> None:
await self.client.close()
self._shutdown_event.set()
try:
await self.client.close()
except asyncio.CancelledError:
raise
except Exception as e:
logger.warning(
"qq_official(%s): close client failed during shutdown: %s",
self.meta().id,
e,
)
logger.info("QQ 官方机器人接口 适配器已被优雅地关闭")
115 changes: 115 additions & 0 deletions tests/test_qqofficial_platform_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import asyncio
from types import SimpleNamespace
from unittest.mock import AsyncMock

import pytest

from astrbot.core.platform.sources.qqofficial.qqofficial_platform_adapter import (
QQOfficialGatewayUnavailableError,
QQOfficialPlatformAdapter,
)


def _platform_config() -> dict:
return {
"id": "qq-official-test",
"appid": "appid",
"secret": "secret",
"enable_group_c2c": True,
"enable_guild_direct_message": True,
}


@pytest.mark.asyncio
async def test_qqofficial_run_retries_after_gateway_timeout(monkeypatch):
first_client = SimpleNamespace(
start=AsyncMock(
side_effect=QQOfficialGatewayUnavailableError(
"gateway metadata unavailable during qq_official startup"
)
),
close=AsyncMock(),
)
adapter_holder: dict[str, QQOfficialPlatformAdapter] = {}

async def second_start(*args, **kwargs):
adapter_holder["adapter"]._shutdown_event.set()
return None

second_client = SimpleNamespace(
start=AsyncMock(side_effect=second_start),
close=AsyncMock(),
)
clients = iter([first_client, second_client])
monkeypatch.setattr(
QQOfficialPlatformAdapter,
"_create_client",
lambda self: next(clients),
)

adapter = QQOfficialPlatformAdapter(_platform_config(), {}, asyncio.Queue())
adapter_holder["adapter"] = adapter
adapter.STARTUP_RETRY_DELAY_SECONDS = 0

await adapter.run()

first_client.start.assert_awaited_once_with(appid="appid", secret="secret")
first_client.close.assert_awaited_once()
second_client.start.assert_awaited_once_with(appid="appid", secret="secret")
second_client.close.assert_awaited_once()


@pytest.mark.asyncio
async def test_qqofficial_run_reraises_non_retryable_error(monkeypatch):
client = SimpleNamespace(
start=AsyncMock(side_effect=ValueError("invalid credentials")),
close=AsyncMock(),
)
monkeypatch.setattr(
QQOfficialPlatformAdapter,
"_create_client",
lambda self: client,
)

adapter = QQOfficialPlatformAdapter(_platform_config(), {}, asyncio.Queue())

with pytest.raises(ValueError, match="invalid credentials"):
await adapter.run()

client.start.assert_awaited_once_with(appid="appid", secret="secret")
client.close.assert_awaited_once()


@pytest.mark.asyncio
async def test_qqofficial_bot_login_raises_gateway_error_when_metadata_missing():
adapter = QQOfficialPlatformAdapter(_platform_config(), {}, asyncio.Queue())
adapter.client.http = SimpleNamespace(login=AsyncMock(return_value=SimpleNamespace()))
adapter.client.api = SimpleNamespace(get_ws_url=AsyncMock(return_value=None))

with pytest.raises(
QQOfficialGatewayUnavailableError,
match="gateway metadata unavailable",
):
await adapter.client._bot_login(SimpleNamespace())

await adapter.terminate()


@pytest.mark.asyncio
async def test_qqofficial_run_propagates_cancelled_error(monkeypatch):
client = SimpleNamespace(
start=AsyncMock(side_effect=asyncio.CancelledError()),
close=AsyncMock(),
)
monkeypatch.setattr(
QQOfficialPlatformAdapter,
"_create_client",
lambda self: client,
)

adapter = QQOfficialPlatformAdapter(_platform_config(), {}, asyncio.Queue())

with pytest.raises(asyncio.CancelledError):
await adapter.run()

client.close.assert_awaited_once()
Loading