New 'matrix' suite runs same API tests with different LLM model configs: - Variants: gemini-flash (baseline), haiku, gpt-4o-mini - Tests: eras_query (SQL correctness), eras_artifact (data output), social_reflex (fast path) - Posts results as test_name[variant] to /tests dashboard - All 9 combos passing (6/9 verified locally, ~35s for ERAS tests) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
150 lines
5.0 KiB
Python
150 lines
5.0 KiB
Python
"""Model matrix tests — run the same API test with different LLM model configs.
|
||
|
||
Each variant defines model overrides for specific node roles. The test runner
|
||
generates one test per (base_test × variant) combination. Results are posted
|
||
with the variant name in brackets, e.g. "eras_umsatz_api[haiku]".
|
||
|
||
Usage via run_tests.py:
|
||
python tests/run_tests.py matrix # all variants × all tests
|
||
python tests/run_tests.py matrix/eras_query[haiku] # single combo
|
||
"""
|
||
|
||
import json
|
||
import os
|
||
import sys
|
||
import urllib.request
|
||
|
||
sys.path.insert(0, os.path.dirname(__file__))
|
||
|
||
_api_url = os.environ.get('ASSAY_API', 'http://assay-runtime-test:8000/api')
|
||
ASSAY_BASE = _api_url.removesuffix('/api') if _api_url.endswith('/api') else _api_url
|
||
SERVICE_TOKEN = '7Oorb9S3OpwFyWgm4zi_Tq7GeamefbjjTgooPVPWAwPDOf6B4TvgvQlLbhmT4DjsqBS_D1g'
|
||
|
||
|
||
# --- Model variants ---
|
||
# Each variant overrides specific node models. Omitted roles keep graph defaults.
|
||
|
||
VARIANTS = {
|
||
'gemini-flash': {
|
||
# Graph default — this is the baseline
|
||
},
|
||
'haiku': {
|
||
'pa': 'anthropic/claude-haiku-4.5',
|
||
'expert_eras': 'anthropic/claude-haiku-4.5',
|
||
'interpreter': 'anthropic/claude-haiku-4.5',
|
||
},
|
||
'gpt-4o-mini': {
|
||
'pa': 'openai/gpt-4o-mini',
|
||
'expert_eras': 'openai/gpt-4o-mini',
|
||
'interpreter': 'openai/gpt-4o-mini',
|
||
},
|
||
}
|
||
|
||
|
||
# --- API helper with model overrides ---
|
||
|
||
def api_chat(text: str, models: dict = None, timeout: int = 120) -> list[tuple[str, str]]:
|
||
"""Send message via /api/chat with optional model overrides. Returns SSE events."""
|
||
body = {'content': text}
|
||
if models:
|
||
body['models'] = models
|
||
|
||
payload = json.dumps(body).encode()
|
||
req = urllib.request.Request(
|
||
f'{ASSAY_BASE}/api/chat',
|
||
data=payload,
|
||
method='POST',
|
||
headers={
|
||
'Authorization': f'Bearer {SERVICE_TOKEN}',
|
||
'Content-Type': 'application/json',
|
||
},
|
||
)
|
||
resp = urllib.request.urlopen(req, timeout=timeout)
|
||
output = resp.read().decode('utf-8')
|
||
|
||
events = []
|
||
for block in output.split('\n\n'):
|
||
event_type, data = '', ''
|
||
for line in block.strip().split('\n'):
|
||
if line.startswith('event: '):
|
||
event_type = line[7:]
|
||
elif line.startswith('data: '):
|
||
data = line[6:]
|
||
if event_type and data:
|
||
events.append((event_type, data))
|
||
return events
|
||
|
||
|
||
# --- Base tests (same logic, different models) ---
|
||
|
||
def _test_eras_query(models: dict):
|
||
"""ERAS query produces correct SQL path (artikelposition, not geraeteverbraeuche)."""
|
||
events = api_chat('Zeig mir die 5 groessten Kunden nach Umsatz', models=models)
|
||
|
||
tool_calls = [e for e in events if e[0] == 'hud' and 'tool_call' in e[1] and 'query_db' in e[1]]
|
||
assert tool_calls, 'no query_db tool call found'
|
||
|
||
query_data = json.loads(tool_calls[0][1])
|
||
query = query_data.get('args', {}).get('query', '')
|
||
|
||
assert 'artikelposition' in query.lower(), f'query missing artikelposition: {query}'
|
||
assert 'geraeteverbraeuche' not in query.lower(), f'query uses wrong table: {query}'
|
||
|
||
# Check it completed
|
||
done_events = [e for e in events if e[0] == 'done']
|
||
assert done_events, 'no done event'
|
||
|
||
|
||
def _test_eras_artifact(models: dict):
|
||
"""ERAS query produces artifact with data rows."""
|
||
events = api_chat('Zeig mir die 5 groessten Kunden nach Umsatz', models=models)
|
||
|
||
artifact_events = [e for e in events if e[0] == 'artifacts']
|
||
assert artifact_events, 'no artifact event'
|
||
|
||
artifacts = json.loads(artifact_events[0][1]).get('artifacts', [])
|
||
assert len(artifacts) >= 1, f'expected >=1 artifact, got {len(artifacts)}'
|
||
|
||
has_data = any(
|
||
art.get('data', {}).get('fields') or art.get('data', {}).get('rows')
|
||
for art in artifacts
|
||
)
|
||
assert has_data, 'no artifact contains data'
|
||
|
||
|
||
def _test_social_reflex(models: dict):
|
||
"""Social greeting takes reflex path (fast, no expert)."""
|
||
events = api_chat('Hallo!', models=models)
|
||
|
||
# Should get a response (delta events)
|
||
deltas = [e for e in events if e[0] == 'delta']
|
||
assert deltas, 'no delta events'
|
||
|
||
# Should complete
|
||
done = [e for e in events if e[0] == 'done']
|
||
assert done, 'no done event'
|
||
|
||
# Should NOT call any tools
|
||
tool_calls = [e for e in events if e[0] == 'hud' and 'tool_call' in e[1]]
|
||
assert not tool_calls, f'reflex path should not use tools, got {len(tool_calls)} calls'
|
||
|
||
|
||
# --- Test registry: base tests that get multiplied by variants ---
|
||
|
||
BASE_TESTS = {
|
||
'eras_query': _test_eras_query,
|
||
'eras_artifact': _test_eras_artifact,
|
||
'social_reflex': _test_social_reflex,
|
||
}
|
||
|
||
|
||
def get_matrix_tests() -> dict:
|
||
"""Generate test×variant matrix. Returns {name: callable} dict for run_tests.py."""
|
||
tests = {}
|
||
for variant_name, models in VARIANTS.items():
|
||
for test_name, test_fn in BASE_TESTS.items():
|
||
combo_name = f'{test_name}[{variant_name}]'
|
||
# Capture current values in closure
|
||
tests[combo_name] = (lambda fn, m: lambda: fn(m))(test_fn, models)
|
||
return tests
|