Skip to content

Commit 64e6c35

Browse files
committed
Refactor budget clipping logic in pipeline service
- Updated `_clip_for_budget` function to return both selected test cases and remaining unselected cases for better budget management. - Enhanced documentation for the function to clarify its purpose and return values. - Adjusted the `run_agent_pipeline` function to accommodate the new return structure, ensuring proper handling of remaining cases.
1 parent 2be88b4 commit 64e6c35

2 files changed

Lines changed: 117 additions & 5 deletions

File tree

src/secnodeapi/services/pipeline.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -246,18 +246,28 @@ def _clip_for_budget(
246246
test_cases: List[TestCase],
247247
remaining_budget: int,
248248
per_endpoint_budget: int,
249-
) -> List[TestCase]:
249+
) -> Tuple[List[TestCase], List[TestCase]]:
250+
"""
251+
Select a batch constrained by global and per-endpoint budgets.
252+
253+
Returns both:
254+
- selected cases to execute now
255+
- remaining queue preserving unselected cases for later iterations
256+
"""
250257
clipped: List[TestCase] = []
258+
remaining: List[TestCase] = []
251259
endpoint_counts: Dict[str, int] = {}
252260
for case in test_cases:
253261
if len(clipped) >= remaining_budget:
254-
break
262+
remaining.append(case)
263+
continue
255264
endpoint_count = endpoint_counts.get(case.endpoint, 0)
256265
if endpoint_count >= per_endpoint_budget:
266+
remaining.append(case)
257267
continue
258268
endpoint_counts[case.endpoint] = endpoint_count + 1
259269
clipped.append(case)
260-
return clipped
270+
return clipped, remaining
261271

262272

263273
def _merge_unique_findings(existing: List[Finding], new_findings: List[Finding]) -> List[Finding]:
@@ -294,10 +304,11 @@ async def run_agent_pipeline(
294304

295305
while queue and remaining_budget > 0 and iteration < pipeline_input.max_iterations:
296306
iteration += 1
297-
batch = _clip_for_budget(queue, remaining_budget, pipeline_input.per_endpoint_budget)
307+
batch, queue = _clip_for_budget(
308+
queue, remaining_budget, pipeline_input.per_endpoint_budget
309+
)
298310
if not batch:
299311
break
300-
queue = queue[len(batch) :]
301312

302313
results, stats = await execute_proactive_tests_detailed(
303314
test_cases=batch,

tests/test_pipeline_service.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,104 @@ async def fake_classify(results):
167167
assert metrics["iterations"] == 2
168168
assert len(confirmed) >= 1
169169
assert any("CHAIN-BOLA" in test_id for batch in observed_batches[1:] for test_id in batch)
170+
171+
172+
@pytest.mark.asyncio
173+
async def test_run_agent_pipeline_preserves_skipped_cases_without_reexecution(monkeypatch) -> None:
174+
structure = SchemaStructure(
175+
title="Queue API",
176+
version="1.0",
177+
base_url="https://api.example.com",
178+
endpoints=[APIEndpoint(path="/a", method="GET"), APIEndpoint(path="/b", method="GET")],
179+
auth_schemes={},
180+
)
181+
seed_cases = [
182+
TestCase(
183+
id="A-1",
184+
name="a1",
185+
description="",
186+
owasp_category="API9",
187+
endpoint="/a",
188+
method="GET",
189+
params={"variant": 1},
190+
),
191+
TestCase(
192+
id="A-2",
193+
name="a2",
194+
description="",
195+
owasp_category="API9",
196+
endpoint="/a",
197+
method="GET",
198+
params={"variant": 2},
199+
),
200+
TestCase(
201+
id="A-3",
202+
name="a3",
203+
description="",
204+
owasp_category="API9",
205+
endpoint="/a",
206+
method="GET",
207+
params={"variant": 3},
208+
),
209+
TestCase(
210+
id="B-1",
211+
name="b1",
212+
description="",
213+
owasp_category="API9",
214+
endpoint="/b",
215+
method="GET",
216+
),
217+
]
218+
219+
executed_ids = []
220+
221+
async def fake_artifacts(_):
222+
return structure, seed_cases
223+
224+
async def fake_execute(**kwargs):
225+
cases = kwargs["test_cases"]
226+
executed_ids.extend(case.id for case in cases)
227+
results = [
228+
TestResult(
229+
test_case=case,
230+
status_code=200,
231+
response_body='{"ok": true}',
232+
response_headers={},
233+
request_url=f"https://api.example.com{case.endpoint}",
234+
request_headers={},
235+
request_body=None,
236+
response_time_ms=1.0,
237+
)
238+
for case in cases
239+
]
240+
stats = ExecutionStats(
241+
attempted=len(cases),
242+
successful_requests=len(cases),
243+
failed_requests=0,
244+
)
245+
return results, stats
246+
247+
async def fake_classify(_):
248+
return [], []
249+
250+
monkeypatch.setattr("secnodeapi.services.pipeline.build_pipeline_artifacts", fake_artifacts)
251+
monkeypatch.setattr("secnodeapi.services.pipeline.execute_proactive_tests_detailed", fake_execute)
252+
monkeypatch.setattr("secnodeapi.services.pipeline.classify_findings", fake_classify)
253+
monkeypatch.setattr("secnodeapi.services.pipeline._build_discovery_tests", lambda _: [])
254+
monkeypatch.setattr("secnodeapi.services.pipeline._build_chain_tests", lambda *_: [])
255+
256+
pipeline_input = pipeline.PipelineInput(
257+
target="https://api.example.com/swagger.json",
258+
concurrency=2,
259+
auth_headers={},
260+
proxy=None,
261+
verify_ssl=True,
262+
request_budget=6,
263+
per_endpoint_budget=2,
264+
max_iterations=2,
265+
)
266+
267+
_, _, _, metrics = await pipeline.run_agent_pipeline(pipeline_input)
268+
assert metrics["iterations"] == 2
269+
assert executed_ids.count("B-1::default") == 1
270+
assert "A-3::default" in executed_ids

0 commit comments

Comments
 (0)