Add PostgreSQL session persistence + multi-session support
- New: db_sessions.py with asyncpg pool, session CRUD (upsert) - New: POST/GET/DELETE /api/sessions endpoints - Refactored api.py: _sessions dict replaces _active_runtime singleton - WS accepts ?session= param, sends session_info on connect - Runtime: added session_id, to_state(), restore_state() - Auto-save to Postgres after each message (WS + REST) - Added asyncpg to requirements.txt - PostgreSQL 16 on VPS, tenant DBs: assay_dev, assay_loop42 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
432e3d8d55
commit
e205e99da0
126
agent/api.py
126
agent/api.py
@ -14,10 +14,15 @@ import httpx
|
|||||||
|
|
||||||
from .auth import AUTH_ENABLED, ZITADEL_ISSUER, _validate_token, require_auth
|
from .auth import AUTH_ENABLED, ZITADEL_ISSUER, _validate_token, require_auth
|
||||||
from .runtime import Runtime, TRACE_FILE
|
from .runtime import Runtime, TRACE_FILE
|
||||||
|
from . import db_sessions
|
||||||
|
|
||||||
log = logging.getLogger("runtime")
|
log = logging.getLogger("runtime")
|
||||||
|
|
||||||
# Active runtime reference (set by WS endpoint)
|
# Session map: session_id -> Runtime (in-memory active sessions)
|
||||||
|
_sessions: dict[str, Runtime] = {}
|
||||||
|
MAX_ACTIVE_SESSIONS = 50
|
||||||
|
|
||||||
|
# Legacy: for backward compat with single-session MCP/test endpoints
|
||||||
_active_runtime: Runtime | None = None
|
_active_runtime: Runtime | None = None
|
||||||
|
|
||||||
# SSE subscribers
|
# SSE subscribers
|
||||||
@ -83,12 +88,13 @@ def _broadcast_sse(event: dict):
|
|||||||
_pipeline_result["event"] = evt
|
_pipeline_result["event"] = evt
|
||||||
|
|
||||||
|
|
||||||
def _state_hash() -> str:
|
def _state_hash(rt: Runtime = None) -> str:
|
||||||
if not _active_runtime:
|
r = rt or _active_runtime
|
||||||
|
if not r:
|
||||||
return "no_session"
|
return "no_session"
|
||||||
raw = json.dumps({
|
raw = json.dumps({
|
||||||
"mem": _active_runtime.memorizer.state,
|
"mem": r.memorizer.state,
|
||||||
"hlen": len(_active_runtime.history),
|
"hlen": len(r.history),
|
||||||
}, sort_keys=True)
|
}, sort_keys=True)
|
||||||
return hashlib.md5(raw.encode()).hexdigest()[:12]
|
return hashlib.md5(raw.encode()).hexdigest()[:12]
|
||||||
|
|
||||||
@ -96,10 +102,44 @@ def _state_hash() -> str:
|
|||||||
def register_routes(app):
|
def register_routes(app):
|
||||||
"""Register all API routes on the FastAPI app."""
|
"""Register all API routes on the FastAPI app."""
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def startup():
|
||||||
|
await db_sessions.init_pool()
|
||||||
|
|
||||||
|
@app.on_event("shutdown")
|
||||||
|
async def shutdown():
|
||||||
|
await db_sessions.close_pool()
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def health():
|
async def health():
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
# --- Session CRUD ---
|
||||||
|
|
||||||
|
@app.post("/api/sessions")
|
||||||
|
async def create_session(body: dict = None, user=Depends(require_auth)):
|
||||||
|
"""Create a new session."""
|
||||||
|
user_id = user.get("sub", "anonymous")
|
||||||
|
graph = (body or {}).get("graph", "v4-eras")
|
||||||
|
session_id = await db_sessions.create_session(user_id, graph)
|
||||||
|
return {"session_id": session_id}
|
||||||
|
|
||||||
|
@app.get("/api/sessions")
|
||||||
|
async def list_sessions(user=Depends(require_auth)):
|
||||||
|
"""List sessions for current user."""
|
||||||
|
user_id = user.get("sub", "anonymous")
|
||||||
|
sessions = await db_sessions.list_sessions(user_id)
|
||||||
|
return {"sessions": sessions}
|
||||||
|
|
||||||
|
@app.delete("/api/sessions/{session_id}")
|
||||||
|
async def delete_session(session_id: str, user=Depends(require_auth)):
|
||||||
|
"""Delete a session."""
|
||||||
|
if session_id in _sessions:
|
||||||
|
rt = _sessions.pop(session_id)
|
||||||
|
rt.sensor.stop()
|
||||||
|
await db_sessions.delete_session(session_id)
|
||||||
|
return {"status": "deleted"}
|
||||||
|
|
||||||
@app.get("/auth/config")
|
@app.get("/auth/config")
|
||||||
async def auth_config():
|
async def auth_config():
|
||||||
from .auth import ZITADEL_ISSUER, ZITADEL_CLIENT_ID, ZITADEL_PROJECT_ID, AUTH_ENABLED
|
from .auth import ZITADEL_ISSUER, ZITADEL_CLIENT_ID, ZITADEL_PROJECT_ID, AUTH_ENABLED
|
||||||
@ -110,18 +150,62 @@ def register_routes(app):
|
|||||||
"projectId": ZITADEL_PROJECT_ID,
|
"projectId": ZITADEL_PROJECT_ID,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async def _get_or_create_session(session_id: str = None, user_claims=None, origin="") -> Runtime:
|
||||||
|
"""Get existing session or create new one."""
|
||||||
|
global _active_runtime
|
||||||
|
|
||||||
|
# Reuse in-memory session
|
||||||
|
if session_id and session_id in _sessions:
|
||||||
|
rt = _sessions[session_id]
|
||||||
|
_active_runtime = rt
|
||||||
|
return rt
|
||||||
|
|
||||||
|
# Try loading from DB
|
||||||
|
if session_id:
|
||||||
|
saved = await db_sessions.load_session(session_id)
|
||||||
|
if saved:
|
||||||
|
rt = Runtime(user_claims=user_claims, origin=origin,
|
||||||
|
broadcast=_broadcast_sse,
|
||||||
|
graph_name=saved["graph_name"],
|
||||||
|
session_id=session_id)
|
||||||
|
rt.restore_state(saved)
|
||||||
|
_sessions[session_id] = rt
|
||||||
|
_active_runtime = rt
|
||||||
|
log.info(f"[api] restored session {session_id} from DB")
|
||||||
|
return rt
|
||||||
|
|
||||||
|
# Create new session
|
||||||
|
user_id = (user_claims or {}).get("sub", "anonymous")
|
||||||
|
new_id = await db_sessions.create_session(user_id)
|
||||||
|
rt = Runtime(user_claims=user_claims, origin=origin,
|
||||||
|
broadcast=_broadcast_sse, session_id=new_id)
|
||||||
|
_sessions[new_id] = rt
|
||||||
|
_active_runtime = rt
|
||||||
|
log.info(f"[api] created new session {new_id}")
|
||||||
|
return rt
|
||||||
|
|
||||||
|
async def _save_session(rt: Runtime):
|
||||||
|
"""Persist session state to DB (upsert)."""
|
||||||
|
state = rt.to_state()
|
||||||
|
await db_sessions.save_session(
|
||||||
|
rt.session_id, state["history"],
|
||||||
|
state["memorizer_state"], state["ui_state"],
|
||||||
|
user_id=rt.identity, graph_name=rt.graph.get("name", "v4-eras"))
|
||||||
|
|
||||||
def _ensure_runtime(user_claims=None, origin=""):
|
def _ensure_runtime(user_claims=None, origin=""):
|
||||||
"""Get or create the persistent runtime."""
|
"""Legacy: get or create singleton runtime (backward compat for MCP/tests)."""
|
||||||
global _active_runtime
|
global _active_runtime
|
||||||
if _active_runtime is None:
|
if _active_runtime is None:
|
||||||
_active_runtime = Runtime(user_claims=user_claims, origin=origin,
|
_active_runtime = Runtime(user_claims=user_claims, origin=origin,
|
||||||
broadcast=_broadcast_sse)
|
broadcast=_broadcast_sse)
|
||||||
log.info("[api] created persistent runtime")
|
_sessions[_active_runtime.session_id] = _active_runtime
|
||||||
|
log.info("[api] created persistent runtime (legacy)")
|
||||||
return _active_runtime
|
return _active_runtime
|
||||||
|
|
||||||
@app.websocket("/ws")
|
@app.websocket("/ws")
|
||||||
async def ws_endpoint(ws: WebSocket, token: str | None = Query(None),
|
async def ws_endpoint(ws: WebSocket, token: str | None = Query(None),
|
||||||
access_token: str | None = Query(None)):
|
access_token: str | None = Query(None),
|
||||||
|
session: str | None = Query(None)):
|
||||||
user_claims = {"sub": "anonymous"}
|
user_claims = {"sub": "anonymous"}
|
||||||
if AUTH_ENABLED and token:
|
if AUTH_ENABLED and token:
|
||||||
try:
|
try:
|
||||||
@ -142,23 +226,30 @@ def register_routes(app):
|
|||||||
origin = ws.headers.get("origin", ws.headers.get("host", ""))
|
origin = ws.headers.get("origin", ws.headers.get("host", ""))
|
||||||
await ws.accept()
|
await ws.accept()
|
||||||
|
|
||||||
# Get or create runtime, attach WS
|
# Get or create session, attach WS
|
||||||
runtime = _ensure_runtime(user_claims=user_claims, origin=origin)
|
runtime = await _get_or_create_session(
|
||||||
|
session_id=session, user_claims=user_claims, origin=origin)
|
||||||
runtime.update_identity(user_claims, origin)
|
runtime.update_identity(user_claims, origin)
|
||||||
runtime.attach_ws(ws)
|
runtime.attach_ws(ws)
|
||||||
|
|
||||||
|
# Tell client which session they're on
|
||||||
|
try:
|
||||||
|
await ws.send_text(json.dumps({
|
||||||
|
"type": "session_info", "session_id": runtime.session_id,
|
||||||
|
"graph": runtime.graph.get("name", "unknown")}))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
data = await ws.receive_text()
|
data = await ws.receive_text()
|
||||||
msg = json.loads(data)
|
msg = json.loads(data)
|
||||||
# Always use current runtime (may change after graph switch)
|
rt = _sessions.get(runtime.session_id, runtime)
|
||||||
rt = _active_runtime or runtime
|
|
||||||
try:
|
try:
|
||||||
if msg.get("type") == "action":
|
if msg.get("type") == "action":
|
||||||
action = msg.get("action", "unknown")
|
action = msg.get("action", "unknown")
|
||||||
data_payload = msg.get("data")
|
data_payload = msg.get("data")
|
||||||
if hasattr(rt, 'use_frames') and rt.use_frames:
|
if hasattr(rt, 'use_frames') and rt.use_frames:
|
||||||
# Frame engine handles actions as ACTION: prefix messages
|
|
||||||
action_text = f"ACTION:{action}"
|
action_text = f"ACTION:{action}"
|
||||||
if data_payload:
|
if data_payload:
|
||||||
action_text += f"|data:{json.dumps(data_payload)}"
|
action_text += f"|data:{json.dumps(data_payload)}"
|
||||||
@ -169,6 +260,8 @@ def register_routes(app):
|
|||||||
rt.process_manager.cancel(msg.get("pid", 0))
|
rt.process_manager.cancel(msg.get("pid", 0))
|
||||||
else:
|
else:
|
||||||
await rt.handle_message(msg.get("text", ""), dashboard=msg.get("dashboard"))
|
await rt.handle_message(msg.get("text", ""), dashboard=msg.get("dashboard"))
|
||||||
|
# Auto-save after each message
|
||||||
|
asyncio.create_task(_save_session(rt))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
import traceback
|
||||||
log.error(f"[ws] handler error: {e}\n{traceback.format_exc()}")
|
log.error(f"[ws] handler error: {e}\n{traceback.format_exc()}")
|
||||||
@ -177,9 +270,8 @@ def register_routes(app):
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
except WebSocketDisconnect:
|
except WebSocketDisconnect:
|
||||||
if _active_runtime:
|
runtime.detach_ws()
|
||||||
_active_runtime.detach_ws()
|
log.info(f"[api] WS disconnected — session {runtime.session_id} stays alive")
|
||||||
log.info("[api] WS disconnected — runtime stays alive")
|
|
||||||
|
|
||||||
async def _auth_debug_ws(ws: WebSocket, token: str | None) -> bool:
|
async def _auth_debug_ws(ws: WebSocket, token: str | None) -> bool:
|
||||||
"""Validate token for debug WS. Returns True if auth OK."""
|
"""Validate token for debug WS. Returns True if auth OK."""
|
||||||
@ -313,6 +405,8 @@ def register_routes(app):
|
|||||||
"response": response,
|
"response": response,
|
||||||
"memorizer": runtime.memorizer.state,
|
"memorizer": runtime.memorizer.state,
|
||||||
}
|
}
|
||||||
|
# Persist session after message
|
||||||
|
await _save_session(runtime)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
import traceback
|
||||||
log.error(f"[api] pipeline error: {e}\n{traceback.format_exc()}")
|
log.error(f"[api] pipeline error: {e}\n{traceback.format_exc()}")
|
||||||
|
|||||||
108
agent/db_sessions.py
Normal file
108
agent/db_sessions.py
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
"""PostgreSQL session storage for assay runtime."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import asyncpg
|
||||||
|
|
||||||
|
log = logging.getLogger("db_sessions")
|
||||||
|
|
||||||
|
_pool: asyncpg.Pool | None = None
|
||||||
|
|
||||||
|
|
||||||
|
async def init_pool(dsn: str = None):
|
||||||
|
"""Create connection pool. Call once on startup."""
|
||||||
|
global _pool
|
||||||
|
if _pool is not None:
|
||||||
|
return
|
||||||
|
dsn = dsn or os.environ.get("POSTGRES_DSN", "")
|
||||||
|
if not dsn:
|
||||||
|
log.warning("POSTGRES_DSN not set — session persistence disabled")
|
||||||
|
return
|
||||||
|
_pool = await asyncpg.create_pool(dsn, min_size=2, max_size=10)
|
||||||
|
log.info("PostgreSQL pool ready")
|
||||||
|
|
||||||
|
|
||||||
|
async def close_pool():
|
||||||
|
global _pool
|
||||||
|
if _pool:
|
||||||
|
await _pool.close()
|
||||||
|
_pool = None
|
||||||
|
|
||||||
|
|
||||||
|
async def create_session(user_id: str, graph_name: str = "v4-eras") -> str:
|
||||||
|
"""Create a new session, return session_id."""
|
||||||
|
session_id = str(uuid4())
|
||||||
|
if not _pool:
|
||||||
|
return session_id # no persistence, still works in-memory
|
||||||
|
await _pool.execute(
|
||||||
|
"""INSERT INTO sessions (id, user_id, graph_name)
|
||||||
|
VALUES ($1, $2, $3)""",
|
||||||
|
session_id, user_id, graph_name,
|
||||||
|
)
|
||||||
|
return session_id
|
||||||
|
|
||||||
|
|
||||||
|
async def load_session(session_id: str) -> dict | None:
|
||||||
|
"""Load session state from DB. Returns None if not found."""
|
||||||
|
if not _pool:
|
||||||
|
return None
|
||||||
|
row = await _pool.fetchrow(
|
||||||
|
"""SELECT user_id, graph_name, memorizer_state, history, ui_state
|
||||||
|
FROM sessions WHERE id = $1""",
|
||||||
|
session_id,
|
||||||
|
)
|
||||||
|
if not row:
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"user_id": row["user_id"],
|
||||||
|
"graph_name": row["graph_name"],
|
||||||
|
"memorizer_state": json.loads(row["memorizer_state"]),
|
||||||
|
"history": json.loads(row["history"]),
|
||||||
|
"ui_state": json.loads(row["ui_state"]),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def save_session(session_id: str, history: list, memorizer_state: dict, ui_state: dict,
|
||||||
|
user_id: str = "unknown", graph_name: str = "v4-eras"):
|
||||||
|
"""Persist session state to DB (upsert)."""
|
||||||
|
if not _pool:
|
||||||
|
return
|
||||||
|
await _pool.execute(
|
||||||
|
"""INSERT INTO sessions (id, user_id, graph_name, history, memorizer_state, ui_state)
|
||||||
|
VALUES ($1, $2, $3, $4::jsonb, $5::jsonb, $6::jsonb)
|
||||||
|
ON CONFLICT (id) DO UPDATE SET
|
||||||
|
history = EXCLUDED.history,
|
||||||
|
memorizer_state = EXCLUDED.memorizer_state,
|
||||||
|
ui_state = EXCLUDED.ui_state,
|
||||||
|
updated_at = now(),
|
||||||
|
last_activity = now()""",
|
||||||
|
session_id, user_id, graph_name,
|
||||||
|
json.dumps(history, ensure_ascii=False),
|
||||||
|
json.dumps(memorizer_state, ensure_ascii=False),
|
||||||
|
json.dumps(ui_state, ensure_ascii=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def list_sessions(user_id: str) -> list[dict]:
|
||||||
|
"""List sessions for a user."""
|
||||||
|
if not _pool:
|
||||||
|
return []
|
||||||
|
rows = await _pool.fetch(
|
||||||
|
"""SELECT id, graph_name, last_activity
|
||||||
|
FROM sessions WHERE user_id = $1
|
||||||
|
ORDER BY last_activity DESC LIMIT 50""",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
return [{"id": r["id"], "graph_name": r["graph_name"],
|
||||||
|
"last_activity": r["last_activity"].isoformat()} for r in rows]
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_session(session_id: str):
|
||||||
|
"""Delete a session."""
|
||||||
|
if not _pool:
|
||||||
|
return
|
||||||
|
await _pool.execute("DELETE FROM sessions WHERE id = $1", session_id)
|
||||||
@ -6,6 +6,7 @@ import logging
|
|||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
from .types import Envelope, Command, InputAnalysis, ThoughtResult, DirectorPlan
|
from .types import Envelope, Command, InputAnalysis, ThoughtResult, DirectorPlan
|
||||||
from .process import ProcessManager
|
from .process import ProcessManager
|
||||||
@ -81,7 +82,9 @@ class OutputSink:
|
|||||||
|
|
||||||
class Runtime:
|
class Runtime:
|
||||||
def __init__(self, user_claims: dict = None, origin: str = "",
|
def __init__(self, user_claims: dict = None, origin: str = "",
|
||||||
broadcast: Callable = None, graph_name: str = None):
|
broadcast: Callable = None, graph_name: str = None,
|
||||||
|
session_id: str = None):
|
||||||
|
self.session_id = session_id or str(uuid4())
|
||||||
self.sink = OutputSink()
|
self.sink = OutputSink()
|
||||||
self.history: list[dict] = []
|
self.history: list[dict] = []
|
||||||
self.MAX_HISTORY = 40
|
self.MAX_HISTORY = 40
|
||||||
@ -410,3 +413,27 @@ class Runtime:
|
|||||||
|
|
||||||
if len(self.history) > self.MAX_HISTORY:
|
if len(self.history) > self.MAX_HISTORY:
|
||||||
self.history = self.history[-self.MAX_HISTORY:]
|
self.history = self.history[-self.MAX_HISTORY:]
|
||||||
|
|
||||||
|
def to_state(self) -> dict:
|
||||||
|
"""Serialize session state for DB storage."""
|
||||||
|
return {
|
||||||
|
"history": self.history,
|
||||||
|
"memorizer_state": self.memorizer.state,
|
||||||
|
"ui_state": {
|
||||||
|
"state": self.ui_node.state,
|
||||||
|
"bindings": self.ui_node.bindings,
|
||||||
|
"machines": self.ui_node.machines,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def restore_state(self, state: dict):
|
||||||
|
"""Restore session state from DB."""
|
||||||
|
self.history = state.get("history", [])
|
||||||
|
memo = state.get("memorizer_state")
|
||||||
|
if memo:
|
||||||
|
self.memorizer.state = memo
|
||||||
|
ui = state.get("ui_state", {})
|
||||||
|
if ui:
|
||||||
|
self.ui_node.state = ui.get("state", {})
|
||||||
|
self.ui_node.bindings = ui.get("bindings", {})
|
||||||
|
self.ui_node.machines = ui.get("machines", {})
|
||||||
|
|||||||
@ -8,3 +8,4 @@ pydantic==2.12.5
|
|||||||
PyJWT[crypto]==2.10.1
|
PyJWT[crypto]==2.10.1
|
||||||
pymysql==1.1.1
|
pymysql==1.1.1
|
||||||
mcp[sse]==1.9.3
|
mcp[sse]==1.9.3
|
||||||
|
asyncpg==0.30.0
|
||||||
|
|||||||
Reference in New Issue
Block a user