Add --parallel=N for concurrent test execution
- run_tests.py: ThreadPoolExecutor runs N tests concurrently within a suite - Each testcase has its own session_id so parallel is safe - Engine tests: fixed asyncio.new_event_loop() for thread safety - Usage: python tests/run_tests.py testcases --parallel=3 - Wall time reduction: ~3x for testcases (15min → 5min with parallel=3) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
c21ff08211
commit
d8e832d2d4
@ -7,6 +7,7 @@ Usage:
|
||||
python tests/run_tests.py api # one suite
|
||||
python tests/run_tests.py matrix/eras_query[haiku] # single test
|
||||
python tests/run_tests.py matrix --repeat=3 # each test 3x, report avg/p50/p95
|
||||
python tests/run_tests.py testcases --parallel=3 # 3 testcases concurrently
|
||||
python tests/run_tests.py api/health roundtrip/full_chat # multiple tests
|
||||
|
||||
Test names: suite/name (without the suite prefix in the test registry).
|
||||
@ -26,6 +27,7 @@ import sys
|
||||
import time
|
||||
import urllib.request
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from datetime import datetime, timezone
|
||||
from dataclasses import dataclass, field, asdict
|
||||
|
||||
@ -186,17 +188,19 @@ def run_test_repeated(name: str, suite: str, fn, repeat: int) -> TestResult:
|
||||
return result
|
||||
|
||||
|
||||
def parse_args(args: list[str]) -> tuple[set[str] | None, set[str], int]:
|
||||
"""Parse CLI args into (suite_filter, test_filter, repeat).
|
||||
def parse_args(args: list[str]) -> tuple[set[str] | None, set[str], int, int]:
|
||||
"""Parse CLI args into (suite_filter, test_filter, repeat, parallel).
|
||||
|
||||
Supports: --repeat=N or --repeat N
|
||||
Supports: --repeat=N, --parallel=N
|
||||
|
||||
Returns:
|
||||
suite_filter: set of suite names, or None for all suites
|
||||
test_filter: set of 'suite/test' names (empty = run all in suite)
|
||||
repeat: number of times to run each test (default 1)
|
||||
parallel: max concurrent tests (default 1 = sequential)
|
||||
"""
|
||||
repeat = 1
|
||||
parallel = 1
|
||||
filtered_args = []
|
||||
skip_next = False
|
||||
for i, arg in enumerate(args):
|
||||
@ -208,11 +212,16 @@ def parse_args(args: list[str]) -> tuple[set[str] | None, set[str], int]:
|
||||
elif arg == '--repeat' and i + 1 < len(args):
|
||||
repeat = int(args[i + 1])
|
||||
skip_next = True
|
||||
elif arg.startswith('--parallel='):
|
||||
parallel = int(arg.split('=', 1)[1])
|
||||
elif arg == '--parallel' and i + 1 < len(args):
|
||||
parallel = int(args[i + 1])
|
||||
skip_next = True
|
||||
else:
|
||||
filtered_args.append(arg)
|
||||
|
||||
if not filtered_args:
|
||||
return None, set(), repeat
|
||||
return None, set(), repeat, parallel
|
||||
|
||||
suites = set()
|
||||
tests = set()
|
||||
@ -222,47 +231,85 @@ def parse_args(args: list[str]) -> tuple[set[str] | None, set[str], int]:
|
||||
suites.add(arg.split('/')[0])
|
||||
else:
|
||||
suites.add(arg)
|
||||
return suites, tests, repeat
|
||||
return suites, tests, repeat, parallel
|
||||
|
||||
|
||||
def _run_one(name: str, suite_name: str, fn, repeat: int) -> TestResult:
|
||||
"""Run a single test (with optional repeat). Thread-safe."""
|
||||
if repeat > 1:
|
||||
return run_test_repeated(name, suite_name, fn, repeat)
|
||||
return run_test(name, suite_name, fn)
|
||||
|
||||
|
||||
def _print_result(suite_name: str, name: str, r: TestResult, repeat: int):
|
||||
"""Print a test result line."""
|
||||
status = 'PASS' if r.status == 'pass' else 'FAIL'
|
||||
if repeat > 1:
|
||||
stats = r.stats
|
||||
print(f' [{status}] {suite_name}/{name} ×{repeat} '
|
||||
f'(avg={stats.get("avg_ms", 0)}ms p50={stats.get("p50_ms", 0)}ms '
|
||||
f'p95={stats.get("p95_ms", 0)}ms pass={stats.get("pass_rate", 0)}%)', flush=True)
|
||||
else:
|
||||
print(f' [{status}] {suite_name}/{name} ({r.duration_ms:.0f}ms)', flush=True)
|
||||
if r.error and repeat == 1:
|
||||
print(f' {r.error[:200]}', flush=True)
|
||||
|
||||
|
||||
def run_suite(suite_name: str, tests: dict, test_filter: set[str],
|
||||
repeat: int = 1) -> list[TestResult]:
|
||||
"""Run tests from a suite, optionally filtered and repeated."""
|
||||
results = []
|
||||
repeat: int = 1, parallel: int = 1) -> list[TestResult]:
|
||||
"""Run tests from a suite, optionally filtered, repeated, and parallelized."""
|
||||
# Build filtered test list
|
||||
filtered = []
|
||||
for name, fn in tests.items():
|
||||
# Apply test filter if specified
|
||||
full_name = f'{suite_name}/{name}'
|
||||
# Strip suite prefix for matching (roundtrip/full_eras matches roundtrip_full_eras)
|
||||
short_name = name.replace(f'{suite_name}_', '')
|
||||
if test_filter and full_name not in test_filter and f'{suite_name}/{short_name}' not in test_filter:
|
||||
continue
|
||||
filtered.append((name, fn))
|
||||
|
||||
if repeat > 1:
|
||||
r = run_test_repeated(name, suite_name, fn, repeat)
|
||||
status = 'PASS' if r.status == 'pass' else 'FAIL'
|
||||
stats = r.stats
|
||||
print(f' [{status}] {suite_name}/{name} ×{repeat} '
|
||||
f'(avg={stats.get("avg_ms", 0)}ms p50={stats.get("p50_ms", 0)}ms '
|
||||
f'p95={stats.get("p95_ms", 0)}ms pass={stats.get("pass_rate", 0)}%)', flush=True)
|
||||
else:
|
||||
r = run_test(name, suite_name, fn)
|
||||
status = 'PASS' if r.status == 'pass' else 'FAIL'
|
||||
print(f' [{status}] {suite_name}/{name} ({r.duration_ms:.0f}ms)', flush=True)
|
||||
if not filtered:
|
||||
return []
|
||||
|
||||
# Sequential execution
|
||||
if parallel <= 1 or len(filtered) <= 1:
|
||||
results = []
|
||||
for name, fn in filtered:
|
||||
r = _run_one(name, suite_name, fn, repeat)
|
||||
_print_result(suite_name, name, r, repeat)
|
||||
results.append(r)
|
||||
return results
|
||||
|
||||
# Parallel execution
|
||||
results = []
|
||||
with ThreadPoolExecutor(max_workers=parallel) as pool:
|
||||
futures = {}
|
||||
for name, fn in filtered:
|
||||
f = pool.submit(_run_one, name, suite_name, fn, repeat)
|
||||
futures[f] = name
|
||||
|
||||
for future in as_completed(futures):
|
||||
name = futures[future]
|
||||
try:
|
||||
r = future.result()
|
||||
except Exception as e:
|
||||
r = TestResult(run_id=RUN_ID, test=name, suite=suite_name,
|
||||
status='error', error=f'ThreadError: {e}', ts=_now_iso())
|
||||
_print_result(suite_name, name, r, repeat)
|
||||
results.append(r)
|
||||
|
||||
results.append(r)
|
||||
if r.error and repeat == 1:
|
||||
print(f' {r.error[:200]}', flush=True)
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
suite_filter, test_filter, repeat = parse_args(sys.argv[1:])
|
||||
suite_filter, test_filter, repeat, parallel = parse_args(sys.argv[1:])
|
||||
|
||||
print(f'=== Test Run {RUN_ID} ===', flush=True)
|
||||
if suite_filter:
|
||||
print(f'Filter: suites={suite_filter}, tests={test_filter or "all"}', flush=True)
|
||||
if repeat > 1:
|
||||
print(f'Repeat: {repeat}x per test', flush=True)
|
||||
if parallel > 1:
|
||||
print(f'Parallel: {parallel} concurrent tests', flush=True)
|
||||
print(f'ASSAY_API: {os.environ.get("ASSAY_API", "not set")}', flush=True)
|
||||
print(f'NYX_URL: {os.environ.get("NYX_URL", "not set")}', flush=True)
|
||||
print(flush=True)
|
||||
@ -272,9 +319,14 @@ def main():
|
||||
for suite_name, loader in SUITES.items():
|
||||
if suite_filter and suite_name not in suite_filter:
|
||||
continue
|
||||
print(f'--- {suite_name}{" ×" + str(repeat) if repeat > 1 else ""} ---', flush=True)
|
||||
label = suite_name
|
||||
if repeat > 1:
|
||||
label += f' ×{repeat}'
|
||||
if parallel > 1:
|
||||
label += f' ∥{parallel}'
|
||||
print(f'--- {label} ---', flush=True)
|
||||
tests = loader()
|
||||
all_results.extend(run_suite(suite_name, tests, test_filter, repeat))
|
||||
all_results.extend(run_suite(suite_name, tests, test_filter, repeat, parallel))
|
||||
print(flush=True)
|
||||
|
||||
# Summary
|
||||
|
||||
@ -407,7 +407,7 @@ def test_frame_trace_reflex():
|
||||
}
|
||||
engine, sink, hud = make_frame_engine(nodes, "v4-eras")
|
||||
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
result = asyncio.new_event_loop().run_until_complete(
|
||||
engine.process_message("hello")
|
||||
)
|
||||
|
||||
@ -434,7 +434,7 @@ def test_frame_trace_expert():
|
||||
}
|
||||
engine, sink, hud = make_frame_engine(nodes, "v4-eras")
|
||||
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
result = asyncio.new_event_loop().run_until_complete(
|
||||
engine.process_message("show top customers")
|
||||
)
|
||||
|
||||
@ -465,7 +465,7 @@ def test_frame_trace_expert_with_interpreter():
|
||||
}
|
||||
engine, sink, hud = make_frame_engine(nodes, "v4-eras")
|
||||
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
result = asyncio.new_event_loop().run_until_complete(
|
||||
engine.process_message("show customer revenue")
|
||||
)
|
||||
|
||||
@ -519,7 +519,7 @@ def test_model_override_per_request():
|
||||
engine, sink, hud = make_frame_engine(nodes, "v4-eras")
|
||||
|
||||
# process_message should accept model_overrides param
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
result = asyncio.new_event_loop().run_until_complete(
|
||||
engine.process_message("hello", model_overrides={"input": "test/fast-model"})
|
||||
)
|
||||
# Should complete without error (overrides applied internally)
|
||||
@ -588,7 +588,7 @@ def test_contextvar_hud_isolation():
|
||||
async def run_both():
|
||||
await asyncio.gather(task_a(), task_b())
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(run_both())
|
||||
asyncio.new_event_loop().run_until_complete(run_both())
|
||||
|
||||
assert len(results_a) == 1 and results_a[0]["from"] == "a", \
|
||||
f"task_a HUD leaked: {results_a}"
|
||||
|
||||
Reference in New Issue
Block a user