diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py index 97b2b2fb49..66d15fee76 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py @@ -6,6 +6,8 @@ from typing import cast import aiofiles +import aiofiles.os +import aiofiles.ospath import botpy import botpy.errors import botpy.message @@ -22,6 +24,7 @@ from astrbot.api.platform import AstrBotMessage, PlatformMetadata from astrbot.core.utils.astrbot_path import get_astrbot_temp_path from astrbot.core.utils.io import download_image_by_url, file_to_base64 +from astrbot.core.utils.media_utils import convert_audio_to_wav from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk @@ -608,7 +611,19 @@ async def _parse_to_qqofficial(message: MessageChain): temp_dir, f"qqofficial_{uuid.uuid4()}.silk", ) + converted_record_wav_path = None try: + if not await QQOfficialMessageEvent._is_wav_audio_file( + record_wav_path + ): + converted_record_wav_path = os.path.join( + temp_dir, + f"qqofficial_{uuid.uuid4()}.wav", + ) + record_wav_path = await convert_audio_to_wav( + record_wav_path, + converted_record_wav_path, + ) duration = await wav_to_tencent_silk( record_wav_path, record_tecent_silk_path, @@ -621,6 +636,16 @@ async def _parse_to_qqofficial(message: MessageChain): except Exception as e: logger.error(f"处理语音时出错: {e}") record_file_path = None + finally: + if converted_record_wav_path and await aiofiles.ospath.exists( + converted_record_wav_path + ): + try: + await aiofiles.os.remove(converted_record_wav_path) + except OSError as e: + logger.warning( + f"[QQOfficial] failed to remove converted audio file: {e}" + ) elif isinstance(i, Video) and not video_file_source: if i.file.startswith("file:///"): video_file_source = i.file[8:] @@ -648,3 +673,12 @@ async def _parse_to_qqofficial(message: MessageChain): file_source, file_name, ) + + @staticmethod + async def _is_wav_audio_file(file_path: str) -> bool: + try: + async with aiofiles.open(file_path, "rb") as f: + header = await f.read(12) + except OSError: + return False + return len(header) >= 12 and header[:4] == b"RIFF" and header[8:12] == b"WAVE" diff --git a/tests/test_qqofficial_message_event.py b/tests/test_qqofficial_message_event.py new file mode 100644 index 0000000000..3af85b0007 --- /dev/null +++ b/tests/test_qqofficial_message_event.py @@ -0,0 +1,97 @@ +from pathlib import Path + +import pytest + +from astrbot.core.message.components import Record +from astrbot.core.message.message_event_result import MessageChain +from astrbot.core.platform.sources.qqofficial.qqofficial_message_event import ( + QQOfficialMessageEvent, +) + + +def _wav_header() -> bytes: + return b"RIFF\x00\x00\x00\x00WAVEfmt " + + +@pytest.mark.asyncio +async def test_parse_to_qqofficial_converts_non_wav_record_before_silk( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +): + source_path = tmp_path / "voice.mp3" + source_path.write_bytes(b"ID3fake-mp3") + converted_paths: list[Path] = [] + + async def fake_convert_audio_to_wav(audio_path: str, output_path: str | None = None): + assert audio_path == str(source_path) + assert output_path is not None + converted_path = Path(output_path) + converted_path.write_bytes(_wav_header()) + converted_paths.append(converted_path) + return output_path + + async def fake_wav_to_tencent_silk(wav_path: str, silk_path: str): + assert converted_paths + assert wav_path == str(converted_paths[0]) + Path(silk_path).write_bytes(b"fake-silk") + return 1200 + + monkeypatch.setattr( + "astrbot.core.platform.sources.qqofficial.qqofficial_message_event.get_astrbot_temp_path", + lambda: str(tmp_path), + ) + monkeypatch.setattr( + "astrbot.core.platform.sources.qqofficial.qqofficial_message_event.convert_audio_to_wav", + fake_convert_audio_to_wav, + ) + monkeypatch.setattr( + "astrbot.core.platform.sources.qqofficial.qqofficial_message_event.wav_to_tencent_silk", + fake_wav_to_tencent_silk, + ) + + result = await QQOfficialMessageEvent._parse_to_qqofficial( + MessageChain([Record.fromFileSystem(str(source_path))]) + ) + + assert converted_paths + assert not converted_paths[0].exists() + assert result[3] is not None + assert str(result[3]).endswith(".silk") + + +@pytest.mark.asyncio +async def test_parse_to_qqofficial_skips_conversion_for_wav_record( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +): + source_path = tmp_path / "voice.wav" + source_path.write_bytes(_wav_header()) + conversions: list[tuple[str, str | None]] = [] + + async def fake_convert_audio_to_wav(audio_path: str, output_path: str | None = None): + conversions.append((audio_path, output_path)) + return output_path or audio_path + + async def fake_wav_to_tencent_silk(wav_path: str, silk_path: str): + assert wav_path == str(source_path) + Path(silk_path).write_bytes(b"fake-silk") + return 800 + + monkeypatch.setattr( + "astrbot.core.platform.sources.qqofficial.qqofficial_message_event.get_astrbot_temp_path", + lambda: str(tmp_path), + ) + monkeypatch.setattr( + "astrbot.core.platform.sources.qqofficial.qqofficial_message_event.convert_audio_to_wav", + fake_convert_audio_to_wav, + ) + monkeypatch.setattr( + "astrbot.core.platform.sources.qqofficial.qqofficial_message_event.wav_to_tencent_silk", + fake_wav_to_tencent_silk, + ) + + result = await QQOfficialMessageEvent._parse_to_qqofficial( + MessageChain([Record.fromFileSystem(str(source_path))]) + ) + + assert conversions == [] + assert result[3] is not None + assert str(result[3]).endswith(".silk")