Compare commits
No commits in common. "44f611685503b54319ac81316a63b7858a4ca617" and "1e64b0a58c159f2339fc8998a876aebbe379d8ac" have entirely different histories.
44f6116855
...
1e64b0a58c
3
.gitignore
vendored
3
.gitignore
vendored
@ -2,7 +2,4 @@
|
||||
__pycache__/
|
||||
*.pyc
|
||||
.env
|
||||
.env.local
|
||||
trace.jsonl
|
||||
docker/mariadb/*.sql.gz
|
||||
docker/mariadb/dump_*.sql
|
||||
|
||||
179
agent/api.py
179
agent/api.py
@ -14,7 +14,6 @@ import httpx
|
||||
|
||||
from .auth import AUTH_ENABLED, ZITADEL_ISSUER, _validate_token, require_auth
|
||||
from .runtime import Runtime, TRACE_FILE
|
||||
from .node_pool import NodePool
|
||||
from . import db_sessions
|
||||
|
||||
log = logging.getLogger("runtime")
|
||||
@ -23,9 +22,6 @@ log = logging.getLogger("runtime")
|
||||
_sessions: dict[str, Runtime] = {}
|
||||
MAX_ACTIVE_SESSIONS = 50
|
||||
|
||||
# Shared node pool (created once, shared across all sessions)
|
||||
_node_pool: NodePool | None = None
|
||||
|
||||
# Legacy: for backward compat with single-session MCP/test endpoints
|
||||
_active_runtime: Runtime | None = None
|
||||
|
||||
@ -36,16 +32,6 @@ _sse_subscribers: list[Queue] = []
|
||||
_test_ws_clients: list[WebSocket] = [] # /ws/test subscribers
|
||||
_trace_ws_clients: list[WebSocket] = [] # /ws/trace subscribers
|
||||
|
||||
# Debug command channel: AI → assay → nyx browser (via SSE) → assay → AI
|
||||
_debug_queues: list[Queue] = [] # per-client SSE queues for debug commands
|
||||
_debug_results: dict[str, asyncio.Event] = {} # cmd_id → event (set when result arrives)
|
||||
_debug_result_data: dict[str, dict] = {} # cmd_id → result payload
|
||||
|
||||
# Test results (in-memory, fed by test-runner Job in assay-test namespace)
|
||||
_test_results: list[dict] = []
|
||||
_test_results_subscribers: list[Queue] = []
|
||||
_test_run_id: str = ""
|
||||
|
||||
|
||||
async def _broadcast_test(event: dict):
|
||||
"""Push to all /ws/test subscribers."""
|
||||
@ -131,32 +117,14 @@ def register_routes(app):
|
||||
|
||||
@app.get("/api/health-stream")
|
||||
async def health_stream(user=Depends(require_auth)):
|
||||
"""SSE heartbeat + debug command stream."""
|
||||
q: Queue = Queue(maxsize=100)
|
||||
_debug_queues.append(q)
|
||||
|
||||
"""SSE heartbeat stream — client uses this for presence detection."""
|
||||
async def generate():
|
||||
try:
|
||||
while True:
|
||||
# Drain any pending debug commands first
|
||||
while not q.empty():
|
||||
try:
|
||||
cmd = q.get_nowait()
|
||||
yield f"event: debug_cmd\ndata: {json.dumps(cmd)}\n\n"
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
yield f"event: heartbeat\ndata: {json.dumps({'ts': int(asyncio.get_event_loop().time()), 'sessions': len(_sessions)})}\n\n"
|
||||
# Wait up to 1s for debug commands, then loop for heartbeat
|
||||
try:
|
||||
cmd = await asyncio.wait_for(q.get(), timeout=1.0)
|
||||
yield f"event: debug_cmd\ndata: {json.dumps(cmd)}\n\n"
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
await asyncio.sleep(15)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
if q in _debug_queues:
|
||||
_debug_queues.remove(q)
|
||||
|
||||
return StreamingResponse(generate(), media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
|
||||
@ -197,16 +165,6 @@ def register_routes(app):
|
||||
"projectId": ZITADEL_PROJECT_ID,
|
||||
}
|
||||
|
||||
def _get_pool(graph_name: str = None) -> NodePool:
|
||||
"""Get or create the shared node pool."""
|
||||
global _node_pool
|
||||
from .runtime import _active_graph_name
|
||||
gname = graph_name or _active_graph_name
|
||||
if _node_pool is None or _node_pool.graph_name != gname:
|
||||
_node_pool = NodePool(gname)
|
||||
log.info(f"[api] created shared node pool for '{gname}'")
|
||||
return _node_pool
|
||||
|
||||
async def _get_or_create_session(session_id: str = None, user_claims=None, origin="") -> Runtime:
|
||||
"""Get existing session or create new one."""
|
||||
global _active_runtime
|
||||
@ -221,26 +179,24 @@ def register_routes(app):
|
||||
if session_id:
|
||||
saved = await db_sessions.load_session(session_id)
|
||||
if saved:
|
||||
pool = _get_pool(saved["graph_name"])
|
||||
rt = Runtime(user_claims=user_claims, origin=origin,
|
||||
broadcast=_broadcast_sse,
|
||||
graph_name=saved["graph_name"],
|
||||
session_id=session_id, pool=pool)
|
||||
session_id=session_id)
|
||||
rt.restore_state(saved)
|
||||
_sessions[session_id] = rt
|
||||
_active_runtime = rt
|
||||
log.info(f"[api] restored session {session_id} from DB (shared pool)")
|
||||
log.info(f"[api] restored session {session_id} from DB")
|
||||
return rt
|
||||
|
||||
# Create new session
|
||||
user_id = (user_claims or {}).get("sub", "anonymous")
|
||||
new_id = await db_sessions.create_session(user_id)
|
||||
pool = _get_pool()
|
||||
rt = Runtime(user_claims=user_claims, origin=origin,
|
||||
broadcast=_broadcast_sse, session_id=new_id, pool=pool)
|
||||
broadcast=_broadcast_sse, session_id=new_id)
|
||||
_sessions[new_id] = rt
|
||||
_active_runtime = rt
|
||||
log.info(f"[api] created new session {new_id} (shared pool)")
|
||||
log.info(f"[api] created new session {new_id}")
|
||||
return rt
|
||||
|
||||
async def _save_session(rt: Runtime):
|
||||
@ -255,11 +211,10 @@ def register_routes(app):
|
||||
"""Legacy: get or create singleton runtime (backward compat for MCP/tests)."""
|
||||
global _active_runtime
|
||||
if _active_runtime is None:
|
||||
pool = _get_pool()
|
||||
_active_runtime = Runtime(user_claims=user_claims, origin=origin,
|
||||
broadcast=_broadcast_sse, pool=pool)
|
||||
broadcast=_broadcast_sse)
|
||||
_sessions[_active_runtime.session_id] = _active_runtime
|
||||
log.info("[api] created persistent runtime (shared pool)")
|
||||
log.info("[api] created persistent runtime (legacy)")
|
||||
return _active_runtime
|
||||
|
||||
async def _auth_debug_ws(ws: WebSocket, token: str | None) -> bool:
|
||||
@ -323,8 +278,6 @@ def register_routes(app):
|
||||
action = body.get("action")
|
||||
action_data = body.get("action_data")
|
||||
dashboard = body.get("dashboard")
|
||||
# Model overrides: {"models": {"input": "x/y", "pa": "a/b"}}
|
||||
model_overrides = body.get("models")
|
||||
|
||||
if not text and not action:
|
||||
raise HTTPException(status_code=400, detail="Missing 'content' or 'action'")
|
||||
@ -349,8 +302,7 @@ def register_routes(app):
|
||||
else:
|
||||
await rt.handle_action(action, action_data)
|
||||
else:
|
||||
await rt.handle_message(text, dashboard=dashboard,
|
||||
model_overrides=model_overrides)
|
||||
await rt.handle_message(text, dashboard=dashboard)
|
||||
# Auto-save
|
||||
await _save_session(rt)
|
||||
except Exception as e:
|
||||
@ -604,13 +556,12 @@ def register_routes(app):
|
||||
|
||||
@app.post("/api/graph/switch")
|
||||
async def switch_graph(body: dict, user=Depends(require_auth)):
|
||||
global _active_runtime, _node_pool
|
||||
global _active_runtime
|
||||
from .engine import load_graph
|
||||
import agent.runtime as rt
|
||||
name = body.get("name", "")
|
||||
graph = load_graph(name) # validates it exists
|
||||
rt._active_graph_name = name
|
||||
_node_pool = None # Force pool recreation for new graph
|
||||
|
||||
# Preserve WS connection across graph switch
|
||||
old_ws = None
|
||||
@ -691,113 +642,3 @@ def register_routes(app):
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return {"lines": parsed}
|
||||
|
||||
# --- Debug command channel ---
|
||||
# Flow: AI POSTs cmd → queued → nyx picks up via poll → executes → POSTs result → AI gets response
|
||||
|
||||
@app.post("/api/debug/cmd")
|
||||
async def debug_cmd(body: dict, user=Depends(require_auth)):
|
||||
"""AI sends a command to execute in the browser. Waits up to 5s for result."""
|
||||
import uuid
|
||||
cmd_id = str(uuid.uuid4())[:8]
|
||||
cmd = body.get("cmd", "")
|
||||
args = body.get("args", {})
|
||||
if not cmd:
|
||||
raise HTTPException(400, "Missing 'cmd'")
|
||||
|
||||
if not _debug_queues:
|
||||
return {"cmd_id": cmd_id, "error": "no browser connected"}
|
||||
|
||||
evt = asyncio.Event()
|
||||
_debug_results[cmd_id] = evt
|
||||
# Push to all connected SSE clients
|
||||
payload = {"cmd_id": cmd_id, "cmd": cmd, "args": args}
|
||||
for q in _debug_queues:
|
||||
try:
|
||||
q.put_nowait(payload)
|
||||
except asyncio.QueueFull:
|
||||
pass
|
||||
|
||||
# Wait for nyx to execute and POST result back
|
||||
try:
|
||||
await asyncio.wait_for(evt.wait(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
_debug_results.pop(cmd_id, None)
|
||||
return {"cmd_id": cmd_id, "error": "timeout — command took too long"}
|
||||
|
||||
result = _debug_result_data.pop(cmd_id, {})
|
||||
_debug_results.pop(cmd_id, None)
|
||||
return {"cmd_id": cmd_id, **result}
|
||||
|
||||
@app.post("/api/debug/result")
|
||||
async def debug_result(body: dict, user=Depends(require_auth)):
|
||||
"""nyx posts command results back here."""
|
||||
cmd_id = body.get("cmd_id", "")
|
||||
if not cmd_id:
|
||||
raise HTTPException(400, "Missing 'cmd_id'")
|
||||
_debug_result_data[cmd_id] = {"result": body.get("result"), "error": body.get("error")}
|
||||
evt = _debug_results.get(cmd_id)
|
||||
if evt:
|
||||
evt.set()
|
||||
return {"ok": True}
|
||||
|
||||
# ── Test results (fed by test-runner Job) ────────────────────────────────
|
||||
|
||||
@app.post("/api/test-results")
|
||||
async def post_test_result(request: Request):
|
||||
"""Receive a single test result from the test-runner Job. No auth (internal)."""
|
||||
global _test_run_id
|
||||
body = await request.json()
|
||||
run_id = body.get("run_id", "")
|
||||
|
||||
# New run → clear old results
|
||||
if run_id and run_id != _test_run_id:
|
||||
_test_results.clear()
|
||||
_test_run_id = run_id
|
||||
|
||||
# Replace existing entry for same test+suite, or append
|
||||
key = (body.get("test"), body.get("suite"))
|
||||
for i, existing in enumerate(_test_results):
|
||||
if (existing.get("test"), existing.get("suite")) == key:
|
||||
_test_results[i] = body
|
||||
break
|
||||
else:
|
||||
_test_results.append(body)
|
||||
|
||||
# Push to SSE subscribers
|
||||
for q in list(_test_results_subscribers):
|
||||
try:
|
||||
q.put_nowait(body)
|
||||
except asyncio.QueueFull:
|
||||
pass
|
||||
return {"ok": True}
|
||||
|
||||
@app.get("/api/test-results")
|
||||
async def get_test_results_sse(request: Request):
|
||||
"""SSE stream: sends current results on connect, then live updates."""
|
||||
q: Queue = Queue(maxsize=100)
|
||||
_test_results_subscribers.append(q)
|
||||
|
||||
async def stream():
|
||||
try:
|
||||
# Send all existing results first
|
||||
for r in list(_test_results):
|
||||
yield f"data: {json.dumps(r)}\n\n"
|
||||
# Then stream new ones
|
||||
while True:
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
try:
|
||||
result = await asyncio.wait_for(q.get(), timeout=15)
|
||||
yield f"data: {json.dumps(result)}\n\n"
|
||||
except asyncio.TimeoutError:
|
||||
yield f"event: heartbeat\ndata: {{}}\n\n"
|
||||
finally:
|
||||
_test_results_subscribers.remove(q)
|
||||
|
||||
return StreamingResponse(stream(), media_type="text/event-stream")
|
||||
|
||||
@app.get("/api/test-results/latest")
|
||||
async def get_test_results_latest():
|
||||
"""JSON snapshot of the latest test run."""
|
||||
return {"run_id": _test_run_id, "results": _test_results}
|
||||
|
||||
@ -63,7 +63,6 @@ def _graph_from_module(mod) -> dict:
|
||||
"conditions": getattr(mod, "CONDITIONS", {}),
|
||||
"audit": getattr(mod, "AUDIT", {}),
|
||||
"engine": getattr(mod, "ENGINE", "imperative"),
|
||||
"models": getattr(mod, "MODELS", {}),
|
||||
}
|
||||
|
||||
|
||||
@ -80,13 +79,7 @@ def instantiate_nodes(graph: dict, send_hud, process_manager: ProcessManager = N
|
||||
nodes[role] = cls(send_hud=send_hud, process_manager=process_manager)
|
||||
else:
|
||||
nodes[role] = cls(send_hud=send_hud)
|
||||
# Apply model from graph config (overrides class default)
|
||||
model = graph.get("models", {}).get(role)
|
||||
if model and hasattr(nodes[role], "model"):
|
||||
nodes[role].model = model
|
||||
log.info(f"[engine] {role} = {impl_name} ({cls.__name__}) model={model}")
|
||||
else:
|
||||
log.info(f"[engine] {role} = {impl_name} ({cls.__name__})")
|
||||
log.info(f"[engine] {role} = {impl_name} ({cls.__name__})")
|
||||
return nodes
|
||||
|
||||
|
||||
|
||||
@ -1,15 +1,12 @@
|
||||
"""Frame Engine: edge-walking deterministic pipeline execution.
|
||||
"""Frame Engine: tick-based deterministic pipeline execution.
|
||||
|
||||
Walks the graph's data edges to determine pipeline flow. Each step
|
||||
dispatches a node and evaluates conditions on outgoing edges to pick
|
||||
the next node. Frames advance on node completion (not on a timer).
|
||||
Replaces the imperative handle_message() with a frame-stepping model:
|
||||
- Each frame dispatches all nodes that have pending input
|
||||
- Frames advance on completion (not on a timer)
|
||||
- 0ms when idle, engine awaits external input
|
||||
- Deterministic ordering: reflex=2 frames, thinker=3-4, interpreter=5
|
||||
|
||||
Edge types:
|
||||
data — typed objects flowing between nodes (walked by engine)
|
||||
context — text injected into LLM prompts (aggregated via _build_context)
|
||||
state — shared mutable state reads (consumed by sensor/runtime)
|
||||
|
||||
Works with any graph definition (v1-v4). Node implementations unchanged.
|
||||
Works with any graph definition (v1, v2, v3). Node implementations unchanged.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@ -19,7 +16,6 @@ import time
|
||||
from dataclasses import dataclass, field, asdict
|
||||
|
||||
from .types import Envelope, Command, InputAnalysis, ThoughtResult, DirectorPlan, PARouting
|
||||
from .nodes.base import _current_hud
|
||||
|
||||
log = logging.getLogger("runtime")
|
||||
|
||||
@ -75,22 +71,8 @@ class FrameTrace:
|
||||
}
|
||||
|
||||
|
||||
# --- Walk context: carries state through the edge walk ---
|
||||
|
||||
@dataclass
|
||||
class _WalkCtx:
|
||||
"""Mutable context passed through the edge walker."""
|
||||
command: Command = None
|
||||
routing: PARouting = None
|
||||
plan: DirectorPlan = None
|
||||
thought: ThoughtResult = None
|
||||
mem_ctx: str = ""
|
||||
dashboard: list = None
|
||||
path_nodes: list = field(default_factory=list) # visited node names for path label
|
||||
|
||||
|
||||
class FrameEngine:
|
||||
"""Edge-walking engine that steps through graph nodes frame by frame."""
|
||||
"""Tick-based engine that steps through graph nodes frame by frame."""
|
||||
|
||||
def __init__(self, graph: dict, nodes: dict, sink, history: list,
|
||||
send_hud, sensor, memorizer, ui_node, identity: str = "unknown",
|
||||
@ -110,7 +92,7 @@ class FrameEngine:
|
||||
self.frame = 0
|
||||
self.bus = {}
|
||||
self.conditions = graph.get("conditions", {})
|
||||
self.data_edges = [e for e in graph.get("edges", []) if e.get("type") == "data"]
|
||||
self.edges = [e for e in graph.get("edges", []) if e.get("type") == "data"]
|
||||
|
||||
self.has_director = "director" in nodes and hasattr(nodes.get("director"), "decide")
|
||||
self.has_interpreter = "interpreter" in nodes
|
||||
@ -189,405 +171,61 @@ class FrameEngine:
|
||||
"trace": t.to_dict(),
|
||||
})
|
||||
|
||||
# --- Condition evaluation ---
|
||||
|
||||
def _eval_condition(self, name: str, ctx: _WalkCtx) -> bool:
|
||||
"""Evaluate a named condition against walk context."""
|
||||
if name == "reflex":
|
||||
return (ctx.command and ctx.command.analysis.intent == "social"
|
||||
and ctx.command.analysis.complexity == "trivial")
|
||||
if name == "has_tool_output":
|
||||
return bool(ctx.thought and ctx.thought.tool_used and ctx.thought.tool_output)
|
||||
# PA routing conditions
|
||||
if name == "expert_is_none":
|
||||
return ctx.routing is not None and ctx.routing.expert == "none"
|
||||
if name.startswith("expert_is_"):
|
||||
expert = name[len("expert_is_"):]
|
||||
return ctx.routing is not None and ctx.routing.expert == expert
|
||||
return False
|
||||
|
||||
def _check_condition(self, name: str, command: Command = None,
|
||||
thought: ThoughtResult = None) -> bool:
|
||||
"""Legacy wrapper for _eval_condition (used by tests)."""
|
||||
ctx = _WalkCtx(command=command, thought=thought)
|
||||
return self._eval_condition(name, ctx)
|
||||
|
||||
# --- Edge resolution ---
|
||||
|
||||
def _resolve_edge(self, outgoing: list, ctx: _WalkCtx) -> dict | None:
|
||||
"""Pick the active edge from a node's outgoing data edges.
|
||||
Conditional edges take priority when they match."""
|
||||
conditional = [e for e in outgoing if e.get("condition")]
|
||||
unconditional = [e for e in outgoing if not e.get("condition")]
|
||||
|
||||
for edge in conditional:
|
||||
if self._eval_condition(edge["condition"], ctx):
|
||||
return edge
|
||||
|
||||
return unconditional[0] if unconditional else None
|
||||
|
||||
# --- Node dispatch adapters ---
|
||||
|
||||
async def _dispatch_pa(self, ctx: _WalkCtx) -> str:
|
||||
"""Dispatch PA node. Returns route summary."""
|
||||
a = ctx.command.analysis
|
||||
rec = self._begin_frame(self.frame + 1, "pa",
|
||||
input_summary=f"intent={a.intent} topic={a.topic}")
|
||||
routing = await self.nodes["pa"].route(
|
||||
ctx.command, self.history, memory_context=ctx.mem_ctx,
|
||||
identity=self.identity, channel=self.channel)
|
||||
ctx.routing = routing
|
||||
route_summary = f"expert={routing.expert} job={routing.job[:60]}"
|
||||
self._end_frame(rec, output_summary=route_summary,
|
||||
route=f"expert_{routing.expert}" if routing.expert != "none" else "output")
|
||||
|
||||
# Stream thinking message
|
||||
if routing.thinking_message:
|
||||
await self.sink.send_delta(routing.thinking_message + "\n\n")
|
||||
|
||||
return routing.expert
|
||||
|
||||
async def _dispatch_expert(self, expert_role: str, ctx: _WalkCtx):
|
||||
"""Dispatch an expert node with progress wrapping and PA retry."""
|
||||
expert = self._experts.get(expert_role.replace("expert_", ""))
|
||||
if not expert:
|
||||
expert_name = expert_role.replace("expert_", "")
|
||||
log.error(f"[frame] expert '{expert_name}' not found")
|
||||
ctx.thought = ThoughtResult(response=f"Expert '{expert_name}' not available.")
|
||||
return
|
||||
|
||||
expert_name = expert_role.replace("expert_", "")
|
||||
rec = self._begin_frame(self.frame + 1, expert_role,
|
||||
input_summary=f"job: {ctx.routing.job[:80]}")
|
||||
|
||||
# Wrap expert HUD for progress streaming
|
||||
original_hud = expert.send_hud
|
||||
expert.send_hud = self._make_progress_wrapper(original_hud, ctx.routing.language)
|
||||
try:
|
||||
thought = await expert.execute(ctx.routing.job, ctx.routing.language)
|
||||
finally:
|
||||
expert.send_hud = original_hud
|
||||
|
||||
ctx.thought = thought
|
||||
thought_summary = (f"response[{len(thought.response)}] tool={thought.tool_used or 'none'} "
|
||||
f"actions={len(thought.actions)} errors={len(thought.errors)}")
|
||||
has_tool = bool(thought.tool_used and thought.tool_output)
|
||||
|
||||
# --- PA retry: expert failed or skipped tools ---
|
||||
expectation = self.memorizer.state.get("user_expectation", "conversational")
|
||||
job_needs_data = any(k in (ctx.routing.job or "").lower()
|
||||
for k in ["query", "select", "tabelle", "table", "daten", "data",
|
||||
"cost", "kosten", "count", "anzahl", "average", "schnitt",
|
||||
"find", "finde", "show", "zeig", "list", "beschreib"])
|
||||
expert_skipped_tools = not has_tool and not thought.errors and job_needs_data
|
||||
if (thought.errors or expert_skipped_tools) and not has_tool and expectation in ("delegated", "waiting_input", "conversational"):
|
||||
retry_reason = f"{len(thought.errors)} errors" if thought.errors else "no tool calls for data job"
|
||||
self._end_frame(rec, output_summary=thought_summary,
|
||||
route="pa_retry", condition=f"expert_failed ({retry_reason}), expectation={expectation}")
|
||||
await self._send_hud({"node": "runtime", "event": "pa_retry",
|
||||
"detail": f"expert failed: {retry_reason}, retrying via PA"})
|
||||
|
||||
retry_msg = "Anderer Ansatz..." if ctx.routing.language == "de" else "Trying a different approach..."
|
||||
await self.sink.send_delta(retry_msg + "\n")
|
||||
|
||||
retry_errors = thought.errors if thought.errors else [
|
||||
{"query": "(none)", "error": "Expert produced no database queries. The job requires data lookup but the expert answered without querying. Reformulate with explicit query instructions."}
|
||||
]
|
||||
error_summary = "; ".join(e.get("error", "")[:80] for e in retry_errors[-2:])
|
||||
rec = self._begin_frame(self.frame + 1, "pa_retry",
|
||||
input_summary=f"errors: {error_summary[:100]}")
|
||||
routing2 = await self.nodes["pa"].route_retry(
|
||||
ctx.command, self.history, memory_context=ctx.mem_ctx,
|
||||
identity=self.identity, channel=self.channel,
|
||||
original_job=ctx.routing.job, errors=retry_errors)
|
||||
self._end_frame(rec, output_summary=f"retry_job: {(routing2.job or '')[:60]}",
|
||||
route=f"expert_{routing2.expert}" if routing2.expert != "none" else "output")
|
||||
|
||||
if routing2.expert != "none":
|
||||
expert2 = self._experts.get(routing2.expert, expert)
|
||||
rec = self._begin_frame(self.frame + 1, f"expert_{routing2.expert}_retry",
|
||||
input_summary=f"retry job: {(routing2.job or '')[:80]}")
|
||||
original_hud2 = expert2.send_hud
|
||||
expert2.send_hud = self._make_progress_wrapper(original_hud2, routing2.language)
|
||||
try:
|
||||
thought = await expert2.execute(routing2.job, routing2.language)
|
||||
finally:
|
||||
expert2.send_hud = original_hud2
|
||||
ctx.thought = thought
|
||||
thought_summary = (f"response[{len(thought.response)}] tool={thought.tool_used or 'none'} "
|
||||
f"errors={len(thought.errors)}")
|
||||
has_tool = bool(thought.tool_used and thought.tool_output)
|
||||
self._end_frame(rec, output_summary=thought_summary,
|
||||
route="interpreter" if has_tool else "output+ui")
|
||||
ctx.routing = routing2
|
||||
return
|
||||
|
||||
# Normal completion (no retry)
|
||||
# Don't end frame yet — caller checks interpreter condition and sets route
|
||||
|
||||
async def _dispatch_director(self, ctx: _WalkCtx):
|
||||
"""Dispatch Director node."""
|
||||
a = ctx.command.analysis
|
||||
rec = self._begin_frame(self.frame + 1, "director",
|
||||
input_summary=f"intent={a.intent} topic={a.topic}")
|
||||
plan = await self.nodes["director"].decide(ctx.command, self.history, memory_context=ctx.mem_ctx)
|
||||
ctx.plan = plan
|
||||
plan_summary = f"goal={plan.goal} tools={len(plan.tool_sequence)} hint={plan.response_hint[:50]}"
|
||||
self._end_frame(rec, output_summary=plan_summary, route="thinker")
|
||||
|
||||
async def _dispatch_thinker(self, ctx: _WalkCtx):
|
||||
"""Dispatch Thinker node (v1 or v2)."""
|
||||
a = ctx.command.analysis
|
||||
rec = self._begin_frame(self.frame + 1, "thinker",
|
||||
input_summary=f"intent={a.intent} topic={a.topic}" if not ctx.plan
|
||||
else f"goal={ctx.plan.goal} tools={len(ctx.plan.tool_sequence)}")
|
||||
|
||||
# v1 hybrid: optional director pre-planning
|
||||
director = self.nodes.get("director")
|
||||
if director and hasattr(director, "plan") and not ctx.plan:
|
||||
is_complex = ctx.command.analysis.complexity == "complex"
|
||||
text = ctx.command.source_text
|
||||
is_data_request = (ctx.command.analysis.intent in ("request", "action")
|
||||
and any(k in text.lower()
|
||||
for k in ["daten", "data", "database", "db", "tabelle", "table",
|
||||
"query", "abfrage", "untersuche", "investigate",
|
||||
"analyse", "analyze", "customer", "kunde"]))
|
||||
if is_complex or (is_data_request and len(text.split()) > 8):
|
||||
await director.plan(self.history, self.memorizer.state, text)
|
||||
ctx.mem_ctx = self._build_context(ctx.dashboard)
|
||||
|
||||
if ctx.plan:
|
||||
thought = await self.nodes["thinker"].process(
|
||||
ctx.command, ctx.plan, self.history, memory_context=ctx.mem_ctx)
|
||||
else:
|
||||
thought = await self.nodes["thinker"].process(
|
||||
ctx.command, self.history, memory_context=ctx.mem_ctx)
|
||||
|
||||
if director and hasattr(director, "current_plan"):
|
||||
director.current_plan = ""
|
||||
|
||||
ctx.thought = thought
|
||||
thought_summary = (f"response[{len(thought.response)}] tool={thought.tool_used or 'none'} "
|
||||
f"actions={len(thought.actions)}")
|
||||
self._end_frame(rec, output_summary=thought_summary, route="output+ui")
|
||||
|
||||
async def _dispatch_interpreter(self, ctx: _WalkCtx):
|
||||
"""Dispatch Interpreter node."""
|
||||
rec = self._begin_frame(self.frame + 1, "interpreter",
|
||||
input_summary=f"tool={ctx.thought.tool_used} output[{len(ctx.thought.tool_output)}]")
|
||||
# Use routing.job for expert pipeline, source_text for director pipeline
|
||||
job = ctx.routing.job if ctx.routing else ctx.command.source_text
|
||||
interpreted = await self.nodes["interpreter"].interpret(
|
||||
ctx.thought.tool_used, ctx.thought.tool_output, job)
|
||||
ctx.thought.response = interpreted.summary
|
||||
self._end_frame(rec, output_summary=f"summary[{len(interpreted.summary)}]", route="output+ui")
|
||||
|
||||
async def _finish_pipeline(self, ctx: _WalkCtx) -> dict:
|
||||
"""Common tail: output+ui parallel, memorizer update, trace."""
|
||||
# If no thought yet (pa_direct path), create from routing
|
||||
if not ctx.thought and ctx.routing:
|
||||
ctx.thought = ThoughtResult(response=ctx.routing.response_hint, actions=[])
|
||||
|
||||
rec = self._begin_frame(self.frame + 1, "output+ui",
|
||||
input_summary=f"response: {(ctx.thought.response or '')[:80]}")
|
||||
|
||||
self.sink.reset()
|
||||
response = await self._run_output_and_ui(ctx.thought, ctx.mem_ctx)
|
||||
self.history.append({"role": "assistant", "content": response})
|
||||
await self.memorizer.update(self.history)
|
||||
|
||||
# v1 director post-processing
|
||||
director = self.nodes.get("director")
|
||||
if director and hasattr(director, "update") and not self.has_pa:
|
||||
await director.update(self.history, self.memorizer.state)
|
||||
|
||||
self._trim_history()
|
||||
|
||||
controls_count = len(self.ui_node.current_controls)
|
||||
self._end_frame(rec, output_summary=f"response[{len(response)}] controls={controls_count}")
|
||||
|
||||
# Build path label from visited nodes
|
||||
path = self._build_path_label(ctx)
|
||||
self._end_trace(path)
|
||||
await self._emit_trace_hud()
|
||||
return self._make_result(response)
|
||||
|
||||
def _build_path_label(self, ctx: _WalkCtx) -> str:
|
||||
"""Build trace path label from visited nodes."""
|
||||
nodes = ctx.path_nodes
|
||||
if not nodes:
|
||||
return "unknown"
|
||||
# Map visited nodes to path labels
|
||||
has_interpreter = "interpreter" in nodes
|
||||
if any(n.startswith("expert_") for n in nodes):
|
||||
return "expert+interpreter" if has_interpreter else "expert"
|
||||
if "director" in nodes:
|
||||
return "director+interpreter" if has_interpreter else "director"
|
||||
if "thinker" in nodes:
|
||||
return "thinker"
|
||||
if "pa" in nodes and not any(n.startswith("expert_") for n in nodes):
|
||||
return "pa_direct"
|
||||
return "unknown"
|
||||
|
||||
# --- Edge walker ---
|
||||
|
||||
async def _walk_edges(self, ctx: _WalkCtx) -> dict:
|
||||
"""Walk data edges from input node through the graph.
|
||||
Returns the pipeline result dict."""
|
||||
current = "input" # just finished frame 1
|
||||
|
||||
while True:
|
||||
# Find outgoing data edges from current node
|
||||
outgoing = [e for e in self.data_edges if e["from"] == current]
|
||||
if not outgoing:
|
||||
break
|
||||
|
||||
# Resolve which edge to follow
|
||||
edge = self._resolve_edge(outgoing, ctx)
|
||||
if not edge:
|
||||
break
|
||||
|
||||
target = edge["to"]
|
||||
|
||||
# Parallel target [output, ui] or terminal output → finish
|
||||
if isinstance(target, list) or target == "output" or target == "memorizer":
|
||||
return await self._finish_pipeline(ctx)
|
||||
|
||||
# Dispatch the target node
|
||||
ctx.path_nodes.append(target)
|
||||
|
||||
if target == "pa":
|
||||
await self._dispatch_pa(ctx)
|
||||
current = "pa"
|
||||
|
||||
elif target.startswith("expert_"):
|
||||
await self._dispatch_expert(target, ctx)
|
||||
# After expert, check interpreter condition
|
||||
has_tool = bool(ctx.thought and ctx.thought.tool_used and ctx.thought.tool_output)
|
||||
if self.has_interpreter and has_tool:
|
||||
# End expert frame with interpreter route
|
||||
last_rec = self.last_trace.frames[-1]
|
||||
if not last_rec.route: # not already ended by retry
|
||||
self._end_frame(last_rec,
|
||||
output_summary=f"response[{len(ctx.thought.response)}] tool={ctx.thought.tool_used}",
|
||||
route="interpreter", condition="has_tool_output=True")
|
||||
ctx.path_nodes.append("interpreter")
|
||||
await self._dispatch_interpreter(ctx)
|
||||
else:
|
||||
# End expert frame with output route
|
||||
last_rec = self.last_trace.frames[-1]
|
||||
if not last_rec.route:
|
||||
thought_summary = (f"response[{len(ctx.thought.response)}] tool={ctx.thought.tool_used or 'none'}")
|
||||
self._end_frame(last_rec, output_summary=thought_summary,
|
||||
route="output+ui",
|
||||
condition="has_tool_output=False" if not has_tool else "")
|
||||
return await self._finish_pipeline(ctx)
|
||||
|
||||
elif target == "director":
|
||||
await self._dispatch_director(ctx)
|
||||
current = "director"
|
||||
|
||||
elif target == "thinker":
|
||||
await self._dispatch_thinker(ctx)
|
||||
# After thinker, check interpreter condition
|
||||
has_tool = bool(ctx.thought and ctx.thought.tool_used and ctx.thought.tool_output)
|
||||
if self.has_interpreter and has_tool:
|
||||
last_rec = self.last_trace.frames[-1]
|
||||
self._end_frame(last_rec,
|
||||
output_summary=f"response[{len(ctx.thought.response)}] tool={ctx.thought.tool_used}",
|
||||
route="interpreter", condition="has_tool_output=True")
|
||||
ctx.path_nodes.append("interpreter")
|
||||
await self._dispatch_interpreter(ctx)
|
||||
return await self._finish_pipeline(ctx)
|
||||
|
||||
elif target == "interpreter":
|
||||
ctx.path_nodes.append("interpreter")
|
||||
await self._dispatch_interpreter(ctx)
|
||||
return await self._finish_pipeline(ctx)
|
||||
|
||||
else:
|
||||
log.warning(f"[frame] unknown target node: {target}")
|
||||
break
|
||||
|
||||
return await self._finish_pipeline(ctx)
|
||||
|
||||
# --- Main entry point ---
|
||||
|
||||
async def process_message(self, text: str, dashboard: list = None,
|
||||
model_overrides: dict = None) -> dict:
|
||||
async def process_message(self, text: str, dashboard: list = None) -> dict:
|
||||
"""Process a message through the frame pipeline.
|
||||
Returns {response, controls, memorizer, frames, trace}.
|
||||
Returns {response, controls, memorizer, frames, trace}."""
|
||||
|
||||
model_overrides: optional {role: model} to override node models for this request only.
|
||||
"""
|
||||
# Apply per-request model overrides (restored after processing)
|
||||
saved_models = {}
|
||||
if model_overrides:
|
||||
for role, model in model_overrides.items():
|
||||
node = self.nodes.get(role)
|
||||
if node and hasattr(node, "model"):
|
||||
saved_models[role] = node.model
|
||||
node.model = model
|
||||
self._begin_trace(text)
|
||||
|
||||
# Set session-scoped HUD for shared nodes (contextvar, per-task)
|
||||
_current_hud.set(self._send_hud)
|
||||
# Handle ACTION: prefix
|
||||
if text.startswith("ACTION:"):
|
||||
return await self._handle_action(text, dashboard)
|
||||
|
||||
try:
|
||||
self._begin_trace(text)
|
||||
# Setup
|
||||
envelope = Envelope(
|
||||
text=text, user_id=self.identity,
|
||||
session_id="test", timestamp=time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
)
|
||||
self.sensor.note_user_activity()
|
||||
if dashboard is not None:
|
||||
self.sensor.update_browser_dashboard(dashboard)
|
||||
self.history.append({"role": "user", "content": text})
|
||||
|
||||
# Handle ACTION: prefix
|
||||
if text.startswith("ACTION:"):
|
||||
return await self._handle_action(text, dashboard)
|
||||
# --- Frame 1: Input ---
|
||||
mem_ctx = self._build_context(dashboard)
|
||||
rec = self._begin_frame(1, "input", input_summary=text[:100])
|
||||
|
||||
# Setup
|
||||
envelope = Envelope(
|
||||
text=text, user_id=self.identity,
|
||||
session_id="test", timestamp=time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
)
|
||||
self.sensor.note_user_activity()
|
||||
if dashboard is not None:
|
||||
self.sensor.update_browser_dashboard(dashboard)
|
||||
self.history.append({"role": "user", "content": text})
|
||||
command = await self.nodes["input"].process(
|
||||
envelope, self.history, memory_context=mem_ctx,
|
||||
identity=self.identity, channel=self.channel)
|
||||
|
||||
# --- Frame 1: Input ---
|
||||
mem_ctx = self._build_context(dashboard)
|
||||
rec = self._begin_frame(1, "input", input_summary=text[:100])
|
||||
a = command.analysis
|
||||
cmd_summary = f"intent={a.intent} language={a.language} tone={a.tone} complexity={a.complexity}"
|
||||
|
||||
command = await self.nodes["input"].process(
|
||||
envelope, self.history, memory_context=mem_ctx,
|
||||
identity=self.identity, channel=self.channel)
|
||||
|
||||
a = command.analysis
|
||||
cmd_summary = f"intent={a.intent} language={a.language} tone={a.tone} complexity={a.complexity}"
|
||||
|
||||
# Build walk context
|
||||
ctx = _WalkCtx(command=command, mem_ctx=mem_ctx, dashboard=dashboard)
|
||||
|
||||
# Check reflex condition
|
||||
is_reflex = self._eval_condition("reflex", ctx)
|
||||
if is_reflex:
|
||||
self._end_frame(rec, output_summary=cmd_summary,
|
||||
route="output (reflex)", condition="reflex=True")
|
||||
await self._send_hud({"node": "runtime", "event": "reflex_path",
|
||||
"detail": f"{a.intent}/{a.complexity}"})
|
||||
return await self._run_reflex(command, mem_ctx)
|
||||
|
||||
# Find next node from edges
|
||||
outgoing = [e for e in self.data_edges if e["from"] == "input" and not e.get("condition")]
|
||||
next_node = outgoing[0]["to"] if outgoing else "unknown"
|
||||
# Check reflex condition
|
||||
is_reflex = self._check_condition("reflex", command=command)
|
||||
if is_reflex:
|
||||
self._end_frame(rec, output_summary=cmd_summary,
|
||||
route=next_node, condition="reflex=False")
|
||||
route="output (reflex)", condition="reflex=True")
|
||||
await self._send_hud({"node": "runtime", "event": "reflex_path",
|
||||
"detail": f"{a.intent}/{a.complexity}"})
|
||||
return await self._run_reflex(command, mem_ctx)
|
||||
else:
|
||||
next_node = "pa" if self.has_pa else ("director" if self.has_director else "thinker")
|
||||
self._end_frame(rec, output_summary=cmd_summary,
|
||||
route=next_node, condition=f"reflex=False")
|
||||
|
||||
# Walk remaining edges
|
||||
return await self._walk_edges(ctx)
|
||||
finally:
|
||||
# Restore original models after per-request overrides
|
||||
for role, original_model in saved_models.items():
|
||||
node = self.nodes.get(role)
|
||||
if node:
|
||||
node.model = original_model
|
||||
# --- Frame 2+: Pipeline ---
|
||||
if self.has_pa:
|
||||
return await self._run_expert_pipeline(command, mem_ctx, dashboard)
|
||||
elif self.has_director:
|
||||
return await self._run_director_pipeline(command, mem_ctx, dashboard)
|
||||
else:
|
||||
return await self._run_thinker_pipeline(command, mem_ctx, dashboard)
|
||||
|
||||
# --- Reflex (kept simple — 2 frames, no edge walking needed) ---
|
||||
# --- Pipeline variants ---
|
||||
|
||||
async def _run_reflex(self, command: Command, mem_ctx: str) -> dict:
|
||||
"""Reflex: Input(F1) → Output(F2)."""
|
||||
@ -605,7 +243,279 @@ class FrameEngine:
|
||||
await self._emit_trace_hud()
|
||||
return self._make_result(response)
|
||||
|
||||
# --- Action handling ---
|
||||
async def _run_expert_pipeline(self, command: Command, mem_ctx: str,
|
||||
dashboard: list = None) -> dict:
|
||||
"""Expert pipeline: Input(F1) → PA(F2) → Expert(F3) → [Interpreter(F4)] → Output."""
|
||||
a = command.analysis
|
||||
|
||||
# Frame 2: PA routes
|
||||
rec = self._begin_frame(2, "pa",
|
||||
input_summary=f"intent={a.intent} topic={a.topic}")
|
||||
routing = await self.nodes["pa"].route(
|
||||
command, self.history, memory_context=mem_ctx,
|
||||
identity=self.identity, channel=self.channel)
|
||||
route_summary = f"expert={routing.expert} job={routing.job[:60]}"
|
||||
self._end_frame(rec, output_summary=route_summary,
|
||||
route=f"expert_{routing.expert}" if routing.expert != "none" else "output")
|
||||
|
||||
# Stream thinking message to user
|
||||
if routing.thinking_message:
|
||||
await self.sink.send_delta(routing.thinking_message + "\n\n")
|
||||
|
||||
# Direct PA response (no expert needed)
|
||||
if routing.expert == "none":
|
||||
rec = self._begin_frame(3, "output+ui",
|
||||
input_summary=f"pa_direct: {routing.response_hint[:80]}")
|
||||
thought = ThoughtResult(response=routing.response_hint, actions=[])
|
||||
response = await self._run_output_and_ui(thought, mem_ctx)
|
||||
self.history.append({"role": "assistant", "content": response})
|
||||
await self.memorizer.update(self.history)
|
||||
self._trim_history()
|
||||
self._end_frame(rec, output_summary=f"response[{len(response)}]")
|
||||
self._end_trace("pa_direct")
|
||||
await self._emit_trace_hud()
|
||||
return self._make_result(response)
|
||||
|
||||
# Frame 3: Expert executes
|
||||
expert = self._experts.get(routing.expert)
|
||||
if not expert:
|
||||
log.error(f"[frame] expert '{routing.expert}' not found")
|
||||
thought = ThoughtResult(response=f"Expert '{routing.expert}' not available.")
|
||||
rec = self._begin_frame(3, "output+ui", input_summary="expert_not_found")
|
||||
response = await self._run_output_and_ui(thought, mem_ctx)
|
||||
self.history.append({"role": "assistant", "content": response})
|
||||
self._end_frame(rec, output_summary="error", error=f"expert '{routing.expert}' not found")
|
||||
self._end_trace("expert_error")
|
||||
await self._emit_trace_hud()
|
||||
return self._make_result(response)
|
||||
|
||||
rec = self._begin_frame(3, f"expert_{routing.expert}",
|
||||
input_summary=f"job: {routing.job[:80]}")
|
||||
|
||||
# Wrap expert's HUD to stream progress to user
|
||||
original_hud = expert.send_hud
|
||||
expert.send_hud = self._make_progress_wrapper(original_hud, routing.language)
|
||||
|
||||
try:
|
||||
thought = await expert.execute(routing.job, routing.language)
|
||||
finally:
|
||||
expert.send_hud = original_hud
|
||||
|
||||
thought_summary = (f"response[{len(thought.response)}] tool={thought.tool_used or 'none'} "
|
||||
f"actions={len(thought.actions)} errors={len(thought.errors)}")
|
||||
has_tool = bool(thought.tool_used and thought.tool_output)
|
||||
|
||||
# PA retry: if expert failed OR skipped tools when data was needed
|
||||
expectation = self.memorizer.state.get("user_expectation", "conversational")
|
||||
# Detect hallucination: expert returned no tool output for a data job
|
||||
job_needs_data = any(k in (routing.job or "").lower()
|
||||
for k in ["query", "select", "tabelle", "table", "daten", "data",
|
||||
"cost", "kosten", "count", "anzahl", "average", "schnitt",
|
||||
"find", "finde", "show", "zeig", "list", "beschreib"])
|
||||
expert_skipped_tools = not has_tool and not thought.errors and job_needs_data
|
||||
if (thought.errors or expert_skipped_tools) and not has_tool and expectation in ("delegated", "waiting_input", "conversational"):
|
||||
retry_reason = f"{len(thought.errors)} errors" if thought.errors else "no tool calls for data job"
|
||||
self._end_frame(rec, output_summary=thought_summary,
|
||||
route="pa_retry", condition=f"expert_failed ({retry_reason}), expectation={expectation}")
|
||||
await self._send_hud({"node": "runtime", "event": "pa_retry",
|
||||
"detail": f"expert failed: {retry_reason}, retrying via PA"})
|
||||
|
||||
# Stream retry notice to user
|
||||
retry_msg = "Anderer Ansatz..." if routing.language == "de" else "Trying a different approach..."
|
||||
await self.sink.send_delta(retry_msg + "\n")
|
||||
|
||||
# PA reformulates with error context
|
||||
retry_errors = thought.errors if thought.errors else [
|
||||
{"query": "(none)", "error": "Expert produced no database queries. The job requires data lookup but the expert answered without querying. Reformulate with explicit query instructions."}
|
||||
]
|
||||
error_summary = "; ".join(e.get("error", "")[:80] for e in retry_errors[-2:])
|
||||
rec = self._begin_frame(self.frame + 1, "pa_retry",
|
||||
input_summary=f"errors: {error_summary[:100]}")
|
||||
routing2 = await self.nodes["pa"].route_retry(
|
||||
command, self.history, memory_context=mem_ctx,
|
||||
identity=self.identity, channel=self.channel,
|
||||
original_job=routing.job, errors=retry_errors)
|
||||
self._end_frame(rec, output_summary=f"retry_job: {(routing2.job or '')[:60]}",
|
||||
route=f"expert_{routing2.expert}" if routing2.expert != "none" else "output")
|
||||
|
||||
if routing2.expert != "none":
|
||||
expert2 = self._experts.get(routing2.expert, expert)
|
||||
rec = self._begin_frame(self.frame + 1, f"expert_{routing2.expert}_retry",
|
||||
input_summary=f"retry job: {(routing2.job or '')[:80]}")
|
||||
original_hud2 = expert2.send_hud
|
||||
expert2.send_hud = self._make_progress_wrapper(original_hud2, routing2.language)
|
||||
try:
|
||||
thought = await expert2.execute(routing2.job, routing2.language)
|
||||
finally:
|
||||
expert2.send_hud = original_hud2
|
||||
thought_summary = (f"response[{len(thought.response)}] tool={thought.tool_used or 'none'} "
|
||||
f"errors={len(thought.errors)}")
|
||||
has_tool = bool(thought.tool_used and thought.tool_output)
|
||||
self._end_frame(rec, output_summary=thought_summary,
|
||||
route="interpreter" if has_tool else "output+ui")
|
||||
routing = routing2 # use retry routing for rest of pipeline
|
||||
|
||||
# Interpreter (conditional)
|
||||
if self.has_interpreter and has_tool:
|
||||
self._end_frame(rec, output_summary=thought_summary,
|
||||
route="interpreter", condition="has_tool_output=True")
|
||||
rec = self._begin_frame(4, "interpreter",
|
||||
input_summary=f"tool={thought.tool_used} output[{len(thought.tool_output)}]")
|
||||
interpreted = await self.nodes["interpreter"].interpret(
|
||||
thought.tool_used, thought.tool_output, routing.job)
|
||||
thought.response = interpreted.summary
|
||||
self._end_frame(rec, output_summary=f"summary[{len(interpreted.summary)}]", route="output+ui")
|
||||
|
||||
rec = self._begin_frame(5, "output+ui",
|
||||
input_summary=f"interpreted: {interpreted.summary[:80]}")
|
||||
path = "expert+interpreter"
|
||||
else:
|
||||
self._end_frame(rec, output_summary=thought_summary,
|
||||
route="output+ui",
|
||||
condition="has_tool_output=False" if not has_tool else "")
|
||||
rec = self._begin_frame(4, "output+ui",
|
||||
input_summary=f"response: {thought.response[:80]}")
|
||||
path = "expert"
|
||||
|
||||
# Clear progress text, render final response
|
||||
self.sink.reset()
|
||||
response = await self._run_output_and_ui(thought, mem_ctx)
|
||||
self.history.append({"role": "assistant", "content": response})
|
||||
await self.memorizer.update(self.history)
|
||||
self._trim_history()
|
||||
|
||||
controls_count = len(self.ui_node.current_controls)
|
||||
self._end_frame(rec, output_summary=f"response[{len(response)}] controls={controls_count}")
|
||||
self._end_trace(path)
|
||||
await self._emit_trace_hud()
|
||||
return self._make_result(response)
|
||||
|
||||
def _make_progress_wrapper(self, original_hud, language: str):
|
||||
"""Wrap an expert's send_hud to also stream progress deltas to the user."""
|
||||
sink = self.sink
|
||||
progress_map = {
|
||||
"tool_call": {"query_db": "Daten werden abgerufen..." if language == "de" else "Fetching data...",
|
||||
"emit_actions": "UI wird erstellt..." if language == "de" else "Building UI...",
|
||||
"create_machine": "Maschine wird erstellt..." if language == "de" else "Creating machine...",
|
||||
"_default": "Verarbeite..." if language == "de" else "Processing..."},
|
||||
"tool_result": {"_default": ""}, # silent on result
|
||||
"planned": {"_default": "Plan erstellt..." if language == "de" else "Plan ready..."},
|
||||
}
|
||||
|
||||
async def wrapper(data: dict):
|
||||
await original_hud(data)
|
||||
event = data.get("event", "")
|
||||
if event in progress_map:
|
||||
tool = data.get("tool", "_default")
|
||||
msg = progress_map[event].get(tool, progress_map[event].get("_default", ""))
|
||||
if msg:
|
||||
await sink.send_delta(msg + "\n")
|
||||
|
||||
return wrapper
|
||||
|
||||
async def _run_director_pipeline(self, command: Command, mem_ctx: str,
|
||||
dashboard: list = None) -> dict:
|
||||
"""Director: Input(F1) → Director(F2) → Thinker(F3) → [Interpreter(F4)] → Output."""
|
||||
a = command.analysis
|
||||
|
||||
# Frame 2: Director
|
||||
rec = self._begin_frame(2, "director",
|
||||
input_summary=f"intent={a.intent} topic={a.topic}")
|
||||
plan = await self.nodes["director"].decide(command, self.history, memory_context=mem_ctx)
|
||||
plan_summary = f"goal={plan.goal} tools={len(plan.tool_sequence)} hint={plan.response_hint[:50]}"
|
||||
self._end_frame(rec, output_summary=plan_summary, route="thinker")
|
||||
|
||||
# Frame 3: Thinker
|
||||
rec = self._begin_frame(3, "thinker",
|
||||
input_summary=plan_summary[:100])
|
||||
thought = await self.nodes["thinker"].process(
|
||||
command, plan, self.history, memory_context=mem_ctx)
|
||||
thought_summary = (f"response[{len(thought.response)}] tool={thought.tool_used or 'none'} "
|
||||
f"actions={len(thought.actions)} machines={len(thought.machine_ops)}")
|
||||
has_tool = bool(thought.tool_used and thought.tool_output)
|
||||
|
||||
# Check interpreter condition
|
||||
if self.has_interpreter and has_tool:
|
||||
self._end_frame(rec, output_summary=thought_summary,
|
||||
route="interpreter", condition="has_tool_output=True")
|
||||
|
||||
# Frame 4: Interpreter
|
||||
rec = self._begin_frame(4, "interpreter",
|
||||
input_summary=f"tool={thought.tool_used} output[{len(thought.tool_output)}]")
|
||||
interpreted = await self.nodes["interpreter"].interpret(
|
||||
thought.tool_used, thought.tool_output, command.source_text)
|
||||
thought.response = interpreted.summary
|
||||
interp_summary = f"summary[{len(interpreted.summary)}] facts={interpreted.key_facts}"
|
||||
self._end_frame(rec, output_summary=interp_summary, route="output+ui")
|
||||
|
||||
# Frame 5: Output
|
||||
rec = self._begin_frame(5, "output+ui",
|
||||
input_summary=f"interpreted: {interpreted.summary[:80]}")
|
||||
path = "director+interpreter"
|
||||
else:
|
||||
self._end_frame(rec, output_summary=thought_summary,
|
||||
route="output+ui",
|
||||
condition="has_tool_output=False" if not has_tool else "")
|
||||
|
||||
# Frame 4: Output
|
||||
rec = self._begin_frame(4, "output+ui",
|
||||
input_summary=f"response: {thought.response[:80]}")
|
||||
path = "director"
|
||||
|
||||
response = await self._run_output_and_ui(thought, mem_ctx)
|
||||
self.history.append({"role": "assistant", "content": response})
|
||||
await self.memorizer.update(self.history)
|
||||
self._trim_history()
|
||||
|
||||
controls_count = len(self.ui_node.current_controls)
|
||||
self._end_frame(rec, output_summary=f"response[{len(response)}] controls={controls_count}")
|
||||
self._end_trace(path)
|
||||
await self._emit_trace_hud()
|
||||
return self._make_result(response)
|
||||
|
||||
async def _run_thinker_pipeline(self, command: Command, mem_ctx: str,
|
||||
dashboard: list = None) -> dict:
|
||||
"""v1: Input(F1) → Thinker(F2) → Output(F3)."""
|
||||
a = command.analysis
|
||||
|
||||
# Frame 2: Thinker
|
||||
rec = self._begin_frame(2, "thinker",
|
||||
input_summary=f"intent={a.intent} topic={a.topic}")
|
||||
|
||||
director = self.nodes.get("director")
|
||||
if director and hasattr(director, "plan"):
|
||||
is_complex = command.analysis.complexity == "complex"
|
||||
text = command.source_text
|
||||
is_data_request = (command.analysis.intent in ("request", "action")
|
||||
and any(k in text.lower()
|
||||
for k in ["daten", "data", "database", "db", "tabelle", "table",
|
||||
"query", "abfrage", "untersuche", "investigate",
|
||||
"analyse", "analyze", "customer", "kunde"]))
|
||||
if is_complex or (is_data_request and len(text.split()) > 8):
|
||||
await director.plan(self.history, self.memorizer.state, text)
|
||||
mem_ctx = self._build_context(dashboard)
|
||||
|
||||
thought = await self.nodes["thinker"].process(command, self.history, memory_context=mem_ctx)
|
||||
if director and hasattr(director, "current_plan"):
|
||||
director.current_plan = ""
|
||||
|
||||
thought_summary = f"response[{len(thought.response)}] tool={thought.tool_used or 'none'}"
|
||||
self._end_frame(rec, output_summary=thought_summary, route="output+ui")
|
||||
|
||||
# Frame 3: Output
|
||||
rec = self._begin_frame(3, "output+ui",
|
||||
input_summary=f"response: {thought.response[:80]}")
|
||||
response = await self._run_output_and_ui(thought, mem_ctx)
|
||||
self.history.append({"role": "assistant", "content": response})
|
||||
await self.memorizer.update(self.history)
|
||||
if director and hasattr(director, "update"):
|
||||
await director.update(self.history, self.memorizer.state)
|
||||
self._trim_history()
|
||||
|
||||
self._end_frame(rec, output_summary=f"response[{len(response)}]")
|
||||
self._end_trace("thinker")
|
||||
await self._emit_trace_hud()
|
||||
return self._make_result(response)
|
||||
|
||||
async def _handle_action(self, text: str, dashboard: list = None) -> dict:
|
||||
"""Handle ACTION: messages (button clicks)."""
|
||||
@ -662,8 +572,8 @@ class FrameEngine:
|
||||
await self._emit_trace_hud()
|
||||
return self._make_result(result)
|
||||
|
||||
# Complex action — needs full pipeline via edge walking
|
||||
self._end_frame(rec, output_summary="no local handler", route="edge_walk")
|
||||
# Complex action — needs full pipeline
|
||||
self._end_frame(rec, output_summary="no local handler", route="pa/director/thinker")
|
||||
|
||||
action_desc = f"ACTION: {action}"
|
||||
if data:
|
||||
@ -675,8 +585,12 @@ class FrameEngine:
|
||||
analysis=InputAnalysis(intent="action", topic=action, complexity="simple"),
|
||||
source_text=action_desc)
|
||||
|
||||
ctx = _WalkCtx(command=command, mem_ctx=mem_ctx, dashboard=dashboard)
|
||||
return await self._walk_edges(ctx)
|
||||
if self.has_pa:
|
||||
return await self._run_expert_pipeline(command, mem_ctx, dashboard)
|
||||
elif self.has_director:
|
||||
return await self._run_director_pipeline(command, mem_ctx, dashboard)
|
||||
else:
|
||||
return await self._run_thinker_pipeline(command, mem_ctx, dashboard)
|
||||
|
||||
# --- Helpers ---
|
||||
|
||||
@ -733,29 +647,6 @@ class FrameEngine:
|
||||
lines.append(f" - {ctype}: {ctrl.get('label', ctrl.get('text', '?'))}")
|
||||
return "\n".join(lines)
|
||||
|
||||
def _make_progress_wrapper(self, original_hud, language: str):
|
||||
"""Wrap an expert's send_hud to also stream progress deltas to the user."""
|
||||
sink = self.sink
|
||||
progress_map = {
|
||||
"tool_call": {"query_db": "Daten werden abgerufen..." if language == "de" else "Fetching data...",
|
||||
"emit_actions": "UI wird erstellt..." if language == "de" else "Building UI...",
|
||||
"create_machine": "Maschine wird erstellt..." if language == "de" else "Creating machine...",
|
||||
"_default": "Verarbeite..." if language == "de" else "Processing..."},
|
||||
"tool_result": {"_default": ""}, # silent on result
|
||||
"planned": {"_default": "Plan erstellt..." if language == "de" else "Plan ready..."},
|
||||
}
|
||||
|
||||
async def wrapper(data: dict):
|
||||
await original_hud(data)
|
||||
event = data.get("event", "")
|
||||
if event in progress_map:
|
||||
tool = data.get("tool", "_default")
|
||||
msg = progress_map[event].get(tool, progress_map[event].get("_default", ""))
|
||||
if msg:
|
||||
await sink.send_delta(msg + "\n")
|
||||
|
||||
return wrapper
|
||||
|
||||
async def _run_output_and_ui(self, thought: ThoughtResult, mem_ctx: str) -> str:
|
||||
"""Run Output and UI nodes in parallel. Returns response text."""
|
||||
self.sink.reset()
|
||||
@ -772,6 +663,16 @@ class FrameEngine:
|
||||
await self.sink.send_artifacts(artifacts)
|
||||
return response
|
||||
|
||||
def _check_condition(self, name: str, command: Command = None,
|
||||
thought: ThoughtResult = None) -> bool:
|
||||
"""Evaluate a named condition."""
|
||||
if name == "reflex" and command:
|
||||
return (command.analysis.intent == "social"
|
||||
and command.analysis.complexity == "trivial")
|
||||
if name == "has_tool_output" and thought:
|
||||
return bool(thought.tool_used and thought.tool_output)
|
||||
return False
|
||||
|
||||
def _make_result(self, response: str) -> dict:
|
||||
"""Build the result dict returned to callers."""
|
||||
return {
|
||||
|
||||
@ -52,14 +52,6 @@ CONDITIONS = {
|
||||
"plan_first": "complexity==complex OR is_data_request",
|
||||
}
|
||||
|
||||
MODELS = {
|
||||
"input": "google/gemini-2.0-flash-001",
|
||||
"thinker": "openai/gpt-4o-mini",
|
||||
"output": "google/gemini-2.0-flash-001",
|
||||
"memorizer": "google/gemini-2.0-flash-001",
|
||||
"director": "google/gemini-2.0-flash-001",
|
||||
}
|
||||
|
||||
AUDIT = {
|
||||
"code_without_tools": True,
|
||||
"intent_without_action": True,
|
||||
|
||||
@ -61,14 +61,5 @@ CONDITIONS = {
|
||||
"has_tool_output": "thinker.tool_used is not empty",
|
||||
}
|
||||
|
||||
MODELS = {
|
||||
"input": "google/gemini-2.0-flash-001",
|
||||
"director": "anthropic/claude-haiku-4.5",
|
||||
"thinker": "google/gemini-2.0-flash-001",
|
||||
"interpreter": "google/gemini-2.0-flash-001",
|
||||
"output": "google/gemini-2.0-flash-001",
|
||||
"memorizer": "google/gemini-2.0-flash-001",
|
||||
}
|
||||
|
||||
# No audits — Director controls tool usage, no need for S3* corrections
|
||||
AUDIT = {}
|
||||
|
||||
@ -62,13 +62,4 @@ CONDITIONS = {
|
||||
"has_tool_output": "thinker.tool_used is not empty",
|
||||
}
|
||||
|
||||
MODELS = {
|
||||
"input": "google/gemini-2.0-flash-001",
|
||||
"director": "anthropic/claude-haiku-4.5",
|
||||
"thinker": "google/gemini-2.0-flash-001",
|
||||
"interpreter": "google/gemini-2.0-flash-001",
|
||||
"output": "google/gemini-2.0-flash-001",
|
||||
"memorizer": "google/gemini-2.0-flash-001",
|
||||
}
|
||||
|
||||
AUDIT = {}
|
||||
|
||||
@ -68,13 +68,4 @@ CONDITIONS = {
|
||||
"has_tool_output": "expert.tool_used is not empty",
|
||||
}
|
||||
|
||||
MODELS = {
|
||||
"input": "google/gemini-2.0-flash-001",
|
||||
"pa": "anthropic/claude-haiku-4.5",
|
||||
"expert_eras": "google/gemini-2.0-flash-001",
|
||||
"interpreter": "google/gemini-2.0-flash-001",
|
||||
"output": "google/gemini-2.0-flash-001",
|
||||
"memorizer": "google/gemini-2.0-flash-001",
|
||||
}
|
||||
|
||||
AUDIT = {}
|
||||
|
||||
249
agent/mcp_app.py
Normal file
249
agent/mcp_app.py
Normal file
@ -0,0 +1,249 @@
|
||||
"""Standalone MCP app — proxies tool calls to assay-runtime. Supports Streamable HTTP + SSE."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv(Path(__file__).parent.parent / ".env")
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI, Request, Depends
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from mcp.server import Server
|
||||
from mcp.server.sse import SseServerTransport
|
||||
from mcp.server.streamable_http import StreamableHTTPServerTransport
|
||||
from mcp.types import TextContent, Tool
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(message)s", datefmt="%H:%M:%S")
|
||||
log = logging.getLogger("mcp-proxy")
|
||||
|
||||
# Config
|
||||
RUNTIME_URL = os.environ.get("RUNTIME_URL", "http://assay-runtime")
|
||||
SERVICE_TOKENS = set(filter(None, os.environ.get("SERVICE_TOKENS", "").split(",")))
|
||||
SERVICE_TOKEN = os.environ.get("SERVICE_TOKENS", "").split(",")[0] if os.environ.get("SERVICE_TOKENS") else ""
|
||||
|
||||
app = FastAPI(title="assay-mcp")
|
||||
_security = HTTPBearer()
|
||||
|
||||
|
||||
async def require_auth(creds: HTTPAuthorizationCredentials = Depends(_security)):
|
||||
if creds.credentials not in SERVICE_TOKENS:
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
return {"sub": "service", "source": "service_token"}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok", "service": "mcp-proxy"}
|
||||
|
||||
|
||||
# --- MCP Server ---
|
||||
|
||||
mcp_server = Server("assay")
|
||||
_mcp_transport = SseServerTransport("/mcp/messages/")
|
||||
|
||||
|
||||
async def _proxy_get(path: str, params: dict = None) -> dict:
|
||||
"""GET request to runtime."""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
resp = await client.get(
|
||||
f"{RUNTIME_URL}{path}",
|
||||
params=params,
|
||||
headers={"Authorization": f"Bearer {SERVICE_TOKEN}"},
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
return resp.json()
|
||||
try:
|
||||
return {"error": resp.json().get("detail", resp.text)}
|
||||
except Exception:
|
||||
return {"error": resp.text}
|
||||
except Exception as e:
|
||||
return {"error": f"Runtime unreachable: {e}"}
|
||||
|
||||
|
||||
async def _proxy_post(path: str, body: dict = None) -> dict:
|
||||
"""POST request to runtime."""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
resp = await client.post(
|
||||
f"{RUNTIME_URL}{path}",
|
||||
json=body or {},
|
||||
headers={"Authorization": f"Bearer {SERVICE_TOKEN}"},
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
return resp.json()
|
||||
try:
|
||||
return {"error": resp.json().get("detail", resp.text)}
|
||||
except Exception:
|
||||
return {"error": resp.text}
|
||||
except Exception as e:
|
||||
return {"error": f"Runtime unreachable: {e}"}
|
||||
|
||||
|
||||
@mcp_server.list_tools()
|
||||
async def list_tools():
|
||||
return [
|
||||
Tool(name="assay_send", description="Send a message to the cognitive agent and get a response.",
|
||||
inputSchema={"type": "object", "properties": {
|
||||
"text": {"type": "string", "description": "Message text to send"},
|
||||
"database": {"type": "string", "description": "Optional: database name for query_db context"},
|
||||
}, "required": ["text"]}),
|
||||
Tool(name="assay_trace", description="Get recent trace events from the pipeline (HUD events, tool calls, audit).",
|
||||
inputSchema={"type": "object", "properties": {
|
||||
"last": {"type": "integer", "description": "Number of recent events (default 20)", "default": 20},
|
||||
"filter": {"type": "string", "description": "Comma-separated event types to filter (e.g. 'tool_call,controls')"},
|
||||
}}),
|
||||
Tool(name="assay_history", description="Get recent chat messages from the active session.",
|
||||
inputSchema={"type": "object", "properties": {
|
||||
"last": {"type": "integer", "description": "Number of recent messages (default 20)", "default": 20},
|
||||
}}),
|
||||
Tool(name="assay_state", description="Get the current memorizer state (mood, topic, language, facts).",
|
||||
inputSchema={"type": "object", "properties": {}}),
|
||||
Tool(name="assay_clear", description="Clear the active session (history, state, controls).",
|
||||
inputSchema={"type": "object", "properties": {}}),
|
||||
Tool(name="assay_graph", description="Get the active graph definition (nodes, edges, description).",
|
||||
inputSchema={"type": "object", "properties": {}}),
|
||||
Tool(name="assay_graph_list", description="List all available graph definitions.",
|
||||
inputSchema={"type": "object", "properties": {}}),
|
||||
Tool(name="assay_graph_switch", description="Switch the active graph for new sessions.",
|
||||
inputSchema={"type": "object", "properties": {
|
||||
"name": {"type": "string", "description": "Graph name to switch to"},
|
||||
}, "required": ["name"]}),
|
||||
]
|
||||
|
||||
|
||||
@mcp_server.call_tool()
|
||||
async def call_tool(name: str, arguments: dict):
|
||||
if name == "assay_send":
|
||||
text = arguments.get("text", "")
|
||||
if not text:
|
||||
return [TextContent(type="text", text="ERROR: Missing 'text' argument.")]
|
||||
|
||||
# Step 1: check runtime is ready
|
||||
check = await _proxy_post("/api/send/check")
|
||||
if "error" in check:
|
||||
return [TextContent(type="text", text=f"ERROR: {check['error']}")]
|
||||
if not check.get("ready"):
|
||||
return [TextContent(type="text", text=f"ERROR: {check.get('reason', 'unknown')}: {check.get('detail', '')}")]
|
||||
|
||||
# Step 2: queue message
|
||||
send = await _proxy_post("/api/send", {"text": text})
|
||||
if "error" in send:
|
||||
return [TextContent(type="text", text=f"ERROR: {send['error']}")]
|
||||
msg_id = send.get("id", "")
|
||||
|
||||
# Step 3: poll for result (max 30s)
|
||||
import asyncio
|
||||
for _ in range(60):
|
||||
await asyncio.sleep(0.5)
|
||||
result = await _proxy_get("/api/result")
|
||||
if "error" in result:
|
||||
return [TextContent(type="text", text=f"ERROR: {result['error']}")]
|
||||
status = result.get("status", "")
|
||||
if status == "done":
|
||||
return [TextContent(type="text", text=result.get("response", "[no response]"))]
|
||||
if status == "error":
|
||||
return [TextContent(type="text", text=f"ERROR: {result.get('detail', 'pipeline failed')}")]
|
||||
return [TextContent(type="text", text="ERROR: Pipeline timeout (30s)")]
|
||||
|
||||
elif name == "assay_trace":
|
||||
last = arguments.get("last", 20)
|
||||
event_filter = arguments.get("filter", "")
|
||||
params = {"last": last}
|
||||
if event_filter:
|
||||
params["filter"] = event_filter
|
||||
result = await _proxy_get("/api/trace", params)
|
||||
if "error" in result:
|
||||
return [TextContent(type="text", text=f"ERROR: {result['error']}")]
|
||||
# Format trace events compactly
|
||||
events = result.get("lines", [])
|
||||
lines = []
|
||||
for e in events:
|
||||
node = e.get("node", "?")
|
||||
event = e.get("event", "?")
|
||||
detail = e.get("detail", "")
|
||||
line = f"{node:12s} {event:20s} {detail}"
|
||||
lines.append(line.rstrip())
|
||||
return [TextContent(type="text", text="\n".join(lines) if lines else "(no events)")]
|
||||
|
||||
elif name == "assay_history":
|
||||
last = arguments.get("last", 20)
|
||||
result = await _proxy_get("/api/history", {"last": last})
|
||||
if "error" in result:
|
||||
return [TextContent(type="text", text=f"ERROR: {result['error']}")]
|
||||
return [TextContent(type="text", text=json.dumps(result.get("messages", []), indent=2))]
|
||||
|
||||
elif name == "assay_state":
|
||||
result = await _proxy_get("/api/state")
|
||||
if "error" in result:
|
||||
return [TextContent(type="text", text=f"ERROR: {result['error']}")]
|
||||
return [TextContent(type="text", text=json.dumps(result, indent=2))]
|
||||
|
||||
elif name == "assay_clear":
|
||||
result = await _proxy_post("/api/clear")
|
||||
if "error" in result:
|
||||
return [TextContent(type="text", text=f"ERROR: {result['error']}")]
|
||||
return [TextContent(type="text", text="Session cleared.")]
|
||||
|
||||
elif name == "assay_graph":
|
||||
result = await _proxy_get("/api/graph/active")
|
||||
if "error" in result:
|
||||
return [TextContent(type="text", text=f"ERROR: {result['error']}")]
|
||||
return [TextContent(type="text", text=json.dumps(result, indent=2))]
|
||||
|
||||
elif name == "assay_graph_list":
|
||||
result = await _proxy_get("/api/graph/list")
|
||||
if "error" in result:
|
||||
return [TextContent(type="text", text=f"ERROR: {result['error']}")]
|
||||
return [TextContent(type="text", text=json.dumps(result.get("graphs", []), indent=2))]
|
||||
|
||||
elif name == "assay_graph_switch":
|
||||
gname = arguments.get("name", "")
|
||||
if not gname:
|
||||
return [TextContent(type="text", text="ERROR: Missing 'name' argument.")]
|
||||
result = await _proxy_post("/api/graph/switch", {"name": gname})
|
||||
if "error" in result:
|
||||
return [TextContent(type="text", text=f"ERROR: {result['error']}")]
|
||||
return [TextContent(type="text", text=f"Switched to graph '{result.get('name', gname)}'. New sessions will use this graph.")]
|
||||
|
||||
else:
|
||||
return [TextContent(type="text", text=f"Unknown tool: {name}")]
|
||||
|
||||
|
||||
# Mount MCP Streamable HTTP endpoint (primary — stateless, survives pod restarts)
|
||||
_http_transports: dict[str, StreamableHTTPServerTransport] = {}
|
||||
_http_tasks: dict[str, any] = {}
|
||||
|
||||
@app.api_route("/mcp", methods=["GET", "POST", "DELETE"])
|
||||
async def mcp_http(request: Request, user=Depends(require_auth)):
|
||||
import asyncio
|
||||
# Get or create session-scoped transport
|
||||
session_id = request.headers.get("mcp-session-id", "default")
|
||||
if session_id not in _http_transports:
|
||||
transport = StreamableHTTPServerTransport(mcp_session_id=session_id)
|
||||
_http_transports[session_id] = transport
|
||||
|
||||
async def _run():
|
||||
async with transport.connect() as streams:
|
||||
await mcp_server.run(streams[0], streams[1], mcp_server.create_initialization_options())
|
||||
_http_tasks[session_id] = asyncio.create_task(_run())
|
||||
|
||||
transport = _http_transports[session_id]
|
||||
await transport.handle_request(request.scope, request.receive, request._send)
|
||||
|
||||
|
||||
# Mount MCP SSE endpoints (legacy fallback)
|
||||
@app.get("/mcp/sse")
|
||||
async def mcp_sse(request: Request, user=Depends(require_auth)):
|
||||
async with _mcp_transport.connect_sse(request.scope, request.receive, request._send) as streams:
|
||||
await mcp_server.run(streams[0], streams[1], mcp_server.create_initialization_options())
|
||||
|
||||
|
||||
@app.post("/mcp/messages/")
|
||||
async def mcp_messages(request: Request, user=Depends(require_auth)):
|
||||
await _mcp_transport.handle_post_message(request.scope, request.receive, request._send)
|
||||
@ -1,50 +0,0 @@
|
||||
"""NodePool: shared stateless node instances across all sessions.
|
||||
|
||||
Stateless nodes (InputNode, PANode, ExpertNode, etc.) hold no per-session
|
||||
state — only config (model, system prompt). They can safely serve multiple
|
||||
concurrent sessions. Session-specific HUD routing uses contextvars.
|
||||
|
||||
Stateful nodes (SensorNode, MemorizerNode, UINode) hold conversational
|
||||
state and must be created per-session.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from .engine import load_graph, instantiate_nodes
|
||||
|
||||
log = logging.getLogger("runtime")
|
||||
|
||||
# Roles that hold per-session state — always created fresh per Runtime
|
||||
STATEFUL_ROLES = frozenset({"sensor", "memorizer", "ui"})
|
||||
|
||||
|
||||
async def _noop_hud(data: dict):
|
||||
"""Placeholder HUD — shared nodes use contextvars for session routing."""
|
||||
pass
|
||||
|
||||
|
||||
class NodePool:
|
||||
"""Shared node instances for stateless LLM nodes.
|
||||
|
||||
Usage:
|
||||
pool = NodePool("v4-eras")
|
||||
# Shared nodes (one instance, all sessions):
|
||||
input_node = pool.shared["input"]
|
||||
# Stateful nodes must be created per-session (not in pool)
|
||||
"""
|
||||
|
||||
def __init__(self, graph_name: str = "v4-eras"):
|
||||
self.graph = load_graph(graph_name)
|
||||
self.graph_name = graph_name
|
||||
|
||||
# Instantiate all nodes with noop HUD (shared nodes use contextvars)
|
||||
all_nodes = instantiate_nodes(self.graph, send_hud=_noop_hud)
|
||||
|
||||
# Split: shared (stateless) vs excluded (stateful)
|
||||
self.shared = {
|
||||
role: node for role, node in all_nodes.items()
|
||||
if role not in STATEFUL_ROLES
|
||||
}
|
||||
|
||||
log.info(f"[pool] created for graph '{graph_name}': "
|
||||
f"{len(self.shared)} shared, {len(STATEFUL_ROLES)} stateful")
|
||||
@ -1,16 +1,11 @@
|
||||
"""Base Node class with context management."""
|
||||
|
||||
import contextvars
|
||||
import logging
|
||||
|
||||
from ..llm import estimate_tokens, fit_context
|
||||
|
||||
log = logging.getLogger("runtime")
|
||||
|
||||
# Per-task HUD callback — set by FrameEngine/Runtime before calling shared nodes.
|
||||
# Isolates HUD events between concurrent sessions (asyncio.Task-scoped).
|
||||
_current_hud = contextvars.ContextVar('send_hud', default=None)
|
||||
|
||||
|
||||
class Node:
|
||||
name: str = "node"
|
||||
@ -23,11 +18,10 @@ class Node:
|
||||
self.context_fill_pct = 0
|
||||
|
||||
async def hud(self, event: str, **data):
|
||||
# Use task-scoped HUD if set (shared node pool), else instance callback
|
||||
hud_fn = _current_hud.get() or self.send_hud
|
||||
# Always include model on context events so frontend knows what model each node uses
|
||||
if event == "context" and self.model:
|
||||
data["model"] = self.model
|
||||
await hud_fn({"node": self.name, "event": event, **data})
|
||||
await self.send_hud({"node": self.name, "event": event, **data})
|
||||
|
||||
def trim_context(self, messages: list[dict]) -> list[dict]:
|
||||
"""Fit messages within this node's token budget."""
|
||||
|
||||
113
agent/runtime.py
113
agent/runtime.py
@ -22,11 +22,10 @@ _active_graph_name = "v4-eras"
|
||||
|
||||
|
||||
class OutputSink:
|
||||
"""Collects output. Streams to attached WebSocket or SSE queue."""
|
||||
"""Collects output. Optionally streams to attached WebSocket."""
|
||||
|
||||
def __init__(self):
|
||||
self.ws = None
|
||||
self.queue: asyncio.Queue | None = None # SSE streaming queue
|
||||
self.response: str = ""
|
||||
self.controls: list = []
|
||||
self.done: bool = False
|
||||
@ -37,104 +36,76 @@ class OutputSink:
|
||||
def detach(self):
|
||||
self.ws = None
|
||||
|
||||
def attach_queue(self, queue: asyncio.Queue):
|
||||
"""Attach an asyncio.Queue for SSE streaming (HTTP mode)."""
|
||||
self.queue = queue
|
||||
|
||||
def detach_queue(self):
|
||||
self.queue = None
|
||||
|
||||
def reset(self):
|
||||
self.response = ""
|
||||
self.controls = []
|
||||
self.done = False
|
||||
|
||||
async def _emit(self, event: dict):
|
||||
"""Send event to WS or SSE queue."""
|
||||
msg = json.dumps(event)
|
||||
if self.queue:
|
||||
try:
|
||||
self.queue.put_nowait(event)
|
||||
except asyncio.QueueFull:
|
||||
pass
|
||||
async def send_delta(self, text: str):
|
||||
self.response += text
|
||||
if self.ws:
|
||||
try:
|
||||
await self.ws.send_text(msg)
|
||||
await self.ws.send_text(json.dumps({"type": "delta", "content": text}))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def send_delta(self, text: str):
|
||||
self.response += text
|
||||
await self._emit({"type": "delta", "content": text})
|
||||
|
||||
async def send_controls(self, controls: list):
|
||||
self.controls = controls
|
||||
await self._emit({"type": "controls", "controls": controls})
|
||||
if self.ws:
|
||||
try:
|
||||
await self.ws.send_text(json.dumps({"type": "controls", "controls": controls}))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def send_artifacts(self, artifacts: list):
|
||||
await self._emit({"type": "artifacts", "artifacts": artifacts})
|
||||
if self.ws:
|
||||
try:
|
||||
await self.ws.send_text(json.dumps({"type": "artifacts", "artifacts": artifacts}))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def send_hud(self, data: dict):
|
||||
await self._emit({"type": "hud", **data})
|
||||
if self.ws:
|
||||
try:
|
||||
await self.ws.send_text(json.dumps({"type": "hud", **data}))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def send_done(self):
|
||||
self.done = True
|
||||
await self._emit({"type": "done"})
|
||||
if self.ws:
|
||||
try:
|
||||
await self.ws.send_text(json.dumps({"type": "done"}))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class Runtime:
|
||||
def __init__(self, user_claims: dict = None, origin: str = "",
|
||||
broadcast: Callable = None, graph_name: str = None,
|
||||
session_id: str = None, pool=None):
|
||||
session_id: str = None):
|
||||
self.session_id = session_id or str(uuid4())
|
||||
self.sink = OutputSink()
|
||||
self.history: list[dict] = []
|
||||
self.MAX_HISTORY = 40
|
||||
self._broadcast = broadcast or (lambda e: None)
|
||||
|
||||
# Load graph and instantiate nodes
|
||||
gname = graph_name or _active_graph_name
|
||||
self.graph = load_graph(gname)
|
||||
self.process_manager = ProcessManager(send_hud=self._send_hud)
|
||||
nodes = instantiate_nodes(self.graph, send_hud=self._send_hud,
|
||||
process_manager=self.process_manager)
|
||||
|
||||
if pool:
|
||||
# Phase 2: use shared node pool for stateless nodes
|
||||
self.graph = pool.graph
|
||||
self.process_manager = ProcessManager(send_hud=self._send_hud)
|
||||
|
||||
# Shared nodes from pool (stateless, serve all sessions)
|
||||
self.input_node = pool.shared.get("input")
|
||||
self.thinker = pool.shared.get("thinker")
|
||||
self.output_node = pool.shared.get("output")
|
||||
self.director = pool.shared.get("director")
|
||||
self.interpreter = pool.shared.get("interpreter")
|
||||
|
||||
# Per-session stateful nodes (fresh each session)
|
||||
from .nodes import UINode, MemorizerNodeV1 as MemorizerNode, SensorNode
|
||||
self.ui_node = UINode(send_hud=self._send_hud)
|
||||
self.memorizer = MemorizerNode(send_hud=self._send_hud)
|
||||
self.sensor = SensorNode(send_hud=self._send_hud)
|
||||
|
||||
# Build combined nodes dict for FrameEngine
|
||||
nodes = dict(pool.shared)
|
||||
nodes["ui"] = self.ui_node
|
||||
nodes["memorizer"] = self.memorizer
|
||||
nodes["sensor"] = self.sensor
|
||||
|
||||
log.info(f"[runtime] using shared pool for graph '{gname}' "
|
||||
f"({len(pool.shared)} shared, 3 per-session)")
|
||||
else:
|
||||
# Legacy: create all nodes per-session
|
||||
self.graph = load_graph(gname)
|
||||
self.process_manager = ProcessManager(send_hud=self._send_hud)
|
||||
nodes = instantiate_nodes(self.graph, send_hud=self._send_hud,
|
||||
process_manager=self.process_manager)
|
||||
|
||||
self.input_node = nodes["input"]
|
||||
self.thinker = nodes.get("thinker")
|
||||
self.output_node = nodes["output"]
|
||||
self.ui_node = nodes["ui"]
|
||||
self.memorizer = nodes["memorizer"]
|
||||
self.director = nodes.get("director")
|
||||
self.sensor = nodes["sensor"]
|
||||
self.interpreter = nodes.get("interpreter")
|
||||
# Bind nodes by role (pipeline code references these)
|
||||
self.input_node = nodes["input"]
|
||||
self.thinker = nodes.get("thinker") # v1/v2/v3
|
||||
self.output_node = nodes["output"]
|
||||
self.ui_node = nodes["ui"]
|
||||
self.memorizer = nodes["memorizer"]
|
||||
self.director = nodes.get("director") # v1/v2/v3, None in v4
|
||||
self.sensor = nodes["sensor"]
|
||||
self.interpreter = nodes.get("interpreter") # v2+ only
|
||||
|
||||
# Detect graph type
|
||||
self.is_v2 = self.director is not None and hasattr(self.director, "decide")
|
||||
@ -326,12 +297,10 @@ class Runtime:
|
||||
lines.append(f" - {ctype}: {ctrl.get('label', ctrl.get('text', '?'))}")
|
||||
return "\n".join(lines)
|
||||
|
||||
async def handle_message(self, text: str, dashboard: list = None,
|
||||
model_overrides: dict = None):
|
||||
async def handle_message(self, text: str, dashboard: list = None):
|
||||
# Frame engine: delegate entirely
|
||||
if self.use_frames:
|
||||
result = await self.frame_engine.process_message(
|
||||
text, dashboard, model_overrides=model_overrides)
|
||||
result = await self.frame_engine.process_message(text, dashboard)
|
||||
return result
|
||||
|
||||
# Detect ACTION: prefix from API/test runner
|
||||
|
||||
@ -1,367 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test orchestrator — runs test suites and posts results to dev assay.
|
||||
|
||||
Usage:
|
||||
python tests/run_tests.py # all suites
|
||||
python tests/run_tests.py api # one suite
|
||||
python tests/run_tests.py matrix/eras_query[haiku] # single test
|
||||
python tests/run_tests.py matrix --repeat=3 # each test 3x, report avg/p50/p95
|
||||
python tests/run_tests.py testcases --parallel=3 # 3 testcases concurrently
|
||||
python tests/run_tests.py api/health roundtrip/full_chat # multiple tests
|
||||
|
||||
Test names: suite/name (without the suite prefix in the test registry).
|
||||
engine tests: graph_load, node_instantiation, edge_types_complete,
|
||||
condition_reflex, condition_tool_output,
|
||||
frame_trace_reflex, frame_trace_expert, frame_trace_expert_with_interpreter
|
||||
api tests: health, eras_umsatz_api, eras_umsatz_artifact
|
||||
matrix tests: eras_query[variant], eras_artifact[variant], social_reflex[variant]
|
||||
variants: gemini-flash, haiku, gpt-4o-mini
|
||||
testcases: fast, reflex_path, expert_eras, domain_context, ... (from testcases/*.md)
|
||||
roundtrip tests: nyx_loads, inject_artifact, inject_message, full_chat, full_eras
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import urllib.request
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from datetime import datetime, timezone
|
||||
from dataclasses import dataclass, field, asdict
|
||||
|
||||
RESULTS_ENDPOINT = os.environ.get('RESULTS_ENDPOINT', '')
|
||||
RUN_ID = os.environ.get('RUN_ID', str(uuid.uuid4())[:8])
|
||||
|
||||
|
||||
def _now_iso() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestResult:
|
||||
run_id: str
|
||||
test: str
|
||||
suite: str
|
||||
status: str # 'pass', 'fail', 'running', 'error'
|
||||
duration_ms: float = 0
|
||||
error: str = ''
|
||||
ts: str = ''
|
||||
stats: dict = field(default_factory=dict) # {runs, min_ms, avg_ms, p50_ms, p95_ms, max_ms, pass_rate}
|
||||
|
||||
|
||||
def post_result(result: TestResult):
|
||||
"""Post a single test result to the dev assay endpoint."""
|
||||
print(json.dumps(asdict(result)), flush=True)
|
||||
if not RESULTS_ENDPOINT:
|
||||
return
|
||||
try:
|
||||
payload = json.dumps(asdict(result)).encode()
|
||||
req = urllib.request.Request(
|
||||
RESULTS_ENDPOINT,
|
||||
data=payload,
|
||||
headers={'Content-Type': 'application/json'},
|
||||
)
|
||||
urllib.request.urlopen(req, timeout=5)
|
||||
except Exception as e:
|
||||
print(f' [warn] failed to post result: {e}', file=sys.stderr)
|
||||
|
||||
|
||||
def run_test(name: str, suite: str, fn) -> TestResult:
|
||||
"""Run a single test function and return the result."""
|
||||
result = TestResult(run_id=RUN_ID, test=name, suite=suite, status='running', ts=_now_iso())
|
||||
post_result(result)
|
||||
|
||||
start = time.time()
|
||||
try:
|
||||
fn()
|
||||
result.status = 'pass'
|
||||
except AssertionError as e:
|
||||
result.status = 'fail'
|
||||
result.error = str(e)
|
||||
except Exception as e:
|
||||
result.status = 'error'
|
||||
result.error = f'{type(e).__name__}: {e}'
|
||||
result.duration_ms = round((time.time() - start) * 1000)
|
||||
result.ts = _now_iso()
|
||||
|
||||
post_result(result)
|
||||
return result
|
||||
|
||||
|
||||
def get_api_tests() -> dict:
|
||||
"""Load API tests from e2e_harness.py."""
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
import e2e_harness
|
||||
e2e_harness.ASSAY_BASE = os.environ.get('ASSAY_API', 'http://assay-runtime-test:8000').rstrip('/api')
|
||||
# Skip browser-dependent tests
|
||||
return {k: v for k, v in e2e_harness.TESTS.items() if 'takeover' not in k and 'panes' not in k}
|
||||
|
||||
|
||||
def get_roundtrip_tests() -> dict:
|
||||
"""Load Playwright roundtrip tests."""
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
from test_roundtrip import TESTS
|
||||
return TESTS
|
||||
|
||||
|
||||
def get_engine_tests() -> dict:
|
||||
"""Load engine-level tests (no LLM, no network)."""
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
from test_engine import TESTS
|
||||
return TESTS
|
||||
|
||||
|
||||
def get_matrix_tests() -> dict:
|
||||
"""Load model matrix tests (real LLM calls, test×variant combos)."""
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
from test_matrix import get_matrix_tests
|
||||
return get_matrix_tests()
|
||||
|
||||
|
||||
def get_testcase_tests() -> dict:
|
||||
"""Load markdown testcases from testcases/ (integration tests, real LLM)."""
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
from test_testcases import get_testcase_tests
|
||||
return get_testcase_tests()
|
||||
|
||||
|
||||
def get_node_tests() -> dict:
|
||||
"""Load node-level tests (direct node instantiation, real LLM + DB, no HTTP)."""
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
from test_node_eras import TESTS
|
||||
return TESTS
|
||||
|
||||
|
||||
def get_ui_tests() -> dict:
|
||||
"""Load UI tests — toolbar, navigation, scroll (Playwright, no backend needed)."""
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
from test_ui import TESTS
|
||||
return TESTS
|
||||
|
||||
|
||||
SUITES = {
|
||||
'engine': get_engine_tests,
|
||||
'api': get_api_tests,
|
||||
'node': get_node_tests,
|
||||
'matrix': get_matrix_tests,
|
||||
'testcases': get_testcase_tests,
|
||||
'roundtrip': get_roundtrip_tests,
|
||||
'ui': get_ui_tests,
|
||||
}
|
||||
|
||||
|
||||
def _compute_stats(durations: list[float], passed: int, total: int) -> dict:
|
||||
"""Compute timing stats from a list of durations."""
|
||||
if not durations:
|
||||
return {}
|
||||
durations.sort()
|
||||
n = len(durations)
|
||||
return {
|
||||
'runs': total,
|
||||
'passed': passed,
|
||||
'pass_rate': round(100 * passed / total) if total else 0,
|
||||
'min_ms': round(durations[0]),
|
||||
'avg_ms': round(sum(durations) / n),
|
||||
'p50_ms': round(durations[n // 2]),
|
||||
'p95_ms': round(durations[min(int(n * 0.95), n - 1)]),
|
||||
'max_ms': round(durations[-1]),
|
||||
}
|
||||
|
||||
|
||||
def run_test_repeated(name: str, suite: str, fn, repeat: int) -> TestResult:
|
||||
"""Run a test N times, aggregate timing stats into one result."""
|
||||
# Post running status
|
||||
result = TestResult(run_id=RUN_ID, test=name, suite=suite, status='running', ts=_now_iso())
|
||||
post_result(result)
|
||||
|
||||
durations = []
|
||||
passed_count = 0
|
||||
last_error = ''
|
||||
|
||||
for i in range(repeat):
|
||||
start = time.time()
|
||||
try:
|
||||
fn()
|
||||
elapsed = round((time.time() - start) * 1000)
|
||||
durations.append(elapsed)
|
||||
passed_count += 1
|
||||
except (AssertionError, Exception) as e:
|
||||
elapsed = round((time.time() - start) * 1000)
|
||||
durations.append(elapsed)
|
||||
last_error = str(e)[:200]
|
||||
|
||||
stats = _compute_stats(durations, passed_count, repeat)
|
||||
result.stats = stats
|
||||
result.duration_ms = stats.get('avg_ms', 0)
|
||||
result.status = 'pass' if passed_count == repeat else ('fail' if passed_count > 0 else 'error')
|
||||
result.error = f'{stats["pass_rate"]}% pass, avg={stats["avg_ms"]}ms p50={stats["p50_ms"]}ms p95={stats["p95_ms"]}ms'
|
||||
if last_error and passed_count < repeat:
|
||||
result.error += f' | last err: {last_error}'
|
||||
result.ts = _now_iso()
|
||||
post_result(result)
|
||||
return result
|
||||
|
||||
|
||||
def parse_args(args: list[str]) -> tuple[set[str] | None, set[str], int, int]:
|
||||
"""Parse CLI args into (suite_filter, test_filter, repeat, parallel).
|
||||
|
||||
Supports: --repeat=N, --parallel=N
|
||||
|
||||
Returns:
|
||||
suite_filter: set of suite names, or None for all suites
|
||||
test_filter: set of 'suite/test' names (empty = run all in suite)
|
||||
repeat: number of times to run each test (default 1)
|
||||
parallel: max concurrent tests (default 1 = sequential)
|
||||
"""
|
||||
repeat = 1
|
||||
parallel = 1
|
||||
filtered_args = []
|
||||
skip_next = False
|
||||
for i, arg in enumerate(args):
|
||||
if skip_next:
|
||||
skip_next = False
|
||||
continue
|
||||
if arg.startswith('--repeat='):
|
||||
repeat = int(arg.split('=', 1)[1])
|
||||
elif arg == '--repeat' and i + 1 < len(args):
|
||||
repeat = int(args[i + 1])
|
||||
skip_next = True
|
||||
elif arg.startswith('--parallel='):
|
||||
parallel = int(arg.split('=', 1)[1])
|
||||
elif arg == '--parallel' and i + 1 < len(args):
|
||||
parallel = int(args[i + 1])
|
||||
skip_next = True
|
||||
else:
|
||||
filtered_args.append(arg)
|
||||
|
||||
if not filtered_args:
|
||||
return None, set(), repeat, parallel
|
||||
|
||||
suites = set()
|
||||
tests = set()
|
||||
for arg in filtered_args:
|
||||
if '/' in arg:
|
||||
tests.add(arg)
|
||||
suites.add(arg.split('/')[0])
|
||||
else:
|
||||
suites.add(arg)
|
||||
return suites, tests, repeat, parallel
|
||||
|
||||
|
||||
def _run_one(name: str, suite_name: str, fn, repeat: int) -> TestResult:
|
||||
"""Run a single test (with optional repeat). Thread-safe."""
|
||||
if repeat > 1:
|
||||
return run_test_repeated(name, suite_name, fn, repeat)
|
||||
return run_test(name, suite_name, fn)
|
||||
|
||||
|
||||
def _print_result(suite_name: str, name: str, r: TestResult, repeat: int):
|
||||
"""Print a test result line."""
|
||||
status = 'PASS' if r.status == 'pass' else 'FAIL'
|
||||
if repeat > 1:
|
||||
stats = r.stats
|
||||
print(f' [{status}] {suite_name}/{name} ×{repeat} '
|
||||
f'(avg={stats.get("avg_ms", 0)}ms p50={stats.get("p50_ms", 0)}ms '
|
||||
f'p95={stats.get("p95_ms", 0)}ms pass={stats.get("pass_rate", 0)}%)', flush=True)
|
||||
else:
|
||||
print(f' [{status}] {suite_name}/{name} ({r.duration_ms:.0f}ms)', flush=True)
|
||||
if r.error and repeat == 1:
|
||||
print(f' {r.error[:200]}', flush=True)
|
||||
|
||||
|
||||
def run_suite(suite_name: str, tests: dict, test_filter: set[str],
|
||||
repeat: int = 1, parallel: int = 1) -> list[TestResult]:
|
||||
"""Run tests from a suite, optionally filtered, repeated, and parallelized."""
|
||||
# Build filtered test list
|
||||
filtered = []
|
||||
for name, fn in tests.items():
|
||||
full_name = f'{suite_name}/{name}'
|
||||
short_name = name.replace(f'{suite_name}_', '')
|
||||
if test_filter and full_name not in test_filter and f'{suite_name}/{short_name}' not in test_filter:
|
||||
continue
|
||||
filtered.append((name, fn))
|
||||
|
||||
if not filtered:
|
||||
return []
|
||||
|
||||
# Sequential execution
|
||||
if parallel <= 1 or len(filtered) <= 1:
|
||||
results = []
|
||||
for name, fn in filtered:
|
||||
r = _run_one(name, suite_name, fn, repeat)
|
||||
_print_result(suite_name, name, r, repeat)
|
||||
results.append(r)
|
||||
return results
|
||||
|
||||
# Parallel execution
|
||||
results = []
|
||||
with ThreadPoolExecutor(max_workers=parallel) as pool:
|
||||
futures = {}
|
||||
for name, fn in filtered:
|
||||
f = pool.submit(_run_one, name, suite_name, fn, repeat)
|
||||
futures[f] = name
|
||||
|
||||
for future in as_completed(futures):
|
||||
name = futures[future]
|
||||
try:
|
||||
r = future.result()
|
||||
except Exception as e:
|
||||
r = TestResult(run_id=RUN_ID, test=name, suite=suite_name,
|
||||
status='error', error=f'ThreadError: {e}', ts=_now_iso())
|
||||
_print_result(suite_name, name, r, repeat)
|
||||
results.append(r)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
suite_filter, test_filter, repeat, parallel = parse_args(sys.argv[1:])
|
||||
|
||||
print(f'=== Test Run {RUN_ID} ===', flush=True)
|
||||
if suite_filter:
|
||||
print(f'Filter: suites={suite_filter}, tests={test_filter or "all"}', flush=True)
|
||||
if repeat > 1:
|
||||
print(f'Repeat: {repeat}x per test', flush=True)
|
||||
if parallel > 1:
|
||||
print(f'Parallel: {parallel} concurrent tests', flush=True)
|
||||
print(f'ASSAY_API: {os.environ.get("ASSAY_API", "not set")}', flush=True)
|
||||
print(f'NYX_URL: {os.environ.get("NYX_URL", "not set")}', flush=True)
|
||||
print(flush=True)
|
||||
|
||||
all_results = []
|
||||
|
||||
for suite_name, loader in SUITES.items():
|
||||
if suite_filter and suite_name not in suite_filter:
|
||||
continue
|
||||
label = suite_name
|
||||
if repeat > 1:
|
||||
label += f' ×{repeat}'
|
||||
if parallel > 1:
|
||||
label += f' ∥{parallel}'
|
||||
print(f'--- {label} ---', flush=True)
|
||||
tests = loader()
|
||||
all_results.extend(run_suite(suite_name, tests, test_filter, repeat, parallel))
|
||||
print(flush=True)
|
||||
|
||||
# Summary
|
||||
passed = sum(1 for r in all_results if r.status == 'pass')
|
||||
failed = sum(1 for r in all_results if r.status in ('fail', 'error'))
|
||||
total_ms = sum(r.duration_ms for r in all_results)
|
||||
print(f'=== {passed} passed, {failed} failed, {len(all_results)} total ({total_ms:.0f}ms) ===', flush=True)
|
||||
|
||||
if RESULTS_ENDPOINT:
|
||||
summary = TestResult(
|
||||
run_id=RUN_ID, test='__summary__', suite='summary',
|
||||
status='pass' if failed == 0 else 'fail',
|
||||
duration_ms=total_ms,
|
||||
error=f'{passed} passed, {failed} failed',
|
||||
)
|
||||
post_result(summary)
|
||||
|
||||
sys.exit(1 if failed else 0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -1,620 +0,0 @@
|
||||
"""Engine test suite — tests graph loading, node instantiation, frame engine
|
||||
routing, conditions, and trace structure. No LLM calls — all nodes mocked.
|
||||
|
||||
Tests:
|
||||
graph_load — load_graph returns correct structure for all graphs
|
||||
node_instantiation — instantiate_nodes creates all roles from registry
|
||||
edge_types_complete — all 3 edge types present, no orphan nodes
|
||||
condition_reflex — reflex condition fires on social+trivial only
|
||||
condition_tool_output — has_tool_output condition fires when tool data present
|
||||
frame_trace_reflex — reflex path produces 2-frame trace
|
||||
frame_trace_expert — expert path produces correct frame sequence
|
||||
frame_trace_director — director path produces correct frame sequence
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
from agent.engine import load_graph, instantiate_nodes, _graph_from_module
|
||||
from agent.frame_engine import FrameEngine, FrameTrace, FrameRecord
|
||||
from agent.types import (
|
||||
Envelope, Command, InputAnalysis, ThoughtResult,
|
||||
DirectorPlan, PARouting, InterpretedResult, Artifact,
|
||||
)
|
||||
|
||||
|
||||
# --- Helpers ---
|
||||
|
||||
class MockSink:
|
||||
"""Captures streamed output."""
|
||||
def __init__(self):
|
||||
self.deltas = []
|
||||
self.controls = []
|
||||
self.artifacts = []
|
||||
self.done_count = 0
|
||||
|
||||
async def send_delta(self, text):
|
||||
self.deltas.append(text)
|
||||
|
||||
async def send_controls(self, controls):
|
||||
self.controls = controls
|
||||
|
||||
async def send_artifacts(self, artifacts):
|
||||
self.artifacts = artifacts
|
||||
|
||||
async def send_done(self):
|
||||
self.done_count += 1
|
||||
|
||||
def reset(self):
|
||||
self.deltas.clear()
|
||||
|
||||
|
||||
class MockHud:
|
||||
"""Captures HUD events."""
|
||||
def __init__(self):
|
||||
self.events = []
|
||||
|
||||
async def __call__(self, data):
|
||||
self.events.append(data)
|
||||
|
||||
def find(self, event):
|
||||
return [e for e in self.events if e.get("event") == event]
|
||||
|
||||
|
||||
class MockMemorizer:
|
||||
"""Minimal memorizer for frame engine."""
|
||||
def __init__(self):
|
||||
self.state = {
|
||||
"user_name": "test",
|
||||
"user_mood": "neutral",
|
||||
"topic": "testing",
|
||||
"topic_history": [],
|
||||
"language": "en",
|
||||
"style_hint": "casual",
|
||||
"facts": [],
|
||||
"user_expectation": "conversational",
|
||||
}
|
||||
|
||||
def get_context_block(self, sensor_lines=None, ui_state=None):
|
||||
return "Memory: test context"
|
||||
|
||||
async def update(self, history):
|
||||
pass
|
||||
|
||||
|
||||
class MockSensor:
|
||||
"""Minimal sensor for frame engine."""
|
||||
def __init__(self):
|
||||
self._flags = []
|
||||
|
||||
def note_user_activity(self):
|
||||
pass
|
||||
|
||||
def update_browser_dashboard(self, dashboard):
|
||||
pass
|
||||
|
||||
def get_context_lines(self):
|
||||
return ["Sensors: test"]
|
||||
|
||||
def consume_flags(self):
|
||||
flags = self._flags[:]
|
||||
self._flags.clear()
|
||||
return flags
|
||||
|
||||
|
||||
class MockUINode:
|
||||
"""Minimal UI node for frame engine."""
|
||||
def __init__(self):
|
||||
self.thinker_controls = []
|
||||
self.state = {}
|
||||
self._artifacts = []
|
||||
|
||||
@property
|
||||
def current_controls(self):
|
||||
return self.thinker_controls
|
||||
|
||||
@current_controls.setter
|
||||
def current_controls(self, value):
|
||||
self.thinker_controls = value
|
||||
|
||||
async def process(self, thought, history, memory_context=""):
|
||||
return self.thinker_controls
|
||||
|
||||
def get_machine_summary(self):
|
||||
return ""
|
||||
|
||||
def get_machine_controls(self):
|
||||
return []
|
||||
|
||||
def get_artifacts(self):
|
||||
return self._artifacts
|
||||
|
||||
def try_machine_transition(self, action):
|
||||
return False, ""
|
||||
|
||||
async def process_local_action(self, action, data):
|
||||
return None, []
|
||||
|
||||
|
||||
class MockInputNode:
|
||||
"""Returns a preconfigured Command."""
|
||||
def __init__(self, intent="request", complexity="simple", topic="test", language="en"):
|
||||
self._intent = intent
|
||||
self._complexity = complexity
|
||||
self._topic = topic
|
||||
self._language = language
|
||||
|
||||
async def process(self, envelope, history, memory_context="", identity="", channel=""):
|
||||
return Command(
|
||||
analysis=InputAnalysis(
|
||||
intent=self._intent, topic=self._topic,
|
||||
complexity=self._complexity, language=self._language,
|
||||
tone="casual",
|
||||
),
|
||||
source_text=envelope.text,
|
||||
)
|
||||
|
||||
|
||||
class MockOutputNode:
|
||||
"""Streams response text via sink."""
|
||||
async def process(self, thought, history, sink, memory_context=""):
|
||||
text = thought.response or "ok"
|
||||
for i in range(0, len(text), 12):
|
||||
await sink.send_delta(text[i:i+12])
|
||||
await sink.send_done()
|
||||
return text
|
||||
|
||||
|
||||
class MockPANode:
|
||||
"""Returns a preconfigured PARouting."""
|
||||
def __init__(self, expert="eras", job="test query", thinking_msg="Working..."):
|
||||
self._expert = expert
|
||||
self._job = job
|
||||
self._thinking_msg = thinking_msg
|
||||
|
||||
def set_available_experts(self, experts):
|
||||
pass
|
||||
|
||||
async def route(self, command, history, memory_context="", identity="", channel=""):
|
||||
return PARouting(
|
||||
expert=self._expert,
|
||||
job=self._job,
|
||||
thinking_message=self._thinking_msg,
|
||||
language="en",
|
||||
)
|
||||
|
||||
async def route_retry(self, command, history, memory_context="", identity="",
|
||||
channel="", original_job="", errors=None):
|
||||
return PARouting(expert=self._expert, job=f"retry: {self._job}", language="en")
|
||||
|
||||
|
||||
class MockExpertNode:
|
||||
"""Returns a preconfigured ThoughtResult."""
|
||||
def __init__(self, response="expert result", tool_used="", tool_output="", errors=None):
|
||||
self._response = response
|
||||
self._tool_used = tool_used
|
||||
self._tool_output = tool_output
|
||||
self._errors = errors or []
|
||||
self.send_hud = MockHud()
|
||||
|
||||
async def execute(self, job, language):
|
||||
return ThoughtResult(
|
||||
response=self._response,
|
||||
tool_used=self._tool_used,
|
||||
tool_output=self._tool_output,
|
||||
errors=self._errors,
|
||||
)
|
||||
|
||||
|
||||
class MockDirectorNode:
|
||||
"""Returns a preconfigured DirectorPlan."""
|
||||
def __init__(self, goal="test", tools=None, hint=""):
|
||||
self._goal = goal
|
||||
self._tools = tools or []
|
||||
self._hint = hint
|
||||
|
||||
async def decide(self, command, history, memory_context=""):
|
||||
return DirectorPlan(
|
||||
goal=self._goal,
|
||||
tool_sequence=self._tools,
|
||||
response_hint=self._hint,
|
||||
)
|
||||
|
||||
def get_context_line(self):
|
||||
return ""
|
||||
|
||||
|
||||
class MockThinkerNode:
|
||||
"""Returns a preconfigured ThoughtResult."""
|
||||
def __init__(self, response="thought result", tool_used="", tool_output=""):
|
||||
self._response = response
|
||||
self._tool_used = tool_used
|
||||
self._tool_output = tool_output
|
||||
|
||||
async def process(self, command, plan=None, history=None, memory_context=""):
|
||||
return ThoughtResult(
|
||||
response=self._response,
|
||||
tool_used=self._tool_used,
|
||||
tool_output=self._tool_output,
|
||||
)
|
||||
|
||||
|
||||
class MockInterpreterNode:
|
||||
"""Returns a preconfigured InterpretedResult."""
|
||||
async def interpret(self, tool_used, tool_output, job):
|
||||
return InterpretedResult(
|
||||
summary=f"Interpreted: {tool_used} returned data",
|
||||
row_count=5,
|
||||
key_facts=["5 rows"],
|
||||
)
|
||||
|
||||
|
||||
def make_frame_engine(nodes, graph_name="v4-eras"):
|
||||
"""Create a FrameEngine with mocked dependencies."""
|
||||
graph = load_graph(graph_name)
|
||||
sink = MockSink()
|
||||
hud = MockHud()
|
||||
memorizer = MockMemorizer()
|
||||
sensor = MockSensor()
|
||||
ui = MockUINode()
|
||||
|
||||
engine = FrameEngine(
|
||||
graph=graph,
|
||||
nodes=nodes,
|
||||
sink=sink,
|
||||
history=[],
|
||||
send_hud=hud,
|
||||
sensor=sensor,
|
||||
memorizer=memorizer,
|
||||
ui_node=ui,
|
||||
identity="test_user",
|
||||
channel="test",
|
||||
)
|
||||
return engine, sink, hud
|
||||
|
||||
|
||||
# --- Tests ---
|
||||
|
||||
def test_graph_load():
|
||||
"""load_graph returns correct structure for all frame-based graphs."""
|
||||
for name in ["v3-framed", "v4-eras"]:
|
||||
g = load_graph(name)
|
||||
assert g["name"] == name, f"graph name mismatch: {g['name']} != {name}"
|
||||
assert g["engine"] == "frames", f"{name} should use frames engine"
|
||||
assert "nodes" in g and len(g["nodes"]) > 0, f"{name} has no nodes"
|
||||
assert "edges" in g and len(g["edges"]) > 0, f"{name} has no edges"
|
||||
assert "conditions" in g, f"{name} has no conditions"
|
||||
# v1 should be imperative
|
||||
g1 = load_graph("v1-current")
|
||||
assert g1["engine"] == "imperative", "v1 should be imperative"
|
||||
|
||||
|
||||
def test_node_instantiation():
|
||||
"""instantiate_nodes creates all roles from registry."""
|
||||
hud = MockHud()
|
||||
for name in ["v3-framed", "v4-eras"]:
|
||||
g = load_graph(name)
|
||||
nodes = instantiate_nodes(g, hud)
|
||||
for role in g["nodes"]:
|
||||
assert role in nodes, f"missing node role '{role}' in {name}"
|
||||
# Check specific node types exist
|
||||
assert "input" in nodes
|
||||
assert "output" in nodes
|
||||
assert "memorizer" in nodes
|
||||
assert "sensor" in nodes
|
||||
|
||||
|
||||
def test_edge_types_complete():
|
||||
"""All 3 edge types present in graph definitions, no orphan nodes."""
|
||||
for name in ["v3-framed", "v4-eras"]:
|
||||
g = load_graph(name)
|
||||
edges = g["edges"]
|
||||
edge_types = {e.get("type") for e in edges}
|
||||
assert "data" in edge_types, f"{name} missing data edges"
|
||||
assert "context" in edge_types, f"{name} missing context edges"
|
||||
assert "state" in edge_types, f"{name} missing state edges"
|
||||
|
||||
# Every node should appear in at least one edge (from or to)
|
||||
node_roles = set(g["nodes"].keys())
|
||||
edge_nodes = set()
|
||||
for e in edges:
|
||||
edge_nodes.add(e["from"])
|
||||
to = e["to"]
|
||||
if isinstance(to, list):
|
||||
edge_nodes.update(to)
|
||||
else:
|
||||
edge_nodes.add(to)
|
||||
# runtime is a virtual target, not a real node
|
||||
edge_nodes.discard("runtime")
|
||||
missing = node_roles - edge_nodes
|
||||
assert not missing, f"{name} has orphan nodes: {missing}"
|
||||
|
||||
|
||||
def test_condition_reflex():
|
||||
"""_check_condition('reflex') fires on social+trivial only."""
|
||||
engine, _, _ = make_frame_engine({
|
||||
"input": MockInputNode(),
|
||||
"output": MockOutputNode(),
|
||||
"memorizer": MockMemorizer(),
|
||||
"sensor": MockSensor(),
|
||||
"ui": MockUINode(),
|
||||
}, "v4-eras")
|
||||
|
||||
# Should fire
|
||||
cmd_social = Command(
|
||||
analysis=InputAnalysis(intent="social", complexity="trivial"),
|
||||
source_text="hi",
|
||||
)
|
||||
assert engine._check_condition("reflex", command=cmd_social), \
|
||||
"reflex should fire for social+trivial"
|
||||
|
||||
# Should NOT fire
|
||||
cmd_request = Command(
|
||||
analysis=InputAnalysis(intent="request", complexity="simple"),
|
||||
source_text="show data",
|
||||
)
|
||||
assert not engine._check_condition("reflex", command=cmd_request), \
|
||||
"reflex should not fire for request+simple"
|
||||
|
||||
cmd_social_complex = Command(
|
||||
analysis=InputAnalysis(intent="social", complexity="complex"),
|
||||
source_text="tell me a long story",
|
||||
)
|
||||
assert not engine._check_condition("reflex", command=cmd_social_complex), \
|
||||
"reflex should not fire for social+complex"
|
||||
|
||||
|
||||
def test_condition_tool_output():
|
||||
"""_check_condition('has_tool_output') fires when tool data present."""
|
||||
engine, _, _ = make_frame_engine({
|
||||
"input": MockInputNode(),
|
||||
"output": MockOutputNode(),
|
||||
"memorizer": MockMemorizer(),
|
||||
"sensor": MockSensor(),
|
||||
"ui": MockUINode(),
|
||||
}, "v4-eras")
|
||||
|
||||
thought_with = ThoughtResult(
|
||||
response="data", tool_used="query_db", tool_output="rows here",
|
||||
)
|
||||
assert engine._check_condition("has_tool_output", thought=thought_with), \
|
||||
"should fire when tool_used and tool_output both set"
|
||||
|
||||
thought_without = ThoughtResult(response="just text")
|
||||
assert not engine._check_condition("has_tool_output", thought=thought_without), \
|
||||
"should not fire when no tool output"
|
||||
|
||||
thought_partial = ThoughtResult(response="x", tool_used="query_db", tool_output="")
|
||||
assert not engine._check_condition("has_tool_output", thought=thought_partial), \
|
||||
"should not fire when tool_output is empty string"
|
||||
|
||||
|
||||
def test_frame_trace_reflex():
|
||||
"""Reflex path: 2 frames (input → output), path='reflex'."""
|
||||
nodes = {
|
||||
"input": MockInputNode(intent="social", complexity="trivial"),
|
||||
"output": MockOutputNode(),
|
||||
"pa": MockPANode(),
|
||||
"expert_eras": MockExpertNode(),
|
||||
"interpreter": MockInterpreterNode(),
|
||||
"memorizer": MockMemorizer(),
|
||||
"sensor": MockSensor(),
|
||||
"ui": MockUINode(),
|
||||
}
|
||||
engine, sink, hud = make_frame_engine(nodes, "v4-eras")
|
||||
|
||||
result = asyncio.new_event_loop().run_until_complete(
|
||||
engine.process_message("hello")
|
||||
)
|
||||
|
||||
trace = result["trace"]
|
||||
assert trace["path"] == "reflex", f"expected reflex path, got {trace['path']}"
|
||||
assert trace["total_frames"] == 2, f"expected 2 frames, got {trace['total_frames']}"
|
||||
assert len(trace["frames"]) == 2
|
||||
assert trace["frames"][0]["node"] == "input"
|
||||
assert trace["frames"][1]["node"] == "output"
|
||||
assert "reflex=True" in trace["frames"][0]["condition"]
|
||||
|
||||
|
||||
def test_frame_trace_expert():
|
||||
"""Expert path without tool output: F1(input)→F2(pa)→F3(expert)→F4(output+ui)."""
|
||||
nodes = {
|
||||
"input": MockInputNode(intent="request", complexity="simple"),
|
||||
"output": MockOutputNode(),
|
||||
"pa": MockPANode(expert="eras", job="get top customers"),
|
||||
"expert_eras": MockExpertNode(response="Here are the customers"),
|
||||
"interpreter": MockInterpreterNode(),
|
||||
"memorizer": MockMemorizer(),
|
||||
"sensor": MockSensor(),
|
||||
"ui": MockUINode(),
|
||||
}
|
||||
engine, sink, hud = make_frame_engine(nodes, "v4-eras")
|
||||
|
||||
result = asyncio.new_event_loop().run_until_complete(
|
||||
engine.process_message("show top customers")
|
||||
)
|
||||
|
||||
trace = result["trace"]
|
||||
assert trace["path"] == "expert", f"expected expert path, got {trace['path']}"
|
||||
assert trace["total_frames"] >= 4, f"expected >=4 frames, got {trace['total_frames']}"
|
||||
nodes_in_trace = [f["node"] for f in trace["frames"]]
|
||||
assert nodes_in_trace[0] == "input"
|
||||
assert nodes_in_trace[1] == "pa"
|
||||
assert "expert_eras" in nodes_in_trace[2]
|
||||
|
||||
|
||||
def test_frame_trace_expert_with_interpreter():
|
||||
"""Expert path with tool output: includes interpreter frame, path='expert+interpreter'."""
|
||||
nodes = {
|
||||
"input": MockInputNode(intent="request", complexity="simple"),
|
||||
"output": MockOutputNode(),
|
||||
"pa": MockPANode(expert="eras", job="query customers"),
|
||||
"expert_eras": MockExpertNode(
|
||||
response="raw data",
|
||||
tool_used="query_db",
|
||||
tool_output="customer_name,revenue\nAcme,1000",
|
||||
),
|
||||
"interpreter": MockInterpreterNode(),
|
||||
"memorizer": MockMemorizer(),
|
||||
"sensor": MockSensor(),
|
||||
"ui": MockUINode(),
|
||||
}
|
||||
engine, sink, hud = make_frame_engine(nodes, "v4-eras")
|
||||
|
||||
result = asyncio.new_event_loop().run_until_complete(
|
||||
engine.process_message("show customer revenue")
|
||||
)
|
||||
|
||||
trace = result["trace"]
|
||||
assert trace["path"] == "expert+interpreter", \
|
||||
f"expected expert+interpreter path, got {trace['path']}"
|
||||
nodes_in_trace = [f["node"] for f in trace["frames"]]
|
||||
assert "interpreter" in nodes_in_trace, "interpreter frame missing"
|
||||
assert trace["total_frames"] >= 5, f"expected >=5 frames, got {trace['total_frames']}"
|
||||
|
||||
|
||||
# --- Phase 1: Config-driven models (RED — will fail until implemented) ---
|
||||
|
||||
def test_graph_has_models():
|
||||
"""All graph definitions include a MODELS dict mapping role → model."""
|
||||
for name in ["v1-current", "v2-director-drives", "v3-framed", "v4-eras"]:
|
||||
g = load_graph(name)
|
||||
assert "models" in g, f"{name}: graph should have a 'models' key"
|
||||
models = g["models"]
|
||||
assert isinstance(models, dict), f"{name}: models should be a dict"
|
||||
assert len(models) > 0, f"{name}: models should not be empty"
|
||||
for role, model in models.items():
|
||||
assert isinstance(model, str) and "/" in model, \
|
||||
f"{name}: model for '{role}' should be provider/model, got {model}"
|
||||
|
||||
|
||||
def test_instantiate_applies_graph_models():
|
||||
"""instantiate_nodes applies model from graph config, overriding class default."""
|
||||
hud = MockHud()
|
||||
g = load_graph("v4-eras")
|
||||
# Override a model in graph config
|
||||
g["models"] = g.get("models", {})
|
||||
g["models"]["input"] = "test/override-model"
|
||||
nodes = instantiate_nodes(g, hud)
|
||||
assert nodes["input"].model == "test/override-model", \
|
||||
f"input node model should be 'test/override-model', got {nodes['input'].model}"
|
||||
|
||||
|
||||
def test_model_override_per_request():
|
||||
"""Engine accepts model overrides that are applied to nodes for one request."""
|
||||
nodes = {
|
||||
"input": MockInputNode(intent="social", complexity="trivial"),
|
||||
"output": MockOutputNode(),
|
||||
"pa": MockPANode(),
|
||||
"expert_eras": MockExpertNode(),
|
||||
"interpreter": MockInterpreterNode(),
|
||||
"memorizer": MockMemorizer(),
|
||||
"sensor": MockSensor(),
|
||||
"ui": MockUINode(),
|
||||
}
|
||||
engine, sink, hud = make_frame_engine(nodes, "v4-eras")
|
||||
|
||||
# process_message should accept model_overrides param
|
||||
result = asyncio.new_event_loop().run_until_complete(
|
||||
engine.process_message("hello", model_overrides={"input": "test/fast-model"})
|
||||
)
|
||||
# Should complete without error (overrides applied internally)
|
||||
assert result["trace"]["path"] == "reflex"
|
||||
|
||||
|
||||
# --- Phase 2: Shared Node Pool (RED — will fail until implemented) ---
|
||||
|
||||
def test_pool_creates_shared_nodes():
|
||||
"""NodePool creates shared instances for stateless nodes."""
|
||||
from agent.node_pool import NodePool
|
||||
pool = NodePool("v4-eras")
|
||||
# Shared nodes should exist
|
||||
assert "input" in pool.shared, "input should be shared"
|
||||
assert "output" in pool.shared, "output should be shared"
|
||||
assert "pa" in pool.shared, "pa should be shared"
|
||||
assert "expert_eras" in pool.shared, "expert_eras should be shared"
|
||||
assert "interpreter" in pool.shared, "interpreter should be shared"
|
||||
|
||||
|
||||
def test_pool_excludes_stateful():
|
||||
"""NodePool excludes stateful nodes (sensor, memorizer, ui)."""
|
||||
from agent.node_pool import NodePool
|
||||
pool = NodePool("v4-eras")
|
||||
assert "sensor" not in pool.shared, "sensor should NOT be shared"
|
||||
assert "memorizer" not in pool.shared, "memorizer should NOT be shared"
|
||||
assert "ui" not in pool.shared, "ui should NOT be shared"
|
||||
|
||||
|
||||
def test_pool_reuses_instances():
|
||||
"""Two Runtimes using the same pool share node objects."""
|
||||
from agent.node_pool import NodePool
|
||||
pool = NodePool("v4-eras")
|
||||
# Same pool → same node instances
|
||||
input1 = pool.shared["input"]
|
||||
input2 = pool.shared["input"]
|
||||
assert input1 is input2, "pool should return same instance"
|
||||
|
||||
|
||||
def test_contextvar_hud_isolation():
|
||||
"""Contextvars isolate HUD events between concurrent tasks."""
|
||||
from agent.nodes.base import _current_hud
|
||||
|
||||
results_a = []
|
||||
results_b = []
|
||||
|
||||
async def hud_a(data):
|
||||
results_a.append(data)
|
||||
|
||||
async def hud_b(data):
|
||||
results_b.append(data)
|
||||
|
||||
async def task_a():
|
||||
_current_hud.set(hud_a)
|
||||
# Simulate work with a yield point
|
||||
await asyncio.sleep(0)
|
||||
hud_fn = _current_hud.get()
|
||||
await hud_fn({"from": "a"})
|
||||
|
||||
async def task_b():
|
||||
_current_hud.set(hud_b)
|
||||
await asyncio.sleep(0)
|
||||
hud_fn = _current_hud.get()
|
||||
await hud_fn({"from": "b"})
|
||||
|
||||
async def run_both():
|
||||
await asyncio.gather(task_a(), task_b())
|
||||
|
||||
asyncio.new_event_loop().run_until_complete(run_both())
|
||||
|
||||
assert len(results_a) == 1 and results_a[0]["from"] == "a", \
|
||||
f"task_a HUD leaked: {results_a}"
|
||||
assert len(results_b) == 1 and results_b[0]["from"] == "b", \
|
||||
f"task_b HUD leaked: {results_b}"
|
||||
|
||||
|
||||
# --- Test registry (for run_tests.py) ---
|
||||
|
||||
TESTS = {
|
||||
# Green — engine mechanics
|
||||
'graph_load': test_graph_load,
|
||||
'node_instantiation': test_node_instantiation,
|
||||
'edge_types_complete': test_edge_types_complete,
|
||||
'condition_reflex': test_condition_reflex,
|
||||
'condition_tool_output': test_condition_tool_output,
|
||||
'frame_trace_reflex': test_frame_trace_reflex,
|
||||
'frame_trace_expert': test_frame_trace_expert,
|
||||
'frame_trace_expert_with_interpreter': test_frame_trace_expert_with_interpreter,
|
||||
# Phase 1: config-driven models
|
||||
'graph_has_models': test_graph_has_models,
|
||||
'instantiate_applies_graph_models': test_instantiate_applies_graph_models,
|
||||
'model_override_per_request': test_model_override_per_request,
|
||||
# Phase 2: shared node pool
|
||||
'pool_creates_shared_nodes': test_pool_creates_shared_nodes,
|
||||
'pool_excludes_stateful': test_pool_excludes_stateful,
|
||||
'pool_reuses_instances': test_pool_reuses_instances,
|
||||
'contextvar_hud_isolation': test_contextvar_hud_isolation,
|
||||
}
|
||||
@ -1,149 +0,0 @@
|
||||
"""Model matrix tests — run the same API test with different LLM model configs.
|
||||
|
||||
Each variant defines model overrides for specific node roles. The test runner
|
||||
generates one test per (base_test × variant) combination. Results are posted
|
||||
with the variant name in brackets, e.g. "eras_umsatz_api[haiku]".
|
||||
|
||||
Usage via run_tests.py:
|
||||
python tests/run_tests.py matrix # all variants × all tests
|
||||
python tests/run_tests.py matrix/eras_query[haiku] # single combo
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import urllib.request
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
|
||||
_api_url = os.environ.get('ASSAY_API', 'http://assay-runtime-test:8000/api')
|
||||
ASSAY_BASE = _api_url.removesuffix('/api') if _api_url.endswith('/api') else _api_url
|
||||
SERVICE_TOKEN = '7Oorb9S3OpwFyWgm4zi_Tq7GeamefbjjTgooPVPWAwPDOf6B4TvgvQlLbhmT4DjsqBS_D1g'
|
||||
|
||||
|
||||
# --- Model variants ---
|
||||
# Each variant overrides specific node models. Omitted roles keep graph defaults.
|
||||
|
||||
VARIANTS = {
|
||||
'gemini-flash': {
|
||||
# Graph default — this is the baseline
|
||||
},
|
||||
'haiku': {
|
||||
'pa': 'anthropic/claude-haiku-4.5',
|
||||
'expert_eras': 'anthropic/claude-haiku-4.5',
|
||||
'interpreter': 'anthropic/claude-haiku-4.5',
|
||||
},
|
||||
'gpt-4o-mini': {
|
||||
'pa': 'openai/gpt-4o-mini',
|
||||
'expert_eras': 'openai/gpt-4o-mini',
|
||||
'interpreter': 'openai/gpt-4o-mini',
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# --- API helper with model overrides ---
|
||||
|
||||
def api_chat(text: str, models: dict = None, timeout: int = 120) -> list[tuple[str, str]]:
|
||||
"""Send message via /api/chat with optional model overrides. Returns SSE events."""
|
||||
body = {'content': text}
|
||||
if models:
|
||||
body['models'] = models
|
||||
|
||||
payload = json.dumps(body).encode()
|
||||
req = urllib.request.Request(
|
||||
f'{ASSAY_BASE}/api/chat',
|
||||
data=payload,
|
||||
method='POST',
|
||||
headers={
|
||||
'Authorization': f'Bearer {SERVICE_TOKEN}',
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
)
|
||||
resp = urllib.request.urlopen(req, timeout=timeout)
|
||||
output = resp.read().decode('utf-8')
|
||||
|
||||
events = []
|
||||
for block in output.split('\n\n'):
|
||||
event_type, data = '', ''
|
||||
for line in block.strip().split('\n'):
|
||||
if line.startswith('event: '):
|
||||
event_type = line[7:]
|
||||
elif line.startswith('data: '):
|
||||
data = line[6:]
|
||||
if event_type and data:
|
||||
events.append((event_type, data))
|
||||
return events
|
||||
|
||||
|
||||
# --- Base tests (same logic, different models) ---
|
||||
|
||||
def _test_eras_query(models: dict):
|
||||
"""ERAS query produces correct SQL path (artikelposition, not geraeteverbraeuche)."""
|
||||
events = api_chat('Zeig mir die 5 groessten Kunden nach Umsatz', models=models)
|
||||
|
||||
tool_calls = [e for e in events if e[0] == 'hud' and 'tool_call' in e[1] and 'query_db' in e[1]]
|
||||
assert tool_calls, 'no query_db tool call found'
|
||||
|
||||
query_data = json.loads(tool_calls[0][1])
|
||||
query = query_data.get('args', {}).get('query', '')
|
||||
|
||||
assert 'artikelposition' in query.lower(), f'query missing artikelposition: {query}'
|
||||
assert 'geraeteverbraeuche' not in query.lower(), f'query uses wrong table: {query}'
|
||||
|
||||
# Check it completed
|
||||
done_events = [e for e in events if e[0] == 'done']
|
||||
assert done_events, 'no done event'
|
||||
|
||||
|
||||
def _test_eras_artifact(models: dict):
|
||||
"""ERAS query produces artifact with data rows."""
|
||||
events = api_chat('Zeig mir die 5 groessten Kunden nach Umsatz', models=models)
|
||||
|
||||
artifact_events = [e for e in events if e[0] == 'artifacts']
|
||||
assert artifact_events, 'no artifact event'
|
||||
|
||||
artifacts = json.loads(artifact_events[0][1]).get('artifacts', [])
|
||||
assert len(artifacts) >= 1, f'expected >=1 artifact, got {len(artifacts)}'
|
||||
|
||||
has_data = any(
|
||||
art.get('data', {}).get('fields') or art.get('data', {}).get('rows')
|
||||
for art in artifacts
|
||||
)
|
||||
assert has_data, 'no artifact contains data'
|
||||
|
||||
|
||||
def _test_social_reflex(models: dict):
|
||||
"""Social greeting takes reflex path (fast, no expert)."""
|
||||
events = api_chat('Hallo!', models=models)
|
||||
|
||||
# Should get a response (delta events)
|
||||
deltas = [e for e in events if e[0] == 'delta']
|
||||
assert deltas, 'no delta events'
|
||||
|
||||
# Should complete
|
||||
done = [e for e in events if e[0] == 'done']
|
||||
assert done, 'no done event'
|
||||
|
||||
# Should NOT call any tools
|
||||
tool_calls = [e for e in events if e[0] == 'hud' and 'tool_call' in e[1]]
|
||||
assert not tool_calls, f'reflex path should not use tools, got {len(tool_calls)} calls'
|
||||
|
||||
|
||||
# --- Test registry: base tests that get multiplied by variants ---
|
||||
|
||||
BASE_TESTS = {
|
||||
'eras_query': _test_eras_query,
|
||||
'eras_artifact': _test_eras_artifact,
|
||||
'social_reflex': _test_social_reflex,
|
||||
}
|
||||
|
||||
|
||||
def get_matrix_tests() -> dict:
|
||||
"""Generate test×variant matrix. Returns {name: callable} dict for run_tests.py."""
|
||||
tests = {}
|
||||
for variant_name, models in VARIANTS.items():
|
||||
for test_name, test_fn in BASE_TESTS.items():
|
||||
combo_name = f'{test_name}[{variant_name}]'
|
||||
# Capture current values in closure
|
||||
tests[combo_name] = (lambda fn, m: lambda: fn(m))(test_fn, models)
|
||||
return tests
|
||||
@ -1,134 +0,0 @@
|
||||
"""Node-level tests for ErasExpertNode.
|
||||
|
||||
Tests the expert node directly — no HTTP, no pipeline, no session.
|
||||
Instantiates ErasExpertNode, calls execute(), asserts on HUD events + ThoughtResult.
|
||||
|
||||
Two LLM calls per test (plan + response) vs 4+ for full matrix tests.
|
||||
Runs against MariaDB directly (DB_HOST from .env — WireGuard on local, ClusterIP in K3s).
|
||||
|
||||
Usage:
|
||||
python tests/run_tests.py node
|
||||
python tests/run_tests.py node/umsatz_uses_artikelposition
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Load .env then .env.local (override) so DB_HOST, OPENROUTER_API_KEY etc. are set.
|
||||
# .env.local is gitignored — use it to point at a local tenant DB:
|
||||
# DB_HOST=localhost
|
||||
# DB_PORT=30310 (mariadb NodePort, dev namespace)
|
||||
# DB_PORT=30311 (mariadb-test NodePort, test namespace)
|
||||
from dotenv import load_dotenv
|
||||
_root = Path(__file__).parent.parent
|
||||
load_dotenv(_root / ".env")
|
||||
load_dotenv(_root / ".env.local", override=True)
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from agent.nodes.eras_expert import ErasExpertNode
|
||||
|
||||
|
||||
def _run(job: str):
|
||||
"""Instantiate expert, run job, return (result, hud_events)."""
|
||||
events = []
|
||||
|
||||
async def hud(e):
|
||||
events.append(e)
|
||||
|
||||
async def _exec():
|
||||
node = ErasExpertNode(send_hud=hud)
|
||||
return await node.execute(job, language="de"), events
|
||||
|
||||
return asyncio.run(_exec())
|
||||
|
||||
|
||||
def _tool_calls(events: list) -> list[dict]:
|
||||
return [e for e in events if e.get("event") == "tool_call"]
|
||||
|
||||
|
||||
def _query_db_calls(events: list) -> list[str]:
|
||||
"""Extract SQL strings from all query_db tool_call events."""
|
||||
return [
|
||||
e["args"]["query"]
|
||||
for e in _tool_calls(events)
|
||||
if e.get("tool") == "query_db" and "args" in e
|
||||
]
|
||||
|
||||
|
||||
# --- Tests ---
|
||||
|
||||
def test_umsatz_uses_artikelposition():
|
||||
"""Umsatz query must use artikelposition, not geraeteverbraeuche."""
|
||||
result, events = _run("Zeig mir die 5 größten Kunden nach Umsatz")
|
||||
|
||||
queries = _query_db_calls(events)
|
||||
assert queries, "no query_db call made"
|
||||
|
||||
combined = " ".join(queries).lower()
|
||||
assert "artikelposition" in combined, \
|
||||
f"expected artikelposition in query, got: {queries[0][:300]}"
|
||||
|
||||
|
||||
def test_umsatz_not_geraeteverbraeuche():
|
||||
"""Umsatz query must not touch geraeteverbraeuche (consumption table)."""
|
||||
result, events = _run("Zeig mir die 5 größten Kunden nach Umsatz")
|
||||
|
||||
queries = _query_db_calls(events)
|
||||
combined = " ".join(queries).lower()
|
||||
assert "geraeteverbraeuche" not in combined, \
|
||||
f"used wrong table geraeteverbraeuche: {queries[0][:300]}"
|
||||
|
||||
|
||||
def test_umsatz_has_result():
|
||||
"""Umsatz query returns non-empty result and completes without errors."""
|
||||
result, events = _run("Zeig mir die 5 größten Kunden nach Umsatz")
|
||||
|
||||
assert not result.errors, \
|
||||
f"expert had errors: {result.errors}"
|
||||
assert result.tool_output, "no tool output (query returned nothing)"
|
||||
assert result.response, "no response text generated"
|
||||
|
||||
|
||||
def test_kunden_count_uses_kunden_table():
|
||||
"""Simple count query uses the kunden table."""
|
||||
result, events = _run("Wie viele Kunden gibt es?")
|
||||
|
||||
queries = _query_db_calls(events)
|
||||
assert queries, "no query_db call made"
|
||||
|
||||
combined = " ".join(queries).lower()
|
||||
assert "kunden" in combined, f"expected kunden table: {queries}"
|
||||
|
||||
|
||||
def test_objekte_joins_objektkunde():
|
||||
"""Objekte-per-Kunde query uses the objektkunde junction table."""
|
||||
result, events = _run("Welcher Kunde hat die meisten Objekte?")
|
||||
|
||||
queries = _query_db_calls(events)
|
||||
assert queries, "no query_db call made"
|
||||
|
||||
combined = " ".join(queries).lower()
|
||||
assert "objektkunde" in combined, \
|
||||
f"expected objektkunde junction: {queries[0][:300] if queries else '(none)'}"
|
||||
|
||||
|
||||
def test_no_sql_exposed_in_response():
|
||||
"""Response text must not contain raw SQL (domain language only)."""
|
||||
result, events = _run("Zeig mir die 5 größten Kunden nach Umsatz")
|
||||
|
||||
text = result.response.lower()
|
||||
assert "select " not in text, f"SQL leaked into response: {result.response[:200]}"
|
||||
assert "from kunden" not in text, f"table name leaked: {result.response[:200]}"
|
||||
|
||||
|
||||
TESTS = {
|
||||
"umsatz_uses_artikelposition": test_umsatz_uses_artikelposition,
|
||||
"umsatz_not_geraeteverbraeuche": test_umsatz_not_geraeteverbraeuche,
|
||||
"umsatz_has_result": test_umsatz_has_result,
|
||||
"kunden_count_uses_kunden_table": test_kunden_count_uses_kunden_table,
|
||||
"objekte_joins_objektkunde": test_objekte_joins_objektkunde,
|
||||
"no_sql_exposed_in_response": test_no_sql_exposed_in_response,
|
||||
}
|
||||
@ -1,253 +0,0 @@
|
||||
"""Testcases suite — runs markdown testcases from testcases/ via /api/chat SSE.
|
||||
|
||||
Each testcase gets its own session (session_id), enabling future parallel runs.
|
||||
Results are posted to /api/test-results for real-time dashboard visibility.
|
||||
|
||||
Usage via run_tests.py:
|
||||
python tests/run_tests.py testcases # all testcases
|
||||
python tests/run_tests.py testcases/fast # single testcase
|
||||
python tests/run_tests.py testcases/reflex_path # by name
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import urllib.request
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent to path for runtime_test imports
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from runtime_test import (
|
||||
parse_testcase, check_response, check_actions, check_state, check_trace,
|
||||
)
|
||||
|
||||
_api_url = os.environ.get('ASSAY_API', 'http://assay-runtime-test:8000/api')
|
||||
ASSAY_BASE = _api_url.removesuffix('/api') if _api_url.endswith('/api') else _api_url
|
||||
API = f'{ASSAY_BASE}/api'
|
||||
SERVICE_TOKEN = '7Oorb9S3OpwFyWgm4zi_Tq7GeamefbjjTgooPVPWAwPDOf6B4TvgvQlLbhmT4DjsqBS_D1g'
|
||||
HEADERS = {'Authorization': f'Bearer {SERVICE_TOKEN}', 'Content-Type': 'application/json'}
|
||||
TESTCASES_DIR = Path(__file__).parent.parent / 'testcases'
|
||||
|
||||
|
||||
# --- SSE client using /api/chat ---
|
||||
|
||||
class ChatClient:
|
||||
"""Sends messages via /api/chat SSE. Each instance has its own session."""
|
||||
|
||||
def __init__(self):
|
||||
self.session_id = str(uuid.uuid4())[:12]
|
||||
self.last_response = ""
|
||||
self.last_memo = {}
|
||||
self.last_actions = []
|
||||
self.last_buttons = []
|
||||
self.last_trace = [] # HUD events from this request
|
||||
|
||||
def send(self, text: str, dashboard: list = None) -> dict:
|
||||
"""Send message via /api/chat, parse SSE stream."""
|
||||
body = {'content': text, 'session_id': self.session_id}
|
||||
if dashboard is not None:
|
||||
body['dashboard'] = dashboard
|
||||
|
||||
payload = json.dumps(body).encode()
|
||||
req = urllib.request.Request(
|
||||
f'{API}/chat', data=payload, method='POST',
|
||||
headers=HEADERS,
|
||||
)
|
||||
resp = urllib.request.urlopen(req, timeout=120)
|
||||
output = resp.read().decode('utf-8')
|
||||
|
||||
# Parse SSE events
|
||||
deltas = []
|
||||
hud_events = []
|
||||
controls = []
|
||||
artifacts = []
|
||||
|
||||
for block in output.split('\n\n'):
|
||||
event_type, data_str = '', ''
|
||||
for line in block.strip().split('\n'):
|
||||
if line.startswith('event: '):
|
||||
event_type = line[7:]
|
||||
elif line.startswith('data: '):
|
||||
data_str = line[6:]
|
||||
if not event_type or not data_str:
|
||||
continue
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if event_type == 'delta':
|
||||
deltas.append(data.get('content', ''))
|
||||
elif event_type == 'hud':
|
||||
hud_events.append(data)
|
||||
elif event_type == 'controls':
|
||||
controls = data.get('controls', [])
|
||||
elif event_type == 'artifacts':
|
||||
artifacts = data.get('artifacts', [])
|
||||
|
||||
self.last_response = ''.join(deltas)
|
||||
self.last_trace = hud_events
|
||||
|
||||
# Extract controls from HUD if not sent as separate event
|
||||
if not controls:
|
||||
for h in reversed(hud_events):
|
||||
if h.get('event') == 'controls':
|
||||
controls = h.get('controls', [])
|
||||
break
|
||||
|
||||
self.last_actions = controls
|
||||
self.last_buttons = [c for c in controls if isinstance(c, dict) and c.get('type') == 'button']
|
||||
|
||||
return {'response': self.last_response, 'controls': controls, 'artifacts': artifacts}
|
||||
|
||||
def send_action(self, action: str) -> dict:
|
||||
"""Send an action via /api/chat as ACTION: format."""
|
||||
body = {
|
||||
'content': f'ACTION:{action}',
|
||||
'session_id': self.session_id,
|
||||
}
|
||||
payload = json.dumps(body).encode()
|
||||
req = urllib.request.Request(
|
||||
f'{API}/chat', data=payload, method='POST',
|
||||
headers=HEADERS,
|
||||
)
|
||||
resp = urllib.request.urlopen(req, timeout=120)
|
||||
output = resp.read().decode('utf-8')
|
||||
|
||||
deltas = []
|
||||
hud_events = []
|
||||
controls = []
|
||||
|
||||
for block in output.split('\n\n'):
|
||||
event_type, data_str = '', ''
|
||||
for line in block.strip().split('\n'):
|
||||
if line.startswith('event: '):
|
||||
event_type = line[7:]
|
||||
elif line.startswith('data: '):
|
||||
data_str = line[6:]
|
||||
if not event_type or not data_str:
|
||||
continue
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if event_type == 'delta':
|
||||
deltas.append(data.get('content', ''))
|
||||
elif event_type == 'hud':
|
||||
hud_events.append(data)
|
||||
elif event_type == 'controls':
|
||||
controls = data.get('controls', [])
|
||||
|
||||
self.last_response = ''.join(deltas)
|
||||
self.last_trace = hud_events
|
||||
|
||||
if not controls:
|
||||
for h in reversed(hud_events):
|
||||
if h.get('event') == 'controls':
|
||||
controls = h.get('controls', [])
|
||||
break
|
||||
|
||||
self.last_actions = controls
|
||||
self.last_buttons = [c for c in controls if isinstance(c, dict) and c.get('type') == 'button']
|
||||
|
||||
return {'response': self.last_response}
|
||||
|
||||
def get_state(self) -> dict:
|
||||
"""Fetch memorizer state from /api/session for this session."""
|
||||
req = urllib.request.Request(
|
||||
f'{API}/session?session={self.session_id}', headers=HEADERS)
|
||||
resp = urllib.request.urlopen(req, timeout=10)
|
||||
data = json.loads(resp.read().decode('utf-8'))
|
||||
self.last_memo = data.get('memorizer', {})
|
||||
return self.last_memo
|
||||
|
||||
|
||||
# --- Testcase runner that returns (name, callable) pairs for run_tests.py ---
|
||||
|
||||
def _run_testcase(tc: dict):
|
||||
"""Execute a parsed testcase. Raises AssertionError on first failure."""
|
||||
client = ChatClient()
|
||||
errors = []
|
||||
|
||||
for step in tc['steps']:
|
||||
step_name = step['name']
|
||||
for cmd in step['commands']:
|
||||
if cmd['type'] == 'clear':
|
||||
# No-op — each testcase has its own session, no need to clear
|
||||
continue
|
||||
|
||||
elif cmd['type'] == 'send':
|
||||
try:
|
||||
client.send(cmd['text'], dashboard=cmd.get('dashboard'))
|
||||
except Exception as e:
|
||||
errors.append(f"[{step_name}] send failed: {e}")
|
||||
continue
|
||||
|
||||
elif cmd['type'] == 'action':
|
||||
try:
|
||||
client.send_action(cmd['action'])
|
||||
except Exception as e:
|
||||
errors.append(f"[{step_name}] action failed: {e}")
|
||||
continue
|
||||
|
||||
elif cmd['type'] == 'action_match':
|
||||
patterns = cmd['patterns']
|
||||
matched = None
|
||||
for pattern in patterns:
|
||||
pat = pattern.lower()
|
||||
for a in client.last_buttons:
|
||||
action_str = (a.get('action') or '').lower()
|
||||
label_str = (a.get('label') or '').lower()
|
||||
if pat in action_str or pat in label_str:
|
||||
matched = a.get('action') or a.get('label', '')
|
||||
break
|
||||
if matched:
|
||||
break
|
||||
if matched:
|
||||
try:
|
||||
client.send_action(matched)
|
||||
except Exception as e:
|
||||
errors.append(f"[{step_name}] action_match failed: {e}")
|
||||
else:
|
||||
errors.append(f"[{step_name}] no button matching {patterns}")
|
||||
|
||||
elif cmd['type'] == 'expect_response':
|
||||
passed, detail = check_response(client.last_response, cmd['check'])
|
||||
if not passed:
|
||||
errors.append(f"[{step_name}] response: {cmd['check']} — {detail}")
|
||||
|
||||
elif cmd['type'] == 'expect_actions':
|
||||
passed, detail = check_actions(client.last_actions, cmd['check'])
|
||||
if not passed:
|
||||
errors.append(f"[{step_name}] actions: {cmd['check']} — {detail}")
|
||||
|
||||
elif cmd['type'] == 'expect_state':
|
||||
client.get_state()
|
||||
passed, detail = check_state(client.last_memo, cmd['check'])
|
||||
if not passed:
|
||||
errors.append(f"[{step_name}] state: {cmd['check']} — {detail}")
|
||||
|
||||
elif cmd['type'] == 'expect_trace':
|
||||
passed, detail = check_trace(client.last_trace, cmd['check'])
|
||||
if not passed:
|
||||
errors.append(f"[{step_name}] trace: {cmd['check']} — {detail}")
|
||||
|
||||
if errors:
|
||||
raise AssertionError(f"{len(errors)} check(s) failed:\n" + "\n".join(errors[:5]))
|
||||
|
||||
|
||||
def get_testcase_tests() -> dict:
|
||||
"""Load all testcases as {name: callable} for run_tests.py."""
|
||||
tests = {}
|
||||
for md_file in sorted(TESTCASES_DIR.glob('*.md')):
|
||||
tc = parse_testcase(md_file)
|
||||
if not tc['name'] or not tc['steps']:
|
||||
continue
|
||||
# Use filename stem as test name (e.g., "fast", "reflex_path")
|
||||
name = md_file.stem
|
||||
tests[name] = (lambda t: lambda: _run_testcase(t))(tc)
|
||||
return tests
|
||||
140
tests/test_ui.py
140
tests/test_ui.py
@ -1,140 +0,0 @@
|
||||
"""
|
||||
UI tests — toolbar, navigation, scroll preservation.
|
||||
|
||||
Runs against a nyx instance with VITE_AUTH_DISABLED=true (auth skipped).
|
||||
|
||||
Local dev (after restarting Vite with .env.local VITE_AUTH_DISABLED=true):
|
||||
NYX_URL=http://localhost:5173 python tests/run_tests.py ui
|
||||
|
||||
K3s test build (no restart needed — already built with auth disabled):
|
||||
NYX_URL=http://localhost:30802 python tests/run_tests.py ui
|
||||
"""
|
||||
|
||||
import os
|
||||
from playwright.sync_api import sync_playwright, Page, expect
|
||||
|
||||
NYX_URL = os.environ.get('NYX_URL', 'http://localhost:30802')
|
||||
|
||||
_pw = None
|
||||
_browser = None
|
||||
|
||||
|
||||
def _ensure_browser():
|
||||
global _pw, _browser
|
||||
if _browser is None:
|
||||
_pw = sync_playwright().start()
|
||||
_browser = _pw.chromium.launch(headless=True)
|
||||
return _browser
|
||||
|
||||
|
||||
def _page(path: str = '/nyx') -> tuple:
|
||||
browser = _ensure_browser()
|
||||
ctx = browser.new_context(viewport={'width': 1280, 'height': 800})
|
||||
page = ctx.new_page()
|
||||
page.goto(f'{NYX_URL}{path}')
|
||||
page.wait_for_selector('.app-toolbar', timeout=15000)
|
||||
page.wait_for_timeout(500) # let Vue commit reactive toolbar updates
|
||||
return page, ctx
|
||||
|
||||
|
||||
def _click_nav(page: Page, text: str):
|
||||
page.locator('.sidebar-link', has_text=text).click()
|
||||
page.wait_for_timeout(800)
|
||||
|
||||
|
||||
# ── Tests ────────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_toolbar_nyx_has_all_groups():
|
||||
"""nyx shows 4 toolbar groups: connection, quad-view, themes, panels."""
|
||||
page, ctx = _page('/nyx')
|
||||
try:
|
||||
expect(page.locator('.toolbar-group')).to_have_count(4, timeout=5000)
|
||||
finally:
|
||||
ctx.close()
|
||||
|
||||
|
||||
def test_toolbar_tests_has_two_groups():
|
||||
"""tests view shows 2 toolbar groups: connection + themes."""
|
||||
page, ctx = _page('/tests')
|
||||
try:
|
||||
expect(page.locator('.toolbar-group')).to_have_count(2, timeout=5000)
|
||||
finally:
|
||||
ctx.close()
|
||||
|
||||
|
||||
def test_toolbar_home_has_one_group():
|
||||
"""home page shows 1 toolbar group: themes only."""
|
||||
page, ctx = _page('/')
|
||||
try:
|
||||
expect(page.locator('.toolbar-group')).to_have_count(1, timeout=5000)
|
||||
finally:
|
||||
ctx.close()
|
||||
|
||||
|
||||
def test_toolbar_survives_roundtrip():
|
||||
"""Navigate nyx→tests→home→nyx — toolbar groups correct at each stop."""
|
||||
page, ctx = _page('/nyx')
|
||||
try:
|
||||
expect(page.locator('.toolbar-group')).to_have_count(4, timeout=5000)
|
||||
|
||||
_click_nav(page, 'Tests')
|
||||
expect(page.locator('.toolbar-group')).to_have_count(2, timeout=3000)
|
||||
|
||||
_click_nav(page, 'Home')
|
||||
expect(page.locator('.toolbar-group')).to_have_count(1, timeout=3000)
|
||||
|
||||
_click_nav(page, 'nyx')
|
||||
expect(page.locator('.toolbar-group')).to_have_count(4, timeout=3000)
|
||||
|
||||
_click_nav(page, 'Tests')
|
||||
expect(page.locator('.toolbar-group')).to_have_count(2, timeout=3000)
|
||||
finally:
|
||||
ctx.close()
|
||||
|
||||
|
||||
def test_scroll_preserved_across_navigation():
|
||||
"""Scroll down in tests view, navigate away and back — position preserved."""
|
||||
page, ctx = _page('/tests')
|
||||
try:
|
||||
page.wait_for_selector('.tests-view', timeout=5000)
|
||||
|
||||
# Scroll the tests container
|
||||
page.evaluate('() => { const el = document.querySelector(".tests-view"); if (el) el.scrollTop = 200; }')
|
||||
page.wait_for_timeout(200)
|
||||
before = page.evaluate('() => document.querySelector(".tests-view")?.scrollTop ?? 0')
|
||||
|
||||
_click_nav(page, 'Home')
|
||||
_click_nav(page, 'Tests')
|
||||
|
||||
after = page.evaluate('() => document.querySelector(".tests-view")?.scrollTop ?? 0')
|
||||
assert after == before, f'scroll not preserved: was {before}, now {after}'
|
||||
finally:
|
||||
ctx.close()
|
||||
|
||||
|
||||
def test_all_views_stay_in_dom():
|
||||
"""After visiting nyx and tests, both stay in DOM (hidden not removed)."""
|
||||
page, ctx = _page('/nyx')
|
||||
try:
|
||||
expect(page.locator('.toolbar-group')).to_have_count(4, timeout=5000)
|
||||
|
||||
_click_nav(page, 'Tests')
|
||||
# AgentsView should still be in DOM (just hidden)
|
||||
assert page.locator('.agents-view').count() > 0, 'AgentsView removed from DOM'
|
||||
|
||||
_click_nav(page, 'nyx')
|
||||
# TestsView should still be in DOM
|
||||
assert page.locator('.tests-view').count() > 0, 'TestsView removed from DOM'
|
||||
finally:
|
||||
ctx.close()
|
||||
|
||||
|
||||
# Test registry
|
||||
TESTS = {
|
||||
'ui_toolbar_nyx_all_groups': test_toolbar_nyx_has_all_groups,
|
||||
'ui_toolbar_tests_two_groups': test_toolbar_tests_has_two_groups,
|
||||
'ui_toolbar_home_one_group': test_toolbar_home_has_one_group,
|
||||
'ui_toolbar_roundtrip': test_toolbar_survives_roundtrip,
|
||||
'ui_scroll_preserved': test_scroll_preserved_across_navigation,
|
||||
'ui_views_stay_in_dom': test_all_views_stay_in_dom,
|
||||
}
|
||||
Reference in New Issue
Block a user