Skip to content

Commit fa5d505

Browse files
Merge pull request aden-hive#447 from pradyten/fix/hallucination-detection-full-string-check
fix(graph): check entire string for code indicators in hallucination detection
2 parents 854a867 + df7b950 commit fa5d505

3 files changed

Lines changed: 325 additions & 8 deletions

File tree

core/framework/graph/node.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,7 @@ def write(self, key: str, value: Any, validate: bool = True) -> None:
196196
# Check for obviously hallucinated content
197197
if len(value) > 5000:
198198
# Long strings that look like code are suspicious
199-
code_indicators = ["```python", "def ", "class ", "import ", "async def "]
200-
if any(indicator in value[:500] for indicator in code_indicators):
199+
if self._contains_code_indicators(value):
201200
logger.warning(
202201
f"⚠ Suspicious write to key '{key}': appears to be code "
203202
f"({len(value)} chars). Consider using validate=False if intended."
@@ -210,6 +209,51 @@ def write(self, key: str, value: Any, validate: bool = True) -> None:
210209

211210
self._data[key] = value
212211

212+
def _contains_code_indicators(self, value: str) -> bool:
213+
"""
214+
Check for code patterns in a string using sampling for efficiency.
215+
216+
For strings under 10KB, checks the entire content.
217+
For longer strings, samples at strategic positions to balance
218+
performance with detection accuracy.
219+
220+
Args:
221+
value: The string to check for code indicators
222+
223+
Returns:
224+
True if code indicators are found, False otherwise
225+
"""
226+
code_indicators = [
227+
# Python
228+
"```python", "def ", "class ", "import ", "async def ", "from ",
229+
# JavaScript/TypeScript
230+
"function ", "const ", "let ", "=> {", "require(", "export ",
231+
# SQL
232+
"SELECT ", "INSERT ", "UPDATE ", "DELETE ", "DROP ",
233+
# HTML/Script injection
234+
"<script", "<?php", "<%",
235+
]
236+
237+
# For strings under 10KB, check the entire content
238+
if len(value) < 10000:
239+
return any(indicator in value for indicator in code_indicators)
240+
241+
# For longer strings, sample at strategic positions
242+
sample_positions = [
243+
0, # Start
244+
len(value) // 4, # 25%
245+
len(value) // 2, # 50%
246+
3 * len(value) // 4, # 75%
247+
max(0, len(value) - 2000), # Near end
248+
]
249+
250+
for pos in sample_positions:
251+
chunk = value[pos:pos + 2000]
252+
if any(indicator in chunk for indicator in code_indicators):
253+
return True
254+
255+
return False
256+
213257
def read_all(self) -> dict[str, Any]:
214258
"""Read all accessible data."""
215259
if self._allowed_read:

core/framework/graph/validator.py

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,52 @@ class OutputValidator:
3030
Used by the executor to catch bad outputs before they pollute memory.
3131
"""
3232

33+
def _contains_code_indicators(self, value: str) -> bool:
34+
"""
35+
Check for code patterns in a string using sampling for efficiency.
36+
37+
For strings under 10KB, checks the entire content.
38+
For longer strings, samples at strategic positions to balance
39+
performance with detection accuracy.
40+
41+
Args:
42+
value: The string to check for code indicators
43+
44+
Returns:
45+
True if code indicators are found, False otherwise
46+
"""
47+
code_indicators = [
48+
# Python
49+
"def ", "class ", "import ", "from ", "if __name__",
50+
"async def ", "await ", "try:", "except:",
51+
# JavaScript/TypeScript
52+
"function ", "const ", "let ", "=> {", "require(", "export ",
53+
# SQL
54+
"SELECT ", "INSERT ", "UPDATE ", "DELETE ", "DROP ",
55+
# HTML/Script injection
56+
"<script", "<?php", "<%",
57+
]
58+
59+
# For strings under 10KB, check the entire content
60+
if len(value) < 10000:
61+
return any(indicator in value for indicator in code_indicators)
62+
63+
# For longer strings, sample at strategic positions
64+
sample_positions = [
65+
0, # Start
66+
len(value) // 4, # 25%
67+
len(value) // 2, # 50%
68+
3 * len(value) // 4, # 75%
69+
max(0, len(value) - 2000), # Near end
70+
]
71+
72+
for pos in sample_positions:
73+
chunk = value[pos:pos + 2000]
74+
if any(indicator in chunk for indicator in code_indicators):
75+
return True
76+
77+
return False
78+
3379
def validate_output_keys(
3480
self,
3581
output: dict[str, Any],
@@ -93,12 +139,8 @@ def validate_no_hallucination(
93139
if not isinstance(value, str):
94140
continue
95141

96-
# Check for Python-like code
97-
code_indicators = [
98-
"def ", "class ", "import ", "from ", "if __name__",
99-
"async def ", "await ", "try:", "except:"
100-
]
101-
if any(indicator in value[:500] for indicator in code_indicators):
142+
# Check for code patterns in the entire string, not just first 500 chars
143+
if self._contains_code_indicators(value):
102144
# Could be legitimate, but warn
103145
logger.warning(
104146
f"Output key '{key}' may contain code - verify this is expected"
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
"""
2+
Test hallucination detection in SharedMemory and OutputValidator.
3+
4+
These tests verify that code detection works correctly across the entire
5+
string content, not just the first 500 characters.
6+
"""
7+
8+
import pytest
9+
from framework.graph.node import SharedMemory, MemoryWriteError
10+
from framework.graph.validator import OutputValidator, ValidationResult
11+
12+
13+
class TestSharedMemoryHallucinationDetection:
14+
"""Test the SharedMemory hallucination detection."""
15+
16+
def test_detects_code_at_start(self):
17+
"""Code at the start of the string should be detected."""
18+
memory = SharedMemory()
19+
code_content = "```python\nimport os\ndef hack(): pass\n```" + "A" * 6000
20+
21+
with pytest.raises(MemoryWriteError) as exc_info:
22+
memory.write("output", code_content)
23+
24+
assert "hallucinated code" in str(exc_info.value)
25+
26+
def test_detects_code_in_middle(self):
27+
"""Code in the middle of the string should be detected (was previously missed)."""
28+
memory = SharedMemory()
29+
# 600 chars of padding, then code, then more padding to exceed 5000 chars
30+
padding_start = "A" * 600
31+
code = "\n```python\nimport os\ndef malicious(): pass\n```\n"
32+
padding_end = "B" * 5000
33+
content = padding_start + code + padding_end
34+
35+
with pytest.raises(MemoryWriteError) as exc_info:
36+
memory.write("output", content)
37+
38+
assert "hallucinated code" in str(exc_info.value)
39+
40+
def test_detects_code_at_end(self):
41+
"""Code at the end of the string should be detected (was previously missed)."""
42+
memory = SharedMemory()
43+
padding = "A" * 5500
44+
code = "\n```python\nclass Exploit:\n pass\n```"
45+
content = padding + code
46+
47+
with pytest.raises(MemoryWriteError) as exc_info:
48+
memory.write("output", content)
49+
50+
assert "hallucinated code" in str(exc_info.value)
51+
52+
def test_detects_javascript_code(self):
53+
"""JavaScript code patterns should be detected."""
54+
memory = SharedMemory()
55+
padding = "A" * 600
56+
code = "\nfunction malicious() { require('child_process'); }\n"
57+
padding_end = "B" * 5000
58+
content = padding + code + padding_end
59+
60+
with pytest.raises(MemoryWriteError) as exc_info:
61+
memory.write("output", content)
62+
63+
assert "hallucinated code" in str(exc_info.value)
64+
65+
def test_detects_sql_injection(self):
66+
"""SQL patterns should be detected."""
67+
memory = SharedMemory()
68+
padding = "A" * 600
69+
code = "\nDROP TABLE users; SELECT * FROM passwords;\n"
70+
padding_end = "B" * 5000
71+
content = padding + code + padding_end
72+
73+
with pytest.raises(MemoryWriteError) as exc_info:
74+
memory.write("output", content)
75+
76+
assert "hallucinated code" in str(exc_info.value)
77+
78+
def test_detects_script_injection(self):
79+
"""HTML script injection should be detected."""
80+
memory = SharedMemory()
81+
padding = "A" * 600
82+
code = "\n<script>alert('xss')</script>\n"
83+
padding_end = "B" * 5000
84+
content = padding + code + padding_end
85+
86+
with pytest.raises(MemoryWriteError) as exc_info:
87+
memory.write("output", content)
88+
89+
assert "hallucinated code" in str(exc_info.value)
90+
91+
def test_allows_short_strings_without_validation(self):
92+
"""Strings under 5000 chars should not trigger validation."""
93+
memory = SharedMemory()
94+
content = "def hello(): pass" # Contains code indicator but short
95+
96+
# Should not raise - too short to validate
97+
memory.write("output", content)
98+
assert memory.read("output") == content
99+
100+
def test_allows_long_strings_without_code(self):
101+
"""Long strings without code indicators should be allowed."""
102+
memory = SharedMemory()
103+
content = "This is a long text document. " * 500 # ~15000 chars, no code
104+
105+
memory.write("output", content)
106+
assert memory.read("output") == content
107+
108+
def test_validate_false_bypasses_check(self):
109+
"""Using validate=False should bypass the check."""
110+
memory = SharedMemory()
111+
code_content = "```python\nimport os\n```" + "A" * 6000
112+
113+
# Should not raise when validate=False
114+
memory.write("output", code_content, validate=False)
115+
assert memory.read("output") == code_content
116+
117+
def test_sampling_for_very_long_strings(self):
118+
"""Very long strings (>10KB) should be sampled at multiple positions."""
119+
memory = SharedMemory()
120+
# Create a 50KB string with code at the 75% mark
121+
size = 50000
122+
code_position = int(size * 0.75)
123+
content = "A" * code_position + "def hidden_code(): pass" + "B" * (size - code_position - 25)
124+
125+
with pytest.raises(MemoryWriteError) as exc_info:
126+
memory.write("output", content)
127+
128+
assert "hallucinated code" in str(exc_info.value)
129+
130+
131+
class TestOutputValidatorHallucinationDetection:
132+
"""Test the OutputValidator hallucination detection."""
133+
134+
def test_detects_code_anywhere_in_output(self):
135+
"""Code anywhere in the output value should trigger a warning."""
136+
validator = OutputValidator()
137+
padding = "Normal text content. " * 50
138+
code = "\ndef suspicious_function():\n pass\n"
139+
output = {"result": padding + code}
140+
141+
# The method logs a warning but doesn't fail
142+
result = validator.validate_no_hallucination(output)
143+
# The warning is logged - we can't easily test logging, but the method should work
144+
assert isinstance(result, ValidationResult)
145+
146+
def test_contains_code_indicators_full_check(self):
147+
"""_contains_code_indicators should check the entire string."""
148+
validator = OutputValidator()
149+
150+
# Code at position 600 (was previously missed with [:500] check)
151+
padding = "A" * 600
152+
code = "import os"
153+
content = padding + code
154+
155+
assert validator._contains_code_indicators(content) is True
156+
157+
def test_contains_code_indicators_sampling(self):
158+
"""_contains_code_indicators should sample for very long strings."""
159+
validator = OutputValidator()
160+
161+
# 50KB string with code at 75% position
162+
size = 50000
163+
code_position = int(size * 0.75)
164+
content = "A" * code_position + "class HiddenClass:" + "B" * (size - code_position - 18)
165+
166+
assert validator._contains_code_indicators(content) is True
167+
168+
def test_no_false_positive_for_clean_text(self):
169+
"""Clean text without code should not trigger false positives."""
170+
validator = OutputValidator()
171+
172+
# Long text without any code indicators
173+
content = "This is a perfectly normal document. " * 300
174+
175+
assert validator._contains_code_indicators(content) is False
176+
177+
def test_detects_multiple_languages(self):
178+
"""Should detect code patterns from multiple programming languages."""
179+
validator = OutputValidator()
180+
181+
test_cases = [
182+
"function test() {}", # JavaScript
183+
"const x = 5;", # JavaScript
184+
"SELECT * FROM users", # SQL
185+
"DROP TABLE data", # SQL
186+
"<script>", # HTML
187+
"<?php", # PHP
188+
]
189+
190+
for code in test_cases:
191+
assert validator._contains_code_indicators(code) is True, f"Failed to detect: {code}"
192+
193+
194+
class TestEdgeCases:
195+
"""Test edge cases for hallucination detection."""
196+
197+
def test_empty_string(self):
198+
"""Empty strings should not cause errors."""
199+
memory = SharedMemory()
200+
memory.write("output", "")
201+
assert memory.read("output") == ""
202+
203+
def test_non_string_values(self):
204+
"""Non-string values should not be validated for code."""
205+
memory = SharedMemory()
206+
207+
# These should all work without validation
208+
memory.write("number", 12345)
209+
memory.write("list", [1, 2, 3])
210+
memory.write("dict", {"key": "value"})
211+
memory.write("bool", True)
212+
213+
assert memory.read("number") == 12345
214+
assert memory.read("list") == [1, 2, 3]
215+
216+
def test_exactly_5000_chars(self):
217+
"""String of exactly 5000 chars should not trigger validation."""
218+
memory = SharedMemory()
219+
content = "def code(): pass" + "A" * (5000 - 16) # Exactly 5000 chars
220+
221+
# Should not raise - exactly at threshold, not over
222+
memory.write("output", content)
223+
assert len(memory.read("output")) == 5000
224+
225+
def test_5001_chars_triggers_validation(self):
226+
"""String of 5001 chars with code should trigger validation."""
227+
memory = SharedMemory()
228+
content = "def code(): pass" + "A" * (5001 - 16) # 5001 chars
229+
230+
with pytest.raises(MemoryWriteError):
231+
memory.write("output", content)

0 commit comments

Comments
 (0)