"""PostgreSQL session storage for assay runtime.""" import json import logging import os from uuid import uuid4 import asyncpg log = logging.getLogger("db_sessions") _pool: asyncpg.Pool | None = None async def init_pool(dsn: str = None): """Create connection pool. Call once on startup.""" global _pool if _pool is not None: return dsn = dsn or os.environ.get("POSTGRES_DSN", "") if not dsn: log.warning("POSTGRES_DSN not set — session persistence disabled") return _pool = await asyncpg.create_pool(dsn, min_size=2, max_size=10) log.info("PostgreSQL pool ready") async def close_pool(): global _pool if _pool: await _pool.close() _pool = None async def create_session(user_id: str, graph_name: str = "v4-eras") -> str: """Create a new session, return session_id.""" session_id = str(uuid4()) if not _pool: return session_id # no persistence, still works in-memory await _pool.execute( """INSERT INTO sessions (id, user_id, graph_name) VALUES ($1, $2, $3)""", session_id, user_id, graph_name, ) return session_id async def load_session(session_id: str) -> dict | None: """Load session state from DB. Returns None if not found.""" if not _pool: return None row = await _pool.fetchrow( """SELECT user_id, graph_name, memorizer_state, history, ui_state FROM sessions WHERE id = $1""", session_id, ) if not row: return None return { "session_id": session_id, "user_id": row["user_id"], "graph_name": row["graph_name"], "memorizer_state": json.loads(row["memorizer_state"]), "history": json.loads(row["history"]), "ui_state": json.loads(row["ui_state"]), } async def save_session(session_id: str, history: list, memorizer_state: dict, ui_state: dict, user_id: str = "unknown", graph_name: str = "v4-eras"): """Persist session state to DB (upsert).""" if not _pool: return await _pool.execute( """INSERT INTO sessions (id, user_id, graph_name, history, memorizer_state, ui_state) VALUES ($1, $2, $3, $4::jsonb, $5::jsonb, $6::jsonb) ON CONFLICT (id) DO UPDATE SET history = EXCLUDED.history, memorizer_state = EXCLUDED.memorizer_state, ui_state = EXCLUDED.ui_state, updated_at = now(), last_activity = now()""", session_id, user_id, graph_name, json.dumps(history, ensure_ascii=False), json.dumps(memorizer_state, ensure_ascii=False), json.dumps(ui_state, ensure_ascii=False), ) async def list_sessions(user_id: str) -> list[dict]: """List sessions for a user.""" if not _pool: return [] rows = await _pool.fetch( """SELECT id, graph_name, last_activity FROM sessions WHERE user_id = $1 ORDER BY last_activity DESC LIMIT 50""", user_id, ) return [{"id": r["id"], "graph_name": r["graph_name"], "last_activity": r["last_activity"].isoformat()} for r in rows] async def delete_session(session_id: str): """Delete a session.""" if not _pool: return await _pool.execute("DELETE FROM sessions WHERE id = $1", session_id)