diff --git a/docs/agents.md b/docs/agents.md index 1e04f7e9..23b18b64 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -142,11 +142,6 @@ Supplying a list of tools doesn't always mean the LLM will use a tool. You can f !!! note - To prevent infinite loops, the framework automatically resets `tool_choice` to "auto" after a tool call in the following scenarios: - - 1. When `tool_choice` is set to a specific function name (any string that's not "auto", "required", or "none") - 2. When `tool_choice` is set to "required" AND there is only one tool available - - This targeted reset mechanism allows the model to decide whether to make additional tool calls in subsequent turns while avoiding infinite loops in these specific cases. - + To prevent infinite loops, the framework automatically resets `tool_choice` to "auto" after a tool call. This behavior is configurable via [`agent.reset_tool_choice`][agents.agent.Agent.reset_tool_choice]. The infinite loop is because tool results are sent to the LLM, which then generates another tool call because of `tool_choice`, ad infinitum. + If you want the Agent to completely stop after a tool call (rather than continuing with auto mode), you can set [`Agent.tool_use_behavior="stop_on_first_tool"`] which will directly use the tool output as the final response without further LLM processing. diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index 80b9cc6d..61ca4a0f 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -4,7 +4,7 @@ import dataclasses import inspect from collections.abc import Awaitable -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, cast from openai.types.responses import ( @@ -77,6 +77,23 @@ class QueueCompleteSentinel: _NOT_FINAL_OUTPUT = ToolsToFinalOutputResult(is_final_output=False, final_output=None) +@dataclass +class AgentToolUseTracker: + agent_to_tools: list[tuple[Agent, list[str]]] = field(default_factory=list) + """Tuple of (agent, list of tools used). Can't use a dict because agents aren't hashable.""" + + def add_tool_use(self, agent: Agent[Any], tool_names: list[str]) -> None: + existing_data = next((item for item in self.agent_to_tools if item[0] == agent), None) + if existing_data: + existing_data[1].extend(tool_names) + else: + self.agent_to_tools.append((agent, tool_names)) + + def has_used_tools(self, agent: Agent[Any]) -> bool: + existing_data = next((item for item in self.agent_to_tools if item[0] == agent), None) + return existing_data is not None and len(existing_data[1]) > 0 + + @dataclass class ToolRunHandoff: handoff: Handoff @@ -101,6 +118,7 @@ class ProcessedResponse: handoffs: list[ToolRunHandoff] functions: list[ToolRunFunction] computer_actions: list[ToolRunComputerAction] + tools_used: list[str] # Names of all tools used, including hosted tools def has_tools_to_run(self) -> bool: # Handoffs, functions and computer actions need local processing @@ -208,29 +226,6 @@ async def execute_tools_and_side_effects( new_step_items.extend([result.run_item for result in function_results]) new_step_items.extend(computer_results) - # Reset tool_choice to "auto" after tool execution to prevent infinite loops - if processed_response.functions or processed_response.computer_actions: - tools = agent.tools - - if ( - run_config.model_settings and - cls._should_reset_tool_choice(run_config.model_settings, tools) - ): - # update the run_config model settings with a copy - new_run_config_settings = dataclasses.replace( - run_config.model_settings, - tool_choice="auto" - ) - run_config = dataclasses.replace(run_config, model_settings=new_run_config_settings) - - if cls._should_reset_tool_choice(agent.model_settings, tools): - # Create a modified copy instead of modifying the original agent - new_model_settings = dataclasses.replace( - agent.model_settings, - tool_choice="auto" - ) - agent = dataclasses.replace(agent, model_settings=new_model_settings) - # Second, check if there are any handoffs if run_handoffs := processed_response.handoffs: return await cls.execute_handoffs( @@ -322,22 +317,16 @@ async def execute_tools_and_side_effects( ) @classmethod - def _should_reset_tool_choice(cls, model_settings: ModelSettings, tools: list[Tool]) -> bool: - if model_settings is None or model_settings.tool_choice is None: - return False + def maybe_reset_tool_choice( + cls, agent: Agent[Any], tool_use_tracker: AgentToolUseTracker, model_settings: ModelSettings + ) -> ModelSettings: + """Resets tool choice to None if the agent has used tools and the agent's reset_tool_choice + flag is True.""" - # for specific tool choices - if ( - isinstance(model_settings.tool_choice, str) and - model_settings.tool_choice not in ["auto", "required", "none"] - ): - return True + if agent.reset_tool_choice is True and tool_use_tracker.has_used_tools(agent): + return dataclasses.replace(model_settings, tool_choice=None) - # for one tool and required tool choice - if model_settings.tool_choice == "required": - return len(tools) == 1 - - return False + return model_settings @classmethod def process_model_response( @@ -354,7 +343,7 @@ def process_model_response( run_handoffs = [] functions = [] computer_actions = [] - + tools_used: list[str] = [] handoff_map = {handoff.tool_name: handoff for handoff in handoffs} function_map = {tool.name: tool for tool in all_tools if isinstance(tool, FunctionTool)} computer_tool = next((tool for tool in all_tools if isinstance(tool, ComputerTool)), None) @@ -364,12 +353,15 @@ def process_model_response( items.append(MessageOutputItem(raw_item=output, agent=agent)) elif isinstance(output, ResponseFileSearchToolCall): items.append(ToolCallItem(raw_item=output, agent=agent)) + tools_used.append("file_search") elif isinstance(output, ResponseFunctionWebSearch): items.append(ToolCallItem(raw_item=output, agent=agent)) + tools_used.append("web_search") elif isinstance(output, ResponseReasoningItem): items.append(ReasoningItem(raw_item=output, agent=agent)) elif isinstance(output, ResponseComputerToolCall): items.append(ToolCallItem(raw_item=output, agent=agent)) + tools_used.append("computer_use") if not computer_tool: _error_tracing.attach_error_to_current_span( SpanError( @@ -391,6 +383,8 @@ def process_model_response( if not isinstance(output, ResponseFunctionToolCall): continue + tools_used.append(output.name) + # Handoffs if output.name in handoff_map: items.append(HandoffCallItem(raw_item=output, agent=agent)) @@ -422,6 +416,7 @@ def process_model_response( handoffs=run_handoffs, functions=functions, computer_actions=computer_actions, + tools_used=tools_used, ) @classmethod diff --git a/src/agents/agent.py b/src/agents/agent.py index 3258e15a..b31f00b1 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -155,6 +155,10 @@ class Agent(Generic[TContext]): web search, etc are always processed by the LLM. """ + reset_tool_choice: bool = True + """Whether to reset the tool choice to the default value after a tool has been called. Defaults + to True. This ensures that the agent doesn't enter an infinite loop of tool usage.""" + def clone(self, **kwargs: Any) -> Agent[TContext]: """Make a copy of the agent, with the given arguments changed. For example, you could do: ``` diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index b9547123..c5b591c6 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -208,8 +208,10 @@ async def _fetch_response( list_input = ItemHelpers.input_to_new_input_list(input) parallel_tool_calls = ( - True if model_settings.parallel_tool_calls and tools and len(tools) > 0 - else False if model_settings.parallel_tool_calls is False + True + if model_settings.parallel_tool_calls and tools and len(tools) > 0 + else False + if model_settings.parallel_tool_calls is False else NOT_GIVEN ) diff --git a/src/agents/run.py b/src/agents/run.py index b7ac85f9..5c21b709 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -10,6 +10,7 @@ from agents.tool import Tool from ._run_impl import ( + AgentToolUseTracker, NextStepFinalOutput, NextStepHandoff, NextStepRunAgain, @@ -151,6 +152,8 @@ async def run( if run_config is None: run_config = RunConfig() + tool_use_tracker = AgentToolUseTracker() + with TraceCtxManager( workflow_name=run_config.workflow_name, trace_id=run_config.trace_id, @@ -227,6 +230,7 @@ async def run( context_wrapper=context_wrapper, run_config=run_config, should_run_agent_start_hooks=should_run_agent_start_hooks, + tool_use_tracker=tool_use_tracker, ), ) else: @@ -239,6 +243,7 @@ async def run( context_wrapper=context_wrapper, run_config=run_config, should_run_agent_start_hooks=should_run_agent_start_hooks, + tool_use_tracker=tool_use_tracker, ) should_run_agent_start_hooks = False @@ -486,6 +491,7 @@ async def _run_streamed_impl( current_agent = starting_agent current_turn = 0 should_run_agent_start_hooks = True + tool_use_tracker = AgentToolUseTracker() streamed_result._event_queue.put_nowait(AgentUpdatedStreamEvent(new_agent=current_agent)) @@ -546,6 +552,7 @@ async def _run_streamed_impl( context_wrapper, run_config, should_run_agent_start_hooks, + tool_use_tracker, ) should_run_agent_start_hooks = False @@ -613,6 +620,7 @@ async def _run_single_turn_streamed( context_wrapper: RunContextWrapper[TContext], run_config: RunConfig, should_run_agent_start_hooks: bool, + tool_use_tracker: AgentToolUseTracker, ) -> SingleStepResult: if should_run_agent_start_hooks: await asyncio.gather( @@ -635,6 +643,8 @@ async def _run_single_turn_streamed( all_tools = await cls._get_all_tools(agent) model = cls._get_model(agent, run_config) model_settings = agent.model_settings.resolve(run_config.model_settings) + model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings) + final_response: ModelResponse | None = None input = ItemHelpers.input_to_new_input_list(streamed_result.input) @@ -687,6 +697,7 @@ async def _run_single_turn_streamed( hooks=hooks, context_wrapper=context_wrapper, run_config=run_config, + tool_use_tracker=tool_use_tracker, ) RunImpl.stream_step_result_to_queue(single_step_result, streamed_result._event_queue) @@ -704,6 +715,7 @@ async def _run_single_turn( context_wrapper: RunContextWrapper[TContext], run_config: RunConfig, should_run_agent_start_hooks: bool, + tool_use_tracker: AgentToolUseTracker, ) -> SingleStepResult: # Ensure we run the hooks before anything else if should_run_agent_start_hooks: @@ -732,6 +744,7 @@ async def _run_single_turn( handoffs, context_wrapper, run_config, + tool_use_tracker, ) return await cls._get_single_step_result_from_response( @@ -745,6 +758,7 @@ async def _run_single_turn( hooks=hooks, context_wrapper=context_wrapper, run_config=run_config, + tool_use_tracker=tool_use_tracker, ) @classmethod @@ -761,6 +775,7 @@ async def _get_single_step_result_from_response( hooks: RunHooks[TContext], context_wrapper: RunContextWrapper[TContext], run_config: RunConfig, + tool_use_tracker: AgentToolUseTracker, ) -> SingleStepResult: processed_response = RunImpl.process_model_response( agent=agent, @@ -769,6 +784,9 @@ async def _get_single_step_result_from_response( output_schema=output_schema, handoffs=handoffs, ) + + tool_use_tracker.add_tool_use(agent, processed_response.tools_used) + return await RunImpl.execute_tools_and_side_effects( agent=agent, original_input=original_input, @@ -868,9 +886,12 @@ async def _get_new_response( handoffs: list[Handoff], context_wrapper: RunContextWrapper[TContext], run_config: RunConfig, + tool_use_tracker: AgentToolUseTracker, ) -> ModelResponse: model = cls._get_model(agent, run_config) model_settings = agent.model_settings.resolve(run_config.model_settings) + model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings) + new_response = await model.get_response( system_instructions=system_prompt, input=input, diff --git a/tests/fake_model.py b/tests/fake_model.py index f2ba6229..ecbb7583 100644 --- a/tests/fake_model.py +++ b/tests/fake_model.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections.abc import AsyncIterator +from typing import Any from openai.types.responses import Response, ResponseCompletedEvent @@ -31,6 +32,7 @@ def __init__( [initial_output] if initial_output else [] ) self.tracing_enabled = tracing_enabled + self.last_turn_args: dict[str, Any] = {} def set_next_output(self, output: list[TResponseOutputItem] | Exception): self.turn_outputs.append(output) @@ -53,6 +55,14 @@ async def get_response( handoffs: list[Handoff], tracing: ModelTracing, ) -> ModelResponse: + self.last_turn_args = { + "system_instructions": system_instructions, + "input": input, + "model_settings": model_settings, + "tools": tools, + "output_schema": output_schema, + } + with generation_span(disabled=not self.tracing_enabled) as span: output = self.get_next_output() diff --git a/tests/test_tool_choice_reset.py b/tests/test_tool_choice_reset.py index 7dae6f63..f95117fd 100644 --- a/tests/test_tool_choice_reset.py +++ b/tests/test_tool_choice_reset.py @@ -1,63 +1,78 @@ import pytest -from agents import Agent, ModelSettings, Runner, Tool -from agents._run_impl import RunImpl +from agents import Agent, ModelSettings, Runner +from agents._run_impl import AgentToolUseTracker, RunImpl from .fake_model import FakeModel -from .test_responses import ( - get_function_tool, - get_function_tool_call, - get_text_message, -) +from .test_responses import get_function_tool, get_function_tool_call, get_text_message class TestToolChoiceReset: - def test_should_reset_tool_choice_direct(self): """ Test the _should_reset_tool_choice method directly with various inputs to ensure it correctly identifies cases where reset is needed. """ - # Case 1: tool_choice = None should not reset + agent = Agent(name="test_agent") + + # Case 1: Empty tool use tracker should not change the "None" tool choice model_settings = ModelSettings(tool_choice=None) - tools1: list[Tool] = [get_function_tool("tool1")] - # Cast to list[Tool] to fix type checking issues - assert not RunImpl._should_reset_tool_choice(model_settings, tools1) + tracker = AgentToolUseTracker() + new_settings = RunImpl.maybe_reset_tool_choice(agent, tracker, model_settings) + assert new_settings.tool_choice == model_settings.tool_choice - # Case 2: tool_choice = "auto" should not reset + # Case 2: Empty tool use tracker should not change the "auto" tool choice model_settings = ModelSettings(tool_choice="auto") - assert not RunImpl._should_reset_tool_choice(model_settings, tools1) + tracker = AgentToolUseTracker() + new_settings = RunImpl.maybe_reset_tool_choice(agent, tracker, model_settings) + assert model_settings.tool_choice == new_settings.tool_choice - # Case 3: tool_choice = "none" should not reset - model_settings = ModelSettings(tool_choice="none") - assert not RunImpl._should_reset_tool_choice(model_settings, tools1) + # Case 3: Empty tool use tracker should not change the "required" tool choice + model_settings = ModelSettings(tool_choice="required") + tracker = AgentToolUseTracker() + new_settings = RunImpl.maybe_reset_tool_choice(agent, tracker, model_settings) + assert model_settings.tool_choice == new_settings.tool_choice # Case 4: tool_choice = "required" with one tool should reset model_settings = ModelSettings(tool_choice="required") - assert RunImpl._should_reset_tool_choice(model_settings, tools1) + tracker = AgentToolUseTracker() + tracker.add_tool_use(agent, ["tool1"]) + new_settings = RunImpl.maybe_reset_tool_choice(agent, tracker, model_settings) + assert new_settings.tool_choice is None - # Case 5: tool_choice = "required" with multiple tools should not reset + # Case 5: tool_choice = "required" with multiple tools should reset model_settings = ModelSettings(tool_choice="required") - tools2: list[Tool] = [get_function_tool("tool1"), get_function_tool("tool2")] - assert not RunImpl._should_reset_tool_choice(model_settings, tools2) - - # Case 6: Specific tool choice should reset - model_settings = ModelSettings(tool_choice="specific_tool") - assert RunImpl._should_reset_tool_choice(model_settings, tools1) + tracker = AgentToolUseTracker() + tracker.add_tool_use(agent, ["tool1", "tool2"]) + new_settings = RunImpl.maybe_reset_tool_choice(agent, tracker, model_settings) + assert new_settings.tool_choice is None + + # Case 6: Tool usage on a different agent should not affect the tool choice + model_settings = ModelSettings(tool_choice="foo_bar") + tracker = AgentToolUseTracker() + tracker.add_tool_use(Agent(name="other_agent"), ["foo_bar", "baz"]) + new_settings = RunImpl.maybe_reset_tool_choice(agent, tracker, model_settings) + assert new_settings.tool_choice == model_settings.tool_choice + + # Case 7: tool_choice = "foo_bar" with multiple tools should reset + model_settings = ModelSettings(tool_choice="foo_bar") + tracker = AgentToolUseTracker() + tracker.add_tool_use(agent, ["foo_bar", "baz"]) + new_settings = RunImpl.maybe_reset_tool_choice(agent, tracker, model_settings) + assert new_settings.tool_choice is None @pytest.mark.asyncio async def test_required_tool_choice_with_multiple_runs(self): """ - Test scenario 1: When multiple runs are executed with tool_choice="required" - Ensure each run works correctly and doesn't get stuck in infinite loop - Also verify that tool_choice remains "required" between runs + Test scenario 1: When multiple runs are executed with tool_choice="required", ensure each + run works correctly and doesn't get stuck in an infinite loop. Also verify that tool_choice + remains "required" between runs. """ # Set up our fake model with responses for two runs fake_model = FakeModel() - fake_model.add_multiple_turn_outputs([ - [get_text_message("First run response")], - [get_text_message("Second run response")] - ]) + fake_model.add_multiple_turn_outputs( + [[get_text_message("First run response")], [get_text_message("Second run response")]] + ) # Create agent with a custom tool and tool_choice="required" custom_tool = get_function_tool("custom_tool") @@ -71,24 +86,26 @@ async def test_required_tool_choice_with_multiple_runs(self): # First run should work correctly and preserve tool_choice result1 = await Runner.run(agent, "first run") assert result1.final_output == "First run response" - assert agent.model_settings.tool_choice == "required", "tool_choice should stay required" + assert fake_model.last_turn_args["model_settings"].tool_choice == "required", ( + "tool_choice should stay required" + ) # Second run should also work correctly with tool_choice still required result2 = await Runner.run(agent, "second run") assert result2.final_output == "Second run response" - assert agent.model_settings.tool_choice == "required", "tool_choice should stay required" + assert fake_model.last_turn_args["model_settings"].tool_choice == "required", ( + "tool_choice should stay required" + ) @pytest.mark.asyncio async def test_required_with_stop_at_tool_name(self): """ - Test scenario 2: When using required tool_choice with stop_at_tool_names behavior - Ensure it correctly stops at the specified tool + Test scenario 2: When using required tool_choice with stop_at_tool_names behavior, ensure + it correctly stops at the specified tool """ # Set up fake model to return a tool call for second_tool fake_model = FakeModel() - fake_model.set_next_output([ - get_function_tool_call("second_tool", "{}") - ]) + fake_model.set_next_output([get_function_tool_call("second_tool", "{}")]) # Create agent with two tools and tool_choice="required" and stop_at_tool behavior first_tool = get_function_tool("first_tool", return_value="first tool result") @@ -109,8 +126,8 @@ async def test_required_with_stop_at_tool_name(self): @pytest.mark.asyncio async def test_specific_tool_choice(self): """ - Test scenario 3: When using a specific tool choice name - Ensure it doesn't cause infinite loops + Test scenario 3: When using a specific tool choice name, ensure it doesn't cause infinite + loops. """ # Set up fake model to return a text message fake_model = FakeModel() @@ -135,17 +152,19 @@ async def test_specific_tool_choice(self): @pytest.mark.asyncio async def test_required_with_single_tool(self): """ - Test scenario 4: When using required tool_choice with only one tool - Ensure it doesn't cause infinite loops + Test scenario 4: When using required tool_choice with only one tool, ensure it doesn't cause + infinite loops. """ # Set up fake model to return a tool call followed by a text message fake_model = FakeModel() - fake_model.add_multiple_turn_outputs([ - # First call returns a tool call - [get_function_tool_call("custom_tool", "{}")], - # Second call returns a text message - [get_text_message("Final response")] - ]) + fake_model.add_multiple_turn_outputs( + [ + # First call returns a tool call + [get_function_tool_call("custom_tool", "{}")], + # Second call returns a text message + [get_text_message("Final response")], + ] + ) # Create agent with a single tool and tool_choice="required" custom_tool = get_function_tool("custom_tool", return_value="tool result") @@ -159,3 +178,33 @@ async def test_required_with_single_tool(self): # Run should complete without infinite loops result = await Runner.run(agent, "first run") assert result.final_output == "Final response" + + @pytest.mark.asyncio + async def test_dont_reset_tool_choice_if_not_required(self): + """ + Test scenario 5: When agent.reset_tool_choice is False, ensure tool_choice is not reset. + """ + # Set up fake model to return a tool call followed by a text message + fake_model = FakeModel() + fake_model.add_multiple_turn_outputs( + [ + # First call returns a tool call + [get_function_tool_call("custom_tool", "{}")], + # Second call returns a text message + [get_text_message("Final response")], + ] + ) + + # Create agent with a single tool and tool_choice="required" and reset_tool_choice=False + custom_tool = get_function_tool("custom_tool", return_value="tool result") + agent = Agent( + name="test_agent", + model=fake_model, + tools=[custom_tool], + model_settings=ModelSettings(tool_choice="required"), + reset_tool_choice=False, + ) + + await Runner.run(agent, "test") + + assert fake_model.last_turn_args["model_settings"].tool_choice == "required"