This repository has been archived on 2026-04-03. You can view files and clone it, but cannot push or open issues or pull requests.
agent-runtime/tests/test_matrix.py
Nico 4e679a3ad9 Add model matrix test suite: 3 tests × 3 variants = 9 combos
New 'matrix' suite runs same API tests with different LLM model configs:
- Variants: gemini-flash (baseline), haiku, gpt-4o-mini
- Tests: eras_query (SQL correctness), eras_artifact (data output), social_reflex (fast path)
- Posts results as test_name[variant] to /tests dashboard
- All 9 combos passing (6/9 verified locally, ~35s for ERAS tests)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-03 18:12:24 +02:00

150 lines
5.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Model matrix tests — run the same API test with different LLM model configs.
Each variant defines model overrides for specific node roles. The test runner
generates one test per (base_test × variant) combination. Results are posted
with the variant name in brackets, e.g. "eras_umsatz_api[haiku]".
Usage via run_tests.py:
python tests/run_tests.py matrix # all variants × all tests
python tests/run_tests.py matrix/eras_query[haiku] # single combo
"""
import json
import os
import sys
import urllib.request
sys.path.insert(0, os.path.dirname(__file__))
_api_url = os.environ.get('ASSAY_API', 'http://assay-runtime-test:8000/api')
ASSAY_BASE = _api_url.removesuffix('/api') if _api_url.endswith('/api') else _api_url
SERVICE_TOKEN = '7Oorb9S3OpwFyWgm4zi_Tq7GeamefbjjTgooPVPWAwPDOf6B4TvgvQlLbhmT4DjsqBS_D1g'
# --- Model variants ---
# Each variant overrides specific node models. Omitted roles keep graph defaults.
VARIANTS = {
'gemini-flash': {
# Graph default — this is the baseline
},
'haiku': {
'pa': 'anthropic/claude-haiku-4.5',
'expert_eras': 'anthropic/claude-haiku-4.5',
'interpreter': 'anthropic/claude-haiku-4.5',
},
'gpt-4o-mini': {
'pa': 'openai/gpt-4o-mini',
'expert_eras': 'openai/gpt-4o-mini',
'interpreter': 'openai/gpt-4o-mini',
},
}
# --- API helper with model overrides ---
def api_chat(text: str, models: dict = None, timeout: int = 120) -> list[tuple[str, str]]:
"""Send message via /api/chat with optional model overrides. Returns SSE events."""
body = {'content': text}
if models:
body['models'] = models
payload = json.dumps(body).encode()
req = urllib.request.Request(
f'{ASSAY_BASE}/api/chat',
data=payload,
method='POST',
headers={
'Authorization': f'Bearer {SERVICE_TOKEN}',
'Content-Type': 'application/json',
},
)
resp = urllib.request.urlopen(req, timeout=timeout)
output = resp.read().decode('utf-8')
events = []
for block in output.split('\n\n'):
event_type, data = '', ''
for line in block.strip().split('\n'):
if line.startswith('event: '):
event_type = line[7:]
elif line.startswith('data: '):
data = line[6:]
if event_type and data:
events.append((event_type, data))
return events
# --- Base tests (same logic, different models) ---
def _test_eras_query(models: dict):
"""ERAS query produces correct SQL path (artikelposition, not geraeteverbraeuche)."""
events = api_chat('Zeig mir die 5 groessten Kunden nach Umsatz', models=models)
tool_calls = [e for e in events if e[0] == 'hud' and 'tool_call' in e[1] and 'query_db' in e[1]]
assert tool_calls, 'no query_db tool call found'
query_data = json.loads(tool_calls[0][1])
query = query_data.get('args', {}).get('query', '')
assert 'artikelposition' in query.lower(), f'query missing artikelposition: {query}'
assert 'geraeteverbraeuche' not in query.lower(), f'query uses wrong table: {query}'
# Check it completed
done_events = [e for e in events if e[0] == 'done']
assert done_events, 'no done event'
def _test_eras_artifact(models: dict):
"""ERAS query produces artifact with data rows."""
events = api_chat('Zeig mir die 5 groessten Kunden nach Umsatz', models=models)
artifact_events = [e for e in events if e[0] == 'artifacts']
assert artifact_events, 'no artifact event'
artifacts = json.loads(artifact_events[0][1]).get('artifacts', [])
assert len(artifacts) >= 1, f'expected >=1 artifact, got {len(artifacts)}'
has_data = any(
art.get('data', {}).get('fields') or art.get('data', {}).get('rows')
for art in artifacts
)
assert has_data, 'no artifact contains data'
def _test_social_reflex(models: dict):
"""Social greeting takes reflex path (fast, no expert)."""
events = api_chat('Hallo!', models=models)
# Should get a response (delta events)
deltas = [e for e in events if e[0] == 'delta']
assert deltas, 'no delta events'
# Should complete
done = [e for e in events if e[0] == 'done']
assert done, 'no done event'
# Should NOT call any tools
tool_calls = [e for e in events if e[0] == 'hud' and 'tool_call' in e[1]]
assert not tool_calls, f'reflex path should not use tools, got {len(tool_calls)} calls'
# --- Test registry: base tests that get multiplied by variants ---
BASE_TESTS = {
'eras_query': _test_eras_query,
'eras_artifact': _test_eras_artifact,
'social_reflex': _test_social_reflex,
}
def get_matrix_tests() -> dict:
"""Generate test×variant matrix. Returns {name: callable} dict for run_tests.py."""
tests = {}
for variant_name, models in VARIANTS.items():
for test_name, test_fn in BASE_TESTS.items():
combo_name = f'{test_name}[{variant_name}]'
# Capture current values in closure
tests[combo_name] = (lambda fn, m: lambda: fn(m))(test_fn, models)
return tests