Custom Guardrails
While the library provides 16 built-in guardrails, you’ll often need custom validation logic for your specific use case. This guide shows you how to write your own guardrails.
Basic Structure
Section titled “Basic Structure”Every guardrail is a function that returns a GuardrailResult:
from pydantic_ai_guardrails import GuardrailResult
async def my_guardrail(prompt: str) -> GuardrailResult: if should_block(prompt): return { 'tripwire_triggered': True, 'message': 'Why it was blocked', 'severity': 'high', 'suggestion': 'How to fix it', } return {'tripwire_triggered': False}GuardrailResult Fields
Section titled “GuardrailResult Fields”| Field | Required | Type | Description |
|---|---|---|---|
tripwire_triggered | Yes | bool | True to block, False to pass |
message | No | str | Human-readable explanation |
severity | No | 'low' | 'medium' | 'high' | 'critical' | Severity level |
suggestion | No | str | How to fix the issue (used in auto-retry) |
metadata | No | dict | Additional structured data |
Input Guardrail Example
Section titled “Input Guardrail Example”-
Define your validation function
from pydantic_ai_guardrails import GuardrailResultasync def block_competitors(prompt: str) -> GuardrailResult:"""Block mentions of competitor products."""competitors = ['competitor_a', 'competitor_b', 'competitor_c']prompt_lower = prompt.lower()found = [c for c in competitors if c in prompt_lower]if found:return {'tripwire_triggered': True,'message': f'Competitor mentions detected: {found}','severity': 'medium','metadata': {'competitors_found': found},}return {'tripwire_triggered': False} -
Wrap it in InputGuardrail
from pydantic_ai_guardrails import InputGuardrailguardrail = InputGuardrail(block_competitors,name='competitor_blocker',description='Blocks mentions of competitor products',) -
Add to your agent
from pydantic_ai_guardrails import GuardedAgentguarded_agent = GuardedAgent(agent,input_guardrails=[guardrail],)
Output Guardrail Example
Section titled “Output Guardrail Example”from pydantic_ai_guardrails import GuardrailResult, OutputGuardrail
async def check_response_quality(output: str) -> GuardrailResult: """Ensure response meets quality standards.""" issues = []
# Check length if len(output) < 50: issues.append('Response too short')
# Check for placeholder text if '[TODO]' in output or '[PLACEHOLDER]' in output: issues.append('Contains placeholder text')
# Check for confidence hedging hedge_phrases = ['I think', 'maybe', 'probably', 'not sure'] if any(phrase in output.lower() for phrase in hedge_phrases): issues.append('Contains hedging language')
if issues: return { 'tripwire_triggered': True, 'message': f'Quality issues: {", ".join(issues)}', 'severity': 'medium', 'suggestion': 'Provide a more confident, complete response', }
return {'tripwire_triggered': False}
guardrail = OutputGuardrail(check_response_quality, name='quality_check')Accessing Dependencies
Section titled “Accessing Dependencies”Use GuardrailContext to access injected dependencies:
from pydantic_ai_guardrails import GuardrailContext, GuardrailResult
async def check_user_permissions( ctx: GuardrailContext, # First parameter prompt: str, # Second parameter) -> GuardrailResult: """Check if user has permission to use this agent."""
# Access dependencies user_service = ctx.deps['user_service'] user_id = ctx.deps['user_id']
user = await user_service.get_user(user_id)
if not user.has_agent_access: return { 'tripwire_triggered': True, 'message': f'User {user_id} not authorized', 'severity': 'critical', 'metadata': {'user_id': user_id, 'tier': user.tier}, }
return {'tripwire_triggered': False}Usage:
result = await guarded_agent.run( 'Hello', deps={ 'user_service': UserService(), 'user_id': 'user_123', },)Sync vs Async
Section titled “Sync vs Async”Both sync and async functions work. The library auto-detects which you’re using:
# Async (recommended for I/O operations)async def async_guardrail(prompt: str) -> GuardrailResult: result = await some_external_api(prompt) return {'tripwire_triggered': result.is_bad}
# Sync (runs in thread pool automatically)def sync_guardrail(prompt: str) -> GuardrailResult: # CPU-bound work is fine here return {'tripwire_triggered': False}Output Guardrails with Message History
Section titled “Output Guardrails with Message History”Output guardrails can access the full conversation:
async def validate_conversation( ctx: GuardrailContext, output: str,) -> GuardrailResult: """Validate based on the full conversation."""
# Original prompt original_prompt = ctx.prompt
# Full message history messages = ctx.messages or []
# Count tool calls tool_calls = 0 for msg in messages: if hasattr(msg, 'parts'): for part in msg.parts: if hasattr(part, 'tool_name'): tool_calls += 1
# Example: require at least one tool call if tool_calls == 0: return { 'tripwire_triggered': True, 'message': 'No tools were used to generate response', 'severity': 'medium', }
return {'tripwire_triggered': False}Calling External APIs
Section titled “Calling External APIs”import httpxfrom pydantic_ai_guardrails import GuardrailContext, GuardrailResult
async def check_content_moderation( ctx: GuardrailContext, prompt: str,) -> GuardrailResult: """Call external moderation API."""
async with httpx.AsyncClient() as client: response = await client.post( 'https://api.moderation.example.com/check', json={'content': prompt}, headers={'Authorization': f'Bearer {ctx.deps["api_key"]}'}, ) result = response.json()
if result['flagged']: return { 'tripwire_triggered': True, 'message': f'Content flagged: {result["categories"]}', 'severity': 'high', 'metadata': result, }
return {'tripwire_triggered': False}Stateful Guardrails with Classes
Section titled “Stateful Guardrails with Classes”For guardrails that need initialization or state:
from pydantic_ai_guardrails import InputGuardrail, GuardrailResult
class ContentFilterGuardrail: def __init__(self, blocked_patterns: list[str]): import re self.patterns = [re.compile(p, re.IGNORECASE) for p in blocked_patterns]
async def __call__(self, prompt: str) -> GuardrailResult: for pattern in self.patterns: if pattern.search(prompt): return { 'tripwire_triggered': True, 'message': f'Blocked pattern found: {pattern.pattern}', 'severity': 'high', } return {'tripwire_triggered': False}
# Usagefilter_guardrail = ContentFilterGuardrail([ r'hack\s+into', r'bypass\s+security',])
guarded_agent = GuardedAgent( agent, input_guardrails=[InputGuardrail(filter_guardrail, name='content_filter')],)Writing Good Suggestions
Section titled “Writing Good Suggestions”The suggestion field is used in auto-retry to help the LLM fix issues:
async def check_pii(output: str) -> GuardrailResult: if contains_email(output): return { 'tripwire_triggered': True, 'message': 'Email address detected in output', 'severity': 'high', # Good suggestion: specific and actionable 'suggestion': ( 'Replace all email addresses with placeholders like ' '[EMAIL] or describe them generically without including ' 'the actual address.' ), } return {'tripwire_triggered': False}Testing Custom Guardrails
Section titled “Testing Custom Guardrails”Use the built-in testing utilities:
from pydantic_ai_guardrails import ( assert_guardrail_passes, assert_guardrail_blocks, create_test_context,)
async def test_competitor_guardrail(): guardrail = InputGuardrail(block_competitors)
# Should pass await assert_guardrail_passes( guardrail, 'Tell me about your product', )
# Should block await assert_guardrail_blocks( guardrail, 'How does your product compare to competitor_a?', )See Testing for more details.
Next Steps
Section titled “Next Steps”- Auto-Retry - Use suggestions for LLM self-correction
- Error Handling - Handle violations gracefully
- Human-in-the-Loop - Add human review to guardrails