-
Notifications
You must be signed in to change notification settings - Fork 94
LCORE-1830: Implement Question Validity Safety Capability in Pydantic AI #1913
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| """Pluggable capabilities for pydantic-ai agents in Lightspeed. | ||
| Provides safety, guardrail, and policy capabilities that hook into | ||
| pydantic-ai's AbstractCapability lifecycle to enforce constraints | ||
| before, during, or after agent runs. | ||
| """ | ||
|
|
||
| from pydantic_ai_lightspeed.capabilities.question_validity import QuestionValidity | ||
|
|
||
| __all__ = ["QuestionValidity"] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| """Question validity capability for agent input validation.""" | ||
|
|
||
| from pydantic_ai_lightspeed.capabilities.question_validity._capability import ( | ||
| QuestionValidity, | ||
| ) | ||
|
|
||
| __all__ = ["QuestionValidity"] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,132 @@ | ||
| """Question validity capability for filtering off-topic user queries. | ||
|
|
||
| This module implements a guardrail that classifies user questions as | ||
| Kubernetes/OpenShift-related or not (It can be customized to any | ||
| topic as well), using an LLM-based check before the main agent | ||
| processes the request. Invalid questions are rejected with a | ||
| predefined response, bypassing the primary agent entirely. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from collections.abc import Sequence | ||
| from dataclasses import dataclass, field | ||
| from string import Template | ||
|
|
||
| from pydantic_ai import AgentRunResult, RunContext | ||
| from pydantic_ai._agent_graph import GraphAgentState | ||
| from pydantic_ai.capabilities import AbstractCapability, WrapRunHandler | ||
| from pydantic_ai.direct import model_request | ||
| from pydantic_ai.messages import ModelRequest, TextContent, UserContent | ||
| from pydantic_ai.models import Model, infer_model | ||
|
|
||
| from log import get_logger | ||
| from models.config import ( | ||
| QuestionValidityConfig, | ||
| ) | ||
|
|
||
| logger = get_logger(__name__) | ||
|
|
||
| SUBJECT_REJECTED = "REJECTED" | ||
| SUBJECT_ALLOWED = "ALLOWED" | ||
|
|
||
|
|
||
| def _extract_message_str_from_user_content(user_content: Sequence[UserContent]) -> str: | ||
| """Extract and combine all text content into a string from a UserContent sequence. | ||
|
|
||
| Parameters: | ||
| user_content: A sequence of user content items to extract text from. | ||
|
|
||
| Returns: | ||
| A single string with all text content joined by newlines. | ||
| """ | ||
| str_arr: list[str] = [] | ||
| for c in user_content: | ||
| match c: | ||
| case str() as s: | ||
| str_arr.append(s) | ||
| case TextContent(content=c): | ||
| str_arr.append(c) | ||
|
|
||
| return "\n".join(str_arr) | ||
|
|
||
|
|
||
| @dataclass | ||
| class QuestionValidity(AbstractCapability[None]): | ||
| """Block or modify user input based on a guardrail check. | ||
|
|
||
| The guard function receives the user prompt and returns True if safe. | ||
|
|
||
|
Comment on lines
+56
to
+59
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update class docstring to match actual behavior. Line 106 describes a boolean guard-function contract, but this class performs LLM classification and short-circuiting with 🤖 Prompt for AI Agents |
||
| Example: | ||
| ```python | ||
| from pydantic_ai import Agent | ||
| from pydantic_ai.models.openai import OpenAIResponsesModel | ||
|
|
||
| model = OpenAIResponsesModel("gpt-4o-mini") | ||
| agent = Agent("openai:gpt-4.1", capabilities=[QuestionValidity(model)]) | ||
| ``` | ||
| """ | ||
|
|
||
| config: QuestionValidityConfig | ||
| _model: Model = field(init=False) | ||
|
|
||
| def __post_init__(self) -> None: | ||
| """Initialize the model instance from the configured model ID.""" | ||
| self._model = infer_model(self.config.model_id) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is conceptually wrong because we use Llama Stack as OpenRespones provider, not raw pydantic-ai inference, so you cannot use Create llama stack provider from client (pass it as a private attribute f.e.), use minimal settings (store=false, no conversation etc.), create
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @asimurka Just curious and want to know the reason why we can not use raw pydantic-ai inference here (why does it have to be coming from Llama Stack)? Do we have any constraints here? 😁
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have no constants but our consumers already use this shield with models supported by Llama Stack. I'm not sure what models they exactly use but those may not be supported by Pydantic, you never know. If we use Llama Stack provider, the migration will be much smoother for our customers.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it! Thanks for filling me in! That makes sense! So we might have to address these model set in the future when we completely remove Llama Stack. I'll change it to use the OpenAIResponsesModel for now. def _create_model_from_llama_stack_client(model_id: str) -> OpenAIResponsesModel:
client = AsyncLlamaStackClientHolder().get_client()
provider = _llama_stack_provider_from_client(client)
settings = OpenAIResponsesModelSettings(openai_store=False)
return OpenAIResponsesModel(model_id, provider=provider, settings=settings)And we create it in the def __post_init__(self) -> None:
self._model = _create_model_from_llama_stack_client(self.config.model_id)Note that because the implementation will use Does this looks good to you @asimurka @jrobertboos
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes this is the best way :) regarding:
I would just remove the "_" and use the methods from pydantic_ai.py and then import them into the capability. (increases PR scope a little bit but imo worth it) |
||
|
|
||
| def _build_prompt(self, message: str | Sequence[UserContent] | None) -> str: | ||
| """Build the classification prompt from the user message. | ||
|
|
||
| Parameters: | ||
| message: The user input as a string, sequence of user content, or None. | ||
|
|
||
| Returns: | ||
| The rendered prompt string ready to send to the validity model. | ||
| """ | ||
| match message: | ||
| case str() as s: | ||
| _message = s | ||
| case Sequence() as seq: | ||
| _message = _extract_message_str_from_user_content(seq) | ||
| case None: | ||
| _message = "" | ||
|
|
||
| return Template(self.config.model_prompt).substitute( | ||
| message=_message, allowed=SUBJECT_ALLOWED, rejected=SUBJECT_REJECTED | ||
| ) | ||
|
|
||
| async def wrap_run( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For discussion: overriding
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So just one thing that I would like to bring up for consideration. The output from Follow the agent run lifecycle It's just my thought. I'll follow whatever you guys think is the most suitable solution here. 😁
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok this sounds reasonable. |
||
| self, ctx: RunContext, *, handler: WrapRunHandler | ||
| ) -> AgentRunResult: | ||
| """Run the question validity check before delegating to the main agent. | ||
|
|
||
| Sends the user prompt to the validity model for classification. | ||
| If the question is allowed, the handler proceeds normally. | ||
| Otherwise, a rejection response is returned and the main agent | ||
| is bypassed. | ||
|
|
||
| Parameters: | ||
| ctx: The run context containing the user prompt and usage tracker. | ||
| handler: The handler that invokes the main agent run. | ||
|
|
||
| Returns: | ||
| The agent run result, either from the main agent or a rejection. | ||
| """ | ||
| prompt = self._build_prompt(ctx.prompt) | ||
|
|
||
| result = await model_request( | ||
| model=self._model, | ||
| messages=[ModelRequest.user_text_prompt(prompt)], | ||
| ) | ||
|
asimurka marked this conversation as resolved.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If model is set up correctly, you can use its
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Model.request() feels like an internal function. model_request() is a thin wrapper that's more user facing I think. Also, if we're using request() then we force ourselves to pass
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I think it's ok to use model_request() I just saw some demo where they used the method but this is ok IMO. |
||
|
|
||
| # Include token usage from the question validity request | ||
| ctx.usage.incr(result.usage) | ||
|
|
||
| if result.text is not None and result.text.strip() == SUBJECT_ALLOWED: | ||
| return await handler() # proceed with the real run | ||
|
|
||
| # short-circuit: return the rejection message with shield usage tracked | ||
| state = GraphAgentState(usage=ctx.usage) | ||
| return AgentRunResult( | ||
| output=self.config.invalid_question_response, _state=state | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| """Unit tests for pydantic_ai_lightspeed capabilities.""" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| """Unit tests for question validity capability.""" |
Uh oh!
There was an error while loading. Please reload this page.