Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions examples/ai_modinput_app/bin/agentic_weather.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from _collections_abc import dict_items
from typing import final, override

from splunklib.ai.messages import AIMessage, ContentBlock, TextBlock

# ! NOTE: This insert is only needed for splunk-sdk-python CI/CD to work.
# ! Remove this if you're modifying this example locally.
sys.path.insert(0, "/splunklib-deps")
Expand Down Expand Up @@ -95,9 +97,9 @@ def stream_events(self, inputs: InputDefinition, ew: EventWriter) -> None:
weather_events += list(reader)

for weather_event in weather_events:
weather_event["human_readable"] = asyncio.run(
self.invoke_agent(weather_event)
)
result = asyncio.run(self.invoke_agent(weather_event))
weather_event["human_readable"] = self.parse_content(result)

logger.debug(f"{weather_event=}")

event = Event(
Expand All @@ -112,7 +114,7 @@ def stream_events(self, inputs: InputDefinition, ew: EventWriter) -> None:

logger.debug(f"Finishing enrichment for {input_name} at {csv_file_path}")

async def invoke_agent(self, weather_event: dict[str, str | int]) -> str:
async def invoke_agent(self, weather_event: dict[str, str | int]) -> AIMessage:
if not self.service:
raise AssertionError("No Splunk connection available")

Expand All @@ -127,7 +129,27 @@ async def invoke_agent(self, weather_event: dict[str, str | int]) -> str:
data=weather_event,
)
logger.debug(f"{response=}")
return response.final_message.content
return response.final_message

def _parse_content_block(self, block: str | ContentBlock) -> str | None:
match block:
case TextBlock():
return block.text
case str():
return block
case _:
return None

def parse_content(self, message: AIMessage) -> str:
"""Parses the content from AIMessage and builds a single string our of it"""
if isinstance(message.content, str):
return message.content

return " ".join(
parsed_block
for block in message.content
if (parsed_block := self._parse_content_block(block))
)


