|
1 | 1 | """ |
2 | 2 | AI validation stage for filtering true-positive findings. |
3 | 3 | """ |
| 4 | + |
4 | 5 | import asyncio |
5 | 6 | import json |
6 | | -from typing import List, Optional |
| 7 | +from typing import List, Optional, Tuple |
7 | 8 |
|
8 | 9 | import structlog |
| 10 | +from pydantic import BaseModel, ValidationError |
9 | 11 |
|
10 | 12 | from ..vulnerability_models import Finding, TestResult |
11 | | - |
12 | 13 | from .llm_client import call_llm |
13 | 14 |
|
14 | 15 | logger = structlog.get_logger(__name__) |
15 | 16 |
|
16 | 17 |
|
17 | | -async def _evaluate_single_result(result: TestResult) -> Optional[Finding]: |
18 | | - """Evaluate one test result and return a confirmed finding when applicable.""" |
19 | | - sys_prompt = ( |
20 | | - "You are a CISO and Elite AppSec Triager reviewing penetration test results. Your job is to definitively determine " |
21 | | - "if an executed test case reveals a true positive vulnerability, or if it is a false positive / expected behavior. " |
22 | | - "Apply strict heuristics:\n" |
23 | | - "- 401 Unauthorized / 403 Forbidden is usually expected security behavior (NOT a vulnerability).\n" |
24 | | - "- 500 Internal Server Error reveals a lack of robustness, potentially a DoS or injection vuln, but requires context.\n" |
25 | | - "- 200/201 OK on an endpoint that shouldn't grant access (e.g. a BOLA or mass assignment test) is a highly probable vulnerability.\n" |
26 | | - "- If a Rate Limit test returns 200 OK after 100 requests, RATE LIMITING IS BROKEN.\n\n" |
27 | | - "You MUST perform a chain-of-thought analysis before concluding.\n" |
28 | | - "Output ONLY JSON with this precise schema:\n" |
29 | | - '{"analysis": "str (your thought process)", "is_vulnerable": bool, "cvss_score": float, "cvss_vector": "str", ' |
30 | | - '"description": "str", "remediation": "str", "confidence": float}' |
| 18 | +class AIValidationPayload(BaseModel): |
| 19 | + analysis: str |
| 20 | + is_vulnerable: bool |
| 21 | + cvss_score: float |
| 22 | + cvss_vector: str |
| 23 | + description: str |
| 24 | + remediation: str |
| 25 | + confidence: float |
| 26 | + |
| 27 | + |
| 28 | +def _class_keywords(result: TestResult) -> str: |
| 29 | + return " ".join( |
| 30 | + [ |
| 31 | + result.test_case.owasp_category.lower(), |
| 32 | + result.test_case.name.lower(), |
| 33 | + result.test_case.description.lower(), |
| 34 | + ] |
31 | 35 | ) |
32 | 36 |
|
33 | | - user_prompt = f"Test Result Context:\n{result.model_dump_json(indent=2)}" |
34 | 37 |
|
35 | | - try: |
36 | | - llm_resp = await call_llm(sys_prompt, user_prompt, temperature=0.1) |
37 | | - data = json.loads(llm_resp) |
| 38 | +def _deterministic_validate_result(result: TestResult) -> Optional[Finding]: |
| 39 | + """Deterministic validators for high-value classes to reduce hallucinations.""" |
| 40 | + category_text = _class_keywords(result) |
| 41 | + status = result.status_code |
| 42 | + is_2xx = 200 <= status < 300 |
| 43 | + req_snippet = f"{result.test_case.method} {result.request_url}" |
| 44 | + resp_snippet = result.response_body[:500] |
38 | 45 |
|
39 | | - if data.get("is_vulnerable") and data.get("confidence", 0.0) >= 0.75: |
40 | | - req_snippet = f"{result.test_case.method} {result.request_url}" |
41 | | - resp_snippet = result.response_body[:500] |
| 46 | + bola_like = any( |
| 47 | + keyword in category_text |
| 48 | + for keyword in ("bola", "idor", "api1", "bfla", "api5", "broken object", "broken function") |
| 49 | + ) |
| 50 | + mass_assignment_like = any( |
| 51 | + keyword in category_text for keyword in ("mass assignment", "bopla", "api3") |
| 52 | + ) |
| 53 | + rate_limit_like = any( |
| 54 | + keyword in category_text for keyword in ("rate limit", "ratelimit", "api4") |
| 55 | + ) |
| 56 | + |
| 57 | + if bola_like and is_2xx: |
| 58 | + return Finding( |
| 59 | + test_case_id=result.test_case.id, |
| 60 | + endpoint=result.test_case.endpoint, |
| 61 | + method=result.test_case.method, |
| 62 | + vulnerability_class=result.test_case.owasp_category, |
| 63 | + cvss_score=8.2, |
| 64 | + cvss_vector="CVSS:3.1/AV:N/AC:L/PR:L/UI:N/S:U/C:H/I:H/A:N", |
| 65 | + description=( |
| 66 | + "Deterministic validator observed successful access on an authorization-focused test." |
| 67 | + ), |
| 68 | + remediation="Enforce object and function authorization checks server-side.", |
| 69 | + confidence=0.9, |
| 70 | + evidence_request=req_snippet, |
| 71 | + evidence_response=resp_snippet, |
| 72 | + validation_source="deterministic", |
| 73 | + identity=result.test_case.identity, |
| 74 | + ) |
42 | 75 |
|
| 76 | + if mass_assignment_like and is_2xx and isinstance(result.test_case.body, dict): |
| 77 | + protected_fields = {"is_admin", "role", "credit_balance", "permissions"} |
| 78 | + if protected_fields.intersection(set(result.test_case.body.keys())): |
43 | 79 | return Finding( |
44 | 80 | test_case_id=result.test_case.id, |
45 | 81 | endpoint=result.test_case.endpoint, |
46 | 82 | method=result.test_case.method, |
47 | 83 | vulnerability_class=result.test_case.owasp_category, |
48 | | - cvss_score=data.get("cvss_score", 0.0), |
49 | | - cvss_vector=data.get("cvss_vector", "CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:L/I:N/A:N"), |
50 | | - description=data.get("description", "Vulnerability detected."), |
51 | | - remediation=data.get("remediation", "Review endpoint authorization."), |
52 | | - confidence=data.get("confidence", 0.8), |
| 84 | + cvss_score=8.0, |
| 85 | + cvss_vector="CVSS:3.1/AV:N/AC:L/PR:L/UI:N/S:U/C:H/I:H/A:N", |
| 86 | + description=( |
| 87 | + "Deterministic validator observed accepted protected-field mutation payload." |
| 88 | + ), |
| 89 | + remediation="Allowlist writable fields and reject privileged attributes at API boundary.", |
| 90 | + confidence=0.88, |
53 | 91 | evidence_request=req_snippet, |
54 | 92 | evidence_response=resp_snippet, |
| 93 | + validation_source="deterministic", |
| 94 | + identity=result.test_case.identity, |
55 | 95 | ) |
| 96 | + |
| 97 | + if rate_limit_like and is_2xx: |
| 98 | + return Finding( |
| 99 | + test_case_id=result.test_case.id, |
| 100 | + endpoint=result.test_case.endpoint, |
| 101 | + method=result.test_case.method, |
| 102 | + vulnerability_class=result.test_case.owasp_category, |
| 103 | + cvss_score=6.8, |
| 104 | + cvss_vector="CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:N/I:N/A:H", |
| 105 | + description=( |
| 106 | + "Rate-limit-focused test continued receiving successful responses without clear throttling." |
| 107 | + ), |
| 108 | + remediation="Introduce per-principal and per-IP rate limits with enforced backoff.", |
| 109 | + confidence=0.8, |
| 110 | + evidence_request=req_snippet, |
| 111 | + evidence_response=resp_snippet, |
| 112 | + validation_source="deterministic", |
| 113 | + identity=result.test_case.identity, |
| 114 | + ) |
| 115 | + |
| 116 | + return None |
| 117 | + |
| 118 | + |
| 119 | +def _build_finding_from_ai(result: TestResult, payload: AIValidationPayload, source: str) -> Finding: |
| 120 | + req_snippet = f"{result.test_case.method} {result.request_url}" |
| 121 | + resp_snippet = result.response_body[:500] |
| 122 | + return Finding( |
| 123 | + test_case_id=result.test_case.id, |
| 124 | + endpoint=result.test_case.endpoint, |
| 125 | + method=result.test_case.method, |
| 126 | + vulnerability_class=result.test_case.owasp_category, |
| 127 | + cvss_score=payload.cvss_score, |
| 128 | + cvss_vector=payload.cvss_vector, |
| 129 | + description=payload.description, |
| 130 | + remediation=payload.remediation, |
| 131 | + confidence=payload.confidence, |
| 132 | + evidence_request=req_snippet, |
| 133 | + evidence_response=resp_snippet, |
| 134 | + validation_source=source, |
| 135 | + identity=result.test_case.identity, |
| 136 | + ) |
| 137 | + |
| 138 | + |
| 139 | +async def _evaluate_single_result(result: TestResult) -> Tuple[Optional[Finding], Optional[Finding]]: |
| 140 | + """Evaluate one test result and return (confirmed, suspected).""" |
| 141 | + deterministic = _deterministic_validate_result(result) |
| 142 | + if deterministic is not None: |
| 143 | + return deterministic, None |
| 144 | + |
| 145 | + sys_prompt = ( |
| 146 | + "You are a senior AppSec triager validating API pentest results. " |
| 147 | + "Return ONLY valid JSON with exact schema:\n" |
| 148 | + '{"analysis":"str","is_vulnerable":bool,"cvss_score":float,"cvss_vector":"str",' |
| 149 | + '"description":"str","remediation":"str","confidence":float}\n' |
| 150 | + "Never return markdown or extra text." |
| 151 | + ) |
| 152 | + user_prompt = f"Test Result Context:\n{result.model_dump_json(indent=2)}" |
| 153 | + |
| 154 | + try: |
| 155 | + llm_resp = await call_llm(sys_prompt, user_prompt, temperature=0.1) |
| 156 | + payload = AIValidationPayload.model_validate(json.loads(llm_resp)) |
| 157 | + if not payload.is_vulnerable: |
| 158 | + return None, None |
| 159 | + if payload.confidence >= 0.75: |
| 160 | + return _build_finding_from_ai(result, payload, "ai"), None |
| 161 | + return None, _build_finding_from_ai(result, payload, "ai-suspected") |
| 162 | + except (json.JSONDecodeError, ValidationError) as e: |
| 163 | + logger.warning( |
| 164 | + "AI output schema validation failed", |
| 165 | + error=str(e), |
| 166 | + test_id=result.test_case.id, |
| 167 | + ) |
56 | 168 | except Exception as e: |
57 | 169 | logger.warning("Failed validation evaluation", error=str(e), test_id=result.test_case.id) |
58 | | - return None |
| 170 | + return None, None |
59 | 171 |
|
60 | 172 |
|
61 | | -async def validate_findings_with_ai(results: List[TestResult]) -> List[Finding]: |
62 | | - """Replay AI analysis over execution results to confirm findings.""" |
| 173 | +async def classify_findings(results: List[TestResult]) -> Tuple[List[Finding], List[Finding]]: |
| 174 | + """Classify findings into deterministic/AI-confirmed and suspected buckets.""" |
63 | 175 | logger.info("Validating execution results with AI", count=len(results)) |
| 176 | + confirmed: List[Finding] = [] |
| 177 | + suspected: List[Finding] = [] |
64 | 178 |
|
65 | | - findings: List[Finding] = [] |
66 | 179 | batch_size = 5 |
67 | 180 | for i in range(0, len(results), batch_size): |
68 | 181 | batch = results[i : i + batch_size] |
69 | 182 | tasks = [_evaluate_single_result(result) for result in batch] |
70 | | - batch_findings = await asyncio.gather(*tasks, return_exceptions=True) |
| 183 | + batch_outcomes = await asyncio.gather(*tasks, return_exceptions=True) |
| 184 | + |
| 185 | + for outcome in batch_outcomes: |
| 186 | + if isinstance(outcome, Exception): |
| 187 | + logger.error("Error evaluating result", exc_info=outcome) |
| 188 | + continue |
| 189 | + confirmed_finding, suspected_finding = outcome |
| 190 | + if confirmed_finding is not None: |
| 191 | + confirmed.append(confirmed_finding) |
| 192 | + if suspected_finding is not None: |
| 193 | + suspected.append(suspected_finding) |
71 | 194 |
|
72 | | - for finding in batch_findings: |
73 | | - if isinstance(finding, Exception): |
74 | | - logger.error("Error evaluating result", exc_info=finding) |
75 | | - elif finding is not None: |
76 | | - findings.append(finding) |
| 195 | + return confirmed, suspected |
77 | 196 |
|
78 | | - return findings |
| 197 | + |
| 198 | +async def validate_findings_with_ai(results: List[TestResult]) -> List[Finding]: |
| 199 | + """Backwards-compatible wrapper returning only confirmed findings.""" |
| 200 | + confirmed, _ = await classify_findings(results) |
| 201 | + return confirmed |
0 commit comments