Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions astrbot/core/message/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,13 @@

from astrbot.core import astrbot_config, file_token_service, logger
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64
from astrbot.core.utils.io import ( # noqa: I001
download_audio_by_url,
download_file,
download_image_by_url,
file_to_base64,
save_temp_audio,
)


class ComponentType(str, Enum):
Expand Down Expand Up @@ -157,17 +163,12 @@ async def convert_to_file_path(self) -> str:
if self.file.startswith("file:///"):
return self.file[8:]
if self.file.startswith("http"):
file_path = await download_image_by_url(self.file)
file_path = await download_audio_by_url(self.file)
return os.path.abspath(file_path)
if self.file.startswith("base64://"):
bs64_data = self.file.removeprefix("base64://")
image_bytes = base64.b64decode(bs64_data)
file_path = os.path.join(
get_astrbot_temp_path(), f"recordseg_{uuid.uuid4()}.jpg"
)
with open(file_path, "wb") as f:
f.write(image_bytes)
return os.path.abspath(file_path)
audio_bytes = base64.b64decode(bs64_data)
return os.path.abspath(save_temp_audio(audio_bytes))
if os.path.exists(self.file):
return os.path.abspath(self.file)
raise Exception(f"not a valid file: {self.file}")
Expand All @@ -185,7 +186,7 @@ async def convert_to_base64(self) -> str:
if self.file.startswith("file:///"):
bs64_data = file_to_base64(self.file[8:])
elif self.file.startswith("http"):
file_path = await download_image_by_url(self.file)
file_path = await download_audio_by_url(self.file)
bs64_data = file_to_base64(file_path)
elif self.file.startswith("base64://"):
bs64_data = self.file
Expand Down
40 changes: 40 additions & 0 deletions astrbot/core/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,46 @@
logger = logging.getLogger("astrbot")


def save_temp_audio(audio_data: bytes) -> str:
"""Save audio data to a temporary file with a proper extension."""
temp_dir = get_astrbot_temp_path()
timestamp = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
p = os.path.join(temp_dir, f"recordseg_{timestamp}.audio")
with open(p, "wb") as f:
f.write(audio_data)
return p


async def download_audio_by_url(url: str) -> str:
"""Download audio from URL. Returns local file path."""
try:
ssl_context = ssl.create_default_context(cafile=certifi.where())
connector = aiohttp.TCPConnector(ssl=ssl_context)
async with aiohttp.ClientSession(
trust_env=True,
connector=connector,
) as session:
async with session.get(url) as resp:
resp.raise_for_status()
data = await resp.read()
return save_temp_audio(data)
except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError):
logger.warning(
f"SSL certificate verification failed for {url}. "
"Disabling SSL verification (CERT_NONE) as a fallback. "
"This is insecure and exposes the application to man-in-the-middle attacks. "
"Please investigate and resolve certificate issues."
)
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
async with aiohttp.ClientSession() as session:
async with session.get(url, ssl=ssl_context) as resp:
resp.raise_for_status()
data = await resp.read()
return save_temp_audio(data)


def on_error(func, path, exc_info) -> None:
"""A callback of the rmtree function."""
import stat
Expand Down
Loading