Add PostgreSQL session persistence + multi-session support
- New: db_sessions.py with asyncpg pool, session CRUD (upsert) - New: POST/GET/DELETE /api/sessions endpoints - Refactored api.py: _sessions dict replaces _active_runtime singleton - WS accepts ?session= param, sends session_info on connect - Runtime: added session_id, to_state(), restore_state() - Auto-save to Postgres after each message (WS + REST) - Added asyncpg to requirements.txt - PostgreSQL 16 on VPS, tenant DBs: assay_dev, assay_loop42 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
432e3d8d55
commit
e205e99da0
126
agent/api.py
126
agent/api.py
@ -14,10 +14,15 @@ import httpx
|
||||
|
||||
from .auth import AUTH_ENABLED, ZITADEL_ISSUER, _validate_token, require_auth
|
||||
from .runtime import Runtime, TRACE_FILE
|
||||
from . import db_sessions
|
||||
|
||||
log = logging.getLogger("runtime")
|
||||
|
||||
# Active runtime reference (set by WS endpoint)
|
||||
# Session map: session_id -> Runtime (in-memory active sessions)
|
||||
_sessions: dict[str, Runtime] = {}
|
||||
MAX_ACTIVE_SESSIONS = 50
|
||||
|
||||
# Legacy: for backward compat with single-session MCP/test endpoints
|
||||
_active_runtime: Runtime | None = None
|
||||
|
||||
# SSE subscribers
|
||||
@ -83,12 +88,13 @@ def _broadcast_sse(event: dict):
|
||||
_pipeline_result["event"] = evt
|
||||
|
||||
|
||||
def _state_hash() -> str:
|
||||
if not _active_runtime:
|
||||
def _state_hash(rt: Runtime = None) -> str:
|
||||
r = rt or _active_runtime
|
||||
if not r:
|
||||
return "no_session"
|
||||
raw = json.dumps({
|
||||
"mem": _active_runtime.memorizer.state,
|
||||
"hlen": len(_active_runtime.history),
|
||||
"mem": r.memorizer.state,
|
||||
"hlen": len(r.history),
|
||||
}, sort_keys=True)
|
||||
return hashlib.md5(raw.encode()).hexdigest()[:12]
|
||||
|
||||
@ -96,10 +102,44 @@ def _state_hash() -> str:
|
||||
def register_routes(app):
|
||||
"""Register all API routes on the FastAPI app."""
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup():
|
||||
await db_sessions.init_pool()
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown():
|
||||
await db_sessions.close_pool()
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
# --- Session CRUD ---
|
||||
|
||||
@app.post("/api/sessions")
|
||||
async def create_session(body: dict = None, user=Depends(require_auth)):
|
||||
"""Create a new session."""
|
||||
user_id = user.get("sub", "anonymous")
|
||||
graph = (body or {}).get("graph", "v4-eras")
|
||||
session_id = await db_sessions.create_session(user_id, graph)
|
||||
return {"session_id": session_id}
|
||||
|
||||
@app.get("/api/sessions")
|
||||
async def list_sessions(user=Depends(require_auth)):
|
||||
"""List sessions for current user."""
|
||||
user_id = user.get("sub", "anonymous")
|
||||
sessions = await db_sessions.list_sessions(user_id)
|
||||
return {"sessions": sessions}
|
||||
|
||||
@app.delete("/api/sessions/{session_id}")
|
||||
async def delete_session(session_id: str, user=Depends(require_auth)):
|
||||
"""Delete a session."""
|
||||
if session_id in _sessions:
|
||||
rt = _sessions.pop(session_id)
|
||||
rt.sensor.stop()
|
||||
await db_sessions.delete_session(session_id)
|
||||
return {"status": "deleted"}
|
||||
|
||||
@app.get("/auth/config")
|
||||
async def auth_config():
|
||||
from .auth import ZITADEL_ISSUER, ZITADEL_CLIENT_ID, ZITADEL_PROJECT_ID, AUTH_ENABLED
|
||||
@ -110,18 +150,62 @@ def register_routes(app):
|
||||
"projectId": ZITADEL_PROJECT_ID,
|
||||
}
|
||||
|
||||
async def _get_or_create_session(session_id: str = None, user_claims=None, origin="") -> Runtime:
|
||||
"""Get existing session or create new one."""
|
||||
global _active_runtime
|
||||
|
||||
# Reuse in-memory session
|
||||
if session_id and session_id in _sessions:
|
||||
rt = _sessions[session_id]
|
||||
_active_runtime = rt
|
||||
return rt
|
||||
|
||||
# Try loading from DB
|
||||
if session_id:
|
||||
saved = await db_sessions.load_session(session_id)
|
||||
if saved:
|
||||
rt = Runtime(user_claims=user_claims, origin=origin,
|
||||
broadcast=_broadcast_sse,
|
||||
graph_name=saved["graph_name"],
|
||||
session_id=session_id)
|
||||
rt.restore_state(saved)
|
||||
_sessions[session_id] = rt
|
||||
_active_runtime = rt
|
||||
log.info(f"[api] restored session {session_id} from DB")
|
||||
return rt
|
||||
|
||||
# Create new session
|
||||
user_id = (user_claims or {}).get("sub", "anonymous")
|
||||
new_id = await db_sessions.create_session(user_id)
|
||||
rt = Runtime(user_claims=user_claims, origin=origin,
|
||||
broadcast=_broadcast_sse, session_id=new_id)
|
||||
_sessions[new_id] = rt
|
||||
_active_runtime = rt
|
||||
log.info(f"[api] created new session {new_id}")
|
||||
return rt
|
||||
|
||||
async def _save_session(rt: Runtime):
|
||||
"""Persist session state to DB (upsert)."""
|
||||
state = rt.to_state()
|
||||
await db_sessions.save_session(
|
||||
rt.session_id, state["history"],
|
||||
state["memorizer_state"], state["ui_state"],
|
||||
user_id=rt.identity, graph_name=rt.graph.get("name", "v4-eras"))
|
||||
|
||||
def _ensure_runtime(user_claims=None, origin=""):
|
||||
"""Get or create the persistent runtime."""
|
||||
"""Legacy: get or create singleton runtime (backward compat for MCP/tests)."""
|
||||
global _active_runtime
|
||||
if _active_runtime is None:
|
||||
_active_runtime = Runtime(user_claims=user_claims, origin=origin,
|
||||
broadcast=_broadcast_sse)
|
||||
log.info("[api] created persistent runtime")
|
||||
_sessions[_active_runtime.session_id] = _active_runtime
|
||||
log.info("[api] created persistent runtime (legacy)")
|
||||
return _active_runtime
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def ws_endpoint(ws: WebSocket, token: str | None = Query(None),
|
||||
access_token: str | None = Query(None)):
|
||||
access_token: str | None = Query(None),
|
||||
session: str | None = Query(None)):
|
||||
user_claims = {"sub": "anonymous"}
|
||||
if AUTH_ENABLED and token:
|
||||
try:
|
||||
@ -142,23 +226,30 @@ def register_routes(app):
|
||||
origin = ws.headers.get("origin", ws.headers.get("host", ""))
|
||||
await ws.accept()
|
||||
|
||||
# Get or create runtime, attach WS
|
||||
runtime = _ensure_runtime(user_claims=user_claims, origin=origin)
|
||||
# Get or create session, attach WS
|
||||
runtime = await _get_or_create_session(
|
||||
session_id=session, user_claims=user_claims, origin=origin)
|
||||
runtime.update_identity(user_claims, origin)
|
||||
runtime.attach_ws(ws)
|
||||
|
||||
# Tell client which session they're on
|
||||
try:
|
||||
await ws.send_text(json.dumps({
|
||||
"type": "session_info", "session_id": runtime.session_id,
|
||||
"graph": runtime.graph.get("name", "unknown")}))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
while True:
|
||||
data = await ws.receive_text()
|
||||
msg = json.loads(data)
|
||||
# Always use current runtime (may change after graph switch)
|
||||
rt = _active_runtime or runtime
|
||||
rt = _sessions.get(runtime.session_id, runtime)
|
||||
try:
|
||||
if msg.get("type") == "action":
|
||||
action = msg.get("action", "unknown")
|
||||
data_payload = msg.get("data")
|
||||
if hasattr(rt, 'use_frames') and rt.use_frames:
|
||||
# Frame engine handles actions as ACTION: prefix messages
|
||||
action_text = f"ACTION:{action}"
|
||||
if data_payload:
|
||||
action_text += f"|data:{json.dumps(data_payload)}"
|
||||
@ -169,6 +260,8 @@ def register_routes(app):
|
||||
rt.process_manager.cancel(msg.get("pid", 0))
|
||||
else:
|
||||
await rt.handle_message(msg.get("text", ""), dashboard=msg.get("dashboard"))
|
||||
# Auto-save after each message
|
||||
asyncio.create_task(_save_session(rt))
|
||||
except Exception as e:
|
||||
import traceback
|
||||
log.error(f"[ws] handler error: {e}\n{traceback.format_exc()}")
|
||||
@ -177,9 +270,8 @@ def register_routes(app):
|
||||
except Exception:
|
||||
pass
|
||||
except WebSocketDisconnect:
|
||||
if _active_runtime:
|
||||
_active_runtime.detach_ws()
|
||||
log.info("[api] WS disconnected — runtime stays alive")
|
||||
runtime.detach_ws()
|
||||
log.info(f"[api] WS disconnected — session {runtime.session_id} stays alive")
|
||||
|
||||
async def _auth_debug_ws(ws: WebSocket, token: str | None) -> bool:
|
||||
"""Validate token for debug WS. Returns True if auth OK."""
|
||||
@ -313,6 +405,8 @@ def register_routes(app):
|
||||
"response": response,
|
||||
"memorizer": runtime.memorizer.state,
|
||||
}
|
||||
# Persist session after message
|
||||
await _save_session(runtime)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
log.error(f"[api] pipeline error: {e}\n{traceback.format_exc()}")
|
||||
|
||||
108
agent/db_sessions.py
Normal file
108
agent/db_sessions.py
Normal file
@ -0,0 +1,108 @@
|
||||
"""PostgreSQL session storage for assay runtime."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from uuid import uuid4
|
||||
|
||||
import asyncpg
|
||||
|
||||
log = logging.getLogger("db_sessions")
|
||||
|
||||
_pool: asyncpg.Pool | None = None
|
||||
|
||||
|
||||
async def init_pool(dsn: str = None):
|
||||
"""Create connection pool. Call once on startup."""
|
||||
global _pool
|
||||
if _pool is not None:
|
||||
return
|
||||
dsn = dsn or os.environ.get("POSTGRES_DSN", "")
|
||||
if not dsn:
|
||||
log.warning("POSTGRES_DSN not set — session persistence disabled")
|
||||
return
|
||||
_pool = await asyncpg.create_pool(dsn, min_size=2, max_size=10)
|
||||
log.info("PostgreSQL pool ready")
|
||||
|
||||
|
||||
async def close_pool():
|
||||
global _pool
|
||||
if _pool:
|
||||
await _pool.close()
|
||||
_pool = None
|
||||
|
||||
|
||||
async def create_session(user_id: str, graph_name: str = "v4-eras") -> str:
|
||||
"""Create a new session, return session_id."""
|
||||
session_id = str(uuid4())
|
||||
if not _pool:
|
||||
return session_id # no persistence, still works in-memory
|
||||
await _pool.execute(
|
||||
"""INSERT INTO sessions (id, user_id, graph_name)
|
||||
VALUES ($1, $2, $3)""",
|
||||
session_id, user_id, graph_name,
|
||||
)
|
||||
return session_id
|
||||
|
||||
|
||||
async def load_session(session_id: str) -> dict | None:
|
||||
"""Load session state from DB. Returns None if not found."""
|
||||
if not _pool:
|
||||
return None
|
||||
row = await _pool.fetchrow(
|
||||
"""SELECT user_id, graph_name, memorizer_state, history, ui_state
|
||||
FROM sessions WHERE id = $1""",
|
||||
session_id,
|
||||
)
|
||||
if not row:
|
||||
return None
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"user_id": row["user_id"],
|
||||
"graph_name": row["graph_name"],
|
||||
"memorizer_state": json.loads(row["memorizer_state"]),
|
||||
"history": json.loads(row["history"]),
|
||||
"ui_state": json.loads(row["ui_state"]),
|
||||
}
|
||||
|
||||
|
||||
async def save_session(session_id: str, history: list, memorizer_state: dict, ui_state: dict,
|
||||
user_id: str = "unknown", graph_name: str = "v4-eras"):
|
||||
"""Persist session state to DB (upsert)."""
|
||||
if not _pool:
|
||||
return
|
||||
await _pool.execute(
|
||||
"""INSERT INTO sessions (id, user_id, graph_name, history, memorizer_state, ui_state)
|
||||
VALUES ($1, $2, $3, $4::jsonb, $5::jsonb, $6::jsonb)
|
||||
ON CONFLICT (id) DO UPDATE SET
|
||||
history = EXCLUDED.history,
|
||||
memorizer_state = EXCLUDED.memorizer_state,
|
||||
ui_state = EXCLUDED.ui_state,
|
||||
updated_at = now(),
|
||||
last_activity = now()""",
|
||||
session_id, user_id, graph_name,
|
||||
json.dumps(history, ensure_ascii=False),
|
||||
json.dumps(memorizer_state, ensure_ascii=False),
|
||||
json.dumps(ui_state, ensure_ascii=False),
|
||||
)
|
||||
|
||||
|
||||
async def list_sessions(user_id: str) -> list[dict]:
|
||||
"""List sessions for a user."""
|
||||
if not _pool:
|
||||
return []
|
||||
rows = await _pool.fetch(
|
||||
"""SELECT id, graph_name, last_activity
|
||||
FROM sessions WHERE user_id = $1
|
||||
ORDER BY last_activity DESC LIMIT 50""",
|
||||
user_id,
|
||||
)
|
||||
return [{"id": r["id"], "graph_name": r["graph_name"],
|
||||
"last_activity": r["last_activity"].isoformat()} for r in rows]
|
||||
|
||||
|
||||
async def delete_session(session_id: str):
|
||||
"""Delete a session."""
|
||||
if not _pool:
|
||||
return
|
||||
await _pool.execute("DELETE FROM sessions WHERE id = $1", session_id)
|
||||
@ -6,6 +6,7 @@ import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
from uuid import uuid4
|
||||
|
||||
from .types import Envelope, Command, InputAnalysis, ThoughtResult, DirectorPlan
|
||||
from .process import ProcessManager
|
||||
@ -81,7 +82,9 @@ class OutputSink:
|
||||
|
||||
class Runtime:
|
||||
def __init__(self, user_claims: dict = None, origin: str = "",
|
||||
broadcast: Callable = None, graph_name: str = None):
|
||||
broadcast: Callable = None, graph_name: str = None,
|
||||
session_id: str = None):
|
||||
self.session_id = session_id or str(uuid4())
|
||||
self.sink = OutputSink()
|
||||
self.history: list[dict] = []
|
||||
self.MAX_HISTORY = 40
|
||||
@ -410,3 +413,27 @@ class Runtime:
|
||||
|
||||
if len(self.history) > self.MAX_HISTORY:
|
||||
self.history = self.history[-self.MAX_HISTORY:]
|
||||
|
||||
def to_state(self) -> dict:
|
||||
"""Serialize session state for DB storage."""
|
||||
return {
|
||||
"history": self.history,
|
||||
"memorizer_state": self.memorizer.state,
|
||||
"ui_state": {
|
||||
"state": self.ui_node.state,
|
||||
"bindings": self.ui_node.bindings,
|
||||
"machines": self.ui_node.machines,
|
||||
},
|
||||
}
|
||||
|
||||
def restore_state(self, state: dict):
|
||||
"""Restore session state from DB."""
|
||||
self.history = state.get("history", [])
|
||||
memo = state.get("memorizer_state")
|
||||
if memo:
|
||||
self.memorizer.state = memo
|
||||
ui = state.get("ui_state", {})
|
||||
if ui:
|
||||
self.ui_node.state = ui.get("state", {})
|
||||
self.ui_node.bindings = ui.get("bindings", {})
|
||||
self.ui_node.machines = ui.get("machines", {})
|
||||
|
||||
@ -8,3 +8,4 @@ pydantic==2.12.5
|
||||
PyJWT[crypto]==2.10.1
|
||||
pymysql==1.1.1
|
||||
mcp[sse]==1.9.3
|
||||
asyncpg==0.30.0
|
||||
|
||||
Reference in New Issue
Block a user