From 376fdb2458bc58e86dd769262c26fe51fd831ffc Mon Sep 17 00:00:00 2001 From: Nico Date: Fri, 3 Apr 2026 18:41:15 +0200 Subject: [PATCH] Wire NodePool into Runtime + API: shared stateless nodes across sessions - Runtime accepts optional pool= param, uses shared nodes from pool for stateless roles, creates fresh sensor/memorizer/ui per-session - FrameEngine sets _current_hud contextvar at start of process_message - API creates global NodePool once, passes to all Runtime instances - Graph switch resets pool for new graph - Legacy _ensure_runtime also uses pool - 15/15 engine + matrix tests green Co-Authored-By: Claude Opus 4.6 (1M context) --- agent/api.py | 32 ++++++++++++++++++------ agent/frame_engine.py | 4 +++ agent/runtime.py | 57 +++++++++++++++++++++++++++++++------------ 3 files changed, 71 insertions(+), 22 deletions(-) diff --git a/agent/api.py b/agent/api.py index deda543..eee6ecf 100644 --- a/agent/api.py +++ b/agent/api.py @@ -14,6 +14,7 @@ import httpx from .auth import AUTH_ENABLED, ZITADEL_ISSUER, _validate_token, require_auth from .runtime import Runtime, TRACE_FILE +from .node_pool import NodePool from . import db_sessions log = logging.getLogger("runtime") @@ -22,6 +23,9 @@ log = logging.getLogger("runtime") _sessions: dict[str, Runtime] = {} MAX_ACTIVE_SESSIONS = 50 +# Shared node pool (created once, shared across all sessions) +_node_pool: NodePool | None = None + # Legacy: for backward compat with single-session MCP/test endpoints _active_runtime: Runtime | None = None @@ -193,6 +197,16 @@ def register_routes(app): "projectId": ZITADEL_PROJECT_ID, } + def _get_pool(graph_name: str = None) -> NodePool: + """Get or create the shared node pool.""" + global _node_pool + from .runtime import _active_graph_name + gname = graph_name or _active_graph_name + if _node_pool is None or _node_pool.graph_name != gname: + _node_pool = NodePool(gname) + log.info(f"[api] created shared node pool for '{gname}'") + return _node_pool + async def _get_or_create_session(session_id: str = None, user_claims=None, origin="") -> Runtime: """Get existing session or create new one.""" global _active_runtime @@ -207,24 +221,26 @@ def register_routes(app): if session_id: saved = await db_sessions.load_session(session_id) if saved: + pool = _get_pool(saved["graph_name"]) rt = Runtime(user_claims=user_claims, origin=origin, broadcast=_broadcast_sse, graph_name=saved["graph_name"], - session_id=session_id) + session_id=session_id, pool=pool) rt.restore_state(saved) _sessions[session_id] = rt _active_runtime = rt - log.info(f"[api] restored session {session_id} from DB") + log.info(f"[api] restored session {session_id} from DB (shared pool)") return rt # Create new session user_id = (user_claims or {}).get("sub", "anonymous") new_id = await db_sessions.create_session(user_id) + pool = _get_pool() rt = Runtime(user_claims=user_claims, origin=origin, - broadcast=_broadcast_sse, session_id=new_id) + broadcast=_broadcast_sse, session_id=new_id, pool=pool) _sessions[new_id] = rt _active_runtime = rt - log.info(f"[api] created new session {new_id}") + log.info(f"[api] created new session {new_id} (shared pool)") return rt async def _save_session(rt: Runtime): @@ -239,10 +255,11 @@ def register_routes(app): """Legacy: get or create singleton runtime (backward compat for MCP/tests).""" global _active_runtime if _active_runtime is None: + pool = _get_pool() _active_runtime = Runtime(user_claims=user_claims, origin=origin, - broadcast=_broadcast_sse) + broadcast=_broadcast_sse, pool=pool) _sessions[_active_runtime.session_id] = _active_runtime - log.info("[api] created persistent runtime (legacy)") + log.info("[api] created persistent runtime (shared pool)") return _active_runtime async def _auth_debug_ws(ws: WebSocket, token: str | None) -> bool: @@ -587,12 +604,13 @@ def register_routes(app): @app.post("/api/graph/switch") async def switch_graph(body: dict, user=Depends(require_auth)): - global _active_runtime + global _active_runtime, _node_pool from .engine import load_graph import agent.runtime as rt name = body.get("name", "") graph = load_graph(name) # validates it exists rt._active_graph_name = name + _node_pool = None # Force pool recreation for new graph # Preserve WS connection across graph switch old_ws = None diff --git a/agent/frame_engine.py b/agent/frame_engine.py index 9d224a8..7aed5fb 100644 --- a/agent/frame_engine.py +++ b/agent/frame_engine.py @@ -16,6 +16,7 @@ import time from dataclasses import dataclass, field, asdict from .types import Envelope, Command, InputAnalysis, ThoughtResult, DirectorPlan, PARouting +from .nodes.base import _current_hud log = logging.getLogger("runtime") @@ -189,6 +190,9 @@ class FrameEngine: saved_models[role] = node.model node.model = model + # Set session-scoped HUD for shared nodes (contextvar, per-task) + _current_hud.set(self._send_hud) + try: self._begin_trace(text) diff --git a/agent/runtime.py b/agent/runtime.py index 063bae2..361e3e1 100644 --- a/agent/runtime.py +++ b/agent/runtime.py @@ -85,29 +85,56 @@ class OutputSink: class Runtime: def __init__(self, user_claims: dict = None, origin: str = "", broadcast: Callable = None, graph_name: str = None, - session_id: str = None): + session_id: str = None, pool=None): self.session_id = session_id or str(uuid4()) self.sink = OutputSink() self.history: list[dict] = [] self.MAX_HISTORY = 40 self._broadcast = broadcast or (lambda e: None) - # Load graph and instantiate nodes gname = graph_name or _active_graph_name - self.graph = load_graph(gname) - self.process_manager = ProcessManager(send_hud=self._send_hud) - nodes = instantiate_nodes(self.graph, send_hud=self._send_hud, - process_manager=self.process_manager) - # Bind nodes by role (pipeline code references these) - self.input_node = nodes["input"] - self.thinker = nodes.get("thinker") # v1/v2/v3 - self.output_node = nodes["output"] - self.ui_node = nodes["ui"] - self.memorizer = nodes["memorizer"] - self.director = nodes.get("director") # v1/v2/v3, None in v4 - self.sensor = nodes["sensor"] - self.interpreter = nodes.get("interpreter") # v2+ only + if pool: + # Phase 2: use shared node pool for stateless nodes + self.graph = pool.graph + self.process_manager = ProcessManager(send_hud=self._send_hud) + + # Shared nodes from pool (stateless, serve all sessions) + self.input_node = pool.shared.get("input") + self.thinker = pool.shared.get("thinker") + self.output_node = pool.shared.get("output") + self.director = pool.shared.get("director") + self.interpreter = pool.shared.get("interpreter") + + # Per-session stateful nodes (fresh each session) + from .nodes import UINode, MemorizerNodeV1 as MemorizerNode, SensorNode + self.ui_node = UINode(send_hud=self._send_hud) + self.memorizer = MemorizerNode(send_hud=self._send_hud) + self.sensor = SensorNode(send_hud=self._send_hud) + + # Build combined nodes dict for FrameEngine + nodes = dict(pool.shared) + nodes["ui"] = self.ui_node + nodes["memorizer"] = self.memorizer + nodes["sensor"] = self.sensor + + log.info(f"[runtime] using shared pool for graph '{gname}' " + f"({len(pool.shared)} shared, 3 per-session)") + else: + # Legacy: create all nodes per-session + self.graph = load_graph(gname) + self.process_manager = ProcessManager(send_hud=self._send_hud) + nodes = instantiate_nodes(self.graph, send_hud=self._send_hud, + process_manager=self.process_manager) + + self.input_node = nodes["input"] + self.thinker = nodes.get("thinker") + self.output_node = nodes["output"] + self.ui_node = nodes["ui"] + self.memorizer = nodes["memorizer"] + self.director = nodes.get("director") + self.sensor = nodes["sensor"] + self.interpreter = nodes.get("interpreter") # Detect graph type self.is_v2 = self.director is not None and hasattr(self.director, "decide")