Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions astrbot/core/astr_agent_run_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
)
from astrbot.core.provider.entities import LLMResponse
from astrbot.core.provider.provider import TTSProvider
from astrbot.core.repeat_reply_guard import (
DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD,
normalize_repeat_reply_guard_threshold,
)

AgentRunner = ToolLoopAgentRunner[AstrAgentContext]

Expand Down Expand Up @@ -87,17 +91,30 @@ def _build_tool_result_status_message(
return status_msg


def _build_chain_signature(msg_chain: MessageChain) -> str:
signature = msg_chain.get_plain_text(with_other_comps_mark=True).strip()
if not signature:
return ""
return re.sub(r"\s+", " ", signature)


async def run_agent(
agent_runner: AgentRunner,
max_step: int = 30,
show_tool_use: bool = True,
show_tool_call_result: bool = False,
stream_to_general: bool = False,
show_reasoning: bool = False,
repeat_reply_guard_threshold: int = DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD,
) -> AsyncGenerator[MessageChain | None, None]:
step_idx = 0
astr_event = agent_runner.run_context.context.event
tool_name_by_call_id: dict[str, str] = {}
guard_threshold = normalize_repeat_reply_guard_threshold(
repeat_reply_guard_threshold
)
guard_last_signature = ""
guard_repeat_count = 0
while step_idx < max_step + 1:
step_idx += 1

Expand Down Expand Up @@ -193,6 +210,38 @@ async def run_agent(
await astr_event.send(chain)
continue

if resp.type == "llm_result" and guard_threshold > 0:
chain_signature = _build_chain_signature(resp.data["chain"])
if chain_signature:
if chain_signature == guard_last_signature:
guard_repeat_count += 1
else:
guard_last_signature = chain_signature
guard_repeat_count = 1

if guard_repeat_count >= guard_threshold:
logger.warning(
"Agent repeated identical llm_result %d times; forcing convergence. threshold=%d",
guard_repeat_count,
guard_threshold,
)
if not agent_runner.done():
if agent_runner.req:
agent_runner.req.func_tool = None
agent_runner.run_context.messages.append(
Message(
role="user",
content=(
"检测到你连续多次输出相同回复。"
"请停止重复,基于已有信息给出最终答复,"
"不要再次调用工具。"
),
)
)
# Jump to the same convergence path as max-step limit.
step_idx = max_step
continue

if stream_to_general and resp.type == "streaming_delta":
continue

Expand Down Expand Up @@ -288,6 +337,7 @@ async def run_live_agent(
show_tool_use: bool = True,
show_tool_call_result: bool = False,
show_reasoning: bool = False,
repeat_reply_guard_threshold: int = DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD,
) -> AsyncGenerator[MessageChain | None, None]:
"""Live Mode 的 Agent 运行器,支持流式 TTS

Expand All @@ -311,6 +361,7 @@ async def run_live_agent(
show_tool_call_result=show_tool_call_result,
stream_to_general=False,
show_reasoning=show_reasoning,
repeat_reply_guard_threshold=repeat_reply_guard_threshold,
):
yield chain
return
Expand Down Expand Up @@ -343,6 +394,7 @@ async def run_live_agent(
show_tool_use,
show_tool_call_result,
show_reasoning,
repeat_reply_guard_threshold,
)
)

Expand Down Expand Up @@ -430,6 +482,7 @@ async def _run_agent_feeder(
show_tool_use: bool,
show_tool_call_result: bool,
show_reasoning: bool,
repeat_reply_guard_threshold: int,
) -> None:
"""运行 Agent 并将文本输出分句放入队列"""
buffer = ""
Expand All @@ -441,6 +494,7 @@ async def _run_agent_feeder(
show_tool_call_result=show_tool_call_result,
stream_to_general=False,
show_reasoning=show_reasoning,
repeat_reply_guard_threshold=repeat_reply_guard_threshold,
):
if chain is None:
continue
Expand Down
13 changes: 13 additions & 0 deletions astrbot/core/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
from typing import Any, TypedDict

from astrbot.core.repeat_reply_guard import DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD
from astrbot.core.utils.astrbot_path import get_astrbot_data_path

VERSION = "4.22.0"
Expand Down Expand Up @@ -149,6 +150,7 @@
"unsupported_streaming_strategy": "realtime_segmenting",
"reachability_check": False,
"max_agent_step": 30,
"repeat_reply_guard_threshold": DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD,
"tool_call_timeout": 120,
"tool_schema_mode": "full",
"llm_safety_mode": True,
Expand Down Expand Up @@ -2685,6 +2687,9 @@ class ChatProviderTemplate(TypedDict):
"max_agent_step": {
"type": "int",
},
"repeat_reply_guard_threshold": {
"type": "int",
},
"tool_call_timeout": {
"type": "int",
},
Expand Down Expand Up @@ -3430,6 +3435,14 @@ class ChatProviderTemplate(TypedDict):
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.repeat_reply_guard_threshold": {
"description": "连续相同回复拦截阈值",
"type": "int",
"hint": "同一轮 Agent 运行中连续出现相同回复达到该次数时,将触发防循环收敛。设置为 0 可关闭。",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.tool_call_timeout": {
"description": "工具调用超时时间(秒)",
"type": "int",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
LLMResponse,
ProviderRequest,
)
from astrbot.core.repeat_reply_guard import (
DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD,
normalize_config_repeat_reply_guard_threshold,
)
from astrbot.core.star.star_handler import EventType
from astrbot.core.utils.metrics import Metric
from astrbot.core.utils.session_lock import session_lock_manager
Expand Down Expand Up @@ -64,6 +68,15 @@ async def initialize(self, ctx: PipelineContext) -> None:
self.tool_schema_mode = "full"
if isinstance(self.max_step, bool): # workaround: #2622
self.max_step = 30
self.repeat_reply_guard_threshold: int = settings.get(
"repeat_reply_guard_threshold",
DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD,
)
self.repeat_reply_guard_threshold = (
normalize_config_repeat_reply_guard_threshold(
self.repeat_reply_guard_threshold
)
)
self.show_tool_use: bool = settings.get("show_tool_use_status", True)
self.show_tool_call_result: bool = settings.get("show_tool_call_result", False)
self.show_reasoning = settings.get("display_reasoning_text", False)
Expand Down Expand Up @@ -274,6 +287,7 @@ async def process(
self.show_tool_use,
self.show_tool_call_result,
show_reasoning=self.show_reasoning,
repeat_reply_guard_threshold=self.repeat_reply_guard_threshold,
),
),
)
Expand Down Expand Up @@ -304,6 +318,7 @@ async def process(
self.show_tool_use,
self.show_tool_call_result,
show_reasoning=self.show_reasoning,
repeat_reply_guard_threshold=self.repeat_reply_guard_threshold,
),
),
)
Expand Down Expand Up @@ -334,6 +349,7 @@ async def process(
self.show_tool_call_result,
stream_to_general,
show_reasoning=self.show_reasoning,
repeat_reply_guard_threshold=self.repeat_reply_guard_threshold,
):
yield

Expand Down
18 changes: 18 additions & 0 deletions astrbot/core/repeat_reply_guard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD = 3


def normalize_repeat_reply_guard_threshold(value, *, invalid_fallback: int = 0) -> int:
if isinstance(value, bool):
return invalid_fallback
try:
parsed = int(value)
except (TypeError, ValueError):
return invalid_fallback
return max(0, parsed)


def normalize_config_repeat_reply_guard_threshold(value) -> int:
return normalize_repeat_reply_guard_threshold(
value,
invalid_fallback=DEFAULT_REPEAT_REPLY_GUARD_THRESHOLD,
)
148 changes: 148 additions & 0 deletions tests/unit/test_astr_agent_run_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
from types import SimpleNamespace

import pytest

from astrbot.core.astr_agent_run_util import run_agent
from astrbot.core.message.message_event_result import MessageChain


def _llm_result_response(text: str):
return SimpleNamespace(
type="llm_result",
data={"chain": MessageChain().message(text)},
)


class _DummyTrace:
def record(self, *args, **kwargs) -> None:
return None


class _DummyEvent:
def __init__(self) -> None:
self._extras: dict = {}
self._stopped = False
self.result_texts: list[str] = []
self.trace = _DummyTrace()

def is_stopped(self) -> bool:
return self._stopped

def get_extra(self, key: str, default=None):
return self._extras.get(key, default)

def set_extra(self, key: str, value) -> None:
self._extras[key] = value

def set_result(self, result) -> None:
self.result_texts.append(result.get_plain_text(with_other_comps_mark=True))

def clear_result(self) -> None:
return None

def get_platform_name(self) -> str:
return "slack"

def get_platform_id(self) -> str:
return "slack"

async def send(self, _msg_chain) -> None:
return None


class _FakeRunner:
def __init__(self, steps: list[list[SimpleNamespace]]) -> None:
self._steps = steps
self._step_idx = 0
self._done = False
self.streaming = False
self.req = SimpleNamespace(func_tool=object())
self.run_context = SimpleNamespace(
context=SimpleNamespace(event=_DummyEvent()),
messages=[],
)
self.stats = SimpleNamespace(to_dict=lambda: {})

def done(self) -> bool:
return self._done

def request_stop(self) -> None:
self.run_context.context.event.set_extra("agent_stop_requested", True)

def was_aborted(self) -> bool:
return False

async def step(self):
if self._step_idx >= len(self._steps):
self._done = True
return

current = self._steps[self._step_idx]
self._step_idx += 1
for resp in current:
yield resp

if self._step_idx >= len(self._steps):
self._done = True


@pytest.mark.asyncio
async def test_repeat_reply_guard_forces_convergence():
runner = _FakeRunner(
[
[_llm_result_response("重复输出")],
[_llm_result_response("重复输出")],
[_llm_result_response("重复输出")],
[_llm_result_response("最终答案")],
]
)

async for _ in run_agent(
runner,
max_step=8,
show_tool_use=False,
show_tool_call_result=False,
repeat_reply_guard_threshold=3,
):
pass

assert runner.run_context.context.event.result_texts == [
"重复输出",
"重复输出",
"最终答案",
]
assert runner.req.func_tool is None
assert any(
msg.role == "user" and "检测到你连续多次输出相同回复" in str(msg.content)
for msg in runner.run_context.messages
)


@pytest.mark.asyncio
async def test_repeat_reply_guard_can_be_disabled_with_zero_threshold():
runner = _FakeRunner(
[
[_llm_result_response("重复输出")],
[_llm_result_response("重复输出")],
[_llm_result_response("重复输出")],
[_llm_result_response("最终答案")],
]
)
original_func_tool = runner.req.func_tool

async for _ in run_agent(
runner,
max_step=8,
show_tool_use=False,
show_tool_call_result=False,
repeat_reply_guard_threshold=0,
):
pass

assert runner.run_context.context.event.result_texts == [
"重复输出",
"重复输出",
"重复输出",
"最终答案",
]
assert runner.req.func_tool is original_func_tool
Loading