Add PostgreSQL session persistence + multi-session support

- New: db_sessions.py with asyncpg pool, session CRUD (upsert)
- New: POST/GET/DELETE /api/sessions endpoints
- Refactored api.py: _sessions dict replaces _active_runtime singleton
- WS accepts ?session= param, sends session_info on connect
- Runtime: added session_id, to_state(), restore_state()
- Auto-save to Postgres after each message (WS + REST)
- Added asyncpg to requirements.txt
- PostgreSQL 16 on VPS, tenant DBs: assay_dev, assay_loop42

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Nico 2026-03-31 20:53:31 +02:00
parent 432e3d8d55
commit e205e99da0
4 changed files with 247 additions and 17 deletions

View File

@ -14,10 +14,15 @@ import httpx
from .auth import AUTH_ENABLED, ZITADEL_ISSUER, _validate_token, require_auth
from .runtime import Runtime, TRACE_FILE
from . import db_sessions
log = logging.getLogger("runtime")
# Active runtime reference (set by WS endpoint)
# Session map: session_id -> Runtime (in-memory active sessions)
_sessions: dict[str, Runtime] = {}
MAX_ACTIVE_SESSIONS = 50
# Legacy: for backward compat with single-session MCP/test endpoints
_active_runtime: Runtime | None = None
# SSE subscribers
@ -83,12 +88,13 @@ def _broadcast_sse(event: dict):
_pipeline_result["event"] = evt
def _state_hash() -> str:
if not _active_runtime:
def _state_hash(rt: Runtime = None) -> str:
r = rt or _active_runtime
if not r:
return "no_session"
raw = json.dumps({
"mem": _active_runtime.memorizer.state,
"hlen": len(_active_runtime.history),
"mem": r.memorizer.state,
"hlen": len(r.history),
}, sort_keys=True)
return hashlib.md5(raw.encode()).hexdigest()[:12]
@ -96,10 +102,44 @@ def _state_hash() -> str:
def register_routes(app):
"""Register all API routes on the FastAPI app."""
@app.on_event("startup")
async def startup():
await db_sessions.init_pool()
@app.on_event("shutdown")
async def shutdown():
await db_sessions.close_pool()
@app.get("/health")
async def health():
return {"status": "ok"}
# --- Session CRUD ---
@app.post("/api/sessions")
async def create_session(body: dict = None, user=Depends(require_auth)):
"""Create a new session."""
user_id = user.get("sub", "anonymous")
graph = (body or {}).get("graph", "v4-eras")
session_id = await db_sessions.create_session(user_id, graph)
return {"session_id": session_id}
@app.get("/api/sessions")
async def list_sessions(user=Depends(require_auth)):
"""List sessions for current user."""
user_id = user.get("sub", "anonymous")
sessions = await db_sessions.list_sessions(user_id)
return {"sessions": sessions}
@app.delete("/api/sessions/{session_id}")
async def delete_session(session_id: str, user=Depends(require_auth)):
"""Delete a session."""
if session_id in _sessions:
rt = _sessions.pop(session_id)
rt.sensor.stop()
await db_sessions.delete_session(session_id)
return {"status": "deleted"}
@app.get("/auth/config")
async def auth_config():
from .auth import ZITADEL_ISSUER, ZITADEL_CLIENT_ID, ZITADEL_PROJECT_ID, AUTH_ENABLED
@ -110,18 +150,62 @@ def register_routes(app):
"projectId": ZITADEL_PROJECT_ID,
}
async def _get_or_create_session(session_id: str = None, user_claims=None, origin="") -> Runtime:
"""Get existing session or create new one."""
global _active_runtime
# Reuse in-memory session
if session_id and session_id in _sessions:
rt = _sessions[session_id]
_active_runtime = rt
return rt
# Try loading from DB
if session_id:
saved = await db_sessions.load_session(session_id)
if saved:
rt = Runtime(user_claims=user_claims, origin=origin,
broadcast=_broadcast_sse,
graph_name=saved["graph_name"],
session_id=session_id)
rt.restore_state(saved)
_sessions[session_id] = rt
_active_runtime = rt
log.info(f"[api] restored session {session_id} from DB")
return rt
# Create new session
user_id = (user_claims or {}).get("sub", "anonymous")
new_id = await db_sessions.create_session(user_id)
rt = Runtime(user_claims=user_claims, origin=origin,
broadcast=_broadcast_sse, session_id=new_id)
_sessions[new_id] = rt
_active_runtime = rt
log.info(f"[api] created new session {new_id}")
return rt
async def _save_session(rt: Runtime):
"""Persist session state to DB (upsert)."""
state = rt.to_state()
await db_sessions.save_session(
rt.session_id, state["history"],
state["memorizer_state"], state["ui_state"],
user_id=rt.identity, graph_name=rt.graph.get("name", "v4-eras"))
def _ensure_runtime(user_claims=None, origin=""):
"""Get or create the persistent runtime."""
"""Legacy: get or create singleton runtime (backward compat for MCP/tests)."""
global _active_runtime
if _active_runtime is None:
_active_runtime = Runtime(user_claims=user_claims, origin=origin,
broadcast=_broadcast_sse)
log.info("[api] created persistent runtime")
_sessions[_active_runtime.session_id] = _active_runtime
log.info("[api] created persistent runtime (legacy)")
return _active_runtime
@app.websocket("/ws")
async def ws_endpoint(ws: WebSocket, token: str | None = Query(None),
access_token: str | None = Query(None)):
access_token: str | None = Query(None),
session: str | None = Query(None)):
user_claims = {"sub": "anonymous"}
if AUTH_ENABLED and token:
try:
@ -142,23 +226,30 @@ def register_routes(app):
origin = ws.headers.get("origin", ws.headers.get("host", ""))
await ws.accept()
# Get or create runtime, attach WS
runtime = _ensure_runtime(user_claims=user_claims, origin=origin)
# Get or create session, attach WS
runtime = await _get_or_create_session(
session_id=session, user_claims=user_claims, origin=origin)
runtime.update_identity(user_claims, origin)
runtime.attach_ws(ws)
# Tell client which session they're on
try:
await ws.send_text(json.dumps({
"type": "session_info", "session_id": runtime.session_id,
"graph": runtime.graph.get("name", "unknown")}))
except Exception:
pass
try:
while True:
data = await ws.receive_text()
msg = json.loads(data)
# Always use current runtime (may change after graph switch)
rt = _active_runtime or runtime
rt = _sessions.get(runtime.session_id, runtime)
try:
if msg.get("type") == "action":
action = msg.get("action", "unknown")
data_payload = msg.get("data")
if hasattr(rt, 'use_frames') and rt.use_frames:
# Frame engine handles actions as ACTION: prefix messages
action_text = f"ACTION:{action}"
if data_payload:
action_text += f"|data:{json.dumps(data_payload)}"
@ -169,6 +260,8 @@ def register_routes(app):
rt.process_manager.cancel(msg.get("pid", 0))
else:
await rt.handle_message(msg.get("text", ""), dashboard=msg.get("dashboard"))
# Auto-save after each message
asyncio.create_task(_save_session(rt))
except Exception as e:
import traceback
log.error(f"[ws] handler error: {e}\n{traceback.format_exc()}")
@ -177,9 +270,8 @@ def register_routes(app):
except Exception:
pass
except WebSocketDisconnect:
if _active_runtime:
_active_runtime.detach_ws()
log.info("[api] WS disconnected — runtime stays alive")
runtime.detach_ws()
log.info(f"[api] WS disconnected — session {runtime.session_id} stays alive")
async def _auth_debug_ws(ws: WebSocket, token: str | None) -> bool:
"""Validate token for debug WS. Returns True if auth OK."""
@ -313,6 +405,8 @@ def register_routes(app):
"response": response,
"memorizer": runtime.memorizer.state,
}
# Persist session after message
await _save_session(runtime)
except Exception as e:
import traceback
log.error(f"[api] pipeline error: {e}\n{traceback.format_exc()}")

