Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from typing import cast

import aiofiles
import aiofiles.os
import aiofiles.ospath
import botpy
import botpy.errors
import botpy.message
Expand All @@ -22,6 +24,7 @@
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.io import download_image_by_url, file_to_base64
from astrbot.core.utils.media_utils import convert_audio_to_wav
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk


Expand Down Expand Up @@ -608,7 +611,19 @@ async def _parse_to_qqofficial(message: MessageChain):
temp_dir,
f"qqofficial_{uuid.uuid4()}.silk",
)
converted_record_wav_path = None
try:
if not await QQOfficialMessageEvent._is_wav_audio_file(
record_wav_path
):
converted_record_wav_path = os.path.join(
temp_dir,
f"qqofficial_{uuid.uuid4()}.wav",
)
record_wav_path = await convert_audio_to_wav(
record_wav_path,
converted_record_wav_path,
)
duration = await wav_to_tencent_silk(
record_wav_path,
record_tecent_silk_path,
Expand All @@ -621,6 +636,16 @@ async def _parse_to_qqofficial(message: MessageChain):
except Exception as e:
logger.error(f"处理语音时出错: {e}")
record_file_path = None
finally:
if converted_record_wav_path and await aiofiles.ospath.exists(
converted_record_wav_path
):
try:
await aiofiles.os.remove(converted_record_wav_path)
except OSError as e:
logger.warning(
f"[QQOfficial] failed to remove converted audio file: {e}"
)
elif isinstance(i, Video) and not video_file_source:
if i.file.startswith("file:///"):
video_file_source = i.file[8:]
Expand Down Expand Up @@ -648,3 +673,12 @@ async def _parse_to_qqofficial(message: MessageChain):
file_source,
file_name,
)

@staticmethod
async def _is_wav_audio_file(file_path: str) -> bool:
try:
async with aiofiles.open(file_path, "rb") as f:
header = await f.read(12)
except OSError:
return False
return len(header) >= 12 and header[:4] == b"RIFF" and header[8:12] == b"WAVE"
97 changes: 97 additions & 0 deletions tests/test_qqofficial_message_event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from pathlib import Path

import pytest

from astrbot.core.message.components import Record
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.platform.sources.qqofficial.qqofficial_message_event import (
QQOfficialMessageEvent,
)


def _wav_header() -> bytes:
return b"RIFF\x00\x00\x00\x00WAVEfmt "


@pytest.mark.asyncio
async def test_parse_to_qqofficial_converts_non_wav_record_before_silk(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
):
source_path = tmp_path / "voice.mp3"
source_path.write_bytes(b"ID3fake-mp3")
converted_paths: list[Path] = []

async def fake_convert_audio_to_wav(audio_path: str, output_path: str | None = None):
assert audio_path == str(source_path)
assert output_path is not None
converted_path = Path(output_path)
converted_path.write_bytes(_wav_header())
converted_paths.append(converted_path)
return output_path

async def fake_wav_to_tencent_silk(wav_path: str, silk_path: str):
assert converted_paths
assert wav_path == str(converted_paths[0])
Path(silk_path).write_bytes(b"fake-silk")
return 1200

monkeypatch.setattr(
"astrbot.core.platform.sources.qqofficial.qqofficial_message_event.get_astrbot_temp_path",
lambda: str(tmp_path),
)
monkeypatch.setattr(
"astrbot.core.platform.sources.qqofficial.qqofficial_message_event.convert_audio_to_wav",
fake_convert_audio_to_wav,
)
monkeypatch.setattr(
"astrbot.core.platform.sources.qqofficial.qqofficial_message_event.wav_to_tencent_silk",
fake_wav_to_tencent_silk,
)

result = await QQOfficialMessageEvent._parse_to_qqofficial(
MessageChain([Record.fromFileSystem(str(source_path))])
)

assert converted_paths
assert not converted_paths[0].exists()
assert result[3] is not None
assert str(result[3]).endswith(".silk")


@pytest.mark.asyncio
async def test_parse_to_qqofficial_skips_conversion_for_wav_record(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
):
source_path = tmp_path / "voice.wav"
source_path.write_bytes(_wav_header())
conversions: list[tuple[str, str | None]] = []

async def fake_convert_audio_to_wav(audio_path: str, output_path: str | None = None):
conversions.append((audio_path, output_path))
return output_path or audio_path

async def fake_wav_to_tencent_silk(wav_path: str, silk_path: str):
assert wav_path == str(source_path)
Path(silk_path).write_bytes(b"fake-silk")
return 800

monkeypatch.setattr(
"astrbot.core.platform.sources.qqofficial.qqofficial_message_event.get_astrbot_temp_path",
lambda: str(tmp_path),
)
monkeypatch.setattr(
"astrbot.core.platform.sources.qqofficial.qqofficial_message_event.convert_audio_to_wav",
fake_convert_audio_to_wav,
)
monkeypatch.setattr(
"astrbot.core.platform.sources.qqofficial.qqofficial_message_event.wav_to_tencent_silk",
fake_wav_to_tencent_silk,
)

result = await QQOfficialMessageEvent._parse_to_qqofficial(
MessageChain([Record.fromFileSystem(str(source_path))])
)

assert conversions == []
assert result[3] is not None
assert str(result[3]).endswith(".silk")
Loading