Add PostgreSQL session persistence + multi-session support

- New: db_sessions.py with asyncpg pool, session CRUD (upsert)
- New: POST/GET/DELETE /api/sessions endpoints
- Refactored api.py: _sessions dict replaces _active_runtime singleton
- WS accepts ?session= param, sends session_info on connect
- Runtime: added session_id, to_state(), restore_state()
- Auto-save to Postgres after each message (WS + REST)
- Added asyncpg to requirements.txt
- PostgreSQL 16 on VPS, tenant DBs: assay_dev, assay_loop42

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Nico 2026-03-31 20:53:31 +02:00
parent 432e3d8d55
commit e205e99da0
4 changed files with 247 additions and 17 deletions

View File

@ -14,10 +14,15 @@ import httpx
from .auth import AUTH_ENABLED, ZITADEL_ISSUER, _validate_token, require_auth from .auth import AUTH_ENABLED, ZITADEL_ISSUER, _validate_token, require_auth
from .runtime import Runtime, TRACE_FILE from .runtime import Runtime, TRACE_FILE
from . import db_sessions
log = logging.getLogger("runtime") log = logging.getLogger("runtime")
# Active runtime reference (set by WS endpoint) # Session map: session_id -> Runtime (in-memory active sessions)
_sessions: dict[str, Runtime] = {}
MAX_ACTIVE_SESSIONS = 50
# Legacy: for backward compat with single-session MCP/test endpoints
_active_runtime: Runtime | None = None _active_runtime: Runtime | None = None
# SSE subscribers # SSE subscribers
@ -83,12 +88,13 @@ def _broadcast_sse(event: dict):
_pipeline_result["event"] = evt _pipeline_result["event"] = evt
def _state_hash() -> str: def _state_hash(rt: Runtime = None) -> str:
if not _active_runtime: r = rt or _active_runtime
if not r:
return "no_session" return "no_session"
raw = json.dumps({ raw = json.dumps({
"mem": _active_runtime.memorizer.state, "mem": r.memorizer.state,
"hlen": len(_active_runtime.history), "hlen": len(r.history),
}, sort_keys=True) }, sort_keys=True)
return hashlib.md5(raw.encode()).hexdigest()[:12] return hashlib.md5(raw.encode()).hexdigest()[:12]
@ -96,10 +102,44 @@ def _state_hash() -> str:
def register_routes(app): def register_routes(app):
"""Register all API routes on the FastAPI app.""" """Register all API routes on the FastAPI app."""
@app.on_event("startup")
async def startup():
await db_sessions.init_pool()
@app.on_event("shutdown")
async def shutdown():
await db_sessions.close_pool()
@app.get("/health") @app.get("/health")
async def health(): async def health():
return {"status": "ok"} return {"status": "ok"}
# --- Session CRUD ---
@app.post("/api/sessions")
async def create_session(body: dict = None, user=Depends(require_auth)):
"""Create a new session."""
user_id = user.get("sub", "anonymous")
graph = (body or {}).get("graph", "v4-eras")
session_id = await db_sessions.create_session(user_id, graph)
return {"session_id": session_id}
@app.get("/api/sessions")
async def list_sessions(user=Depends(require_auth)):
"""List sessions for current user."""
user_id = user.get("sub", "anonymous")
sessions = await db_sessions.list_sessions(user_id)
return {"sessions": sessions}
@app.delete("/api/sessions/{session_id}")
async def delete_session(session_id: str, user=Depends(require_auth)):
"""Delete a session."""
if session_id in _sessions:
rt = _sessions.pop(session_id)
rt.sensor.stop()
await db_sessions.delete_session(session_id)
return {"status": "deleted"}
@app.get("/auth/config") @app.get("/auth/config")
async def auth_config(): async def auth_config():
from .auth import ZITADEL_ISSUER, ZITADEL_CLIENT_ID, ZITADEL_PROJECT_ID, AUTH_ENABLED from .auth import ZITADEL_ISSUER, ZITADEL_CLIENT_ID, ZITADEL_PROJECT_ID, AUTH_ENABLED
@ -110,18 +150,62 @@ def register_routes(app):
"projectId": ZITADEL_PROJECT_ID, "projectId": ZITADEL_PROJECT_ID,
} }
async def _get_or_create_session(session_id: str = None, user_claims=None, origin="") -> Runtime:
"""Get existing session or create new one."""
global _active_runtime
# Reuse in-memory session
if session_id and session_id in _sessions:
rt = _sessions[session_id]
_active_runtime = rt
return rt
# Try loading from DB
if session_id:
saved = await db_sessions.load_session(session_id)
if saved:
rt = Runtime(user_claims=user_claims, origin=origin,
broadcast=_broadcast_sse,
graph_name=saved["graph_name"],
session_id=session_id)
rt.restore_state(saved)
_sessions[session_id] = rt
_active_runtime = rt
log.info(f"[api] restored session {session_id} from DB")
return rt
# Create new session
user_id = (user_claims or {}).get("sub", "anonymous")
new_id = await db_sessions.create_session(user_id)
rt = Runtime(user_claims=user_claims, origin=origin,
broadcast=_broadcast_sse, session_id=new_id)
_sessions[new_id] = rt
_active_runtime = rt
log.info(f"[api] created new session {new_id}")
return rt
async def _save_session(rt: Runtime):
"""Persist session state to DB (upsert)."""
state = rt.to_state()
await db_sessions.save_session(
rt.session_id, state["history"],
state["memorizer_state"], state["ui_state"],
user_id=rt.identity, graph_name=rt.graph.get("name", "v4-eras"))
def _ensure_runtime(user_claims=None, origin=""): def _ensure_runtime(user_claims=None, origin=""):
"""Get or create the persistent runtime.""" """Legacy: get or create singleton runtime (backward compat for MCP/tests)."""
global _active_runtime global _active_runtime
if _active_runtime is None: if _active_runtime is None:
_active_runtime = Runtime(user_claims=user_claims, origin=origin, _active_runtime = Runtime(user_claims=user_claims, origin=origin,
broadcast=_broadcast_sse) broadcast=_broadcast_sse)
log.info("[api] created persistent runtime") _sessions[_active_runtime.session_id] = _active_runtime
log.info("[api] created persistent runtime (legacy)")
return _active_runtime return _active_runtime
@app.websocket("/ws") @app.websocket("/ws")
async def ws_endpoint(ws: WebSocket, token: str | None = Query(None), async def ws_endpoint(ws: WebSocket, token: str | None = Query(None),
access_token: str | None = Query(None)): access_token: str | None = Query(None),
session: str | None = Query(None)):
user_claims = {"sub": "anonymous"} user_claims = {"sub": "anonymous"}
if AUTH_ENABLED and token: if AUTH_ENABLED and token:
try: try:
@ -142,23 +226,30 @@ def register_routes(app):
origin = ws.headers.get("origin", ws.headers.get("host", "")) origin = ws.headers.get("origin", ws.headers.get("host", ""))
await ws.accept() await ws.accept()
# Get or create runtime, attach WS # Get or create session, attach WS
runtime = _ensure_runtime(user_claims=user_claims, origin=origin) runtime = await _get_or_create_session(
session_id=session, user_claims=user_claims, origin=origin)
runtime.update_identity(user_claims, origin) runtime.update_identity(user_claims, origin)
runtime.attach_ws(ws) runtime.attach_ws(ws)
# Tell client which session they're on
try:
await ws.send_text(json.dumps({
"type": "session_info", "session_id": runtime.session_id,
"graph": runtime.graph.get("name", "unknown")}))
except Exception:
pass
try: try:
while True: while True:
data = await ws.receive_text() data = await ws.receive_text()
msg = json.loads(data) msg = json.loads(data)
# Always use current runtime (may change after graph switch) rt = _sessions.get(runtime.session_id, runtime)
rt = _active_runtime or runtime
try: try:
if msg.get("type") == "action": if msg.get("type") == "action":
action = msg.get("action", "unknown") action = msg.get("action", "unknown")
data_payload = msg.get("data") data_payload = msg.get("data")
if hasattr(rt, 'use_frames') and rt.use_frames: if hasattr(rt, 'use_frames') and rt.use_frames:
# Frame engine handles actions as ACTION: prefix messages
action_text = f"ACTION:{action}" action_text = f"ACTION:{action}"
if data_payload: if data_payload:
action_text += f"|data:{json.dumps(data_payload)}" action_text += f"|data:{json.dumps(data_payload)}"
@ -169,6 +260,8 @@ def register_routes(app):
rt.process_manager.cancel(msg.get("pid", 0)) rt.process_manager.cancel(msg.get("pid", 0))
else: else:
await rt.handle_message(msg.get("text", ""), dashboard=msg.get("dashboard")) await rt.handle_message(msg.get("text", ""), dashboard=msg.get("dashboard"))
# Auto-save after each message
asyncio.create_task(_save_session(rt))
except Exception as e: except Exception as e:
import traceback import traceback
log.error(f"[ws] handler error: {e}\n{traceback.format_exc()}") log.error(f"[ws] handler error: {e}\n{traceback.format_exc()}")
@ -177,9 +270,8 @@ def register_routes(app):
except Exception: except Exception:
pass pass
except WebSocketDisconnect: except WebSocketDisconnect:
if _active_runtime: runtime.detach_ws()
_active_runtime.detach_ws() log.info(f"[api] WS disconnected — session {runtime.session_id} stays alive")
log.info("[api] WS disconnected — runtime stays alive")
async def _auth_debug_ws(ws: WebSocket, token: str | None) -> bool: async def _auth_debug_ws(ws: WebSocket, token: str | None) -> bool:
"""Validate token for debug WS. Returns True if auth OK.""" """Validate token for debug WS. Returns True if auth OK."""
@ -313,6 +405,8 @@ def register_routes(app):
"response": response, "response": response,
"memorizer": runtime.memorizer.state, "memorizer": runtime.memorizer.state,
} }
# Persist session after message
await _save_session(runtime)
except Exception as e: except Exception as e:
import traceback import traceback
log.error(f"[api] pipeline error: {e}\n{traceback.format_exc()}") log.error(f"[api] pipeline error: {e}\n{traceback.format_exc()}")