108
agent/db_sessions.py Normal file
View File

@ -0,0 +1,108 @@
"""PostgreSQL session storage for assay runtime."""
import json
import logging
import os
from uuid import uuid4
import asyncpg
log = logging.getLogger("db_sessions")
_pool: asyncpg.Pool | None = None
async def init_pool(dsn: str = None):
"""Create connection pool. Call once on startup."""
global _pool
if _pool is not None:
return
dsn = dsn or os.environ.get("POSTGRES_DSN", "")
if not dsn:
log.warning("POSTGRES_DSN not set — session persistence disabled")
return
_pool = await asyncpg.create_pool(dsn, min_size=2, max_size=10)
log.info("PostgreSQL pool ready")
async def close_pool():
global _pool
if _pool:
await _pool.close()
_pool = None
async def create_session(user_id: str, graph_name: str = "v4-eras") -> str:
"""Create a new session, return session_id."""
session_id = str(uuid4())
if not _pool:
return session_id # no persistence, still works in-memory
await _pool.execute(
"""INSERT INTO sessions (id, user_id, graph_name)
VALUES ($1, $2, $3)""",
session_id, user_id, graph_name,
)
return session_id
async def load_session(session_id: str) -> dict | None:
"""Load session state from DB. Returns None if not found."""
if not _pool:
return None
row = await _pool.fetchrow(
"""SELECT user_id, graph_name, memorizer_state, history, ui_state
FROM sessions WHERE id = $1""",
session_id,
)
if not row:
return None
return {
"session_id": session_id,
"user_id": row["user_id"],
"graph_name": row["graph_name"],
"memorizer_state": json.loads(row["memorizer_state"]),
"history": json.loads(row["history"]),
"ui_state": json.loads(row["ui_state"]),
}
async def save_session(session_id: str, history: list, memorizer_state: dict, ui_state: dict,
user_id: str = "unknown", graph_name: str = "v4-eras"):
"""Persist session state to DB (upsert)."""
if not _pool:
return
await _pool.execute(
"""INSERT INTO sessions (id, user_id, graph_name, history, memorizer_state, ui_state)
VALUES ($1, $2, $3, $4::jsonb, $5::jsonb, $6::jsonb)
ON CONFLICT (id) DO UPDATE SET
history = EXCLUDED.history,
memorizer_state = EXCLUDED.memorizer_state,
ui_state = EXCLUDED.ui_state,
updated_at = now(),
last_activity = now()""",
session_id, user_id, graph_name,
json.dumps(history, ensure_ascii=False),
json.dumps(memorizer_state, ensure_ascii=False),
json.dumps(ui_state, ensure_ascii=False),
)
async def list_sessions(user_id: str) -> list[dict]:
"""List sessions for a user."""
if not _pool:
return []
rows = await _pool.fetch(
"""SELECT id, graph_name, last_activity
FROM sessions WHERE user_id = $1
ORDER BY last_activity DESC LIMIT 50""",
user_id,
)
return [{"id": r["id"], "graph_name": r["graph_name"],
"last_activity": r["last_activity"].isoformat()} for r in rows]
async def delete_session(session_id: str):
"""Delete a session."""
if not _pool:
return
await _pool.execute("DELETE FROM sessions WHERE id = $1", session_id)

