Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions mellea/stdlib/frameworks/react.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
history tracking. Raises ``RuntimeError`` if the loop ends without a final answer.
"""

import pydantic

# from PIL import Image as PILImage
from mellea.backends.model_options import ModelOption
from mellea.core.backend import Backend, BaseModelSubclass
Expand All @@ -24,6 +26,14 @@
from mellea.stdlib.context import ChatContext


class TrueOrFalse(pydantic.BaseModel):
"""Response indicating whether the ReACT agent has completed its task."""

answer: bool = pydantic.Field(
description="True if you have enough information to answer the user's question, False if you need more tool calls"
)


async def react(
goal: str,
context: ChatContext,
Expand Down Expand Up @@ -105,9 +115,43 @@ async def react(
if tool_res.name == MELLEA_FINALIZER_TOOL:
is_final = True

# Check for special case where model already has the answer, but it won't call the finalize tool.
# Instead of letting this run out of iterations and fail, let's ask.
# Only do this before we fail on iteration limit as a last resort because it's hard to justify doing it earlier for now.
elif -1 < loop_budget <= turn_num and step.value:
# If the turn number has reached the end of loop budget (and budget is not unlimited),
# then it's time to check if the model is just loopy and already has the answer.
print("### Done Check")
print("STEP_TOOL_CALLS:", step.tool_calls)
print("STEP:", step)
print("CONTEXT:", context)
content = mfuncs.chat(
content=f"Do you know the answer to the user's original query ({goal})? If so, respond with True. If you need to take more actions, then respond False.",
context=context,
backend=backend,
format=TrueOrFalse,
)[0].content
have_answer = TrueOrFalse.model_validate_json(content).answer

print("### Done Check ANSWER: ", have_answer)
if have_answer:
# Create a synthetic finalizer tool response to be consistent with normal loop
finalizer_response = ToolMessage(
role="tool",
content=step.value,
tool_output=step.value,
name=MELLEA_FINALIZER_TOOL,
args={},
tool=None, # type: ignore
)
tool_responses = [finalizer_response]
context = context.add(finalizer_response)
is_final = True

if is_final:
assert len(tool_responses) == 1, "multiple tools were called with 'final'"

# Apply format if requested
if format is not None:
step, next_context = await mfuncs.aact(
action=ReactThought(),
Expand Down
75 changes: 75 additions & 0 deletions test/stdlib/test_react_direct_answer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""Test ReACT framework handling of direct answers without tool calls."""

import pytest

from mellea.backends.tools import tool
from mellea.stdlib.context import ChatContext
from mellea.stdlib.frameworks.react import react
from mellea.stdlib.session import start_session


@pytest.mark.ollama
@pytest.mark.llm
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@pytest.mark.llm
@pytest.mark.e2e

We updated our markers (llm->e2e).

async def test_react_direct_answer_without_tools():
"""Test that ReACT handles direct answers when model doesn't call tools.

This tests the case where the model provides a direct answer in step.value
without making any tool calls. The fix ensures the loop terminates properly
instead of continuing until loop_budget is exhausted.
"""
m = start_session()

# Ask a simple question that doesn't require tools
# The model should provide a direct answer without calling any tools
out, _ = await react(
goal="What is 2 + 2?",
context=ChatContext(),
backend=m.backend,
tools=[], # No tools provided
loop_budget=3, # Should complete in 1 iteration, not exhaust budget
)

# Verify we got an answer
assert out.value is not None
assert len(out.value) > 0

# The answer should contain "4" or "four"
answer_lower = out.value.lower()
assert "4" in answer_lower or "four" in answer_lower


@pytest.mark.ollama
@pytest.mark.llm
async def test_react_direct_answer_with_unused_tools():
"""Test that ReACT handles direct answers even when tools are available.

This tests the case where tools are provided but the model chooses to
answer directly without using them.
"""
m = start_session()

# Create a dummy tool that won't be needed
@tool
def search_web(query: str) -> str:
"""Search the web for information."""
return "Search results"

# Ask a question that doesn't need the tool
out, _ = await react(
goal="What is the capital of France?",
context=ChatContext(),
backend=m.backend,
tools=[search_web],
loop_budget=3,
)

# Verify we got an answer
assert out.value is not None
assert len(out.value) > 0

# The answer should mention Paris
answer_lower = out.value.lower()
assert "paris" in answer_lower


# Made with Bob
Loading