Wire NodePool into Runtime + API: shared stateless nodes across sessions

- Runtime accepts optional pool= param, uses shared nodes from pool
  for stateless roles, creates fresh sensor/memorizer/ui per-session
- FrameEngine sets _current_hud contextvar at start of process_message
- API creates global NodePool once, passes to all Runtime instances
- Graph switch resets pool for new graph
- Legacy _ensure_runtime also uses pool
- 15/15 engine + matrix tests green

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Nico 2026-04-03 18:41:15 +02:00
parent ae2338e70a
commit 376fdb2458
3 changed files with 71 additions and 22 deletions

View File

@ -14,6 +14,7 @@ import httpx
from .auth import AUTH_ENABLED, ZITADEL_ISSUER, _validate_token, require_auth from .auth import AUTH_ENABLED, ZITADEL_ISSUER, _validate_token, require_auth
from .runtime import Runtime, TRACE_FILE from .runtime import Runtime, TRACE_FILE
from .node_pool import NodePool
from . import db_sessions from . import db_sessions
log = logging.getLogger("runtime") log = logging.getLogger("runtime")
@ -22,6 +23,9 @@ log = logging.getLogger("runtime")
_sessions: dict[str, Runtime] = {} _sessions: dict[str, Runtime] = {}
MAX_ACTIVE_SESSIONS = 50 MAX_ACTIVE_SESSIONS = 50
# Shared node pool (created once, shared across all sessions)
_node_pool: NodePool | None = None
# Legacy: for backward compat with single-session MCP/test endpoints # Legacy: for backward compat with single-session MCP/test endpoints
_active_runtime: Runtime | None = None _active_runtime: Runtime | None = None
@ -193,6 +197,16 @@ def register_routes(app):
"projectId": ZITADEL_PROJECT_ID, "projectId": ZITADEL_PROJECT_ID,
} }
def _get_pool(graph_name: str = None) -> NodePool:
"""Get or create the shared node pool."""
global _node_pool
from .runtime import _active_graph_name
gname = graph_name or _active_graph_name
if _node_pool is None or _node_pool.graph_name != gname:
_node_pool = NodePool(gname)
log.info(f"[api] created shared node pool for '{gname}'")
return _node_pool
async def _get_or_create_session(session_id: str = None, user_claims=None, origin="") -> Runtime: async def _get_or_create_session(session_id: str = None, user_claims=None, origin="") -> Runtime:
"""Get existing session or create new one.""" """Get existing session or create new one."""
global _active_runtime global _active_runtime
@ -207,24 +221,26 @@ def register_routes(app):
if session_id: if session_id:
saved = await db_sessions.load_session(session_id) saved = await db_sessions.load_session(session_id)
if saved: if saved:
pool = _get_pool(saved["graph_name"])
rt = Runtime(user_claims=user_claims, origin=origin, rt = Runtime(user_claims=user_claims, origin=origin,
broadcast=_broadcast_sse, broadcast=_broadcast_sse,
graph_name=saved["graph_name"], graph_name=saved["graph_name"],
session_id=session_id) session_id=session_id, pool=pool)
rt.restore_state(saved) rt.restore_state(saved)
_sessions[session_id] = rt _sessions[session_id] = rt
_active_runtime = rt _active_runtime = rt
log.info(f"[api] restored session {session_id} from DB") log.info(f"[api] restored session {session_id} from DB (shared pool)")
return rt return rt
# Create new session # Create new session
user_id = (user_claims or {}).get("sub", "anonymous") user_id = (user_claims or {}).get("sub", "anonymous")
new_id = await db_sessions.create_session(user_id) new_id = await db_sessions.create_session(user_id)
pool = _get_pool()
rt = Runtime(user_claims=user_claims, origin=origin, rt = Runtime(user_claims=user_claims, origin=origin,
broadcast=_broadcast_sse, session_id=new_id) broadcast=_broadcast_sse, session_id=new_id, pool=pool)
_sessions[new_id] = rt _sessions[new_id] = rt
_active_runtime = rt _active_runtime = rt
log.info(f"[api] created new session {new_id}") log.info(f"[api] created new session {new_id} (shared pool)")
return rt return rt
async def _save_session(rt: Runtime): async def _save_session(rt: Runtime):
@ -239,10 +255,11 @@ def register_routes(app):
"""Legacy: get or create singleton runtime (backward compat for MCP/tests).""" """Legacy: get or create singleton runtime (backward compat for MCP/tests)."""
global _active_runtime global _active_runtime
if _active_runtime is None: if _active_runtime is None:
pool = _get_pool()
_active_runtime = Runtime(user_claims=user_claims, origin=origin, _active_runtime = Runtime(user_claims=user_claims, origin=origin,
broadcast=_broadcast_sse) broadcast=_broadcast_sse, pool=pool)
_sessions[_active_runtime.session_id] = _active_runtime _sessions[_active_runtime.session_id] = _active_runtime
log.info("[api] created persistent runtime (legacy)") log.info("[api] created persistent runtime (shared pool)")
return _active_runtime return _active_runtime
async def _auth_debug_ws(ws: WebSocket, token: str | None) -> bool: async def _auth_debug_ws(ws: WebSocket, token: str | None) -> bool:
@ -587,12 +604,13 @@ def register_routes(app):
@app.post("/api/graph/switch") @app.post("/api/graph/switch")
async def switch_graph(body: dict, user=Depends(require_auth)): async def switch_graph(body: dict, user=Depends(require_auth)):
global _active_runtime global _active_runtime, _node_pool
from .engine import load_graph from .engine import load_graph
import agent.runtime as rt import agent.runtime as rt
name = body.get("name", "") name = body.get("name", "")
graph = load_graph(name) # validates it exists graph = load_graph(name) # validates it exists
rt._active_graph_name = name rt._active_graph_name = name
_node_pool = None # Force pool recreation for new graph
# Preserve WS connection across graph switch # Preserve WS connection across graph switch
old_ws = None old_ws = None

