diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index 6311681cd6..5ab2765a1f 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -36,7 +36,13 @@ from astrbot.core import astrbot_config, file_token_service, logger from astrbot.core.utils.astrbot_path import get_astrbot_temp_path -from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64 +from astrbot.core.utils.io import ( # noqa: I001 + download_audio_by_url, + download_file, + download_image_by_url, + file_to_base64, + save_temp_audio, +) class ComponentType(str, Enum): @@ -157,17 +163,12 @@ async def convert_to_file_path(self) -> str: if self.file.startswith("file:///"): return self.file[8:] if self.file.startswith("http"): - file_path = await download_image_by_url(self.file) + file_path = await download_audio_by_url(self.file) return os.path.abspath(file_path) if self.file.startswith("base64://"): bs64_data = self.file.removeprefix("base64://") - image_bytes = base64.b64decode(bs64_data) - file_path = os.path.join( - get_astrbot_temp_path(), f"recordseg_{uuid.uuid4()}.jpg" - ) - with open(file_path, "wb") as f: - f.write(image_bytes) - return os.path.abspath(file_path) + audio_bytes = base64.b64decode(bs64_data) + return os.path.abspath(save_temp_audio(audio_bytes)) if os.path.exists(self.file): return os.path.abspath(self.file) raise Exception(f"not a valid file: {self.file}") @@ -185,7 +186,7 @@ async def convert_to_base64(self) -> str: if self.file.startswith("file:///"): bs64_data = file_to_base64(self.file[8:]) elif self.file.startswith("http"): - file_path = await download_image_by_url(self.file) + file_path = await download_audio_by_url(self.file) bs64_data = file_to_base64(file_path) elif self.file.startswith("base64://"): bs64_data = self.file diff --git a/astrbot/core/utils/io.py b/astrbot/core/utils/io.py index b565926749..ca95aad9c6 100644 --- a/astrbot/core/utils/io.py +++ b/astrbot/core/utils/io.py @@ -19,6 +19,46 @@ logger = logging.getLogger("astrbot") +def save_temp_audio(audio_data: bytes) -> str: + """Save audio data to a temporary file with a proper extension.""" + temp_dir = get_astrbot_temp_path() + timestamp = f"{int(time.time())}_{uuid.uuid4().hex[:8]}" + p = os.path.join(temp_dir, f"recordseg_{timestamp}.audio") + with open(p, "wb") as f: + f.write(audio_data) + return p + + +async def download_audio_by_url(url: str) -> str: + """Download audio from URL. Returns local file path.""" + try: + ssl_context = ssl.create_default_context(cafile=certifi.where()) + connector = aiohttp.TCPConnector(ssl=ssl_context) + async with aiohttp.ClientSession( + trust_env=True, + connector=connector, + ) as session: + async with session.get(url) as resp: + resp.raise_for_status() + data = await resp.read() + return save_temp_audio(data) + except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError): + logger.warning( + f"SSL certificate verification failed for {url}. " + "Disabling SSL verification (CERT_NONE) as a fallback. " + "This is insecure and exposes the application to man-in-the-middle attacks. " + "Please investigate and resolve certificate issues." + ) + ssl_context = ssl.create_default_context() + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + async with aiohttp.ClientSession() as session: + async with session.get(url, ssl=ssl_context) as resp: + resp.raise_for_status() + data = await resp.read() + return save_temp_audio(data) + + def on_error(func, path, exc_info) -> None: """A callback of the rmtree function.""" import stat