View File

@ -6,6 +6,7 @@ import logging
import time
from pathlib import Path
from typing import Callable
from uuid import uuid4
from .types import Envelope, Command, InputAnalysis, ThoughtResult, DirectorPlan
from .process import ProcessManager
@ -81,7 +82,9 @@ class OutputSink:
class Runtime:
def __init__(self, user_claims: dict = None, origin: str = "",
broadcast: Callable = None, graph_name: str = None):
broadcast: Callable = None, graph_name: str = None,
session_id: str = None):
self.session_id = session_id or str(uuid4())
self.sink = OutputSink()
self.history: list[dict] = []
self.MAX_HISTORY = 40
@ -410,3 +413,27 @@ class Runtime:
if len(self.history) > self.MAX_HISTORY:
self.history = self.history[-self.MAX_HISTORY:]
def to_state(self) -> dict:
"""Serialize session state for DB storage."""
return {
"history": self.history,
"memorizer_state": self.memorizer.state,
"ui_state": {
"state": self.ui_node.state,
"bindings": self.ui_node.bindings,
"machines": self.ui_node.machines,
},
}
def restore_state(self, state: dict):
"""Restore session state from DB."""
self.history = state.get("history", [])
memo = state.get("memorizer_state")
if memo:
self.memorizer.state = memo
ui = state.get("ui_state", {})
if ui:
self.ui_node.state = ui.get("state", {})
self.ui_node.bindings = ui.get("bindings", {})
self.ui_node.machines = ui.get("machines", {})

View File

@ -8,3 +8,4 @@ pydantic==2.12.5
PyJWT[crypto]==2.10.1
pymysql==1.1.1
mcp[sse]==1.9.3
asyncpg==0.30.0