108
agent/db_sessions.py Normal file
View File

@ -0,0 +1,108 @@
"""PostgreSQL session storage for assay runtime."""
import json
import logging
import os
from uuid import uuid4
import asyncpg
log = logging.getLogger("db_sessions")
_pool: asyncpg.Pool | None = None
async def init_pool(dsn: str = None):
"""Create connection pool. Call once on startup."""
global _pool
if _pool is not None:
return
dsn = dsn or os.environ.get("POSTGRES_DSN", "")
if not dsn:
log.warning("POSTGRES_DSN not set — session persistence disabled")
return
_pool = await asyncpg.create_pool(dsn, min_size=2, max_size=10)
log.info("PostgreSQL pool ready")
async def close_pool():
global _pool
if _pool:
await _pool.close()
_pool = None
async def create_session(user_id: str, graph_name: str = "v4-eras") -> str:
"""Create a new session, return session_id."""
session_id = str(uuid4())
if not _pool:
return session_id # no persistence, still works in-memory
await _pool.execute(
"""INSERT INTO sessions (id, user_id, graph_name)
VALUES ($1, $2, $3)""",
session_id, user_id, graph_name,
)
return session_id
async def load_session(session_id: str) -> dict | None:
"""Load session state from DB. Returns None if not found."""
if not _pool:
return None
row = await _pool.fetchrow(
"""SELECT user_id, graph_name, memorizer_state, history, ui_state
FROM sessions WHERE id = $1""",
session_id,
)
if not row:
return None
return {
"session_id": session_id,
"user_id": row["user_id"],
"graph_name": row["graph_name"],
"memorizer_state": json.loads(row["memorizer_state"]),
"history": json.loads(row["history"]),
"ui_state": json.loads(row["ui_state"]),
}
async def save_session(session_id: str, history: list, memorizer_state: dict, ui_state: dict,
user_id: str = "unknown", graph_name: str = "v4-eras"):
"""Persist session state to DB (upsert)."""
if not _pool:
return
await _pool.execute(
"""INSERT INTO sessions (id, user_id, graph_name, history, memorizer_state, ui_state)
VALUES ($1, $2, $3, $4::jsonb, $5::jsonb, $6::jsonb)
ON CONFLICT (id) DO UPDATE SET
history = EXCLUDED.history,
memorizer_state = EXCLUDED.memorizer_state,
ui_state = EXCLUDED.ui_state,
updated_at = now(),
last_activity = now()""",
session_id, user_id, graph_name,
json.dumps(history, ensure_ascii=False),
json.dumps(memorizer_state, ensure_ascii=False),
json.dumps(ui_state, ensure_ascii=False),
)
async def list_sessions(user_id: str) -> list[dict]:
"""List sessions for a user."""
if not _pool:
return []
rows = await _pool.fetch(
"""SELECT id, graph_name, last_activity
FROM sessions WHERE user_id = $1
ORDER BY last_activity DESC LIMIT 50""",
user_id,
)
return [{"id": r["id"], "graph_name": r["graph_name"],
"last_activity": r["last_activity"].isoformat()} for r in rows]
async def delete_session(session_id: str):
"""Delete a session."""
if not _pool:
return
await _pool.execute("DELETE FROM sessions WHERE id = $1", session_id)

