From 83dc4c9a1eec14ab6a81f0cf91359fd3ad7f2988 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 30 Dec 2024 13:23:03 +0000 Subject: [PATCH] extend "RunContext" --- pydantic_ai_slim/pydantic_ai/agent.py | 45 ++++++++++++++------------- pydantic_ai_slim/pydantic_ai/tools.py | 16 ++++++---- tests/test_logfire.py | 10 +++--- 3 files changed, 38 insertions(+), 33 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 3a0959ab..b8fd0ac5 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -40,6 +40,16 @@ _logfire = logfire_api.Logfire(otel_scope='pydantic-ai') +# while waiting for https://github.com/pydantic/logfire/issues/745 +try: + import logfire._internal.stack_info +except ImportError: + pass +else: + from pathlib import Path + + logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),) + NoneType = type(None) EndStrategy = Literal['early', 'exhaustive'] """The strategy for handling multiple tool calls when a final result is found. @@ -215,7 +225,7 @@ async def run( """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) - model_used, mode_selection = await self._get_model(model) + model_used = await self._get_model(model) deps = self._get_deps(deps) new_message_index = len(message_history) if message_history else 0 @@ -224,11 +234,10 @@ async def run( '{agent_name} run {prompt=}', prompt=user_prompt, agent=self, - mode_selection=mode_selection, model_name=model_used.name(), agent_name=self.name or 'agent', ) as run_span: - run_context = RunContext(deps, 0, [], None, model_used, usage or result.Usage()) + run_context = RunContext(deps, model_used, usage or result.Usage(), user_prompt) messages = await self._prepare_messages(user_prompt, message_history, run_context) run_context.messages = messages @@ -238,15 +247,14 @@ async def run( model_settings = merge_model_settings(self.model_settings, model_settings) usage_limits = usage_limits or UsageLimits() - run_step = 0 while True: usage_limits.check_before_request(run_context.usage) - run_step += 1 - with _logfire.span('preparing model and tools {run_step=}', run_step=run_step): + run_context.run_step += 1 + with _logfire.span('preparing model and tools {run_step=}', run_step=run_context.run_step): agent_model = await self._prepare_model(run_context) - with _logfire.span('model request', run_step=run_step) as model_req_span: + with _logfire.span('model request', run_step=run_context.run_step) as model_req_span: model_response, request_usage = await agent_model.request(messages, model_settings) model_req_span.set_attribute('response', model_response) model_req_span.set_attribute('usage', request_usage) @@ -255,7 +263,7 @@ async def run( run_context.usage.incr(request_usage, requests=1) usage_limits.check_tokens(run_context.usage) - with _logfire.span('handle model response', run_step=run_step) as handle_span: + with _logfire.span('handle model response', run_step=run_context.run_step) as handle_span: final_result, tool_responses = await self._handle_model_response(model_response, run_context) if tool_responses: @@ -377,7 +385,7 @@ async def main(): # f_back because `asynccontextmanager` adds one frame if frame := inspect.currentframe(): # pragma: no branch self._infer_name(frame.f_back) - model_used, mode_selection = await self._get_model(model) + model_used = await self._get_model(model) deps = self._get_deps(deps) new_message_index = len(message_history) if message_history else 0 @@ -386,11 +394,10 @@ async def main(): '{agent_name} run stream {prompt=}', prompt=user_prompt, agent=self, - mode_selection=mode_selection, model_name=model_used.name(), agent_name=self.name or 'agent', ) as run_span: - run_context = RunContext(deps, 0, [], None, model_used, usage or result.Usage()) + run_context = RunContext(deps, model_used, usage or result.Usage(), user_prompt) messages = await self._prepare_messages(user_prompt, message_history, run_context) run_context.messages = messages @@ -400,15 +407,14 @@ async def main(): model_settings = merge_model_settings(self.model_settings, model_settings) usage_limits = usage_limits or UsageLimits() - run_step = 0 while True: - run_step += 1 + run_context.run_step += 1 usage_limits.check_before_request(run_context.usage) - with _logfire.span('preparing model and tools {run_step=}', run_step=run_step): + with _logfire.span('preparing model and tools {run_step=}', run_step=run_context.run_step): agent_model = await self._prepare_model(run_context) - with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span: + with _logfire.span('model request {run_step=}', run_step=run_context.run_step) as model_req_span: async with agent_model.request_stream(messages, model_settings) as model_response: run_context.usage.requests += 1 model_req_span.set_attribute('response_type', model_response.__class__.__name__) @@ -781,14 +787,14 @@ def _register_tool(self, tool: Tool[AgentDeps]) -> None: self._function_tools[tool.name] = tool - async def _get_model(self, model: models.Model | models.KnownModelName | None) -> tuple[models.Model, str]: + async def _get_model(self, model: models.Model | models.KnownModelName | None) -> models.Model: """Create a model configured for this agent. Args: model: model to use for this run, required if `model` was not set when creating the agent. Returns: - a tuple of `(model used, how the model was selected)` + The model used """ model_: models.Model if some_model := self._override_model: @@ -799,18 +805,15 @@ async def _get_model(self, model: models.Model | models.KnownModelName | None) - '(Even when `override(model=...)` is customizing the model that will actually be called)' ) model_ = some_model.value - mode_selection = 'override-model' elif model is not None: model_ = models.infer_model(model) - mode_selection = 'custom' elif self.model is not None: # noinspection PyTypeChecker model_ = self.model = models.infer_model(self.model) - mode_selection = 'from-agent' else: raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.') - return model_, mode_selection + return model_ async def _prepare_model(self, run_context: RunContext[AgentDeps]) -> models.AgentModel: """Build tools and create an agent model.""" diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 34015b4f..746a3ebc 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -40,16 +40,20 @@ class RunContext(Generic[AgentDeps]): deps: AgentDeps """Dependencies for the agent.""" - retry: int - """Number of retries so far.""" - messages: list[_messages.ModelMessage] - """Messages exchanged in the conversation so far.""" - tool_name: str | None - """Name of the tool being called.""" model: models.Model """The model used in this run.""" usage: Usage """LLM usage associated with the run.""" + prompt: str + """The original user prompt passed to the run.""" + messages: list[_messages.ModelMessage] = field(default_factory=list) + """Messages exchanged in the conversation so far.""" + tool_name: str | None = None + """Name of the tool being called.""" + retry: int = 0 + """Number of retries so far.""" + run_step: int = 0 + """The current step in the run.""" def replace_with( self, retry: int | None = None, tool_name: str | None | _utils.Unset = _utils.UNSET diff --git a/tests/test_logfire.py b/tests/test_logfire.py index 5d4caf4d..762a51cf 100644 --- a/tests/test_logfire.py +++ b/tests/test_logfire.py @@ -91,8 +91,8 @@ async def my_ret(x: int) -> str: ) assert summary.attributes[0] == snapshot( { - 'code.filepath': 'agent.py', - 'code.function': 'run', + 'code.filepath': 'test_logfire.py', + 'code.function': 'test_logfire', 'code.lineno': 123, 'prompt': 'Hello', 'agent': IsJson( @@ -111,7 +111,6 @@ async def my_ret(x: int) -> str: 'model_settings': None, } ), - 'mode_selection': 'from-agent', 'model_name': 'test-model', 'agent_name': 'my_agent', 'logfire.msg_template': '{agent_name} run {prompt=}', @@ -176,7 +175,6 @@ async def my_ret(x: int) -> str: 'model': {'type': 'object', 'title': 'TestModel', 'x-python-datatype': 'dataclass'} }, }, - 'mode_selection': {}, 'model_name': {}, 'agent_name': {}, 'all_messages': { @@ -263,8 +261,8 @@ async def my_ret(x: int) -> str: ) assert summary.attributes[1] == snapshot( { - 'code.filepath': 'agent.py', - 'code.function': 'run', + 'code.filepath': 'test_logfire.py', + 'code.function': 'test_logfire', 'code.lineno': IsInt(), 'run_step': 1, 'logfire.msg_template': 'preparing model and tools {run_step=}',