diff --git a/agent/mcp_app.py b/agent/mcp_app.py deleted file mode 100644 index abc2b6a..0000000 --- a/agent/mcp_app.py +++ /dev/null @@ -1,249 +0,0 @@ -"""Standalone MCP app — proxies tool calls to assay-runtime. Supports Streamable HTTP + SSE.""" - -import json -import logging -import os -from pathlib import Path - -from dotenv import load_dotenv -load_dotenv(Path(__file__).parent.parent / ".env") - -import httpx -from fastapi import FastAPI, Request, Depends -from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer - -from mcp.server import Server -from mcp.server.sse import SseServerTransport -from mcp.server.streamable_http import StreamableHTTPServerTransport -from mcp.types import TextContent, Tool - -logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(message)s", datefmt="%H:%M:%S") -log = logging.getLogger("mcp-proxy") - -# Config -RUNTIME_URL = os.environ.get("RUNTIME_URL", "http://assay-runtime") -SERVICE_TOKENS = set(filter(None, os.environ.get("SERVICE_TOKENS", "").split(","))) -SERVICE_TOKEN = os.environ.get("SERVICE_TOKENS", "").split(",")[0] if os.environ.get("SERVICE_TOKENS") else "" - -app = FastAPI(title="assay-mcp") -_security = HTTPBearer() - - -async def require_auth(creds: HTTPAuthorizationCredentials = Depends(_security)): - if creds.credentials not in SERVICE_TOKENS: - from fastapi import HTTPException - raise HTTPException(status_code=401, detail="Invalid token") - return {"sub": "service", "source": "service_token"} - - -@app.get("/health") -async def health(): - return {"status": "ok", "service": "mcp-proxy"} - - -# --- MCP Server --- - -mcp_server = Server("assay") -_mcp_transport = SseServerTransport("/mcp/messages/") - - -async def _proxy_get(path: str, params: dict = None) -> dict: - """GET request to runtime.""" - try: - async with httpx.AsyncClient(timeout=30) as client: - resp = await client.get( - f"{RUNTIME_URL}{path}", - params=params, - headers={"Authorization": f"Bearer {SERVICE_TOKEN}"}, - ) - if resp.status_code == 200: - return resp.json() - try: - return {"error": resp.json().get("detail", resp.text)} - except Exception: - return {"error": resp.text} - except Exception as e: - return {"error": f"Runtime unreachable: {e}"} - - -async def _proxy_post(path: str, body: dict = None) -> dict: - """POST request to runtime.""" - try: - async with httpx.AsyncClient(timeout=30) as client: - resp = await client.post( - f"{RUNTIME_URL}{path}", - json=body or {}, - headers={"Authorization": f"Bearer {SERVICE_TOKEN}"}, - ) - if resp.status_code == 200: - return resp.json() - try: - return {"error": resp.json().get("detail", resp.text)} - except Exception: - return {"error": resp.text} - except Exception as e: - return {"error": f"Runtime unreachable: {e}"} - - -@mcp_server.list_tools() -async def list_tools(): - return [ - Tool(name="assay_send", description="Send a message to the cognitive agent and get a response.", - inputSchema={"type": "object", "properties": { - "text": {"type": "string", "description": "Message text to send"}, - "database": {"type": "string", "description": "Optional: database name for query_db context"}, - }, "required": ["text"]}), - Tool(name="assay_trace", description="Get recent trace events from the pipeline (HUD events, tool calls, audit).", - inputSchema={"type": "object", "properties": { - "last": {"type": "integer", "description": "Number of recent events (default 20)", "default": 20}, - "filter": {"type": "string", "description": "Comma-separated event types to filter (e.g. 'tool_call,controls')"}, - }}), - Tool(name="assay_history", description="Get recent chat messages from the active session.", - inputSchema={"type": "object", "properties": { - "last": {"type": "integer", "description": "Number of recent messages (default 20)", "default": 20}, - }}), - Tool(name="assay_state", description="Get the current memorizer state (mood, topic, language, facts).", - inputSchema={"type": "object", "properties": {}}), - Tool(name="assay_clear", description="Clear the active session (history, state, controls).", - inputSchema={"type": "object", "properties": {}}), - Tool(name="assay_graph", description="Get the active graph definition (nodes, edges, description).", - inputSchema={"type": "object", "properties": {}}), - Tool(name="assay_graph_list", description="List all available graph definitions.", - inputSchema={"type": "object", "properties": {}}), - Tool(name="assay_graph_switch", description="Switch the active graph for new sessions.", - inputSchema={"type": "object", "properties": { - "name": {"type": "string", "description": "Graph name to switch to"}, - }, "required": ["name"]}), - ] - - -@mcp_server.call_tool() -async def call_tool(name: str, arguments: dict): - if name == "assay_send": - text = arguments.get("text", "") - if not text: - return [TextContent(type="text", text="ERROR: Missing 'text' argument.")] - - # Step 1: check runtime is ready - check = await _proxy_post("/api/send/check") - if "error" in check: - return [TextContent(type="text", text=f"ERROR: {check['error']}")] - if not check.get("ready"): - return [TextContent(type="text", text=f"ERROR: {check.get('reason', 'unknown')}: {check.get('detail', '')}")] - - # Step 2: queue message - send = await _proxy_post("/api/send", {"text": text}) - if "error" in send: - return [TextContent(type="text", text=f"ERROR: {send['error']}")] - msg_id = send.get("id", "") - - # Step 3: poll for result (max 30s) - import asyncio - for _ in range(60): - await asyncio.sleep(0.5) - result = await _proxy_get("/api/result") - if "error" in result: - return [TextContent(type="text", text=f"ERROR: {result['error']}")] - status = result.get("status", "") - if status == "done": - return [TextContent(type="text", text=result.get("response", "[no response]"))] - if status == "error": - return [TextContent(type="text", text=f"ERROR: {result.get('detail', 'pipeline failed')}")] - return [TextContent(type="text", text="ERROR: Pipeline timeout (30s)")] - - elif name == "assay_trace": - last = arguments.get("last", 20) - event_filter = arguments.get("filter", "") - params = {"last": last} - if event_filter: - params["filter"] = event_filter - result = await _proxy_get("/api/trace", params) - if "error" in result: - return [TextContent(type="text", text=f"ERROR: {result['error']}")] - # Format trace events compactly - events = result.get("lines", []) - lines = [] - for e in events: - node = e.get("node", "?") - event = e.get("event", "?") - detail = e.get("detail", "") - line = f"{node:12s} {event:20s} {detail}" - lines.append(line.rstrip()) - return [TextContent(type="text", text="\n".join(lines) if lines else "(no events)")] - - elif name == "assay_history": - last = arguments.get("last", 20) - result = await _proxy_get("/api/history", {"last": last}) - if "error" in result: - return [TextContent(type="text", text=f"ERROR: {result['error']}")] - return [TextContent(type="text", text=json.dumps(result.get("messages", []), indent=2))] - - elif name == "assay_state": - result = await _proxy_get("/api/state") - if "error" in result: - return [TextContent(type="text", text=f"ERROR: {result['error']}")] - return [TextContent(type="text", text=json.dumps(result, indent=2))] - - elif name == "assay_clear": - result = await _proxy_post("/api/clear") - if "error" in result: - return [TextContent(type="text", text=f"ERROR: {result['error']}")] - return [TextContent(type="text", text="Session cleared.")] - - elif name == "assay_graph": - result = await _proxy_get("/api/graph/active") - if "error" in result: - return [TextContent(type="text", text=f"ERROR: {result['error']}")] - return [TextContent(type="text", text=json.dumps(result, indent=2))] - - elif name == "assay_graph_list": - result = await _proxy_get("/api/graph/list") - if "error" in result: - return [TextContent(type="text", text=f"ERROR: {result['error']}")] - return [TextContent(type="text", text=json.dumps(result.get("graphs", []), indent=2))] - - elif name == "assay_graph_switch": - gname = arguments.get("name", "") - if not gname: - return [TextContent(type="text", text="ERROR: Missing 'name' argument.")] - result = await _proxy_post("/api/graph/switch", {"name": gname}) - if "error" in result: - return [TextContent(type="text", text=f"ERROR: {result['error']}")] - return [TextContent(type="text", text=f"Switched to graph '{result.get('name', gname)}'. New sessions will use this graph.")] - - else: - return [TextContent(type="text", text=f"Unknown tool: {name}")] - - -# Mount MCP Streamable HTTP endpoint (primary — stateless, survives pod restarts) -_http_transports: dict[str, StreamableHTTPServerTransport] = {} -_http_tasks: dict[str, any] = {} - -@app.api_route("/mcp", methods=["GET", "POST", "DELETE"]) -async def mcp_http(request: Request, user=Depends(require_auth)): - import asyncio - # Get or create session-scoped transport - session_id = request.headers.get("mcp-session-id", "default") - if session_id not in _http_transports: - transport = StreamableHTTPServerTransport(mcp_session_id=session_id) - _http_transports[session_id] = transport - - async def _run(): - async with transport.connect() as streams: - await mcp_server.run(streams[0], streams[1], mcp_server.create_initialization_options()) - _http_tasks[session_id] = asyncio.create_task(_run()) - - transport = _http_transports[session_id] - await transport.handle_request(request.scope, request.receive, request._send) - - -# Mount MCP SSE endpoints (legacy fallback) -@app.get("/mcp/sse") -async def mcp_sse(request: Request, user=Depends(require_auth)): - async with _mcp_transport.connect_sse(request.scope, request.receive, request._send) as streams: - await mcp_server.run(streams[0], streams[1], mcp_server.create_initialization_options()) - - -@app.post("/mcp/messages/") -async def mcp_messages(request: Request, user=Depends(require_auth)): - await _mcp_transport.handle_post_message(request.scope, request.receive, request._send)