Remove old /ws chat endpoint (replaced by Streamable HTTP)
Kept /ws/test and /ws/trace debug WebSockets for dev tooling. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
e94f1c437c
commit
1e64b0a58c
227
agent/api.py
227
agent/api.py
@ -7,7 +7,7 @@ import logging
|
||||
from asyncio import Queue
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import Depends, HTTPException, Query, WebSocket, WebSocketDisconnect
|
||||
from fastapi import Depends, HTTPException, Query, Request, WebSocket, WebSocketDisconnect
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
import httpx
|
||||
@ -110,10 +110,25 @@ def register_routes(app):
|
||||
async def shutdown():
|
||||
await db_sessions.close_pool()
|
||||
|
||||
@app.get("/health")
|
||||
@app.get("/api/health")
|
||||
@app.get("/health") # K8s probes
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.get("/api/health-stream")
|
||||
async def health_stream(user=Depends(require_auth)):
|
||||
"""SSE heartbeat stream — client uses this for presence detection."""
|
||||
async def generate():
|
||||
try:
|
||||
while True:
|
||||
yield f"event: heartbeat\ndata: {json.dumps({'ts': int(asyncio.get_event_loop().time()), 'sessions': len(_sessions)})}\n\n"
|
||||
await asyncio.sleep(15)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
return StreamingResponse(generate(), media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
|
||||
|
||||
# --- Session CRUD ---
|
||||
|
||||
@app.post("/api/sessions")
|
||||
@ -202,124 +217,6 @@ def register_routes(app):
|
||||
log.info("[api] created persistent runtime (legacy)")
|
||||
return _active_runtime
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def ws_endpoint(ws: WebSocket, token: str | None = Query(None),
|
||||
access_token: str | None = Query(None),
|
||||
session: str | None = Query(None)):
|
||||
user_claims = {"sub": "anonymous"}
|
||||
if AUTH_ENABLED and token:
|
||||
try:
|
||||
user_claims = await _validate_token(token)
|
||||
if not user_claims.get("name") and access_token:
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(f"{ZITADEL_ISSUER}/oidc/v1/userinfo",
|
||||
headers={"Authorization": f"Bearer {access_token}"})
|
||||
if resp.status_code == 200:
|
||||
info = resp.json()
|
||||
log.info(f"[auth] userinfo enrichment: {info}")
|
||||
user_claims["name"] = info.get("name")
|
||||
user_claims["preferred_username"] = info.get("preferred_username")
|
||||
user_claims["email"] = info.get("email")
|
||||
except HTTPException:
|
||||
await ws.close(code=4001, reason="Invalid token")
|
||||
return
|
||||
origin = ws.headers.get("origin", ws.headers.get("host", ""))
|
||||
await ws.accept()
|
||||
|
||||
# Get or create session, attach WS
|
||||
runtime = await _get_or_create_session(
|
||||
session_id=session, user_claims=user_claims, origin=origin)
|
||||
runtime.update_identity(user_claims, origin)
|
||||
runtime.attach_ws(ws)
|
||||
|
||||
# Tell client which session they're on + ready signal
|
||||
try:
|
||||
await ws.send_text(json.dumps({
|
||||
"type": "session_info", "session_id": runtime.session_id,
|
||||
"graph": runtime.graph.get("name", "unknown"),
|
||||
"history_len": len(runtime.history)}))
|
||||
await ws.send_text(json.dumps({
|
||||
"type": "ready", "session_id": runtime.session_id}))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
while True:
|
||||
data = await ws.receive_text()
|
||||
msg = json.loads(data)
|
||||
msg_type = msg.get("type", "")
|
||||
rt = _sessions.get(runtime.session_id, runtime)
|
||||
|
||||
# Ping — keep-alive, no processing
|
||||
if msg_type == "ping":
|
||||
continue
|
||||
|
||||
# Auth message from nyx — already authed via query param, ignore
|
||||
if msg_type in ("auth", "connect"):
|
||||
continue
|
||||
|
||||
try:
|
||||
if msg_type == "action":
|
||||
action = msg.get("action", "unknown")
|
||||
data_payload = msg.get("data")
|
||||
if hasattr(rt, 'use_frames') and rt.use_frames:
|
||||
action_text = f"ACTION:{action}"
|
||||
if data_payload:
|
||||
action_text += f"|data:{json.dumps(data_payload)}"
|
||||
await rt.handle_message(action_text)
|
||||
else:
|
||||
await rt.handle_action(action, data_payload)
|
||||
|
||||
elif msg_type == "cancel_process":
|
||||
rt.process_manager.cancel(msg.get("pid", 0))
|
||||
|
||||
elif msg_type == "new":
|
||||
# New session requested
|
||||
rt.detach_ws()
|
||||
new_rt = await _get_or_create_session(
|
||||
user_claims=user_claims, origin=origin)
|
||||
new_rt.attach_ws(ws)
|
||||
runtime = new_rt
|
||||
rt = new_rt
|
||||
await ws.send_text(json.dumps({
|
||||
"type": "session_info", "session_id": rt.session_id,
|
||||
"graph": rt.graph.get("name", "unknown"),
|
||||
"history_len": 0}))
|
||||
await ws.send_text(json.dumps({"type": "cleared"}))
|
||||
continue
|
||||
|
||||
elif msg_type == "stop":
|
||||
# Cancel running pipeline
|
||||
if _pipeline_task and not _pipeline_task.done():
|
||||
_pipeline_task.cancel()
|
||||
continue
|
||||
|
||||
elif msg_type == "message":
|
||||
# nyx format: {type: 'message', content: '...'}
|
||||
text = msg.get("content", "").strip()
|
||||
if text:
|
||||
await rt.handle_message(text, dashboard=msg.get("dashboard"))
|
||||
|
||||
else:
|
||||
# Legacy/assay format: {text: '...', dashboard: [...]}
|
||||
text = msg.get("text", "").strip()
|
||||
if text:
|
||||
await rt.handle_message(text, dashboard=msg.get("dashboard"))
|
||||
|
||||
# Auto-save after each processed message
|
||||
asyncio.create_task(_save_session(rt))
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
log.error(f"[ws] handler error: {e}\n{traceback.format_exc()}")
|
||||
try:
|
||||
await ws.send_text(json.dumps({"type": "hud", "node": "runtime", "event": "error", "detail": str(e)[:200]}))
|
||||
except Exception:
|
||||
pass
|
||||
except WebSocketDisconnect:
|
||||
runtime.detach_ws()
|
||||
log.info(f"[api] WS disconnected — session {runtime.session_id} stays alive")
|
||||
|
||||
async def _auth_debug_ws(ws: WebSocket, token: str | None) -> bool:
|
||||
"""Validate token for debug WS. Returns True if auth OK."""
|
||||
if not AUTH_ENABLED:
|
||||
@ -370,6 +267,96 @@ def register_routes(app):
|
||||
_trace_ws_clients.remove(ws)
|
||||
log.info(f"[api] /ws/trace disconnected ({len(_trace_ws_clients)} clients)")
|
||||
|
||||
# --- Streamable HTTP chat endpoint ---
|
||||
|
||||
@app.post("/api/chat")
|
||||
async def api_chat(request: Request, user=Depends(require_auth)):
|
||||
"""Send a message and receive streaming SSE response."""
|
||||
body = await request.json()
|
||||
session_id = body.get("session_id")
|
||||
text = body.get("content", "").strip()
|
||||
action = body.get("action")
|
||||
action_data = body.get("action_data")
|
||||
dashboard = body.get("dashboard")
|
||||
|
||||
if not text and not action:
|
||||
raise HTTPException(status_code=400, detail="Missing 'content' or 'action'")
|
||||
|
||||
origin = request.headers.get("origin", request.headers.get("host", ""))
|
||||
rt = await _get_or_create_session(
|
||||
session_id=session_id, user_claims=user, origin=origin)
|
||||
rt.update_identity(user, origin)
|
||||
|
||||
# Attach SSE queue to sink for this request
|
||||
q: Queue = Queue(maxsize=500)
|
||||
rt.sink.attach_queue(q)
|
||||
|
||||
async def run_and_close():
|
||||
try:
|
||||
if action:
|
||||
if hasattr(rt, 'use_frames') and rt.use_frames:
|
||||
action_text = f"ACTION:{action}"
|
||||
if action_data:
|
||||
action_text += f"|data:{json.dumps(action_data)}"
|
||||
await rt.handle_message(action_text)
|
||||
else:
|
||||
await rt.handle_action(action, action_data)
|
||||
else:
|
||||
await rt.handle_message(text, dashboard=dashboard)
|
||||
# Auto-save
|
||||
await _save_session(rt)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
log.error(f"[chat] handler error: {e}\n{traceback.format_exc()}")
|
||||
try:
|
||||
q.put_nowait({"type": "error", "detail": str(e)[:200]})
|
||||
except asyncio.QueueFull:
|
||||
pass
|
||||
finally:
|
||||
# Signal end-of-stream
|
||||
q.put_nowait(None)
|
||||
|
||||
# Run pipeline in background task
|
||||
task = asyncio.create_task(run_and_close())
|
||||
|
||||
async def generate():
|
||||
try:
|
||||
while True:
|
||||
event = await q.get()
|
||||
if event is None:
|
||||
break
|
||||
event_type = event.get("type", "message")
|
||||
yield f"event: {event_type}\ndata: {json.dumps(event)}\n\n"
|
||||
except asyncio.CancelledError:
|
||||
task.cancel()
|
||||
finally:
|
||||
rt.sink.detach_queue()
|
||||
|
||||
return StreamingResponse(generate(), media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
|
||||
|
||||
@app.get("/api/session")
|
||||
async def api_session(token: str = Query(None), session: str = Query(None),
|
||||
user=Depends(require_auth)):
|
||||
"""Get or create session — replaces WS connect handshake."""
|
||||
rt = await _get_or_create_session(
|
||||
session_id=session, user_claims=user, origin="")
|
||||
return {
|
||||
"session_id": rt.session_id,
|
||||
"graph": rt.graph.get("name", "unknown"),
|
||||
"history_len": len(rt.history),
|
||||
"status": "ready",
|
||||
"memorizer": rt.memorizer.state,
|
||||
}
|
||||
|
||||
@app.post("/api/stop")
|
||||
async def api_stop(user=Depends(require_auth)):
|
||||
"""Cancel running pipeline."""
|
||||
if _pipeline_task and not _pipeline_task.done():
|
||||
_pipeline_task.cancel()
|
||||
return {"stopped": True}
|
||||
return {"stopped": False, "detail": "No pipeline running"}
|
||||
|
||||
@app.get("/api/events")
|
||||
async def sse_events(user=Depends(require_auth)):
|
||||
q: Queue = Queue(maxsize=100)
|
||||
|
||||
Reference in New Issue
Block a user