-
Notifications
You must be signed in to change notification settings - Fork 383
Improve handling AIMessage.content
#741
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -77,7 +77,9 @@ | |
| AgentResponse, | ||
| AIMessage, | ||
| BaseMessage, | ||
| ContentBlock, | ||
| HumanMessage, | ||
| OpaqueBlock, | ||
| OutputT, | ||
| StructuredOutputCall, | ||
| StructuredOutputMessage, | ||
|
|
@@ -87,6 +89,7 @@ | |
| SubagentStructuredResult, | ||
| SubagentTextResult, | ||
| SystemMessage, | ||
| TextBlock, | ||
| ToolCall, | ||
| ToolFailureResult, | ||
| ToolMessage, | ||
|
|
@@ -951,7 +954,7 @@ async def awrap_tool_call( | |
| return LC_ToolMessage( | ||
| name=_normalize_agent_name(call.name), | ||
| tool_call_id=call.id, | ||
| content=content, | ||
| content=_map_content_to_langchain(content), | ||
| status=status, | ||
| artifact=sdk_result, | ||
| ) | ||
|
|
@@ -1085,7 +1088,10 @@ def _convert_model_response_to_model_result( | |
| # This invariant is asserted via ModelResponse.__post_init__ | ||
| assert len(resp.message.structured_output_calls) <= 1 | ||
|
|
||
| lc_message = LC_AIMessage(content=resp.message.content) | ||
| lc_message = LC_AIMessage( | ||
| content=_map_content_to_langchain(resp.message.content), | ||
| additional_kwargs=resp.message.extras or {}, | ||
| ) | ||
| # This field can't be set via __init__() | ||
| lc_message.tool_calls = [_map_tool_call_to_langchain(c) for c in resp.message.calls] | ||
|
|
||
|
|
@@ -1160,7 +1166,7 @@ def _convert_tool_message_to_lc( | |
| name=name, | ||
| tool_call_id=message.call_id, | ||
| status=status, | ||
| content=content, | ||
| content=_map_content_to_langchain(content), | ||
| artifact=artifact, | ||
| ) | ||
|
|
||
|
|
@@ -1243,9 +1249,10 @@ def _convert_model_result_from_lc(model_response: LC_ModelCallResult) -> ModelRe | |
| ai_message = model_response | ||
| structured_response = None | ||
|
|
||
| additional_kwargs = cast(dict[str, Any], ai_message.additional_kwargs) | ||
| return ModelResponse( | ||
| message=AIMessage( | ||
| content=ai_message.content.__str__(), | ||
| content=_map_content_from_langchain(ai_message.content), # pyright: ignore[reportUnknownArgumentType] | ||
| calls=[ | ||
| _map_tool_call_from_langchain(tc) | ||
| for tc in ai_message.tool_calls | ||
|
|
@@ -1260,6 +1267,7 @@ def _convert_model_result_from_lc(model_response: LC_ModelCallResult) -> ModelRe | |
| for tc in ai_message.tool_calls | ||
| if tc["name"].startswith(TOOL_STRATEGY_TOOL_PREFIX) | ||
| ], | ||
| extras=additional_kwargs, | ||
| ), | ||
| structured_output=structured_response, | ||
| ) | ||
|
|
@@ -1422,6 +1430,28 @@ def _is_agent_name_valid(name: str) -> bool: | |
| return set(name).issubset(AGENT_NAME_ALLOWED_CHARS) | ||
|
|
||
|
|
||
| def _parse_content_block(block: str | ContentBlock) -> str | None: | ||
| match block: | ||
| case TextBlock(): | ||
| return block.text | ||
| case str(): | ||
| return block | ||
| case _: | ||
| return None | ||
|
|
||
|
|
||
| def _parse_content(content: str | list[str | ContentBlock]) -> str: | ||
| """Parses the content from AIMessage and builds a single string our of it""" | ||
| if isinstance(content, str): | ||
| return content | ||
|
|
||
| return " ".join( | ||
| parsed_block | ||
| for block in content | ||
| if (parsed_block := _parse_content_block(block)) | ||
| ) | ||
|
|
||
|
|
||
| def _agent_as_tool(agent: BaseAgent[OutputT]) -> StructuredTool: | ||
| if not agent.name: | ||
| raise AssertionError("Agent must have a name to be used by other Agents") | ||
|
|
@@ -1433,7 +1463,10 @@ def _agent_as_tool(agent: BaseAgent[OutputT]) -> StructuredTool: | |
|
|
||
| async def invoke_agent( | ||
| message: HumanMessage, thread_id: str | None | ||
| ) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]: | ||
| ) -> tuple[ | ||
| OutputT | str | list[str | ContentBlock], | ||
| SubagentStructuredResult | SubagentTextResult, | ||
| ]: | ||
| result = await agent.invoke([message], thread_id=thread_id) | ||
|
|
||
| if agent.output_schema: | ||
|
|
@@ -1443,7 +1476,7 @@ async def invoke_agent( | |
| ) | ||
|
|
||
| return result.final_message.content, SubagentTextResult( | ||
| content=result.final_message.content | ||
| content=_parse_content(result.final_message.content) | ||
| ) | ||
|
|
||
| InputSchema = agent.input_schema | ||
|
|
@@ -1452,13 +1485,19 @@ async def invoke_agent( | |
|
|
||
| async def _run( # pyright: ignore[reportRedeclaration] | ||
| content: str, thread_id: str | ||
| ) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]: | ||
| ) -> tuple[ | ||
| OutputT | str | list[str | ContentBlock], | ||
| SubagentStructuredResult | SubagentTextResult, | ||
| ]: | ||
| return await invoke_agent(HumanMessage(content=content), thread_id) | ||
| else: | ||
|
|
||
| async def _run( # pyright: ignore[reportRedeclaration] | ||
| content: str, | ||
| ) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]: | ||
| ) -> tuple[ | ||
| OutputT | str | list[str | ContentBlock], | ||
| SubagentStructuredResult | SubagentTextResult, | ||
| ]: | ||
| return await invoke_agent(HumanMessage(content=content), None) | ||
|
|
||
| return StructuredTool.from_function( | ||
|
|
@@ -1471,7 +1510,10 @@ async def _run( # pyright: ignore[reportRedeclaration] | |
|
|
||
| async def invoke_agent_structured( | ||
| content: BaseModel, thread_id: str | None | ||
| ) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]: | ||
| ) -> tuple[ | ||
| OutputT | str | list[str | ContentBlock], | ||
| SubagentStructuredResult | SubagentTextResult, | ||
| ]: | ||
| result = await agent.invoke_with_data( | ||
| instructions="Follow the system prompt.", | ||
| data=content.model_dump(), | ||
|
|
@@ -1485,14 +1527,17 @@ async def invoke_agent_structured( | |
| ) | ||
|
|
||
| return result.final_message.content, SubagentTextResult( | ||
| content=result.final_message.content | ||
| content=_parse_content(result.final_message.content) | ||
| ) | ||
|
|
||
| if agent.conversation_store: | ||
|
|
||
| async def _run( | ||
| **kwargs: Any, # noqa: ANN401 | ||
| ) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]: | ||
| ) -> tuple[ | ||
| OutputT | str | list[str | ContentBlock], | ||
| SubagentStructuredResult | SubagentTextResult, | ||
| ]: | ||
| content: BaseModel = kwargs["content"] | ||
| thread_id: str = kwargs["thread_id"] | ||
| return await invoke_agent_structured(content, thread_id) | ||
|
|
@@ -1512,7 +1557,10 @@ async def _run( | |
|
|
||
| async def _run( | ||
| **kwargs: Any, # noqa: ANN401 | ||
| ) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]: | ||
| ) -> tuple[ | ||
| OutputT | str | list[str | ContentBlock], | ||
| SubagentStructuredResult | SubagentTextResult, | ||
| ]: | ||
| content = InputSchema(**kwargs) | ||
| return await invoke_agent_structured(content, None) | ||
|
|
||
|
|
@@ -1564,11 +1612,66 @@ def _map_tool_call_to_langchain(call: ToolCall | SubagentCall) -> LC_ToolCall: | |
| return LC_ToolCall(id=call.id, name=name, args=args) | ||
|
|
||
|
|
||
| def _map_content_from_langchain( | ||
| content: str | list[str | dict[str, Any]], | ||
| ) -> str | list[str | ContentBlock]: | ||
| if isinstance(content, str): | ||
| return content | ||
|
|
||
| result_content = [_map_content_block_from_langchain(b) for b in content] | ||
|
|
||
| return result_content | ||
|
|
||
|
|
||
| def _map_content_block_from_langchain( | ||
| block: str | dict[str, Any], | ||
| ) -> str | ContentBlock: | ||
| if isinstance(block, str): | ||
| return block | ||
|
|
||
| match block.get("type"): | ||
| case "text": | ||
| return TextBlock( | ||
| text=block["text"], | ||
| extras=block.get("extras"), | ||
| ) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if providers add different fields here, we will ignore them?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sounds like this is not provider specific but LC specific, so all of the "extra" properties should be stored here. The |
||
| case _: | ||
| # NOTE: we return data we're not handling | ||
| # as opaque content blocks so they | ||
| # are preserved and sent back to the LLM | ||
| return OpaqueBlock(data=block) | ||
|
|
||
|
|
||
| def _map_content_to_langchain( | ||
| content: str | list[str | ContentBlock], | ||
| ) -> str | list[str | dict[str, Any]]: | ||
| if isinstance(content, str): | ||
| return content | ||
|
|
||
| result_content = [_map_content_block_to_langchain(b) for b in content] | ||
|
|
||
| return result_content | ||
|
|
||
|
|
||
| def _map_content_block_to_langchain(block: str | ContentBlock) -> str | dict[str, Any]: | ||
| if isinstance(block, str): | ||
| return block | ||
|
|
||
| match block: | ||
| case TextBlock(): | ||
| result: dict[str, Any] = {"type": "text", "text": block.text} | ||
| if block.extras: | ||
| result["extras"] = block.extras | ||
| return result | ||
| case OpaqueBlock(): | ||
| return block.data | ||
|
|
||
|
|
||
| def _map_message_from_langchain(message: LC_BaseMessage) -> BaseMessage: | ||
| match message: | ||
| case LC_AIMessage(): | ||
| return AIMessage( | ||
| content=message.content.__str__(), | ||
| content=_map_content_from_langchain(message.content), # pyright: ignore[reportUnknownArgumentType] | ||
| calls=[ | ||
| _map_tool_call_from_langchain(tc) | ||
| for tc in message.tool_calls | ||
|
|
@@ -1583,6 +1686,7 @@ def _map_message_from_langchain(message: LC_BaseMessage) -> BaseMessage: | |
| for tc in message.tool_calls | ||
| if tc["name"].startswith(TOOL_STRATEGY_TOOL_PREFIX) | ||
| ], | ||
| extras=cast(dict[str, Any], message.additional_kwargs), | ||
| ) | ||
| case LC_HumanMessage(): | ||
| return HumanMessage(content=message.content.__str__()) | ||
|
|
@@ -1597,7 +1701,10 @@ def _map_message_from_langchain(message: LC_BaseMessage) -> BaseMessage: | |
| def _map_message_to_langchain(message: BaseMessage) -> LC_AnyMessage: | ||
| match message: | ||
| case AIMessage(): | ||
| lc_message = LC_AIMessage(content=message.content) | ||
| lc_message = LC_AIMessage( | ||
| content=_map_content_to_langchain(message.content), | ||
| additional_kwargs=message.extras or {}, | ||
| ) | ||
| # This field can't be set via constructor | ||
| lc_message.tool_calls = [ | ||
| _map_tool_call_to_langchain(c) for c in message.calls | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,31 @@ | |
| from splunklib.ai.tools import ToolType | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class TextBlock: | ||
| """Plain text content block returned by a model.""" | ||
|
|
||
| text: str | ||
| # TODO: should we have the id here as well? | ||
| # Provider-specific extras (e.g. Gemini thought signature on text blocks). | ||
| extras: dict[str, Any] | None = field(default=None) | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class OpaqueBlock: | ||
| """Content block of an unrecognized or unsupported type. | ||
|
|
||
| The raw provider dict is preserved in `data` so it can be sent back | ||
| to the model unchanged on subsequent calls. | ||
| """ | ||
|
|
||
| data: dict[str, Any] | ||
|
|
||
|
|
||
| # Type alias for all content block variants. | ||
| ContentBlock = TextBlock | OpaqueBlock | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class ToolCall: | ||
| name: str | ||
|
|
@@ -85,12 +110,15 @@ class AIMessage(BaseMessage): | |
| """ | ||
|
|
||
| role: Literal["assistant"] = field(default="assistant", init=False) | ||
| content: str | ||
| content: str | list[str | ContentBlock] | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if we made this a: content: str | list[str | dict]or content: str | list[str | OpaqueBlock]just as LC does? and provide a or instead of
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I imagined the solution to be more explicit, so that the SDK controls the type of content blocks that are "supported", leaving the OpaqueBlock strictly for preserving the data between the SDK and langchain layer, so that the LLMs receive all the context in the next call. This makes them explicitly ignore the OpaqueBlock and if any new model returns some new content block, this would fall into the OpaqueBlock, meaning the already existing code shouldn't break once this happens. Of course someone could use the OpaqueBlock if they wanted to and create some logic around it's contents, but they should expect this code to be changed in case SDK adds first class support to this specific block. And I don't see the issue with breaking such code in the future, since it wasn't officially supported. I don't think we can future proof everything, at least without creating some other potential issues. By making the |
||
|
|
||
| calls: Sequence[ToolCall | SubagentCall] | ||
| structured_output_calls: Sequence[StructuredOutputCall] = field( | ||
| default_factory=tuple | ||
| ) | ||
| # Backend-specific metadata (e.g. provider additional_kwargs) not | ||
| # representable in the standard fields. Opaque to callers. | ||
| extras: dict[str, Any] | None = field(default=None) | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am uncertain here, shouldn't we extract all the text contents and only return these?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we just forward the return of inner
invoke_agent, what's wrong here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My point is that we should only forward the text response as an output of an subagent, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there are others i don't think we should forward them??
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would you want to filter them out or assert if any other block is OpaqueBlock?