Production-Grade AI Agent Architecture: Prototype to Deployment Guide
将 AI Agent 从原型推向生产环境是一个系统工程。本文将从错误处理、状态管理、可观测性、成本优化等多个维度,详细介绍生产级 Agent 架构的设计模式和最佳实践。
生产级 Agent 架构全景
┌─────────────────────────────────────────────────────────────────┐
│ Load Balancer │
└───────────────────────────┬─────────────────────────────────────┘
│
┌───────────────────────────▼─────────────────────────────────────┐
│ API Gateway │
│ (Auth / Rate Limit / Request Validation) │
└───────────────────────────┬─────────────────────────────────────┘
│
┌───────────────────────────▼─────────────────────────────────────┐
│ Agent Orchestrator │
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐ │
│ │Error │ │State │ │Cost │ │Circuit │ │
│ │Handler │ │Manager │ │Tracker │ │Breaker │ │
│ └─────────┘ └─────────┘ └─────────┘ └─────────┘ │
└──────────┬──────────┬──────────┬──────────┬────────────────────┘
│ │ │ │
┌──────▼──┐ ┌─────▼──┐ ┌────▼───┐ ┌───▼──────┐
│ LLM │ │ Vector │ │ Cache │ │ Message │
│ Gateway │ │ DB │ │(Redis) │ │ Queue │
└─────────┘ └────────┘ └────────┘ └──────────┘
│
┌──────▼──────────────────┐
│ Observability Stack │
│ (Traces/Metrics/Logs) │
└─────────────────────────┘
错误处理策略
分层错误处理
生产级 Agent 需要多层次的错误处理机制:
from enum import Enum
from dataclasses import dataclass
from typing import Optional, Callable
import asyncio
import logging
class ErrorSeverity(Enum):
LOW = "low" # 可忽略,继续执行
MEDIUM = "medium" # 需要重试
HIGH = "high" # 需要降级处理
CRITICAL = "critical" # 需要中断并告警
@dataclass
class AgentError:
error_type: str
message: str
severity: ErrorSeverity
recoverable: bool
context: dict
original_exception: Optional[Exception] = None
class AgentErrorHandler:
def __init__(self):
self.logger = logging.getLogger("agent.error")
self.error_counts = {}
async def handle(self, error: AgentError, retry_fn: Callable):
"""统一错误处理入口"""
# 记录错误
self._log_error(error)
# 根据严重程度处理
match error.severity:
case ErrorSeverity.LOW:
return await self._handle_low(error)
case ErrorSeverity.MEDIUM:
return await self._handle_medium(error, retry_fn)
case ErrorSeverity.HIGH:
return await self._handle_high(error)
case ErrorSeverity.CRITICAL:
return await self._handle_critical(error)
async def _handle_medium(self, error: AgentError, retry_fn: Callable):
"""中等错误:指数退避重试"""
max_retries = 3
base_delay = 1
for attempt in range(max_retries):
try:
delay = base_delay * (2 ** attempt)
await asyncio.sleep(delay)
self.logger.info(f"重试 {attempt + 1}/{max_retries}")
return await retry_fn()
except Exception as e:
if attempt == max_retries - 1:
return await self._handle_high(error)
return None
async def _handle_high(self, error: AgentError):
"""高严重度错误:降级处理"""
self.logger.warning(f"触发降级: {error.error_type}")
# 使用备选方案
if error.error_type == "llm_timeout":
return {"response": "系统繁忙,请稍后重试", "degraded": True}
elif error.error_type == "tool_failure":
return {"response": "部分功能暂时不可用", "degraded": True}
async def _handle_critical(self, error: AgentError):
"""严重错误:中断并告警"""
self.logger.critical(f"严重错误: {error.message}")
await self._send_alert(error)
raise SystemExit(1)
def _log_error(self, error: AgentError):
"""记录错误详情"""
self.logger.error(
f"Agent Error: {error.error_type} | "
f"Severity: {error.severity.value} | "
f"Message: {error.message} | "
f"Context: {error.context}"
)
LLM 调用的错误处理
LLM 调用是 Agent 中最不稳定的环节,需要特别处理:
class LLMGateway:
def __init__(self, providers: list, fallback_chain: list):
self.providers = providers
self.fallback_chain = fallback_chain
self.circuit_breakers = {p.name: CircuitBreaker() for p in providers}
async def call(self, messages: list, tools: list = None, **kwargs):
"""带容错的 LLM 调用"""
last_error = None
for provider in self.fallback_chain:
breaker = self.circuit_breakers[provider.name]
if breaker.is_open:
continue # 跳过熔断的 provider
try:
result = await provider.call(messages, tools, **kwargs)
breaker.record_success()
return result
except RateLimitError as e:
breaker.record_failure()
last_error = e
await asyncio.sleep(e.retry_after)
continue
except TimeoutError as e:
breaker.record_failure()
last_error = e
continue
except Exception as e:
breaker.record_failure()
last_error = e
continue
raise AllProvidersFailedError(last_error)
class CircuitBreaker:
"""熔断器模式实现"""
def __init__(self, failure_threshold=5, recovery_timeout=60):
self.failure_count = 0
self.failure_threshold = failure_threshold
self.recovery_timeout = recovery_timeout
self.last_failure_time = None
self.state = "closed" # closed, open, half-open
def record_success(self):
self.failure_count = 0
self.state = "closed"
def record_failure(self):
self.failure_count += 1
self.last_failure_time = time.time()
if self.failure_count >= self.failure_threshold:
self.state = "open"
@property
def is_open(self):
if self.state == "closed":
return False
if self.state == "open":
if time.time() - self.last_failure_time > self.recovery_timeout:
self.state = "half-open"
return False
return True
return False # half-open 状态允许一次尝试
状态管理
会话状态设计
from pydantic import BaseModel, Field
from datetime import datetime
from typing import Any, Optional
import json
import redis
class ConversationState(BaseModel):
session_id: str
user_id: str
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
# 对话历史
messages: list[dict] = []
# Agent 内部状态
current_step: str = "initial"
tool_calls_history: list[dict] = []
intermediate_results: dict = {}
# 元数据
token_count: int = 0
cost_so_far: float = 0.0
error_count: int = 0
# 用户上下文
user_preferences: dict = {}
conversation_summary: Optional[str] = None
class StateManager:
"""基于 Redis 的状态管理器"""
def __init__(self, redis_url: str):
self.redis = redis.from_url(redis_url)
self.default_ttl = 3600 * 24 # 24 小时
async def get_state(self, session_id: str) -> Optional[ConversationState]:
data = self.redis.get(f"agent:state:{session_id}")
if data:
return ConversationState.model_validate_json(data)
return None
async def save_state(self, state: ConversationState):
state.updated_at = datetime.utcnow()
key = f"agent:state:{state.session_id}"
self.redis.setex(
key,
self.default_ttl,
state.model_dump_json()
)
async def update_state(self, session_id: str, updates: dict):
state = await self.get_state(session_id)
if state:
for key, value in updates.items():
setattr(state, key, value)
await self.save_state(state)
async def delete_state(self, session_id: str):
self.redis.delete(f"agent:state:{session_id}")
async def list_sessions(self, user_id: str) -> list[str]:
"""列出用户的所有会话"""
pattern = f"agent:state:*"
keys = self.redis.keys(pattern)
sessions = []
for key in keys:
data = self.redis.get(key)
if data:
state = ConversationState.model_validate_json(data)
if state.user_id == user_id:
sessions.append(state.session_id)
return sessions
检查点机制
对于长时间运行的 Agent 任务,检查点机制至关重要:
class CheckpointManager:
"""检查点管理器"""
def __init__(self, state_manager: StateManager):
self.state_manager = state_manager
async def create_checkpoint(self, session_id: str, step: str, data: dict):
"""创建检查点"""
checkpoint = {
"step": step,
"data": data,
"timestamp": datetime.utcnow().isoformat()
}
key = f"agent:checkpoint:{session_id}:{step}"
self.state_manager.redis.setex(
key,
3600 * 24 * 7, # 保留 7 天
json.dumps(checkpoint, default=str)
)
async def restore_from_checkpoint(self, session_id: str, step: str) -> Optional[dict]:
"""从检查点恢复"""
key = f"agent:checkpoint:{session_id}:{step}"
data = self.state_manager.redis.get(key)
if data:
return json.loads(data)
return None
async def list_checkpoints(self, session_id: str) -> list[str]:
"""列出所有检查点"""
pattern = f"agent:checkpoint:{session_id}:*"
keys = self.state_manager.redis.keys(pattern)
return [k.decode().split(":")[-1] for k in keys]
可观测性
OpenTelemetry 集成
from opentelemetry import trace, metrics
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.instrumentation.requests import RequestsInstrumentor
初始化 Tracing
trace.set_tracer_provider(TracerProvider())
tracer = trace.get_tracer("agent.service")
otlp_exporter = OTLPSpanExporter(endpoint="http://otel-collector:4317")
trace.get_tracer_provider().add_span_processor(
BatchSpanProcessor(otlp_exporter)
)
初始化 Metrics
metrics.set_meter_provider(MeterProvider())
meter = metrics.get_meter("agent.service")
定义指标
llm_call_counter = meter.create_counter(
"agent.llm.calls",
description="LLM 调用次数",
unit="1"
)
llm_latency = meter.create_histogram(
"agent.llm.latency",
description="LLM 调用延迟",
unit="ms"
)
tool_call_counter = meter.create_counter(
"agent.tool.calls",
description="工具调用次数",
unit="1"
)
class AgentObservability:
"""Agent 可观测性装饰器"""
@staticmethod
def traced(operation_name: str):
def decorator(func):
async def wrapper(*args, **kwargs):
with tracer.start_as_current_span(operation_name) as span:
span.set_attribute("agent.operation", operation_name)
try:
result = await func(*args, **kwargs)
span.set_status(trace.Status(trace.StatusCode.OK))
return result
except Exception as e:
span.set_status(
trace.Status(trace.StatusCode.ERROR, str(e))
)
span.record_exception(e)
raise
return wrapper
return decorator
@staticmethod
def measure_llm(provider: str, model: str):
def decorator(func):
async def wrapper(*args, **kwargs):
start_time = time.time()
try:
result = await func(*args, **kwargs)
latency = (time.time() - start_time) * 1000
llm_call_counter.add(1, {"provider": provider, "model": model})
llm_latency.record(latency, {"provider": provider, "model": model})
return result
except Exception as e:
llm_call_counter.add(1, {
"provider": provider,
"model": model,
"status": "error"
})
raise
return wrapper
return decorator
结构化日志
import structlog
logger = structlog.get_logger()
class AgentLogger:
@staticmethod
def log_agent_start(session_id: str, user_input: str):
logger.info(
"agent_started",
session_id=session_id,
input_length=len(user_input),
input_preview=user_input[:100]
)
@staticmethod
def log_llm_call(session_id: str, provider: str, model: str,
tokens: int, latency_ms: float):
logger.info(
"llm_call",
session_id=session_id,
provider=provider,
model=model,
tokens=tokens,
latency_ms=latency_ms
)
@staticmethod
def log_tool_call(session_id: str, tool_name: str,
success: bool, latency_ms: float):
logger.info(
"tool_call",
session_id=session_id,
tool_name=tool_name,
success=success,
latency_ms=latency_ms
)
@staticmethod
def log_error(session_id: str, error_type: str, message: str):
logger.error(
"agent_error",
session_id=session_id,
error_type=error_type,
message=message
)
成本优化
Token 使用优化
class TokenOptimizer:
"""Token 使用优化器"""
def __init__(self):
self.token_budget_per_session = 50000
self.token_usage = {}
def estimate_tokens(self, text: str) -> int:
"""估算 token 数量"""
return len(text) // 4 # 粗略估算
def optimize_messages(self, messages: list, max_tokens: int) -> list:
"""优化消息列表以减少 token 消耗"""
total_tokens = sum(self.estimate_tokens(m["content"]) for m in messages)
if total_tokens <= max_tokens:
return messages
# 策略 1: 保留 system prompt 和最近的消息
system_msgs = [m for m in messages if m["role"] == "system"]
other_msgs = [m for m in messages if m["role"] != "system"]
# 策略 2: 压缩中间消息
compressed = []
remaining_tokens = max_tokens - sum(
self.estimate_tokens(m["content"]) for m in system_msgs
)
for msg in reversed(other_msgs):
msg_tokens = self.estimate_tokens(msg["content"])
if msg_tokens <= remaining_tokens:
compressed.insert(0, msg)
remaining_tokens -= msg_tokens
else:
# 消息太长,进行摘要
summary = self.summarize_message(msg["content"])
summary_tokens = self.estimate_tokens(summary)
if summary_tokens <= remaining_tokens:
compressed.insert(0, {
"role": msg["role"],
"content": f"[摘要] {summary}"
})
remaining_tokens -= summary_tokens
return system_msgs + compressed
def summarize_message(self, content: str) -> str:
"""消息摘要(简化实现)"""
if len(content) <= 200:
return content
return content[:100] + "..." + content[-100:]
智能模型选择
class ModelRouter:
"""根据任务复杂度选择模型"""
COMPLEXITY_SCORES = {
"simple_qa": 1,
"tool_use": 3,
"code_generation": 5,
"complex_reasoning": 8,
"multi_agent_coordination": 9
}
MODEL_CONFIGS = {
"fast": {
"model": "claude-haiku-3.5",
"max_tokens": 4096,
"cost_per_1k": 0.00025
},
"balanced": {
"model": "claude-sonnet-4-20250514",
"max_tokens": 8192,
"cost_per_1k": 0.003
},
"powerful": {
"model": "claude-opus-4-20250514",
"max_tokens": 16384,
"cost_per_1k": 0.015
}
}
def select_model(self, task_type: str, context: dict = None) -> dict:
"""根据任务类型选择模型"""
complexity = self.COMPLEXITY_SCORES.get(task_type, 5)
if complexity <= 3:
return self.MODEL_CONFIGS["fast"]
elif complexity <= 6:
return self.MODEL_CONFIGS["balanced"]
else:
return self.MODEL_CONFIGS["powerful"]
缓存策略
import hashlib
from typing import Optional
class SemanticCache:
"""语义缓存实现"""
def __init__(self, redis_client, similarity_threshold=0.95):
self.redis = redis_client
self.threshold = similarity_threshold
def _compute_hash(self, prompt: str, context: str = "") -> str:
"""计算缓存键"""
content = f"{prompt}:{context}"
return hashlib.sha256(content.encode()).hexdigest()
async def get(self, prompt: str, context: str = "") -> Optional[str]:
"""获取缓存"""
cache_key = self._compute_hash(prompt, context)
cached = self.redis.get(f"agent:cache:{cache_key}")
if cached:
return cached.decode()
return None
async def set(self, prompt: str, response: str,
context: str = "", ttl: int = 3600):
"""设置缓存"""
cache_key = self._compute_hash(prompt, context)
self.redis.setex(
f"agent:cache:{cache_key}",
ttl,
response
)
async def invalidate_pattern(self, pattern: str):
"""按模式失效缓存"""
keys = self.redis.keys(f"agent:cache:*{pattern}*")
if keys:
self.redis.delete(*keys)
限流与配额管理
class RateLimiter:
"""令牌桶限流器"""
def __init__(self, redis_client):
self.redis = redis_client
async def check_rate_limit(self, key: str, max_requests: int,
window_seconds: int) -> bool:
"""检查是否超过限流"""
current = self.redis.get(f"ratelimit:{key}")
if current and int(current) >= max_requests:
return False
pipe = self.redis.pipeline()
pipe.incr(f"ratelimit:{key}")
pipe.expire(f"ratelimit:{key}", window_seconds)
pipe.execute()
return True
class QuotaManager:
"""配额管理器"""
def __init__(self, redis_client):
self.redis = redis_client
self.quota_limits = {
"free": {"daily_requests": 50, "monthly_tokens": 100000},
"pro": {"daily_requests": 500, "monthly_tokens": 5000000},
"enterprise": {"daily_requests": -1, "monthly_tokens": -1}
}
async def check_quota(self, user_id: str, tier: str) -> dict:
"""检查用户配额"""
limits = self.quota_limits.get(tier, self.quota_limits["free"])
daily_key = f"quota:{user_id}:daily:{datetime.utcnow().strftime('%Y%m%d')}"
monthly_key = f"quota:{user_id}:monthly:{datetime.utcnow().strftime('%Y%m')}"
daily_count = int(self.redis.get(daily_key) or 0)
monthly_tokens = int(self.redis.get(monthly_key) or 0)
return {
"allowed": (
(limits["daily_requests"] == -1 or daily_count < limits["daily_requests"]) and
(limits["monthly_tokens"] == -1 or monthly_tokens < limits["monthly_tokens"])
),
"daily_remaining": limits["daily_requests"] - daily_count if limits["daily_requests"] > 0 else -1,
"monthly_tokens_remaining": limits["monthly_tokens"] - monthly_tokens if limits["monthly_tokens"] > 0 else -1
}
测试策略
单元测试
import pytest
from unittest.mock import AsyncMock, patch
class TestAgentToolUse:
@pytest.fixture
def agent(self):
return Agent(tools=[MockSearchTool(), MockCalculatorTool()])
@pytest.mark.asyncio
async def test_tool_selection(self, agent):
"""测试工具选择逻辑"""
response = await agent.run("搜索关于 Python 的信息")
assert "search" in response.tool_calls[0].name
@pytest.mark.asyncio
async def test_error_recovery(self, agent):
"""测试错误恢复"""
with patch.object(MockSearchTool, 'run', side_effect=Exception("API Error")):
response = await agent.run("搜索信息")
assert response.has_fallback
@pytest.mark.asyncio
async def test_multi_step_reasoning(self, agent):
"""测试多步推理"""
response = await agent.run("计算 2^10 并搜索其数学性质")
assert len(response.tool_calls) >= 2
集成测试
@pytest.mark.integration
class TestAgentIntegration:
@pytest.mark.asyncio
async def test_full_conversation_flow(self):
"""测试完整对话流程"""
session_id = str(uuid.uuid4())
agent = ProductionAgent(session_id=session_id)
# 第一轮对话
response1 = await agent.chat("我的名字是张三")
assert "张三" in response1 or "记住" in response1
# 第二轮对话(测试记忆)
response2 = await agent.chat("我叫什么名字?")
assert "张三" in response2
性能测试
@pytest.mark.performance
class TestAgentPerformance:
@pytest.mark.asyncio
async def test_concurrent_sessions(self):
"""测试并发会话处理"""
async def run_session(session_id):
agent = ProductionAgent(session_id=session_id)
return await agent.chat("测试消息")
# 模拟 100 个并发会话
tasks = [run_session(f"session-{i}") for i in range(100)]
results = await asyncio.gather(*tasks, return_exceptions=True)
errors = [r for r in results if isinstance(r, Exception)]
assert len(errors) < 5 # 错误率低于 5%
部署架构
Kubernetes 部署
agent-deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
name: ai-agent
spec:
replicas: 3
selector:
matchLabels:
app: ai-agent
template:
metadata:
labels:
app: ai-agent
spec:
containers:
- name: agent
image: ai-agent:latest
resources:
requests:
memory: "512Mi"
cpu: "500m"
limits:
memory: "2Gi"
cpu: "2"
env:
- name: ANTHROPIC_API_KEY
valueFrom:
secretKeyRef:
name: llm-secrets
key: anthropic-key
- name: REDIS_URL
value: "redis://redis-master:6379"
livenessProbe:
httpGet:
path: /health
port: 8080
initialDelaySeconds: 30
periodSeconds: 10
readinessProbe:
httpGet:
path: /ready
port: 8080
initialDelaySeconds: 5
periodSeconds: 5
---
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
name: ai-agent-hpa
spec:
scaleTargetRef:
apiVersion: apps/v1
kind: Deployment
name: ai-agent
minReplicas: 3
maxReplicas: 20
metrics:
- type: Resource
resource:
name: cpu
target:
type: Utilization
averageUtilization: 70
总结
生产级 Agent 架构的核心要素:
将这些模式组合使用,可以构建出可靠、高效、可扩展的生产级 AI Agent 系统。