diff --git a/python/packages/core/agent_framework/_types.py b/python/packages/core/agent_framework/_types.py index ce47f5fc8e..f30fc04789 100644 --- a/python/packages/core/agent_framework/_types.py +++ b/python/packages/core/agent_framework/_types.py @@ -1973,11 +1973,97 @@ def _coalesce_text_content(contents: list[Content], type_str: Literal["text", "t contents.extend(coalesced_contents) +def _content_items_text(items: Any) -> str | None: + """Return concatenated text when a content item list only contains text.""" + if not isinstance(items, list): + return None + text_parts: list[str] = [] + content_items = cast(list[object], items) + for item in content_items: + if not isinstance(item, Content) or item.type != "text": + return None + text_parts.append(item.text or "") + return "".join(text_parts) + + +def _merge_content_item_lists(existing: Any, incoming: Any) -> Any: + """Merge streamed nested content lists, replacing deltas with a later full value when present.""" + if incoming is None: + return existing + if existing is None: + return deepcopy(incoming) + + existing_text = _content_items_text(existing) + incoming_text = _content_items_text(incoming) + if existing_text is not None and incoming_text is not None: + if incoming_text.startswith(existing_text): + return deepcopy(incoming) + if existing_text.startswith(incoming_text): + return existing + + existing_items = cast(list[Content], existing) + merged = deepcopy(existing_items[0]) + merged.text = existing_text + incoming_text + return [merged] + + if isinstance(existing, list) and isinstance(incoming, list): + existing_list = cast(list[object], existing) + incoming_list = cast(list[object], incoming) + return [*existing_list, *deepcopy(incoming_list)] + return deepcopy(incoming) + + +def _merge_code_interpreter_content(existing: Content, incoming: Content) -> None: + """Merge two code interpreter content items for the same logical call.""" + existing.inputs = _merge_content_item_lists(existing.inputs, incoming.inputs) + existing.outputs = _merge_content_item_lists(existing.outputs, incoming.outputs) + existing.annotations = _combine_annotations(existing.annotations, incoming.annotations) + existing.additional_properties = {**existing.additional_properties, **incoming.additional_properties} + existing.raw_representation = _combine_raw_representations(existing.raw_representation, incoming.raw_representation) + + +def _code_interpreter_key(content: Content) -> tuple[str, str] | None: + """Return the aggregation key for code interpreter call/result content.""" + if content.type not in {"code_interpreter_tool_call", "code_interpreter_tool_result"}: + return None + call_id = content.call_id or content.additional_properties.get("item_id") + if not isinstance(call_id, str) or not call_id: + return None + return content.type, call_id + + +def _coalesce_code_interpreter_content(contents: list[Content]) -> None: + """Coalesce streaming code interpreter chunks by call id.""" + if not contents: + return + + coalesced_contents: list[Content] = [] + seen: dict[tuple[str, str], Content] = {} + for content in contents: + key = _code_interpreter_key(content) + if key is None: + coalesced_contents.append(content) + continue + + existing = seen.get(key) + if existing is None: + copied = deepcopy(content) + seen[key] = copied + coalesced_contents.append(copied) + continue + + _merge_code_interpreter_content(existing, content) + + contents.clear() + contents.extend(coalesced_contents) + + def _finalize_response(response: ChatResponse | AgentResponse) -> None: """Finalizes the response by performing any necessary post-processing.""" for msg in response.messages: _coalesce_text_content(msg.contents, "text") _coalesce_text_content(msg.contents, "text_reasoning") + _coalesce_code_interpreter_content(msg.contents) # region ContinuationToken diff --git a/python/packages/core/tests/core/test_sessions.py b/python/packages/core/tests/core/test_sessions.py index ebb91d0b0d..7c78dba209 100644 --- a/python/packages/core/tests/core/test_sessions.py +++ b/python/packages/core/tests/core/test_sessions.py @@ -307,6 +307,63 @@ async def test_after_run_stores_inputs_and_responses(self) -> None: assert provider.stored[0].text == "hello" assert provider.stored[1].text == "hi" + async def test_after_run_stores_coalesced_code_interpreter_chunks(self) -> None: + from agent_framework import AgentResponse, AgentResponseUpdate, Content + + provider = ConcreteHistoryProvider("mem", store_inputs=False) + updates = [ + AgentResponseUpdate( + role="assistant", + contents=[ + Content.from_code_interpreter_tool_result( + call_id="ci_123", + outputs=[], + ) + ], + ), + AgentResponseUpdate( + contents=[ + Content.from_code_interpreter_tool_call( + call_id="ci_123", + inputs=[Content.from_text(text="import")], + additional_properties={"sequence_number": 1}, + ) + ], + ), + AgentResponseUpdate( + contents=[ + Content.from_code_interpreter_tool_call( + call_id="ci_123", + inputs=[Content.from_text(text=" pandas")], + additional_properties={"sequence_number": 2}, + ) + ], + ), + AgentResponseUpdate( + contents=[ + Content.from_code_interpreter_tool_call( + call_id="ci_123", + inputs=[Content.from_text(text="import pandas as pd")], + additional_properties={"sequence_number": 3}, + ) + ], + ), + ] + ctx = SessionContext(session_id="s1", input_messages=[Message(role="user", contents=["make a sheet"])]) + ctx._response = AgentResponse.from_updates(updates) + + await provider.after_run(agent=None, session=AgentSession(), context=ctx, state={}) # type: ignore[arg-type] + + assert len(provider.stored) == 1 + stored_contents = provider.stored[0].contents + calls = [content for content in stored_contents if content.type == "code_interpreter_tool_call"] + results = [content for content in stored_contents if content.type == "code_interpreter_tool_result"] + assert len(calls) == 1 + assert len(results) == 1 + assert calls[0].inputs is not None + assert len(calls[0].inputs) == 1 + assert calls[0].inputs[0].text == "import pandas as pd" + async def test_after_run_skips_inputs_when_disabled(self) -> None: from agent_framework import AgentResponse