Wire NodePool into Runtime + API: shared stateless nodes across sessions
- Runtime accepts optional pool= param, uses shared nodes from pool for stateless roles, creates fresh sensor/memorizer/ui per-session - FrameEngine sets _current_hud contextvar at start of process_message - API creates global NodePool once, passes to all Runtime instances - Graph switch resets pool for new graph - Legacy _ensure_runtime also uses pool - 15/15 engine + matrix tests green Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
ae2338e70a
commit
376fdb2458
32
agent/api.py
32
agent/api.py
@ -14,6 +14,7 @@ import httpx
|
||||
|
||||
from .auth import AUTH_ENABLED, ZITADEL_ISSUER, _validate_token, require_auth
|
||||
from .runtime import Runtime, TRACE_FILE
|
||||
from .node_pool import NodePool
|
||||
from . import db_sessions
|
||||
|
||||
log = logging.getLogger("runtime")
|
||||
@ -22,6 +23,9 @@ log = logging.getLogger("runtime")
|
||||
_sessions: dict[str, Runtime] = {}
|
||||
MAX_ACTIVE_SESSIONS = 50
|
||||
|
||||
# Shared node pool (created once, shared across all sessions)
|
||||
_node_pool: NodePool | None = None
|
||||
|
||||
# Legacy: for backward compat with single-session MCP/test endpoints
|
||||
_active_runtime: Runtime | None = None
|
||||
|
||||
@ -193,6 +197,16 @@ def register_routes(app):
|
||||
"projectId": ZITADEL_PROJECT_ID,
|
||||
}
|
||||
|
||||
def _get_pool(graph_name: str = None) -> NodePool:
|
||||
"""Get or create the shared node pool."""
|
||||
global _node_pool
|
||||
from .runtime import _active_graph_name
|
||||
gname = graph_name or _active_graph_name
|
||||
if _node_pool is None or _node_pool.graph_name != gname:
|
||||
_node_pool = NodePool(gname)
|
||||
log.info(f"[api] created shared node pool for '{gname}'")
|
||||
return _node_pool
|
||||
|
||||
async def _get_or_create_session(session_id: str = None, user_claims=None, origin="") -> Runtime:
|
||||
"""Get existing session or create new one."""
|
||||
global _active_runtime
|
||||
@ -207,24 +221,26 @@ def register_routes(app):
|
||||
if session_id:
|
||||
saved = await db_sessions.load_session(session_id)
|
||||
if saved:
|
||||
pool = _get_pool(saved["graph_name"])
|
||||
rt = Runtime(user_claims=user_claims, origin=origin,
|
||||
broadcast=_broadcast_sse,
|
||||
graph_name=saved["graph_name"],
|
||||
session_id=session_id)
|
||||
session_id=session_id, pool=pool)
|
||||
rt.restore_state(saved)
|
||||
_sessions[session_id] = rt
|
||||
_active_runtime = rt
|
||||
log.info(f"[api] restored session {session_id} from DB")
|
||||
log.info(f"[api] restored session {session_id} from DB (shared pool)")
|
||||
return rt
|
||||
|
||||
# Create new session
|
||||
user_id = (user_claims or {}).get("sub", "anonymous")
|
||||
new_id = await db_sessions.create_session(user_id)
|
||||
pool = _get_pool()
|
||||
rt = Runtime(user_claims=user_claims, origin=origin,
|
||||
broadcast=_broadcast_sse, session_id=new_id)
|
||||
broadcast=_broadcast_sse, session_id=new_id, pool=pool)
|
||||
_sessions[new_id] = rt
|
||||
_active_runtime = rt
|
||||
log.info(f"[api] created new session {new_id}")
|
||||
log.info(f"[api] created new session {new_id} (shared pool)")
|
||||
return rt
|
||||
|
||||
async def _save_session(rt: Runtime):
|
||||
@ -239,10 +255,11 @@ def register_routes(app):
|
||||
"""Legacy: get or create singleton runtime (backward compat for MCP/tests)."""
|
||||
global _active_runtime
|
||||
if _active_runtime is None:
|
||||
pool = _get_pool()
|
||||
_active_runtime = Runtime(user_claims=user_claims, origin=origin,
|
||||
broadcast=_broadcast_sse)
|
||||
broadcast=_broadcast_sse, pool=pool)
|
||||
_sessions[_active_runtime.session_id] = _active_runtime
|
||||
log.info("[api] created persistent runtime (legacy)")
|
||||
log.info("[api] created persistent runtime (shared pool)")
|
||||
return _active_runtime
|
||||
|
||||
async def _auth_debug_ws(ws: WebSocket, token: str | None) -> bool:
|
||||
@ -587,12 +604,13 @@ def register_routes(app):
|
||||
|
||||
@app.post("/api/graph/switch")
|
||||
async def switch_graph(body: dict, user=Depends(require_auth)):
|
||||
global _active_runtime
|
||||
global _active_runtime, _node_pool
|
||||
from .engine import load_graph
|
||||
import agent.runtime as rt
|
||||
name = body.get("name", "")
|
||||
graph = load_graph(name) # validates it exists
|
||||
rt._active_graph_name = name
|
||||
_node_pool = None # Force pool recreation for new graph
|
||||
|
||||
# Preserve WS connection across graph switch
|
||||
old_ws = None
|
||||
|
||||
@ -16,6 +16,7 @@ import time
|
||||
from dataclasses import dataclass, field, asdict
|
||||
|
||||
from .types import Envelope, Command, InputAnalysis, ThoughtResult, DirectorPlan, PARouting
|
||||
from .nodes.base import _current_hud
|
||||
|
||||
log = logging.getLogger("runtime")
|
||||
|
||||
@ -189,6 +190,9 @@ class FrameEngine:
|
||||
saved_models[role] = node.model
|
||||
node.model = model
|
||||
|
||||
# Set session-scoped HUD for shared nodes (contextvar, per-task)
|
||||
_current_hud.set(self._send_hud)
|
||||
|
||||
try:
|
||||
self._begin_trace(text)
|
||||
|
||||
|
||||
@ -85,29 +85,56 @@ class OutputSink:
|
||||
class Runtime:
|
||||
def __init__(self, user_claims: dict = None, origin: str = "",
|
||||
broadcast: Callable = None, graph_name: str = None,
|
||||
session_id: str = None):
|
||||
session_id: str = None, pool=None):
|
||||
self.session_id = session_id or str(uuid4())
|
||||
self.sink = OutputSink()
|
||||
self.history: list[dict] = []
|
||||
self.MAX_HISTORY = 40
|
||||
self._broadcast = broadcast or (lambda e: None)
|
||||
|
||||
# Load graph and instantiate nodes
|
||||
gname = graph_name or _active_graph_name
|
||||
self.graph = load_graph(gname)
|
||||
self.process_manager = ProcessManager(send_hud=self._send_hud)
|
||||
nodes = instantiate_nodes(self.graph, send_hud=self._send_hud,
|
||||
process_manager=self.process_manager)
|
||||
|
||||
# Bind nodes by role (pipeline code references these)
|
||||
self.input_node = nodes["input"]
|
||||
self.thinker = nodes.get("thinker") # v1/v2/v3
|
||||
self.output_node = nodes["output"]
|
||||
self.ui_node = nodes["ui"]
|
||||
self.memorizer = nodes["memorizer"]
|
||||
self.director = nodes.get("director") # v1/v2/v3, None in v4
|
||||
self.sensor = nodes["sensor"]
|
||||
self.interpreter = nodes.get("interpreter") # v2+ only
|
||||
if pool:
|
||||
# Phase 2: use shared node pool for stateless nodes
|
||||
self.graph = pool.graph
|
||||
self.process_manager = ProcessManager(send_hud=self._send_hud)
|
||||
|
||||
# Shared nodes from pool (stateless, serve all sessions)
|
||||
self.input_node = pool.shared.get("input")
|
||||
self.thinker = pool.shared.get("thinker")
|
||||
self.output_node = pool.shared.get("output")
|
||||
self.director = pool.shared.get("director")
|
||||
self.interpreter = pool.shared.get("interpreter")
|
||||
|
||||
# Per-session stateful nodes (fresh each session)
|
||||
from .nodes import UINode, MemorizerNodeV1 as MemorizerNode, SensorNode
|
||||
self.ui_node = UINode(send_hud=self._send_hud)
|
||||
self.memorizer = MemorizerNode(send_hud=self._send_hud)
|
||||
self.sensor = SensorNode(send_hud=self._send_hud)
|
||||
|
||||
# Build combined nodes dict for FrameEngine
|
||||
nodes = dict(pool.shared)
|
||||
nodes["ui"] = self.ui_node
|
||||
nodes["memorizer"] = self.memorizer
|
||||
nodes["sensor"] = self.sensor
|
||||
|
||||
log.info(f"[runtime] using shared pool for graph '{gname}' "
|
||||
f"({len(pool.shared)} shared, 3 per-session)")
|
||||
else:
|
||||
# Legacy: create all nodes per-session
|
||||
self.graph = load_graph(gname)
|
||||
self.process_manager = ProcessManager(send_hud=self._send_hud)
|
||||
nodes = instantiate_nodes(self.graph, send_hud=self._send_hud,
|
||||
process_manager=self.process_manager)
|
||||
|
||||
self.input_node = nodes["input"]
|
||||
self.thinker = nodes.get("thinker")
|
||||
self.output_node = nodes["output"]
|
||||
self.ui_node = nodes["ui"]
|
||||
self.memorizer = nodes["memorizer"]
|
||||
self.director = nodes.get("director")
|
||||
self.sensor = nodes["sensor"]
|
||||
self.interpreter = nodes.get("interpreter")
|
||||
|
||||
# Detect graph type
|
||||
self.is_v2 = self.director is not None and hasattr(self.director, "decide")
|
||||
|
||||
Reference in New Issue
Block a user