Wire NodePool into Runtime + API: shared stateless nodes across sessions
- Runtime accepts optional pool= param, uses shared nodes from pool for stateless roles, creates fresh sensor/memorizer/ui per-session - FrameEngine sets _current_hud contextvar at start of process_message - API creates global NodePool once, passes to all Runtime instances - Graph switch resets pool for new graph - Legacy _ensure_runtime also uses pool - 15/15 engine + matrix tests green Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
ae2338e70a
commit
376fdb2458
32
agent/api.py
32
agent/api.py
@ -14,6 +14,7 @@ import httpx
|
|||||||
|
|
||||||
from .auth import AUTH_ENABLED, ZITADEL_ISSUER, _validate_token, require_auth
|
from .auth import AUTH_ENABLED, ZITADEL_ISSUER, _validate_token, require_auth
|
||||||
from .runtime import Runtime, TRACE_FILE
|
from .runtime import Runtime, TRACE_FILE
|
||||||
|
from .node_pool import NodePool
|
||||||
from . import db_sessions
|
from . import db_sessions
|
||||||
|
|
||||||
log = logging.getLogger("runtime")
|
log = logging.getLogger("runtime")
|
||||||
@ -22,6 +23,9 @@ log = logging.getLogger("runtime")
|
|||||||
_sessions: dict[str, Runtime] = {}
|
_sessions: dict[str, Runtime] = {}
|
||||||
MAX_ACTIVE_SESSIONS = 50
|
MAX_ACTIVE_SESSIONS = 50
|
||||||
|
|
||||||
|
# Shared node pool (created once, shared across all sessions)
|
||||||
|
_node_pool: NodePool | None = None
|
||||||
|
|
||||||
# Legacy: for backward compat with single-session MCP/test endpoints
|
# Legacy: for backward compat with single-session MCP/test endpoints
|
||||||
_active_runtime: Runtime | None = None
|
_active_runtime: Runtime | None = None
|
||||||
|
|
||||||
@ -193,6 +197,16 @@ def register_routes(app):
|
|||||||
"projectId": ZITADEL_PROJECT_ID,
|
"projectId": ZITADEL_PROJECT_ID,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _get_pool(graph_name: str = None) -> NodePool:
|
||||||
|
"""Get or create the shared node pool."""
|
||||||
|
global _node_pool
|
||||||
|
from .runtime import _active_graph_name
|
||||||
|
gname = graph_name or _active_graph_name
|
||||||
|
if _node_pool is None or _node_pool.graph_name != gname:
|
||||||
|
_node_pool = NodePool(gname)
|
||||||
|
log.info(f"[api] created shared node pool for '{gname}'")
|
||||||
|
return _node_pool
|
||||||
|
|
||||||
async def _get_or_create_session(session_id: str = None, user_claims=None, origin="") -> Runtime:
|
async def _get_or_create_session(session_id: str = None, user_claims=None, origin="") -> Runtime:
|
||||||
"""Get existing session or create new one."""
|
"""Get existing session or create new one."""
|
||||||
global _active_runtime
|
global _active_runtime
|
||||||
@ -207,24 +221,26 @@ def register_routes(app):
|
|||||||
if session_id:
|
if session_id:
|
||||||
saved = await db_sessions.load_session(session_id)
|
saved = await db_sessions.load_session(session_id)
|
||||||
if saved:
|
if saved:
|
||||||
|
pool = _get_pool(saved["graph_name"])
|
||||||
rt = Runtime(user_claims=user_claims, origin=origin,
|
rt = Runtime(user_claims=user_claims, origin=origin,
|
||||||
broadcast=_broadcast_sse,
|
broadcast=_broadcast_sse,
|
||||||
graph_name=saved["graph_name"],
|
graph_name=saved["graph_name"],
|
||||||
session_id=session_id)
|
session_id=session_id, pool=pool)
|
||||||
rt.restore_state(saved)
|
rt.restore_state(saved)
|
||||||
_sessions[session_id] = rt
|
_sessions[session_id] = rt
|
||||||
_active_runtime = rt
|
_active_runtime = rt
|
||||||
log.info(f"[api] restored session {session_id} from DB")
|
log.info(f"[api] restored session {session_id} from DB (shared pool)")
|
||||||
return rt
|
return rt
|
||||||
|
|
||||||
# Create new session
|
# Create new session
|
||||||
user_id = (user_claims or {}).get("sub", "anonymous")
|
user_id = (user_claims or {}).get("sub", "anonymous")
|
||||||
new_id = await db_sessions.create_session(user_id)
|
new_id = await db_sessions.create_session(user_id)
|
||||||
|
pool = _get_pool()
|
||||||
rt = Runtime(user_claims=user_claims, origin=origin,
|
rt = Runtime(user_claims=user_claims, origin=origin,
|
||||||
broadcast=_broadcast_sse, session_id=new_id)
|
broadcast=_broadcast_sse, session_id=new_id, pool=pool)
|
||||||
_sessions[new_id] = rt
|
_sessions[new_id] = rt
|
||||||
_active_runtime = rt
|
_active_runtime = rt
|
||||||
log.info(f"[api] created new session {new_id}")
|
log.info(f"[api] created new session {new_id} (shared pool)")
|
||||||
return rt
|
return rt
|
||||||
|
|
||||||
async def _save_session(rt: Runtime):
|
async def _save_session(rt: Runtime):
|
||||||
@ -239,10 +255,11 @@ def register_routes(app):
|
|||||||
"""Legacy: get or create singleton runtime (backward compat for MCP/tests)."""
|
"""Legacy: get or create singleton runtime (backward compat for MCP/tests)."""
|
||||||
global _active_runtime
|
global _active_runtime
|
||||||
if _active_runtime is None:
|
if _active_runtime is None:
|
||||||
|
pool = _get_pool()
|
||||||
_active_runtime = Runtime(user_claims=user_claims, origin=origin,
|
_active_runtime = Runtime(user_claims=user_claims, origin=origin,
|
||||||
broadcast=_broadcast_sse)
|
broadcast=_broadcast_sse, pool=pool)
|
||||||
_sessions[_active_runtime.session_id] = _active_runtime
|
_sessions[_active_runtime.session_id] = _active_runtime
|
||||||
log.info("[api] created persistent runtime (legacy)")
|
log.info("[api] created persistent runtime (shared pool)")
|
||||||
return _active_runtime
|
return _active_runtime
|
||||||
|
|
||||||
async def _auth_debug_ws(ws: WebSocket, token: str | None) -> bool:
|
async def _auth_debug_ws(ws: WebSocket, token: str | None) -> bool:
|
||||||
@ -587,12 +604,13 @@ def register_routes(app):
|
|||||||
|
|
||||||
@app.post("/api/graph/switch")
|
@app.post("/api/graph/switch")
|
||||||
async def switch_graph(body: dict, user=Depends(require_auth)):
|
async def switch_graph(body: dict, user=Depends(require_auth)):
|
||||||
global _active_runtime
|
global _active_runtime, _node_pool
|
||||||
from .engine import load_graph
|
from .engine import load_graph
|
||||||
import agent.runtime as rt
|
import agent.runtime as rt
|
||||||
name = body.get("name", "")
|
name = body.get("name", "")
|
||||||
graph = load_graph(name) # validates it exists
|
graph = load_graph(name) # validates it exists
|
||||||
rt._active_graph_name = name
|
rt._active_graph_name = name
|
||||||
|
_node_pool = None # Force pool recreation for new graph
|
||||||
|
|
||||||
# Preserve WS connection across graph switch
|
# Preserve WS connection across graph switch
|
||||||
old_ws = None
|
old_ws = None
|
||||||
|
|||||||
@ -16,6 +16,7 @@ import time
|
|||||||
from dataclasses import dataclass, field, asdict
|
from dataclasses import dataclass, field, asdict
|
||||||
|
|
||||||
from .types import Envelope, Command, InputAnalysis, ThoughtResult, DirectorPlan, PARouting
|
from .types import Envelope, Command, InputAnalysis, ThoughtResult, DirectorPlan, PARouting
|
||||||
|
from .nodes.base import _current_hud
|
||||||
|
|
||||||
log = logging.getLogger("runtime")
|
log = logging.getLogger("runtime")
|
||||||
|
|
||||||
@ -189,6 +190,9 @@ class FrameEngine:
|
|||||||
saved_models[role] = node.model
|
saved_models[role] = node.model
|
||||||
node.model = model
|
node.model = model
|
||||||
|
|
||||||
|
# Set session-scoped HUD for shared nodes (contextvar, per-task)
|
||||||
|
_current_hud.set(self._send_hud)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._begin_trace(text)
|
self._begin_trace(text)
|
||||||
|
|
||||||
|
|||||||
@ -85,29 +85,56 @@ class OutputSink:
|
|||||||
class Runtime:
|
class Runtime:
|
||||||
def __init__(self, user_claims: dict = None, origin: str = "",
|
def __init__(self, user_claims: dict = None, origin: str = "",
|
||||||
broadcast: Callable = None, graph_name: str = None,
|
broadcast: Callable = None, graph_name: str = None,
|
||||||
session_id: str = None):
|
session_id: str = None, pool=None):
|
||||||
self.session_id = session_id or str(uuid4())
|
self.session_id = session_id or str(uuid4())
|
||||||
self.sink = OutputSink()
|
self.sink = OutputSink()
|
||||||
self.history: list[dict] = []
|
self.history: list[dict] = []
|
||||||
self.MAX_HISTORY = 40
|
self.MAX_HISTORY = 40
|
||||||
self._broadcast = broadcast or (lambda e: None)
|
self._broadcast = broadcast or (lambda e: None)
|
||||||
|
|
||||||
# Load graph and instantiate nodes
|
|
||||||
gname = graph_name or _active_graph_name
|
gname = graph_name or _active_graph_name
|
||||||
self.graph = load_graph(gname)
|
|
||||||
self.process_manager = ProcessManager(send_hud=self._send_hud)
|
|
||||||
nodes = instantiate_nodes(self.graph, send_hud=self._send_hud,
|
|
||||||
process_manager=self.process_manager)
|
|
||||||
|
|
||||||
# Bind nodes by role (pipeline code references these)
|
if pool:
|
||||||
self.input_node = nodes["input"]
|
# Phase 2: use shared node pool for stateless nodes
|
||||||
self.thinker = nodes.get("thinker") # v1/v2/v3
|
self.graph = pool.graph
|
||||||
self.output_node = nodes["output"]
|
self.process_manager = ProcessManager(send_hud=self._send_hud)
|
||||||
self.ui_node = nodes["ui"]
|
|
||||||
self.memorizer = nodes["memorizer"]
|
# Shared nodes from pool (stateless, serve all sessions)
|
||||||
self.director = nodes.get("director") # v1/v2/v3, None in v4
|
self.input_node = pool.shared.get("input")
|
||||||
self.sensor = nodes["sensor"]
|
self.thinker = pool.shared.get("thinker")
|
||||||
self.interpreter = nodes.get("interpreter") # v2+ only
|
self.output_node = pool.shared.get("output")
|
||||||
|
self.director = pool.shared.get("director")
|
||||||
|
self.interpreter = pool.shared.get("interpreter")
|
||||||
|
|
||||||
|
# Per-session stateful nodes (fresh each session)
|
||||||
|
from .nodes import UINode, MemorizerNodeV1 as MemorizerNode, SensorNode
|
||||||
|
self.ui_node = UINode(send_hud=self._send_hud)
|
||||||
|
self.memorizer = MemorizerNode(send_hud=self._send_hud)
|
||||||
|
self.sensor = SensorNode(send_hud=self._send_hud)
|
||||||
|
|
||||||
|
# Build combined nodes dict for FrameEngine
|
||||||
|
nodes = dict(pool.shared)
|
||||||
|
nodes["ui"] = self.ui_node
|
||||||
|
nodes["memorizer"] = self.memorizer
|
||||||
|
nodes["sensor"] = self.sensor
|
||||||
|
|
||||||
|
log.info(f"[runtime] using shared pool for graph '{gname}' "
|
||||||
|
f"({len(pool.shared)} shared, 3 per-session)")
|
||||||
|
else:
|
||||||
|
# Legacy: create all nodes per-session
|
||||||
|
self.graph = load_graph(gname)
|
||||||
|
self.process_manager = ProcessManager(send_hud=self._send_hud)
|
||||||
|
nodes = instantiate_nodes(self.graph, send_hud=self._send_hud,
|
||||||
|
process_manager=self.process_manager)
|
||||||
|
|
||||||
|
self.input_node = nodes["input"]
|
||||||
|
self.thinker = nodes.get("thinker")
|
||||||
|
self.output_node = nodes["output"]
|
||||||
|
self.ui_node = nodes["ui"]
|
||||||
|
self.memorizer = nodes["memorizer"]
|
||||||
|
self.director = nodes.get("director")
|
||||||
|
self.sensor = nodes["sensor"]
|
||||||
|
self.interpreter = nodes.get("interpreter")
|
||||||
|
|
||||||
# Detect graph type
|
# Detect graph type
|
||||||
self.is_v2 = self.director is not None and hasattr(self.director, "decide")
|
self.is_v2 = self.director is not None and hasattr(self.director, "decide")
|
||||||
|
|||||||
Reference in New Issue
Block a user