diff --git a/agent/api.py b/agent/api.py index 6dc8b54..6db925d 100644 --- a/agent/api.py +++ b/agent/api.py @@ -14,10 +14,15 @@ import httpx from .auth import AUTH_ENABLED, ZITADEL_ISSUER, _validate_token, require_auth from .runtime import Runtime, TRACE_FILE +from . import db_sessions log = logging.getLogger("runtime") -# Active runtime reference (set by WS endpoint) +# Session map: session_id -> Runtime (in-memory active sessions) +_sessions: dict[str, Runtime] = {} +MAX_ACTIVE_SESSIONS = 50 + +# Legacy: for backward compat with single-session MCP/test endpoints _active_runtime: Runtime | None = None # SSE subscribers @@ -83,12 +88,13 @@ def _broadcast_sse(event: dict): _pipeline_result["event"] = evt -def _state_hash() -> str: - if not _active_runtime: +def _state_hash(rt: Runtime = None) -> str: + r = rt or _active_runtime + if not r: return "no_session" raw = json.dumps({ - "mem": _active_runtime.memorizer.state, - "hlen": len(_active_runtime.history), + "mem": r.memorizer.state, + "hlen": len(r.history), }, sort_keys=True) return hashlib.md5(raw.encode()).hexdigest()[:12] @@ -96,10 +102,44 @@ def _state_hash() -> str: def register_routes(app): """Register all API routes on the FastAPI app.""" + @app.on_event("startup") + async def startup(): + await db_sessions.init_pool() + + @app.on_event("shutdown") + async def shutdown(): + await db_sessions.close_pool() + @app.get("/health") async def health(): return {"status": "ok"} + # --- Session CRUD --- + + @app.post("/api/sessions") + async def create_session(body: dict = None, user=Depends(require_auth)): + """Create a new session.""" + user_id = user.get("sub", "anonymous") + graph = (body or {}).get("graph", "v4-eras") + session_id = await db_sessions.create_session(user_id, graph) + return {"session_id": session_id} + + @app.get("/api/sessions") + async def list_sessions(user=Depends(require_auth)): + """List sessions for current user.""" + user_id = user.get("sub", "anonymous") + sessions = await db_sessions.list_sessions(user_id) + return {"sessions": sessions} + + @app.delete("/api/sessions/{session_id}") + async def delete_session(session_id: str, user=Depends(require_auth)): + """Delete a session.""" + if session_id in _sessions: + rt = _sessions.pop(session_id) + rt.sensor.stop() + await db_sessions.delete_session(session_id) + return {"status": "deleted"} + @app.get("/auth/config") async def auth_config(): from .auth import ZITADEL_ISSUER, ZITADEL_CLIENT_ID, ZITADEL_PROJECT_ID, AUTH_ENABLED @@ -110,18 +150,62 @@ def register_routes(app): "projectId": ZITADEL_PROJECT_ID, } + async def _get_or_create_session(session_id: str = None, user_claims=None, origin="") -> Runtime: + """Get existing session or create new one.""" + global _active_runtime + + # Reuse in-memory session + if session_id and session_id in _sessions: + rt = _sessions[session_id] + _active_runtime = rt + return rt + + # Try loading from DB + if session_id: + saved = await db_sessions.load_session(session_id) + if saved: + rt = Runtime(user_claims=user_claims, origin=origin, + broadcast=_broadcast_sse, + graph_name=saved["graph_name"], + session_id=session_id) + rt.restore_state(saved) + _sessions[session_id] = rt + _active_runtime = rt + log.info(f"[api] restored session {session_id} from DB") + return rt + + # Create new session + user_id = (user_claims or {}).get("sub", "anonymous") + new_id = await db_sessions.create_session(user_id) + rt = Runtime(user_claims=user_claims, origin=origin, + broadcast=_broadcast_sse, session_id=new_id) + _sessions[new_id] = rt + _active_runtime = rt + log.info(f"[api] created new session {new_id}") + return rt + + async def _save_session(rt: Runtime): + """Persist session state to DB (upsert).""" + state = rt.to_state() + await db_sessions.save_session( + rt.session_id, state["history"], + state["memorizer_state"], state["ui_state"], + user_id=rt.identity, graph_name=rt.graph.get("name", "v4-eras")) + def _ensure_runtime(user_claims=None, origin=""): - """Get or create the persistent runtime.""" + """Legacy: get or create singleton runtime (backward compat for MCP/tests).""" global _active_runtime if _active_runtime is None: _active_runtime = Runtime(user_claims=user_claims, origin=origin, broadcast=_broadcast_sse) - log.info("[api] created persistent runtime") + _sessions[_active_runtime.session_id] = _active_runtime + log.info("[api] created persistent runtime (legacy)") return _active_runtime @app.websocket("/ws") async def ws_endpoint(ws: WebSocket, token: str | None = Query(None), - access_token: str | None = Query(None)): + access_token: str | None = Query(None), + session: str | None = Query(None)): user_claims = {"sub": "anonymous"} if AUTH_ENABLED and token: try: @@ -142,23 +226,30 @@ def register_routes(app): origin = ws.headers.get("origin", ws.headers.get("host", "")) await ws.accept() - # Get or create runtime, attach WS - runtime = _ensure_runtime(user_claims=user_claims, origin=origin) + # Get or create session, attach WS + runtime = await _get_or_create_session( + session_id=session, user_claims=user_claims, origin=origin) runtime.update_identity(user_claims, origin) runtime.attach_ws(ws) + # Tell client which session they're on + try: + await ws.send_text(json.dumps({ + "type": "session_info", "session_id": runtime.session_id, + "graph": runtime.graph.get("name", "unknown")})) + except Exception: + pass + try: while True: data = await ws.receive_text() msg = json.loads(data) - # Always use current runtime (may change after graph switch) - rt = _active_runtime or runtime + rt = _sessions.get(runtime.session_id, runtime) try: if msg.get("type") == "action": action = msg.get("action", "unknown") data_payload = msg.get("data") if hasattr(rt, 'use_frames') and rt.use_frames: - # Frame engine handles actions as ACTION: prefix messages action_text = f"ACTION:{action}" if data_payload: action_text += f"|data:{json.dumps(data_payload)}" @@ -169,6 +260,8 @@ def register_routes(app): rt.process_manager.cancel(msg.get("pid", 0)) else: await rt.handle_message(msg.get("text", ""), dashboard=msg.get("dashboard")) + # Auto-save after each message + asyncio.create_task(_save_session(rt)) except Exception as e: import traceback log.error(f"[ws] handler error: {e}\n{traceback.format_exc()}") @@ -177,9 +270,8 @@ def register_routes(app): except Exception: pass except WebSocketDisconnect: - if _active_runtime: - _active_runtime.detach_ws() - log.info("[api] WS disconnected — runtime stays alive") + runtime.detach_ws() + log.info(f"[api] WS disconnected — session {runtime.session_id} stays alive") async def _auth_debug_ws(ws: WebSocket, token: str | None) -> bool: """Validate token for debug WS. Returns True if auth OK.""" @@ -313,6 +405,8 @@ def register_routes(app): "response": response, "memorizer": runtime.memorizer.state, } + # Persist session after message + await _save_session(runtime) except Exception as e: import traceback log.error(f"[api] pipeline error: {e}\n{traceback.format_exc()}") diff --git a/agent/db_sessions.py b/agent/db_sessions.py new file mode 100644 index 0000000..443028d --- /dev/null +++ b/agent/db_sessions.py @@ -0,0 +1,108 @@ +"""PostgreSQL session storage for assay runtime.""" + +import json +import logging +import os +from uuid import uuid4 + +import asyncpg + +log = logging.getLogger("db_sessions") + +_pool: asyncpg.Pool | None = None + + +async def init_pool(dsn: str = None): + """Create connection pool. Call once on startup.""" + global _pool + if _pool is not None: + return + dsn = dsn or os.environ.get("POSTGRES_DSN", "") + if not dsn: + log.warning("POSTGRES_DSN not set — session persistence disabled") + return + _pool = await asyncpg.create_pool(dsn, min_size=2, max_size=10) + log.info("PostgreSQL pool ready") + + +async def close_pool(): + global _pool + if _pool: + await _pool.close() + _pool = None + + +async def create_session(user_id: str, graph_name: str = "v4-eras") -> str: + """Create a new session, return session_id.""" + session_id = str(uuid4()) + if not _pool: + return session_id # no persistence, still works in-memory + await _pool.execute( + """INSERT INTO sessions (id, user_id, graph_name) + VALUES ($1, $2, $3)""", + session_id, user_id, graph_name, + ) + return session_id + + +async def load_session(session_id: str) -> dict | None: + """Load session state from DB. Returns None if not found.""" + if not _pool: + return None + row = await _pool.fetchrow( + """SELECT user_id, graph_name, memorizer_state, history, ui_state + FROM sessions WHERE id = $1""", + session_id, + ) + if not row: + return None + return { + "session_id": session_id, + "user_id": row["user_id"], + "graph_name": row["graph_name"], + "memorizer_state": json.loads(row["memorizer_state"]), + "history": json.loads(row["history"]), + "ui_state": json.loads(row["ui_state"]), + } + + +async def save_session(session_id: str, history: list, memorizer_state: dict, ui_state: dict, + user_id: str = "unknown", graph_name: str = "v4-eras"): + """Persist session state to DB (upsert).""" + if not _pool: + return + await _pool.execute( + """INSERT INTO sessions (id, user_id, graph_name, history, memorizer_state, ui_state) + VALUES ($1, $2, $3, $4::jsonb, $5::jsonb, $6::jsonb) + ON CONFLICT (id) DO UPDATE SET + history = EXCLUDED.history, + memorizer_state = EXCLUDED.memorizer_state, + ui_state = EXCLUDED.ui_state, + updated_at = now(), + last_activity = now()""", + session_id, user_id, graph_name, + json.dumps(history, ensure_ascii=False), + json.dumps(memorizer_state, ensure_ascii=False), + json.dumps(ui_state, ensure_ascii=False), + ) + + +async def list_sessions(user_id: str) -> list[dict]: + """List sessions for a user.""" + if not _pool: + return [] + rows = await _pool.fetch( + """SELECT id, graph_name, last_activity + FROM sessions WHERE user_id = $1 + ORDER BY last_activity DESC LIMIT 50""", + user_id, + ) + return [{"id": r["id"], "graph_name": r["graph_name"], + "last_activity": r["last_activity"].isoformat()} for r in rows] + + +async def delete_session(session_id: str): + """Delete a session.""" + if not _pool: + return + await _pool.execute("DELETE FROM sessions WHERE id = $1", session_id) diff --git a/agent/runtime.py b/agent/runtime.py index f014b3c..1e82677 100644 --- a/agent/runtime.py +++ b/agent/runtime.py @@ -6,6 +6,7 @@ import logging import time from pathlib import Path from typing import Callable +from uuid import uuid4 from .types import Envelope, Command, InputAnalysis, ThoughtResult, DirectorPlan from .process import ProcessManager @@ -81,7 +82,9 @@ class OutputSink: class Runtime: def __init__(self, user_claims: dict = None, origin: str = "", - broadcast: Callable = None, graph_name: str = None): + broadcast: Callable = None, graph_name: str = None, + session_id: str = None): + self.session_id = session_id or str(uuid4()) self.sink = OutputSink() self.history: list[dict] = [] self.MAX_HISTORY = 40 @@ -410,3 +413,27 @@ class Runtime: if len(self.history) > self.MAX_HISTORY: self.history = self.history[-self.MAX_HISTORY:] + + def to_state(self) -> dict: + """Serialize session state for DB storage.""" + return { + "history": self.history, + "memorizer_state": self.memorizer.state, + "ui_state": { + "state": self.ui_node.state, + "bindings": self.ui_node.bindings, + "machines": self.ui_node.machines, + }, + } + + def restore_state(self, state: dict): + """Restore session state from DB.""" + self.history = state.get("history", []) + memo = state.get("memorizer_state") + if memo: + self.memorizer.state = memo + ui = state.get("ui_state", {}) + if ui: + self.ui_node.state = ui.get("state", {}) + self.ui_node.bindings = ui.get("bindings", {}) + self.ui_node.machines = ui.get("machines", {}) diff --git a/requirements.txt b/requirements.txt index ec76b60..ae91101 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ pydantic==2.12.5 PyJWT[crypto]==2.10.1 pymysql==1.1.1 mcp[sse]==1.9.3 +asyncpg==0.30.0