Add model matrix test suite: 3 tests × 3 variants = 9 combos
New 'matrix' suite runs same API tests with different LLM model configs: - Variants: gemini-flash (baseline), haiku, gpt-4o-mini - Tests: eras_query (SQL correctness), eras_artifact (data output), social_reflex (fast path) - Posts results as test_name[variant] to /tests dashboard - All 9 combos passing (6/9 verified locally, ~35s for ERAS tests) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
58734c34d2
commit
4e679a3ad9
@ -15,6 +15,8 @@ Test names: suite/name (without the suite prefix in the test registry).
|
||||
condition_reflex, condition_tool_output,
|
||||
frame_trace_reflex, frame_trace_expert, frame_trace_expert_with_interpreter
|
||||
api tests: health, eras_umsatz_api, eras_umsatz_artifact
|
||||
matrix tests: eras_query[variant], eras_artifact[variant], social_reflex[variant]
|
||||
variants: gemini-flash, haiku, gpt-4o-mini
|
||||
roundtrip tests: nyx_loads, inject_artifact, inject_message, full_chat, full_eras
|
||||
"""
|
||||
|
||||
@ -108,9 +110,17 @@ def get_engine_tests() -> dict:
|
||||
return TESTS
|
||||
|
||||
|
||||
def get_matrix_tests() -> dict:
|
||||
"""Load model matrix tests (real LLM calls, test×variant combos)."""
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
from test_matrix import get_matrix_tests
|
||||
return get_matrix_tests()
|
||||
|
||||
|
||||
SUITES = {
|
||||
'engine': get_engine_tests,
|
||||
'api': get_api_tests,
|
||||
'matrix': get_matrix_tests,
|
||||
'roundtrip': get_roundtrip_tests,
|
||||
}
|
||||
|
||||
|
||||
149
tests/test_matrix.py
Normal file
149
tests/test_matrix.py
Normal file
@ -0,0 +1,149 @@
|
||||
"""Model matrix tests — run the same API test with different LLM model configs.
|
||||
|
||||
Each variant defines model overrides for specific node roles. The test runner
|
||||
generates one test per (base_test × variant) combination. Results are posted
|
||||
with the variant name in brackets, e.g. "eras_umsatz_api[haiku]".
|
||||
|
||||
Usage via run_tests.py:
|
||||
python tests/run_tests.py matrix # all variants × all tests
|
||||
python tests/run_tests.py matrix/eras_query[haiku] # single combo
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import urllib.request
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
|
||||
_api_url = os.environ.get('ASSAY_API', 'http://assay-runtime-test:8000/api')
|
||||
ASSAY_BASE = _api_url.removesuffix('/api') if _api_url.endswith('/api') else _api_url
|
||||
SERVICE_TOKEN = '7Oorb9S3OpwFyWgm4zi_Tq7GeamefbjjTgooPVPWAwPDOf6B4TvgvQlLbhmT4DjsqBS_D1g'
|
||||
|
||||
|
||||
# --- Model variants ---
|
||||
# Each variant overrides specific node models. Omitted roles keep graph defaults.
|
||||
|
||||
VARIANTS = {
|
||||
'gemini-flash': {
|
||||
# Graph default — this is the baseline
|
||||
},
|
||||
'haiku': {
|
||||
'pa': 'anthropic/claude-haiku-4.5',
|
||||
'expert_eras': 'anthropic/claude-haiku-4.5',
|
||||
'interpreter': 'anthropic/claude-haiku-4.5',
|
||||
},
|
||||
'gpt-4o-mini': {
|
||||
'pa': 'openai/gpt-4o-mini',
|
||||
'expert_eras': 'openai/gpt-4o-mini',
|
||||
'interpreter': 'openai/gpt-4o-mini',
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# --- API helper with model overrides ---
|
||||
|
||||
def api_chat(text: str, models: dict = None, timeout: int = 120) -> list[tuple[str, str]]:
|
||||
"""Send message via /api/chat with optional model overrides. Returns SSE events."""
|
||||
body = {'content': text}
|
||||
if models:
|
||||
body['models'] = models
|
||||
|
||||
payload = json.dumps(body).encode()
|
||||
req = urllib.request.Request(
|
||||
f'{ASSAY_BASE}/api/chat',
|
||||
data=payload,
|
||||
method='POST',
|
||||
headers={
|
||||
'Authorization': f'Bearer {SERVICE_TOKEN}',
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
)
|
||||
resp = urllib.request.urlopen(req, timeout=timeout)
|
||||
output = resp.read().decode('utf-8')
|
||||
|
||||
events = []
|
||||
for block in output.split('\n\n'):
|
||||
event_type, data = '', ''
|
||||
for line in block.strip().split('\n'):
|
||||
if line.startswith('event: '):
|
||||
event_type = line[7:]
|
||||
elif line.startswith('data: '):
|
||||
data = line[6:]
|
||||
if event_type and data:
|
||||
events.append((event_type, data))
|
||||
return events
|
||||
|
||||
|
||||
# --- Base tests (same logic, different models) ---
|
||||
|
||||
def _test_eras_query(models: dict):
|
||||
"""ERAS query produces correct SQL path (artikelposition, not geraeteverbraeuche)."""
|
||||
events = api_chat('Zeig mir die 5 groessten Kunden nach Umsatz', models=models)
|
||||
|
||||
tool_calls = [e for e in events if e[0] == 'hud' and 'tool_call' in e[1] and 'query_db' in e[1]]
|
||||
assert tool_calls, 'no query_db tool call found'
|
||||
|
||||
query_data = json.loads(tool_calls[0][1])
|
||||
query = query_data.get('args', {}).get('query', '')
|
||||
|
||||
assert 'artikelposition' in query.lower(), f'query missing artikelposition: {query}'
|
||||
assert 'geraeteverbraeuche' not in query.lower(), f'query uses wrong table: {query}'
|
||||
|
||||
# Check it completed
|
||||
done_events = [e for e in events if e[0] == 'done']
|
||||
assert done_events, 'no done event'
|
||||
|
||||
|
||||
def _test_eras_artifact(models: dict):
|
||||
"""ERAS query produces artifact with data rows."""
|
||||
events = api_chat('Zeig mir die 5 groessten Kunden nach Umsatz', models=models)
|
||||
|
||||
artifact_events = [e for e in events if e[0] == 'artifacts']
|
||||
assert artifact_events, 'no artifact event'
|
||||
|
||||
artifacts = json.loads(artifact_events[0][1]).get('artifacts', [])
|
||||
assert len(artifacts) >= 1, f'expected >=1 artifact, got {len(artifacts)}'
|
||||
|
||||
has_data = any(
|
||||
art.get('data', {}).get('fields') or art.get('data', {}).get('rows')
|
||||
for art in artifacts
|
||||
)
|
||||
assert has_data, 'no artifact contains data'
|
||||
|
||||
|
||||
def _test_social_reflex(models: dict):
|
||||
"""Social greeting takes reflex path (fast, no expert)."""
|
||||
events = api_chat('Hallo!', models=models)
|
||||
|
||||
# Should get a response (delta events)
|
||||
deltas = [e for e in events if e[0] == 'delta']
|
||||
assert deltas, 'no delta events'
|
||||
|
||||
# Should complete
|
||||
done = [e for e in events if e[0] == 'done']
|
||||
assert done, 'no done event'
|
||||
|
||||
# Should NOT call any tools
|
||||
tool_calls = [e for e in events if e[0] == 'hud' and 'tool_call' in e[1]]
|
||||
assert not tool_calls, f'reflex path should not use tools, got {len(tool_calls)} calls'
|
||||
|
||||
|
||||
# --- Test registry: base tests that get multiplied by variants ---
|
||||
|
||||
BASE_TESTS = {
|
||||
'eras_query': _test_eras_query,
|
||||
'eras_artifact': _test_eras_artifact,
|
||||
'social_reflex': _test_social_reflex,
|
||||
}
|
||||
|
||||
|
||||
def get_matrix_tests() -> dict:
|
||||
"""Generate test×variant matrix. Returns {name: callable} dict for run_tests.py."""
|
||||
tests = {}
|
||||
for variant_name, models in VARIANTS.items():
|
||||
for test_name, test_fn in BASE_TESTS.items():
|
||||
combo_name = f'{test_name}[{variant_name}]'
|
||||
# Capture current values in closure
|
||||
tests[combo_name] = (lambda fn, m: lambda: fn(m))(test_fn, models)
|
||||
return tests
|
||||
Reference in New Issue
Block a user