diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py index 7e31536a16..99d5719ca8 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py @@ -11,6 +11,7 @@ import botpy import botpy.message from botpy import Client +from botpy.client import ConnectionSession, Robot, Token from astrbot import logger from astrbot.api.event import MessageChain @@ -34,10 +35,43 @@ # QQ 机器人官方框架 +class QQOfficialGatewayUnavailableError(RuntimeError): + """botpy returned no usable gateway metadata during startup.""" + + class botClient(Client): def set_platform(self, platform: QQOfficialPlatformAdapter) -> None: self.platform = platform + async def _bot_login(self, token: Token) -> None: + user = await self.http.login(token) + + # botpy may return None here after a transient /gateway/bot timeout. + self._ws_ap = await self.api.get_ws_url() + session_limit = ( + self._ws_ap.get("session_start_limit") + if isinstance(self._ws_ap, dict) + else None + ) + max_concurrency = ( + session_limit.get("max_concurrency") + if isinstance(session_limit, dict) + else None + ) + if not isinstance(max_concurrency, int): + raise QQOfficialGatewayUnavailableError( + "gateway metadata unavailable during qq_official startup" + ) + + self._connection = ConnectionSession( + max_async=max_concurrency, + connect=self.bot_connect, + dispatch=self.ws_dispatch, + loop=self.loop, + api=self.api, + ) + self._connection.state.robot = Robot(user) + # 收到群消息 async def on_group_at_message_create( self, message: botpy.message.GroupMessage @@ -99,6 +133,8 @@ def _commit(self, abm: AstrBotMessage) -> None: @register_platform_adapter("qq_official", "QQ 机器人官方 API 适配器") class QQOfficialPlatformAdapter(Platform): + STARTUP_RETRY_DELAY_SECONDS = 5 + def __init__( self, platform_config: dict, @@ -123,18 +159,57 @@ def __init__( public_guild_messages=True, direct_message=guild_dm, ) - self.client = botClient( + self._shutdown_event = asyncio.Event() + self.client = self._create_client() + + self._session_last_message_id: dict[str, str] = {} + self._session_scene: dict[str, str] = {} + + self.test_mode = os.environ.get("TEST_MODE", "off") == "on" + + def _create_client(self) -> botClient: + client = botClient( intents=self.intents, bot_log=False, timeout=20, ) + client.set_platform(self) + return client - self.client.set_platform(self) + @staticmethod + def _should_retry_startup_error(error: Exception) -> bool: + return isinstance( + error, + ( + asyncio.TimeoutError, + ConnectionError, + OSError, + QQOfficialGatewayUnavailableError, + ), + ) - self._session_last_message_id: dict[str, str] = {} - self._session_scene: dict[str, str] = {} + async def _restart_client(self) -> None: + try: + await self.client.close() + except asyncio.CancelledError: + raise + except Exception as e: + logger.warning( + "qq_official(%s): close client failed during recovery: %s", + self.meta().id, + e, + ) + self.client = self._create_client() - self.test_mode = os.environ.get("TEST_MODE", "off") == "on" + async def _sleep_until_retry_or_shutdown(self) -> bool: + try: + await asyncio.wait_for( + self._shutdown_event.wait(), + timeout=self.STARTUP_RETRY_DELAY_SECONDS, + ) + return False + except asyncio.TimeoutError: + return True async def send_by_session( self, @@ -500,12 +575,60 @@ def _parse_from_qqofficial( abm.self_id = "qq_official" return abm - def run(self): - return self.client.start(appid=self.appid, secret=self.secret) + async def run(self) -> None: + try: + while not self._shutdown_event.is_set(): + try: + await self.client.start(appid=self.appid, secret=self.secret) + if self._shutdown_event.is_set(): + break + logger.warning( + "qq_official(%s): client stopped unexpectedly, restarting in %ss", + self.meta().id, + self.STARTUP_RETRY_DELAY_SECONDS, + ) + except asyncio.CancelledError: + raise + except Exception as e: + if not self._should_retry_startup_error(e): + raise + if self._shutdown_event.is_set(): + break + logger.warning( + "qq_official(%s): startup failed, retrying in %ss: %s", + self.meta().id, + self.STARTUP_RETRY_DELAY_SECONDS, + e, + ) + + await self._restart_client() + if not await self._sleep_until_retry_or_shutdown(): + break + finally: + try: + await self.client.close() + except asyncio.CancelledError: + raise + except Exception as e: + logger.warning( + "qq_official(%s): close client failed during shutdown: %s", + self.meta().id, + e, + ) def get_client(self) -> botClient: return self.client async def terminate(self) -> None: - await self.client.close() + self._shutdown_event.set() + try: + await self.client.close() + except asyncio.CancelledError: + raise + except Exception as e: + logger.warning( + "qq_official(%s): close client failed during shutdown: %s", + self.meta().id, + e, + ) logger.info("QQ 官方机器人接口 适配器已被优雅地关闭") diff --git a/tests/test_qqofficial_platform_adapter.py b/tests/test_qqofficial_platform_adapter.py new file mode 100644 index 0000000000..40e0c42488 --- /dev/null +++ b/tests/test_qqofficial_platform_adapter.py @@ -0,0 +1,115 @@ +import asyncio +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest + +from astrbot.core.platform.sources.qqofficial.qqofficial_platform_adapter import ( + QQOfficialGatewayUnavailableError, + QQOfficialPlatformAdapter, +) + + +def _platform_config() -> dict: + return { + "id": "qq-official-test", + "appid": "appid", + "secret": "secret", + "enable_group_c2c": True, + "enable_guild_direct_message": True, + } + + +@pytest.mark.asyncio +async def test_qqofficial_run_retries_after_gateway_timeout(monkeypatch): + first_client = SimpleNamespace( + start=AsyncMock( + side_effect=QQOfficialGatewayUnavailableError( + "gateway metadata unavailable during qq_official startup" + ) + ), + close=AsyncMock(), + ) + adapter_holder: dict[str, QQOfficialPlatformAdapter] = {} + + async def second_start(*args, **kwargs): + adapter_holder["adapter"]._shutdown_event.set() + return None + + second_client = SimpleNamespace( + start=AsyncMock(side_effect=second_start), + close=AsyncMock(), + ) + clients = iter([first_client, second_client]) + monkeypatch.setattr( + QQOfficialPlatformAdapter, + "_create_client", + lambda self: next(clients), + ) + + adapter = QQOfficialPlatformAdapter(_platform_config(), {}, asyncio.Queue()) + adapter_holder["adapter"] = adapter + adapter.STARTUP_RETRY_DELAY_SECONDS = 0 + + await adapter.run() + + first_client.start.assert_awaited_once_with(appid="appid", secret="secret") + first_client.close.assert_awaited_once() + second_client.start.assert_awaited_once_with(appid="appid", secret="secret") + second_client.close.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_qqofficial_run_reraises_non_retryable_error(monkeypatch): + client = SimpleNamespace( + start=AsyncMock(side_effect=ValueError("invalid credentials")), + close=AsyncMock(), + ) + monkeypatch.setattr( + QQOfficialPlatformAdapter, + "_create_client", + lambda self: client, + ) + + adapter = QQOfficialPlatformAdapter(_platform_config(), {}, asyncio.Queue()) + + with pytest.raises(ValueError, match="invalid credentials"): + await adapter.run() + + client.start.assert_awaited_once_with(appid="appid", secret="secret") + client.close.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_qqofficial_bot_login_raises_gateway_error_when_metadata_missing(): + adapter = QQOfficialPlatformAdapter(_platform_config(), {}, asyncio.Queue()) + adapter.client.http = SimpleNamespace(login=AsyncMock(return_value=SimpleNamespace())) + adapter.client.api = SimpleNamespace(get_ws_url=AsyncMock(return_value=None)) + + with pytest.raises( + QQOfficialGatewayUnavailableError, + match="gateway metadata unavailable", + ): + await adapter.client._bot_login(SimpleNamespace()) + + await adapter.terminate() + + +@pytest.mark.asyncio +async def test_qqofficial_run_propagates_cancelled_error(monkeypatch): + client = SimpleNamespace( + start=AsyncMock(side_effect=asyncio.CancelledError()), + close=AsyncMock(), + ) + monkeypatch.setattr( + QQOfficialPlatformAdapter, + "_create_client", + lambda self: client, + ) + + adapter = QQOfficialPlatformAdapter(_platform_config(), {}, asyncio.Queue()) + + with pytest.raises(asyncio.CancelledError): + await adapter.run() + + client.close.assert_awaited_once()