Add model matrix test suite: 3 tests × 3 variants = 9 combos
New 'matrix' suite runs same API tests with different LLM model configs: - Variants: gemini-flash (baseline), haiku, gpt-4o-mini - Tests: eras_query (SQL correctness), eras_artifact (data output), social_reflex (fast path) - Posts results as test_name[variant] to /tests dashboard - All 9 combos passing (6/9 verified locally, ~35s for ERAS tests) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
58734c34d2
commit
4e679a3ad9
@ -15,6 +15,8 @@ Test names: suite/name (without the suite prefix in the test registry).
|
|||||||
condition_reflex, condition_tool_output,
|
condition_reflex, condition_tool_output,
|
||||||
frame_trace_reflex, frame_trace_expert, frame_trace_expert_with_interpreter
|
frame_trace_reflex, frame_trace_expert, frame_trace_expert_with_interpreter
|
||||||
api tests: health, eras_umsatz_api, eras_umsatz_artifact
|
api tests: health, eras_umsatz_api, eras_umsatz_artifact
|
||||||
|
matrix tests: eras_query[variant], eras_artifact[variant], social_reflex[variant]
|
||||||
|
variants: gemini-flash, haiku, gpt-4o-mini
|
||||||
roundtrip tests: nyx_loads, inject_artifact, inject_message, full_chat, full_eras
|
roundtrip tests: nyx_loads, inject_artifact, inject_message, full_chat, full_eras
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -108,9 +110,17 @@ def get_engine_tests() -> dict:
|
|||||||
return TESTS
|
return TESTS
|
||||||
|
|
||||||
|
|
||||||
|
def get_matrix_tests() -> dict:
|
||||||
|
"""Load model matrix tests (real LLM calls, test×variant combos)."""
|
||||||
|
sys.path.insert(0, os.path.dirname(__file__))
|
||||||
|
from test_matrix import get_matrix_tests
|
||||||
|
return get_matrix_tests()
|
||||||
|
|
||||||
|
|
||||||
SUITES = {
|
SUITES = {
|
||||||
'engine': get_engine_tests,
|
'engine': get_engine_tests,
|
||||||
'api': get_api_tests,
|
'api': get_api_tests,
|
||||||
|
'matrix': get_matrix_tests,
|
||||||
'roundtrip': get_roundtrip_tests,
|
'roundtrip': get_roundtrip_tests,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
149
tests/test_matrix.py
Normal file
149
tests/test_matrix.py
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
"""Model matrix tests — run the same API test with different LLM model configs.
|
||||||
|
|
||||||
|
Each variant defines model overrides for specific node roles. The test runner
|
||||||
|
generates one test per (base_test × variant) combination. Results are posted
|
||||||
|
with the variant name in brackets, e.g. "eras_umsatz_api[haiku]".
|
||||||
|
|
||||||
|
Usage via run_tests.py:
|
||||||
|
python tests/run_tests.py matrix # all variants × all tests
|
||||||
|
python tests/run_tests.py matrix/eras_query[haiku] # single combo
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import urllib.request
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.dirname(__file__))
|
||||||
|
|
||||||
|
_api_url = os.environ.get('ASSAY_API', 'http://assay-runtime-test:8000/api')
|
||||||
|
ASSAY_BASE = _api_url.removesuffix('/api') if _api_url.endswith('/api') else _api_url
|
||||||
|
SERVICE_TOKEN = '7Oorb9S3OpwFyWgm4zi_Tq7GeamefbjjTgooPVPWAwPDOf6B4TvgvQlLbhmT4DjsqBS_D1g'
|
||||||
|
|
||||||
|
|
||||||
|
# --- Model variants ---
|
||||||
|
# Each variant overrides specific node models. Omitted roles keep graph defaults.
|
||||||
|
|
||||||
|
VARIANTS = {
|
||||||
|
'gemini-flash': {
|
||||||
|
# Graph default — this is the baseline
|
||||||
|
},
|
||||||
|
'haiku': {
|
||||||
|
'pa': 'anthropic/claude-haiku-4.5',
|
||||||
|
'expert_eras': 'anthropic/claude-haiku-4.5',
|
||||||
|
'interpreter': 'anthropic/claude-haiku-4.5',
|
||||||
|
},
|
||||||
|
'gpt-4o-mini': {
|
||||||
|
'pa': 'openai/gpt-4o-mini',
|
||||||
|
'expert_eras': 'openai/gpt-4o-mini',
|
||||||
|
'interpreter': 'openai/gpt-4o-mini',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# --- API helper with model overrides ---
|
||||||
|
|
||||||
|
def api_chat(text: str, models: dict = None, timeout: int = 120) -> list[tuple[str, str]]:
|
||||||
|
"""Send message via /api/chat with optional model overrides. Returns SSE events."""
|
||||||
|
body = {'content': text}
|
||||||
|
if models:
|
||||||
|
body['models'] = models
|
||||||
|
|
||||||
|
payload = json.dumps(body).encode()
|
||||||
|
req = urllib.request.Request(
|
||||||
|
f'{ASSAY_BASE}/api/chat',
|
||||||
|
data=payload,
|
||||||
|
method='POST',
|
||||||
|
headers={
|
||||||
|
'Authorization': f'Bearer {SERVICE_TOKEN}',
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
},
|
||||||
|
)
|
||||||
|
resp = urllib.request.urlopen(req, timeout=timeout)
|
||||||
|
output = resp.read().decode('utf-8')
|
||||||
|
|
||||||
|
events = []
|
||||||
|
for block in output.split('\n\n'):
|
||||||
|
event_type, data = '', ''
|
||||||
|
for line in block.strip().split('\n'):
|
||||||
|
if line.startswith('event: '):
|
||||||
|
event_type = line[7:]
|
||||||
|
elif line.startswith('data: '):
|
||||||
|
data = line[6:]
|
||||||
|
if event_type and data:
|
||||||
|
events.append((event_type, data))
|
||||||
|
return events
|
||||||
|
|
||||||
|
|
||||||
|
# --- Base tests (same logic, different models) ---
|
||||||
|
|
||||||
|
def _test_eras_query(models: dict):
|
||||||
|
"""ERAS query produces correct SQL path (artikelposition, not geraeteverbraeuche)."""
|
||||||
|
events = api_chat('Zeig mir die 5 groessten Kunden nach Umsatz', models=models)
|
||||||
|
|
||||||
|
tool_calls = [e for e in events if e[0] == 'hud' and 'tool_call' in e[1] and 'query_db' in e[1]]
|
||||||
|
assert tool_calls, 'no query_db tool call found'
|
||||||
|
|
||||||
|
query_data = json.loads(tool_calls[0][1])
|
||||||
|
query = query_data.get('args', {}).get('query', '')
|
||||||
|
|
||||||
|
assert 'artikelposition' in query.lower(), f'query missing artikelposition: {query}'
|
||||||
|
assert 'geraeteverbraeuche' not in query.lower(), f'query uses wrong table: {query}'
|
||||||
|
|
||||||
|
# Check it completed
|
||||||
|
done_events = [e for e in events if e[0] == 'done']
|
||||||
|
assert done_events, 'no done event'
|
||||||
|
|
||||||
|
|
||||||
|
def _test_eras_artifact(models: dict):
|
||||||
|
"""ERAS query produces artifact with data rows."""
|
||||||
|
events = api_chat('Zeig mir die 5 groessten Kunden nach Umsatz', models=models)
|
||||||
|
|
||||||
|
artifact_events = [e for e in events if e[0] == 'artifacts']
|
||||||
|
assert artifact_events, 'no artifact event'
|
||||||
|
|
||||||
|
artifacts = json.loads(artifact_events[0][1]).get('artifacts', [])
|
||||||
|
assert len(artifacts) >= 1, f'expected >=1 artifact, got {len(artifacts)}'
|
||||||
|
|
||||||
|
has_data = any(
|
||||||
|
art.get('data', {}).get('fields') or art.get('data', {}).get('rows')
|
||||||
|
for art in artifacts
|
||||||
|
)
|
||||||
|
assert has_data, 'no artifact contains data'
|
||||||
|
|
||||||
|
|
||||||
|
def _test_social_reflex(models: dict):
|
||||||
|
"""Social greeting takes reflex path (fast, no expert)."""
|
||||||
|
events = api_chat('Hallo!', models=models)
|
||||||
|
|
||||||
|
# Should get a response (delta events)
|
||||||
|
deltas = [e for e in events if e[0] == 'delta']
|
||||||
|
assert deltas, 'no delta events'
|
||||||
|
|
||||||
|
# Should complete
|
||||||
|
done = [e for e in events if e[0] == 'done']
|
||||||
|
assert done, 'no done event'
|
||||||
|
|
||||||
|
# Should NOT call any tools
|
||||||
|
tool_calls = [e for e in events if e[0] == 'hud' and 'tool_call' in e[1]]
|
||||||
|
assert not tool_calls, f'reflex path should not use tools, got {len(tool_calls)} calls'
|
||||||
|
|
||||||
|
|
||||||
|
# --- Test registry: base tests that get multiplied by variants ---
|
||||||
|
|
||||||
|
BASE_TESTS = {
|
||||||
|
'eras_query': _test_eras_query,
|
||||||
|
'eras_artifact': _test_eras_artifact,
|
||||||
|
'social_reflex': _test_social_reflex,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_matrix_tests() -> dict:
|
||||||
|
"""Generate test×variant matrix. Returns {name: callable} dict for run_tests.py."""
|
||||||
|
tests = {}
|
||||||
|
for variant_name, models in VARIANTS.items():
|
||||||
|
for test_name, test_fn in BASE_TESTS.items():
|
||||||
|
combo_name = f'{test_name}[{variant_name}]'
|
||||||
|
# Capture current values in closure
|
||||||
|
tests[combo_name] = (lambda fn, m: lambda: fn(m))(test_fn, models)
|
||||||
|
return tests
|
||||||
Reference in New Issue
Block a user