View File

@ -16,6 +16,7 @@ import time
from dataclasses import dataclass, field, asdict from dataclasses import dataclass, field, asdict
from .types import Envelope, Command, InputAnalysis, ThoughtResult, DirectorPlan, PARouting from .types import Envelope, Command, InputAnalysis, ThoughtResult, DirectorPlan, PARouting
from .nodes.base import _current_hud
log = logging.getLogger("runtime") log = logging.getLogger("runtime")
@ -189,6 +190,9 @@ class FrameEngine:
saved_models[role] = node.model saved_models[role] = node.model
node.model = model node.model = model
# Set session-scoped HUD for shared nodes (contextvar, per-task)
_current_hud.set(self._send_hud)
try: try:
self._begin_trace(text) self._begin_trace(text)

View File

@ -85,29 +85,56 @@ class OutputSink:
class Runtime: class Runtime:
def __init__(self, user_claims: dict = None, origin: str = "", def __init__(self, user_claims: dict = None, origin: str = "",
broadcast: Callable = None, graph_name: str = None, broadcast: Callable = None, graph_name: str = None,
session_id: str = None): session_id: str = None, pool=None):
self.session_id = session_id or str(uuid4()) self.session_id = session_id or str(uuid4())
self.sink = OutputSink() self.sink = OutputSink()
self.history: list[dict] = [] self.history: list[dict] = []
self.MAX_HISTORY = 40 self.MAX_HISTORY = 40
self._broadcast = broadcast or (lambda e: None) self._broadcast = broadcast or (lambda e: None)
# Load graph and instantiate nodes
gname = graph_name or _active_graph_name gname = graph_name or _active_graph_name
self.graph = load_graph(gname)
self.process_manager = ProcessManager(send_hud=self._send_hud)
nodes = instantiate_nodes(self.graph, send_hud=self._send_hud,
process_manager=self.process_manager)
# Bind nodes by role (pipeline code references these) if pool:
self.input_node = nodes["input"] # Phase 2: use shared node pool for stateless nodes
self.thinker = nodes.get("thinker") # v1/v2/v3 self.graph = pool.graph
self.output_node = nodes["output"] self.process_manager = ProcessManager(send_hud=self._send_hud)
self.ui_node = nodes["ui"]
self.memorizer = nodes["memorizer"] # Shared nodes from pool (stateless, serve all sessions)
self.director = nodes.get("director") # v1/v2/v3, None in v4 self.input_node = pool.shared.get("input")
self.sensor = nodes["sensor"] self.thinker = pool.shared.get("thinker")
self.interpreter = nodes.get("interpreter") # v2+ only self.output_node = pool.shared.get("output")
self.director = pool.shared.get("director")
self.interpreter = pool.shared.get("interpreter")
# Per-session stateful nodes (fresh each session)
from .nodes import UINode, MemorizerNodeV1 as MemorizerNode, SensorNode
self.ui_node = UINode(send_hud=self._send_hud)
self.memorizer = MemorizerNode(send_hud=self._send_hud)
self.sensor = SensorNode(send_hud=self._send_hud)
# Build combined nodes dict for FrameEngine
nodes = dict(pool.shared)
nodes["ui"] = self.ui_node
nodes["memorizer"] = self.memorizer
nodes["sensor"] = self.sensor
log.info(f"[runtime] using shared pool for graph '{gname}' "
f"({len(pool.shared)} shared, 3 per-session)")
else:
# Legacy: create all nodes per-session
self.graph = load_graph(gname)
self.process_manager = ProcessManager(send_hud=self._send_hud)
nodes = instantiate_nodes(self.graph, send_hud=self._send_hud,
process_manager=self.process_manager)
self.input_node = nodes["input"]
self.thinker = nodes.get("thinker")
self.output_node = nodes["output"]
self.ui_node = nodes["ui"]
self.memorizer = nodes["memorizer"]
self.director = nodes.get("director")
self.sensor = nodes["sensor"]
self.interpreter = nodes.get("interpreter")
# Detect graph type # Detect graph type
self.is_v2 = self.director is not None and hasattr(self.director, "decide") self.is_v2 = self.director is not None and hasattr(self.director, "decide")