Wire NodePool into Runtime + API: shared stateless nodes across sessions

- Runtime accepts optional pool= param, uses shared nodes from pool
  for stateless roles, creates fresh sensor/memorizer/ui per-session
- FrameEngine sets _current_hud contextvar at start of process_message
- API creates global NodePool once, passes to all Runtime instances
- Graph switch resets pool for new graph
- Legacy _ensure_runtime also uses pool
- 15/15 engine + matrix tests green

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Nico 2026-04-03 18:41:15 +02:00
parent ae2338e70a
commit 376fdb2458
3 changed files with 71 additions and 22 deletions

View File

@ -14,6 +14,7 @@ import httpx
from .auth import AUTH_ENABLED, ZITADEL_ISSUER, _validate_token, require_auth
from .runtime import Runtime, TRACE_FILE
from .node_pool import NodePool
from . import db_sessions
log = logging.getLogger("runtime")
@ -22,6 +23,9 @@ log = logging.getLogger("runtime")
_sessions: dict[str, Runtime] = {}
MAX_ACTIVE_SESSIONS = 50
# Shared node pool (created once, shared across all sessions)
_node_pool: NodePool | None = None
# Legacy: for backward compat with single-session MCP/test endpoints
_active_runtime: Runtime | None = None
@ -193,6 +197,16 @@ def register_routes(app):
"projectId": ZITADEL_PROJECT_ID,
}
def _get_pool(graph_name: str = None) -> NodePool:
"""Get or create the shared node pool."""
global _node_pool
from .runtime import _active_graph_name
gname = graph_name or _active_graph_name
if _node_pool is None or _node_pool.graph_name != gname:
_node_pool = NodePool(gname)
log.info(f"[api] created shared node pool for '{gname}'")
return _node_pool
async def _get_or_create_session(session_id: str = None, user_claims=None, origin="") -> Runtime:
"""Get existing session or create new one."""
global _active_runtime
@ -207,24 +221,26 @@ def register_routes(app):
if session_id:
saved = await db_sessions.load_session(session_id)
if saved:
pool = _get_pool(saved["graph_name"])
rt = Runtime(user_claims=user_claims, origin=origin,
broadcast=_broadcast_sse,
graph_name=saved["graph_name"],
session_id=session_id)
session_id=session_id, pool=pool)
rt.restore_state(saved)
_sessions[session_id] = rt
_active_runtime = rt
log.info(f"[api] restored session {session_id} from DB")
log.info(f"[api] restored session {session_id} from DB (shared pool)")
return rt
# Create new session
user_id = (user_claims or {}).get("sub", "anonymous")
new_id = await db_sessions.create_session(user_id)
pool = _get_pool()
rt = Runtime(user_claims=user_claims, origin=origin,
broadcast=_broadcast_sse, session_id=new_id)
broadcast=_broadcast_sse, session_id=new_id, pool=pool)
_sessions[new_id] = rt
_active_runtime = rt
log.info(f"[api] created new session {new_id}")
log.info(f"[api] created new session {new_id} (shared pool)")
return rt
async def _save_session(rt: Runtime):
@ -239,10 +255,11 @@ def register_routes(app):
"""Legacy: get or create singleton runtime (backward compat for MCP/tests)."""
global _active_runtime
if _active_runtime is None:
pool = _get_pool()
_active_runtime = Runtime(user_claims=user_claims, origin=origin,
broadcast=_broadcast_sse)
broadcast=_broadcast_sse, pool=pool)
_sessions[_active_runtime.session_id] = _active_runtime
log.info("[api] created persistent runtime (legacy)")
log.info("[api] created persistent runtime (shared pool)")
return _active_runtime
async def _auth_debug_ws(ws: WebSocket, token: str | None) -> bool:
@ -587,12 +604,13 @@ def register_routes(app):
@app.post("/api/graph/switch")
async def switch_graph(body: dict, user=Depends(require_auth)):
global _active_runtime
global _active_runtime, _node_pool
from .engine import load_graph
import agent.runtime as rt
name = body.get("name", "")
graph = load_graph(name) # validates it exists
rt._active_graph_name = name
_node_pool = None # Force pool recreation for new graph
# Preserve WS connection across graph switch
old_ws = None

View File

@ -16,6 +16,7 @@ import time
from dataclasses import dataclass, field, asdict
from .types import Envelope, Command, InputAnalysis, ThoughtResult, DirectorPlan, PARouting
from .nodes.base import _current_hud
log = logging.getLogger("runtime")
@ -189,6 +190,9 @@ class FrameEngine:
saved_models[role] = node.model
node.model = model
# Set session-scoped HUD for shared nodes (contextvar, per-task)
_current_hud.set(self._send_hud)
try:
self._begin_trace(text)

View File

@ -85,29 +85,56 @@ class OutputSink:
class Runtime:
def __init__(self, user_claims: dict = None, origin: str = "",
broadcast: Callable = None, graph_name: str = None,
session_id: str = None):
session_id: str = None, pool=None):
self.session_id = session_id or str(uuid4())
self.sink = OutputSink()
self.history: list[dict] = []
self.MAX_HISTORY = 40
self._broadcast = broadcast or (lambda e: None)
# Load graph and instantiate nodes
gname = graph_name or _active_graph_name
self.graph = load_graph(gname)
self.process_manager = ProcessManager(send_hud=self._send_hud)
nodes = instantiate_nodes(self.graph, send_hud=self._send_hud,
process_manager=self.process_manager)
# Bind nodes by role (pipeline code references these)
self.input_node = nodes["input"]
self.thinker = nodes.get("thinker") # v1/v2/v3
self.output_node = nodes["output"]
self.ui_node = nodes["ui"]
self.memorizer = nodes["memorizer"]
self.director = nodes.get("director") # v1/v2/v3, None in v4
self.sensor = nodes["sensor"]
self.interpreter = nodes.get("interpreter") # v2+ only
if pool:
# Phase 2: use shared node pool for stateless nodes
self.graph = pool.graph
self.process_manager = ProcessManager(send_hud=self._send_hud)
# Shared nodes from pool (stateless, serve all sessions)
self.input_node = pool.shared.get("input")
self.thinker = pool.shared.get("thinker")
self.output_node = pool.shared.get("output")
self.director = pool.shared.get("director")
self.interpreter = pool.shared.get("interpreter")
# Per-session stateful nodes (fresh each session)
from .nodes import UINode, MemorizerNodeV1 as MemorizerNode, SensorNode
self.ui_node = UINode(send_hud=self._send_hud)
self.memorizer = MemorizerNode(send_hud=self._send_hud)
self.sensor = SensorNode(send_hud=self._send_hud)
# Build combined nodes dict for FrameEngine
nodes = dict(pool.shared)
nodes["ui"] = self.ui_node
nodes["memorizer"] = self.memorizer
nodes["sensor"] = self.sensor
log.info(f"[runtime] using shared pool for graph '{gname}' "
f"({len(pool.shared)} shared, 3 per-session)")
else:
# Legacy: create all nodes per-session
self.graph = load_graph(gname)
self.process_manager = ProcessManager(send_hud=self._send_hud)
nodes = instantiate_nodes(self.graph, send_hud=self._send_hud,
process_manager=self.process_manager)
self.input_node = nodes["input"]
self.thinker = nodes.get("thinker")
self.output_node = nodes["output"]
self.ui_node = nodes["ui"]
self.memorizer = nodes["memorizer"]
self.director = nodes.get("director")
self.sensor = nodes["sensor"]
self.interpreter = nodes.get("interpreter")
# Detect graph type
self.is_v2 = self.director is not None and hasattr(self.director, "decide")