View File

@ -6,6 +6,7 @@ import logging
import time import time
from pathlib import Path from pathlib import Path
from typing import Callable from typing import Callable
from uuid import uuid4
from .types import Envelope, Command, InputAnalysis, ThoughtResult, DirectorPlan from .types import Envelope, Command, InputAnalysis, ThoughtResult, DirectorPlan
from .process import ProcessManager from .process import ProcessManager
@ -81,7 +82,9 @@ class OutputSink:
class Runtime: class Runtime:
def __init__(self, user_claims: dict = None, origin: str = "", def __init__(self, user_claims: dict = None, origin: str = "",
broadcast: Callable = None, graph_name: str = None): broadcast: Callable = None, graph_name: str = None,
session_id: str = None):
self.session_id = session_id or str(uuid4())
self.sink = OutputSink() self.sink = OutputSink()
self.history: list[dict] = [] self.history: list[dict] = []
self.MAX_HISTORY = 40 self.MAX_HISTORY = 40
@ -410,3 +413,27 @@ class Runtime:
if len(self.history) > self.MAX_HISTORY: if len(self.history) > self.MAX_HISTORY:
self.history = self.history[-self.MAX_HISTORY:] self.history = self.history[-self.MAX_HISTORY:]
def to_state(self) -> dict:
"""Serialize session state for DB storage."""
return {
"history": self.history,
"memorizer_state": self.memorizer.state,
"ui_state": {
"state": self.ui_node.state,
"bindings": self.ui_node.bindings,
"machines": self.ui_node.machines,
},
}
def restore_state(self, state: dict):
"""Restore session state from DB."""
self.history = state.get("history", [])
memo = state.get("memorizer_state")
if memo:
self.memorizer.state = memo
ui = state.get("ui_state", {})
if ui:
self.ui_node.state = ui.get("state", {})
self.ui_node.bindings = ui.get("bindings", {})
self.ui_node.machines = ui.get("machines", {})

View File

@ -8,3 +8,4 @@ pydantic==2.12.5
PyJWT[crypto]==2.10.1 PyJWT[crypto]==2.10.1
pymysql==1.1.1 pymysql==1.1.1
mcp[sse]==1.9.3 mcp[sse]==1.9.3
asyncpg==0.30.0