if __name__ == "__main__":
Expand Down
135 changes: 121 additions & 14 deletions splunklib/ai/engines/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@
AgentResponse,
AIMessage,
BaseMessage,
ContentBlock,
HumanMessage,
OpaqueBlock,
OutputT,
StructuredOutputCall,
StructuredOutputMessage,
Expand All @@ -87,6 +89,7 @@
SubagentStructuredResult,
SubagentTextResult,
SystemMessage,
TextBlock,
ToolCall,
ToolFailureResult,
ToolMessage,
Expand Down Expand Up @@ -951,7 +954,7 @@ async def awrap_tool_call(
return LC_ToolMessage(
name=_normalize_agent_name(call.name),
tool_call_id=call.id,
content=content,
content=_map_content_to_langchain(content),
status=status,
artifact=sdk_result,
)
Expand Down Expand Up @@ -1085,7 +1088,10 @@ def _convert_model_response_to_model_result(
# This invariant is asserted via ModelResponse.__post_init__
assert len(resp.message.structured_output_calls) <= 1

lc_message = LC_AIMessage(content=resp.message.content)
lc_message = LC_AIMessage(
content=_map_content_to_langchain(resp.message.content),
additional_kwargs=resp.message.extras or {},
)
# This field can't be set via __init__()
lc_message.tool_calls = [_map_tool_call_to_langchain(c) for c in resp.message.calls]

Expand Down Expand Up @@ -1160,7 +1166,7 @@ def _convert_tool_message_to_lc(
name=name,
tool_call_id=message.call_id,
status=status,
content=content,
content=_map_content_to_langchain(content),
artifact=artifact,
)

Expand Down Expand Up @@ -1243,9 +1249,10 @@ def _convert_model_result_from_lc(model_response: LC_ModelCallResult) -> ModelRe
ai_message = model_response
structured_response = None

additional_kwargs = cast(dict[str, Any], ai_message.additional_kwargs)
return ModelResponse(
message=AIMessage(
content=ai_message.content.__str__(),
content=_map_content_from_langchain(ai_message.content), # pyright: ignore[reportUnknownArgumentType]
calls=[
_map_tool_call_from_langchain(tc)
for tc in ai_message.tool_calls
Expand All @@ -1260,6 +1267,7 @@ def _convert_model_result_from_lc(model_response: LC_ModelCallResult) -> ModelRe
for tc in ai_message.tool_calls
if tc["name"].startswith(TOOL_STRATEGY_TOOL_PREFIX)
],
extras=additional_kwargs,
),
structured_output=structured_response,
)
Expand Down Expand Up @@ -1422,6 +1430,28 @@ def _is_agent_name_valid(name: str) -> bool:
return set(name).issubset(AGENT_NAME_ALLOWED_CHARS)


def _parse_content_block(block: str | ContentBlock) -> str | None:
match block:
case TextBlock():
return block.text
case str():
return block
case _:
return None


def _parse_content(content: str | list[str | ContentBlock]) -> str:
"""Parses the content from AIMessage and builds a single string our of it"""
if isinstance(content, str):
return content

return " ".join(
parsed_block
for block in content
if (parsed_block := _parse_content_block(block))
)


def _agent_as_tool(agent: BaseAgent[OutputT]) -> StructuredTool:
if not agent.name:
raise AssertionError("Agent must have a name to be used by other Agents")
Expand All @@ -1433,7 +1463,10 @@ def _agent_as_tool(agent: BaseAgent[OutputT]) -> StructuredTool:

async def invoke_agent(
message: HumanMessage, thread_id: str | None
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
) -> tuple[
OutputT | str | list[str | ContentBlock],
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am uncertain here, shouldn't we extract all the text contents and only return these?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we just forward the return of inner invoke_agent, what's wrong here?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My point is that we should only forward the text response as an output of an subagent, right?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there are others i don't think we should forward them??

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would you want to filter them out or assert if any other block is OpaqueBlock?

SubagentStructuredResult | SubagentTextResult,
]:
result = await agent.invoke([message], thread_id=thread_id)

if agent.output_schema:
Expand All @@ -1443,7 +1476,7 @@ async def invoke_agent(
)

return result.final_message.content, SubagentTextResult(
content=result.final_message.content
content=_parse_content(result.final_message.content)
)

InputSchema = agent.input_schema
Expand All @@ -1452,13 +1485,19 @@ async def invoke_agent(

async def _run( # pyright: ignore[reportRedeclaration]
content: str, thread_id: str
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
) -> tuple[
OutputT | str | list[str | ContentBlock],
SubagentStructuredResult | SubagentTextResult,
]:
return await invoke_agent(HumanMessage(content=content), thread_id)
else:

async def _run( # pyright: ignore[reportRedeclaration]
content: str,
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
) -> tuple[
OutputT | str | list[str | ContentBlock],
SubagentStructuredResult | SubagentTextResult,
]:
return await invoke_agent(HumanMessage(content=content), None)

return StructuredTool.from_function(
Expand All @@ -1471,7 +1510,10 @@ async def _run( # pyright: ignore[reportRedeclaration]

async def invoke_agent_structured(
content: BaseModel, thread_id: str | None
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
) -> tuple[
OutputT | str | list[str | ContentBlock],
SubagentStructuredResult | SubagentTextResult,
]:
result = await agent.invoke_with_data(
instructions="Follow the system prompt.",
data=content.model_dump(),
Expand All @@ -1485,14 +1527,17 @@ async def invoke_agent_structured(
)

return result.final_message.content, SubagentTextResult(
content=result.final_message.content
content=_parse_content(result.final_message.content)
)

if agent.conversation_store:

async def _run(
**kwargs: Any, # noqa: ANN401
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
) -> tuple[
OutputT | str | list[str | ContentBlock],
SubagentStructuredResult | SubagentTextResult,
]:
content: BaseModel = kwargs["content"]
thread_id: str = kwargs["thread_id"]
return await invoke_agent_structured(content, thread_id)
Expand All @@ -1512,7 +1557,10 @@ async def _run(

async def _run(
**kwargs: Any, # noqa: ANN401
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
) -> tuple[
OutputT | str | list[str | ContentBlock],
SubagentStructuredResult | SubagentTextResult,
]:
content = InputSchema(**kwargs)
return await invoke_agent_structured(content, None)

Expand Down Expand Up @@ -1564,11 +1612,66 @@ def _map_tool_call_to_langchain(call: ToolCall | SubagentCall) -> LC_ToolCall:
return LC_ToolCall(id=call.id, name=name, args=args)


def _map_content_from_langchain(
content: str | list[str | dict[str, Any]],
) -> str | list[str | ContentBlock]:
if isinstance(content, str):
return content

result_content = [_map_content_block_from_langchain(b) for b in content]

return result_content


def _map_content_block_from_langchain(
block: str | dict[str, Any],
) -> str | ContentBlock:
if isinstance(block, str):
return block

match block.get("type"):
case "text":
return TextBlock(
text=block["text"],
extras=block.get("extras"),
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if providers add different fields here, we will ignore them?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds like this is not provider specific but LC specific, so all of the "extra" properties should be stored here. The extras field is just a dict so we can expand it later if needed, and it's more of an internal field (it's used by the langchain layer) so I don't see an issue of changing it in the future.

case _:
# NOTE: we return data we're not handling
# as opaque content blocks so they
# are preserved and sent back to the LLM
return OpaqueBlock(data=block)


def _map_content_to_langchain(
content: str | list[str | ContentBlock],
) -> str | list[str | dict[str, Any]]:
if isinstance(content, str):
return content

result_content = [_map_content_block_to_langchain(b) for b in content]

return result_content


def _map_content_block_to_langchain(block: str | ContentBlock) -> str | dict[str, Any]:
if isinstance(block, str):
return block

match block:
case TextBlock():
result: dict[str, Any] = {"type": "text", "text": block.text}
if block.extras:
result["extras"] = block.extras
return result
case OpaqueBlock():
return block.data


def _map_message_from_langchain(message: LC_BaseMessage) -> BaseMessage:
match message:
case LC_AIMessage():
return AIMessage(
content=message.content.__str__(),
content=_map_content_from_langchain(message.content), # pyright: ignore[reportUnknownArgumentType]
calls=[
_map_tool_call_from_langchain(tc)
for tc in message.tool_calls
Expand All @@ -1583,6 +1686,7 @@ def _map_message_from_langchain(message: LC_BaseMessage) -> BaseMessage:
for tc in message.tool_calls
if tc["name"].startswith(TOOL_STRATEGY_TOOL_PREFIX)
],
extras=cast(dict[str, Any], message.additional_kwargs),
)
case LC_HumanMessage():
return HumanMessage(content=message.content.__str__())
Expand All @@ -1597,7 +1701,10 @@ def _map_message_from_langchain(message: LC_BaseMessage) -> BaseMessage:
def _map_message_to_langchain(message: BaseMessage) -> LC_AnyMessage:
match message:
case AIMessage():
lc_message = LC_AIMessage(content=message.content)
lc_message = LC_AIMessage(
content=_map_content_to_langchain(message.content),
additional_kwargs=message.extras or {},
)
# This field can't be set via constructor
lc_message.tool_calls = [
_map_tool_call_to_langchain(c) for c in message.calls
Expand Down
2 changes: 0 additions & 2 deletions splunklib/ai/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,5 +199,3 @@ async def model_middleware(
if self._deadline is not None and monotonic() >= self._deadline:
raise TimeoutExceededException(timeout_seconds=self._seconds)
return await handler(request)


30 changes: 29 additions & 1 deletion splunklib/ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,31 @@
from splunklib.ai.tools import ToolType


@dataclass(frozen=True)
class TextBlock:
"""Plain text content block returned by a model."""

text: str
# TODO: should we have the id here as well?
# Provider-specific extras (e.g. Gemini thought signature on text blocks).
extras: dict[str, Any] | None = field(default=None)


@dataclass(frozen=True)
class OpaqueBlock:
"""Content block of an unrecognized or unsupported type.

The raw provider dict is preserved in `data` so it can be sent back
to the model unchanged on subsequent calls.
"""

data: dict[str, Any]


# Type alias for all content block variants.
ContentBlock = TextBlock | OpaqueBlock


@dataclass(frozen=True)
class ToolCall:
name: str
Expand Down Expand Up @@ -85,12 +110,15 @@ class AIMessage(BaseMessage):
"""

role: Literal["assistant"] = field(default="assistant", init=False)
content: str
content: str | list[str | ContentBlock]
Copy link
Copy Markdown
Member

@mateusz834 mateusz834 Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we made this a:

    content: str | list[str | dict]

or

    content: str | list[str | OpaqueBlock]

just as LC does? and provide a contect_block property that converts this to blocks (same just as LC)?

class ContentBlock:
    extras: dict[str, Any] | None
    pass

class TextBlock(ContentBlock):
     text: str

or instead of contect_block we could provide texts property that returns a list[str], extracted from the content? And then later if needed we can figure out the blocks thing.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I imagined the solution to be more explicit, so that the SDK controls the type of content blocks that are "supported", leaving the OpaqueBlock strictly for preserving the data between the SDK and langchain layer, so that the LLMs receive all the context in the next call.

This makes them explicitly ignore the OpaqueBlock and if any new model returns some new content block, this would fall into the OpaqueBlock, meaning the already existing code shouldn't break once this happens.

Of course someone could use the OpaqueBlock if they wanted to and create some logic around it's contents, but they should expect this code to be changed in case SDK adds first class support to this specific block. And I don't see the issue with breaking such code in the future, since it wasn't officially supported.

I don't think we can future proof everything, at least without creating some other potential issues.

By making the OpaqueBlock or in this case more of a base ContentBlock we could add a helper method for getting the text list from them (if possible).
I'm worried this could introduce issues in the developers code as the content-blocks and their contents might change in the future, but I'm open to discussion.
However, In this case we could also just have the TextBlock accept any content-block that contains the text field, or any other content-block that you can get the text from. And leave the OpaqueBlock as is.


calls: Sequence[ToolCall | SubagentCall]
structured_output_calls: Sequence[StructuredOutputCall] = field(
default_factory=tuple
)
# Backend-specific metadata (e.g. provider additional_kwargs) not
# representable in the standard fields. Opaque to callers.
extras: dict[str, Any] | None = field(default=None)


@dataclass(frozen=True)
Expand Down
Loading