Skip to content

Commit 2be88b4

Browse files
committed
Enhance test executor and pipeline service functionality
- Introduced ExecutionStats dataclass to track test execution metrics. - Refactored execute_proactive_tests to return execution statistics alongside results. - Added detailed execution logging for proactive tests. - Implemented new test cases for identity context handling and classification of findings. - Enhanced pipeline service to enforce request budgets and manage test case branching.
1 parent b24be62 commit 2be88b4

4 files changed

Lines changed: 314 additions & 26 deletions

File tree

src/secnodeapi/test_executor.py

Lines changed: 63 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,43 @@
22
Test Executor engine for SecNode API Pentester.
33
Asynchronously fires HTTP requests based on generated test cases.
44
"""
5-
import time
5+
66
import asyncio
7+
import time
8+
from dataclasses import dataclass
9+
from typing import Dict, List, Optional, Tuple
10+
711
import httpx
8-
from typing import List, Dict, Optional
912
import structlog
1013

1114
from .vulnerability_models import TestCase, TestResult
1215

1316
logger = structlog.get_logger(__name__)
1417

1518

19+
@dataclass(frozen=True)
20+
class ExecutionStats:
21+
attempted: int
22+
successful_requests: int
23+
failed_requests: int
24+
25+
1626
async def _execute_single_test(
1727
client: httpx.AsyncClient,
1828
test_case: TestCase,
1929
base_url: str,
2030
auth_headers: Dict[str, str],
21-
semaphore: asyncio.Semaphore
31+
semaphore: asyncio.Semaphore,
2232
) -> Optional[TestResult]:
2333
"""Execute a single test case with retry logic and concurrency control."""
2434
max_retries = 3
2535
method = test_case.method.upper()
2636
url = f"{base_url.rstrip('/')}/{test_case.endpoint.lstrip('/')}"
27-
37+
2838
headers = {**test_case.headers, **auth_headers}
2939
params = test_case.params
3040
json_body = test_case.body if isinstance(test_case.body, dict) else None
31-
41+
3242
async with semaphore:
3343
for attempt in range(max_retries):
3444
start_time = time.time()
@@ -39,32 +49,34 @@ async def _execute_single_test(
3949
headers=headers,
4050
params=params,
4151
json=json_body,
42-
timeout=10.0
52+
timeout=10.0,
4353
)
44-
54+
4555
duration_ms = (time.time() - start_time) * 1000
46-
56+
4757
return TestResult(
4858
test_case=test_case,
4959
status_code=response.status_code,
5060
response_body=response.text[:1000] if response.text else "EMPTY",
5161
response_headers=dict(response.headers),
5262
request_url=str(response.request.url),
5363
request_headers=dict(response.request.headers),
54-
request_body=response.request.content.decode('utf-8')[:500] if response.request.content else None,
55-
response_time_ms=duration_ms
64+
request_body=response.request.content.decode("utf-8")[:500]
65+
if response.request.content
66+
else None,
67+
response_time_ms=duration_ms,
5668
)
5769
except (httpx.RequestError, httpx.TimeoutException) as e:
5870
logger.warning(
5971
"Request failed, retrying",
6072
test_id=test_case.id,
6173
attempt=attempt + 1,
62-
error=str(e)
74+
error=str(e),
6375
)
6476
if attempt == max_retries - 1:
6577
logger.error("Max retries reached", test_id=test_case.id)
6678
return None
67-
await asyncio.sleep(2 ** attempt) # Exponential backoff
79+
await asyncio.sleep(2**attempt) # Exponential backoff
6880
return None
6981

7082

@@ -75,29 +87,57 @@ async def execute_proactive_tests(
7587
auth_headers: Optional[Dict[str, str]] = None,
7688
proxy: Optional[str] = None,
7789
verify_ssl: bool = True,
90+
max_requests: Optional[int] = None,
7891
) -> List[TestResult]:
7992
"""Execute generated test cases concurrently."""
80-
logger.info("Executing proactive tests", total=len(test_cases), concurrency=concurrency)
81-
82-
results = []
93+
results, _ = await execute_proactive_tests_detailed(
94+
test_cases=test_cases,
95+
base_url=base_url,
96+
concurrency=concurrency,
97+
auth_headers=auth_headers,
98+
proxy=proxy,
99+
verify_ssl=verify_ssl,
100+
max_requests=max_requests,
101+
)
102+
return results
103+
104+
105+
async def execute_proactive_tests_detailed(
106+
test_cases: List[TestCase],
107+
base_url: str,
108+
concurrency: int = 5,
109+
auth_headers: Optional[Dict[str, str]] = None,
110+
proxy: Optional[str] = None,
111+
verify_ssl: bool = True,
112+
max_requests: Optional[int] = None,
113+
) -> Tuple[List[TestResult], ExecutionStats]:
114+
"""Execute generated tests and return both results and execution stats."""
115+
selected_cases = test_cases[:max_requests] if max_requests is not None else test_cases
116+
logger.info("Executing proactive tests", total=len(selected_cases), concurrency=concurrency)
117+
118+
results: List[TestResult] = []
83119
auth_headers = auth_headers or {}
84120
semaphore = asyncio.Semaphore(concurrency)
85-
86-
# We use a single AsyncClient for connection pooling optimizations
121+
122+
# We use a single AsyncClient for connection pooling optimizations.
87123
async with httpx.AsyncClient(verify=verify_ssl, proxy=proxy) as client:
88124
tasks = [
89125
_execute_single_test(client, tc, base_url, auth_headers, semaphore)
90-
for tc in test_cases
126+
for tc in selected_cases
91127
]
92-
93-
# gather and filter out any None from failures
128+
94129
outcomes = await asyncio.gather(*tasks, return_exceptions=True)
95-
96-
for idx, outcome in enumerate(outcomes):
130+
131+
for outcome in outcomes:
97132
if isinstance(outcome, Exception):
98133
logger.error("Uncaught exception in test executor worker", exc_info=outcome)
99134
elif outcome is not None:
100135
results.append(outcome)
101136

102137
logger.info("Finished executing tests", successful_requests=len(results))
103-
return results
138+
stats = ExecutionStats(
139+
attempted=len(selected_cases),
140+
successful_requests=len(results),
141+
failed_requests=max(0, len(selected_cases) - len(results)),
142+
)
143+
return results, stats

tests/test_ai_modules.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from secnodeapi.ai.generate import generate_test_cases
44
from secnodeapi.ai.understand import understand_api_with_ai
5-
from secnodeapi.ai.validate import validate_findings_with_ai
5+
from secnodeapi.ai.validate import classify_findings, validate_findings_with_ai
66
from secnodeapi.vulnerability_models import (
77
APIUnderstanding,
88
SchemaStructure,
@@ -86,3 +86,69 @@ async def fake_call_llm(*args, **kwargs):
8686
findings = await validate_findings_with_ai([result])
8787
assert len(findings) == 1
8888
assert findings[0].test_case_id == "T-1"
89+
90+
91+
@pytest.mark.asyncio
92+
async def test_classify_findings_deterministic_validator_no_llm(monkeypatch) -> None:
93+
test_case = TestCase(
94+
id="T-DET-1",
95+
name="BOLA object access",
96+
description="idor",
97+
owasp_category="API1: Broken Object Level Authorization",
98+
endpoint="/users/2",
99+
method="GET",
100+
)
101+
result = TestResult(
102+
test_case=test_case,
103+
status_code=200,
104+
response_body='{"id":2}',
105+
response_headers={},
106+
request_url="https://api.example.com/users/2",
107+
request_headers={},
108+
request_body=None,
109+
response_time_ms=8.0,
110+
)
111+
112+
async def fail_if_called(*args, **kwargs):
113+
raise AssertionError("LLM should not be called for deterministic finding")
114+
115+
monkeypatch.setattr("secnodeapi.ai.validate.call_llm", fail_if_called)
116+
confirmed, suspected = await classify_findings([result])
117+
assert len(confirmed) == 1
118+
assert confirmed[0].validation_source == "deterministic"
119+
assert suspected == []
120+
121+
122+
@pytest.mark.asyncio
123+
async def test_classify_findings_low_confidence_goes_to_suspected(monkeypatch) -> None:
124+
test_case = TestCase(
125+
id="T-SUS-1",
126+
name="custom logic probe",
127+
description="custom category",
128+
owasp_category="BIZ-LOGIC",
129+
endpoint="/orders",
130+
method="GET",
131+
)
132+
result = TestResult(
133+
test_case=test_case,
134+
status_code=200,
135+
response_body="ok",
136+
response_headers={},
137+
request_url="https://api.example.com/orders",
138+
request_headers={},
139+
request_body=None,
140+
response_time_ms=10.0,
141+
)
142+
143+
async def fake_call_llm(*args, **kwargs):
144+
return (
145+
'{"analysis":"weak signal","is_vulnerable":true,"cvss_score":5.0,'
146+
'"cvss_vector":"CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:L/I:L/A:N",'
147+
'"description":"possible issue","remediation":"review","confidence":0.62}'
148+
)
149+
150+
monkeypatch.setattr("secnodeapi.ai.validate.call_llm", fake_call_llm)
151+
confirmed, suspected = await classify_findings([result])
152+
assert confirmed == []
153+
assert len(suspected) == 1
154+
assert suspected[0].validation_source == "ai-suspected"

tests/test_cli_and_config.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,24 @@ def test_parse_auth_file(tmp_path: Path) -> None:
3434
assert headers == {"X-Key": "abc"}
3535

3636

37+
def test_parse_identities_file(tmp_path: Path) -> None:
38+
identities_file = tmp_path / "identities.json"
39+
identities_file.write_text(
40+
json.dumps(
41+
{
42+
"identities": [
43+
{"name": "anon", "headers": {}},
44+
{"name": "user", "headers": {"Authorization": "Bearer x"}},
45+
]
46+
}
47+
),
48+
encoding="utf-8",
49+
)
50+
identities = cli.parse_identities(str(identities_file))
51+
assert len(identities) == 2
52+
assert identities[1].name == "user"
53+
54+
3755
def test_require_provider_key_raises(monkeypatch: pytest.MonkeyPatch) -> None:
3856
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
3957
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
@@ -59,6 +77,7 @@ def test_parse_args_with_dry_run_output(monkeypatch: pytest.MonkeyPatch) -> None
5977
assert args.target == "https://api.example.com/swagger.json"
6078
assert args.dry_run is True
6179
assert args.dry_run_output == "out.json"
80+
assert args.mode == "agent"
6281

6382

6483
def test_parse_args_requires_dry_run_for_output(monkeypatch: pytest.MonkeyPatch) -> None:

0 commit comments

Comments
 (0)