diff --git a/engine/argument_risk_engine/explanation/evidence.py b/engine/argument_risk_engine/explanation/evidence.py index d9e7868..d477924 100644 --- a/engine/argument_risk_engine/explanation/evidence.py +++ b/engine/argument_risk_engine/explanation/evidence.py @@ -1,6 +1,65 @@ from __future__ import annotations +from dataclasses import asdict, dataclass + + +@dataclass(frozen=True) +class EvidenceSpan: + text: str + start_char: int + end_char: int + source: str = "input_text" + match_type: str = "exact" + confidence: float = 1.0 + + @property + def quote(self) -> str: + return self.text + + @property + def start(self) -> int: + return self.start_char + + @property + def end(self) -> int: + return self.end_char + + def __getitem__(self, key: str) -> object: + aliases = {"quote": "text", "start": "start_char", "end": "end_char"} + return getattr(self, aliases.get(key, key)) + + def get(self, key: str, default: object = None) -> object: + try: + return self[key] + except AttributeError: + return default + + def to_dict(self) -> dict[str, object]: + data = asdict(self) + data.update({"quote": self.text, "start": self.start_char, "end": self.end_char}) + return data + + +def find_evidence_spans(text: str, evidence_text: str, *, source: str = "input_text", match_type: str = "exact") -> list[EvidenceSpan]: + """Return exact evidence spans from text; never fabricate missing evidence.""" + + if not text or not evidence_text: + return [] + start = text.find(evidence_text) + if start < 0: + normalized = evidence_text.strip() + start = text.find(normalized) if normalized else -1 + evidence_text = normalized + if start < 0: + return [] + end = start + len(evidence_text) + if text[start:end] != evidence_text: + return [] + return [EvidenceSpan(evidence_text, start, end, source=source, match_type=match_type, confidence=1.0)] + def evidence_span(text: str, claim: str) -> dict[str, object]: - start = text.find(claim) - return {"quote": claim, "start": max(start, 0), "end": max(start, 0) + len(claim)} + spans = find_evidence_spans(text, str(claim)) + if not spans: + return {} + return spans[0].to_dict() diff --git a/engine/argument_risk_engine/extraction/claim_extractor.py b/engine/argument_risk_engine/extraction/claim_extractor.py index 4acac95..32a3031 100644 --- a/engine/argument_risk_engine/extraction/claim_extractor.py +++ b/engine/argument_risk_engine/extraction/claim_extractor.py @@ -1,9 +1,180 @@ from __future__ import annotations import re +from collections.abc import Iterable +from dataclasses import dataclass, field +CLAIM_TYPES = { + "causal_claim", + "comparative_claim", + "normative_claim", + "generalization", + "prediction", + "recommendation", + "evidential_claim", + "statistical_claim", + "analogy_claim", + "question_claim", + "descriptive_claim", + "unclear", +} -def extract_claims(text: str) -> list[str]: - pieces = re.split(r"(?<=[.!?])\s+|\n+", text.strip()) - claims = [piece.strip() for piece in pieces if len(piece.strip()) >= 8] - return claims or ([text.strip()] if text.strip() else []) +_MARKERS: dict[str, tuple[str, ...]] = { + "causal_claim": ( + "because", + "therefore", + "leads to", + "causes", + "results in", + "due to", + "explains", + "responsible for", + ), + "comparative_claim": ( + "better than", + "worse than", + "more than", + "less than", + "superior", + "inferior", + "compared with", + ), + "normative_claim": ("should", "must", "ought", "need to", "have to"), + "recommendation": ("recommend", "recommendation", "advise", "suggest", "best to"), + "prediction": ("will", "likely", "expected to", "forecast", "probably"), + "generalization": ("always", "never", "everyone", "no one", "all", "none", "most people"), + "statistical_claim": ("percent", "average", "rate", "sample", "survey", "study", "data", "statistically"), + "analogy_claim": ("like", "similar to", "just as", "equivalent to", "same as"), + "evidential_claim": ("evidence", "according to", "shows", "found", "study", "data", "research"), +} + +_STRONG_MARKER_TYPES = set(_MARKERS) | {"question_claim"} +_MEANINGFUL_RE = re.compile(r"[A-Za-z0-9][A-Za-z0-9'-]*") +_SENTENCE_END_RE = re.compile(r"(?<=[.!?])(?=\s|$)") + + +@dataclass(frozen=True) +class Claim(str): + """A sentence-level claim that behaves like a string for legacy callers.""" + + text: str = field(default="") + start_char: int = 0 + end_char: int = 0 + claim_type: str = "unclear" + markers: tuple[str, ...] = field(default_factory=tuple) + + def __new__(cls, text: str, start_char: int = 0, end_char: int | None = None, claim_type: str = "unclear", markers: Iterable[str] = ()): + obj = str.__new__(cls, text) + return obj + + def __init__(self, text: str, start_char: int = 0, end_char: int | None = None, claim_type: str = "unclear", markers: Iterable[str] = ()): + object.__setattr__(self, "text", text) + object.__setattr__(self, "start_char", start_char) + object.__setattr__(self, "end_char", len(text) + start_char if end_char is None else end_char) + object.__setattr__(self, "claim_type", claim_type if claim_type in CLAIM_TYPES else "unclear") + object.__setattr__(self, "markers", tuple(markers)) + + def model_dump(self) -> dict[str, object]: + return { + "text": self.text, + "start_char": self.start_char, + "end_char": self.end_char, + "claim_type": self.claim_type, + "markers": list(self.markers), + } + + +def _meaningful_tokens(text: str) -> list[str]: + return _MEANINGFUL_RE.findall(text) + + +def _contains_marker(text: str, marker: str) -> bool: + pattern = r"(? tuple[str, tuple[str, ...]]: + text = sentence.strip() + if not text: + return "unclear", () + if text.endswith("?"): + return "question_claim", () + + matched: dict[str, list[str]] = {} + for claim_type, markers in _MARKERS.items(): + hits = [marker for marker in markers if _contains_marker(text, marker)] + if hits: + matched[claim_type] = hits + + if not matched: + return ("descriptive_claim", ()) if len(_meaningful_tokens(text)) >= 5 else ("unclear", ()) + + precedence = [ + "statistical_claim", + "causal_claim", + "comparative_claim", + "normative_claim", + "recommendation", + "prediction", + "generalization", + "analogy_claim", + "evidential_claim", + ] + for claim_type in precedence: + if claim_type in matched: + return claim_type, tuple(matched[claim_type]) + first_type = next(iter(matched)) + return first_type, tuple(matched[first_type]) + + +def _sentence_spans(text: str) -> list[tuple[str, int, int]]: + spans: list[tuple[str, int, int]] = [] + cursor = 0 + for match in _SENTENCE_END_RE.finditer(text): + end = match.end() + raw = text[cursor:end] + stripped = raw.strip() + if stripped: + start = cursor + len(raw) - len(raw.lstrip()) + finish = cursor + len(raw.rstrip()) + spans.append((text[start:finish], start, finish)) + cursor = end + while cursor < len(text) and text[cursor].isspace(): + cursor += 1 + if cursor < len(text): + raw = text[cursor:] + for part in re.finditer(r"[^\n]+", raw): + segment = part.group(0).strip() + if segment: + start = cursor + part.start() + len(part.group(0)) - len(part.group(0).lstrip()) + finish = cursor + part.end() - (len(part.group(0)) - len(part.group(0).rstrip())) + spans.append((text[start:finish], start, finish)) + return spans + + +def extract_claims(text: str) -> list[Claim]: + """Extract sentence-level claims with stable character offsets. + + Fragments with fewer than five meaningful tokens are ignored unless they contain + one of the configured strong claim markers (or are questions). + """ + + if not text or not text.strip(): + return [] + + claims: list[Claim] = [] + for sentence, start, end in _sentence_spans(text): + claim_type, markers = detect_claim_type(sentence) + strong = claim_type in _STRONG_MARKER_TYPES and (bool(markers) or claim_type == "question_claim") + legacy_claim_label = bool(re.search(r"(? bool: + return entry.activation_status == "deprecated" or entry.academic_status == "deprecated" + + +def is_healthy_suppressor(entry: TaxonomyEntry) -> bool: + return bool(entry.healthy_suppressor or entry.canonical_category == "healthy_reasoning_pattern") + + +def is_candidate_only(entry: TaxonomyEntry) -> bool: + return bool(entry.enabled_for_retrieval and not entry.enabled_for_classification) + + +def final_classification_candidates(candidates: list[object]) -> list[object]: + """Drop deprecated, healthy-suppressor, and candidate-only retrieval matches.""" + + filtered: list[object] = [] + for candidate in candidates: + entry = getattr(candidate, "entry", candidate) + if is_deprecated(entry) or is_healthy_suppressor(entry) or is_candidate_only(entry): + continue + filtered.append(candidate) + return filtered diff --git a/engine/argument_risk_engine/retrieval/inverted_index.py b/engine/argument_risk_engine/retrieval/inverted_index.py index 1437530..0f6d1dd 100644 --- a/engine/argument_risk_engine/retrieval/inverted_index.py +++ b/engine/argument_risk_engine/retrieval/inverted_index.py @@ -2,24 +2,130 @@ import re from collections import defaultdict +from collections.abc import Iterable +from dataclasses import dataclass, field -TOKEN_RE = re.compile(r"[a-z0-9_']+") +from argument_risk_engine.taxonomy.models import ActivationStatus, TaxonomyEntry, TaxonomyPack +TOKEN_RE = re.compile(r"[a-z0-9][a-z0-9_'-]*") -def tokenize(text: str) -> list[str]: - return TOKEN_RE.findall(text.lower()) +STOPWORDS = { + "a", "an", "and", "are", "as", "at", "be", "by", "for", "from", "has", "have", "in", "into", "is", "it", "its", "of", "on", "or", "that", "the", "their", "them", "then", "there", "these", "this", "those", "to", "was", "were", "with", "without", "we", "you", "they", "i", +} + +GENERIC_TERMS = { + "argument", "claim", "claims", "evidence", "reason", "reasoning", "risk", "risks", "statement", "statements", "language", "pattern", "patterns", "people", "person", "group", "thing", "things", "good", "bad", "better", "worse", "issue", "case", "example", "examples", "support", "supports", "may", "might", "could", "would", "should", +} + +FIELD_WEIGHTS = { + "name": 3.0, + "synonyms": 3.0, + "signals": 4.0, + "trigger_patterns": 5.0, + "definitions": 1.25, +} + + +def tokenize(text: str, *, keep_generic: bool = True) -> list[str]: + tokens = TOKEN_RE.findall(str(text).lower()) + if keep_generic: + return tokens + return [token for token in tokens if token not in STOPWORDS and token not in GENERIC_TERMS and len(token) > 1] + + +def normalize_phrase(text: str) -> str: + return " ".join(tokenize(text, keep_generic=True)) + + +def significant_terms(text: str) -> list[str]: + return tokenize(text, keep_generic=False) + + +@dataclass(frozen=True) +class IndexedField: + field: str + value: str + terms: tuple[str, ...] + phrase: str + + +@dataclass +class IndexedEntry: + entry: TaxonomyEntry + fields: list[IndexedField] = field(default_factory=list) + is_healthy_suppressor: bool = False + is_candidate_only: bool = False class InvertedIndex: - def __init__(self) -> None: + def __init__(self, pack: TaxonomyPack | None = None) -> None: self.index: dict[str, set[str]] = defaultdict(set) + self.entries: dict[str, IndexedEntry] = {} + self.ignored_terms: set[str] = set(STOPWORDS | GENERIC_TERMS) + if pack is not None: + self.build(pack) def add(self, doc_id: str, text: str) -> None: - for token in tokenize(text): + for token in significant_terms(text): self.index[token].add(doc_id) + def build(self, pack: TaxonomyPack) -> None: + for entry in pack.entries: + if not _retrievable(entry): + continue + indexed = IndexedEntry( + entry=entry, + is_healthy_suppressor=_is_healthy_suppressor(entry), + is_candidate_only=bool(entry.enabled_for_retrieval and not entry.enabled_for_classification), + ) + for field_name, value in _entry_field_values(entry): + terms = tuple(significant_terms(value)) + phrase = normalize_phrase(value) + if not terms and not phrase: + continue + indexed.fields.append(IndexedField(field_name, value, terms, phrase)) + for term in terms: + self.index[term].add(entry.id) + if indexed.fields: + self.entries[entry.id] = indexed + def search(self, text: str) -> set[str]: matches: set[str] = set() - for token in tokenize(text): + for token in significant_terms(text): matches.update(self.index.get(token, set())) return matches + + def get(self, doc_id: str) -> IndexedEntry | None: + return self.entries.get(doc_id) + + +def _list_from_metadata(entry: TaxonomyEntry, key: str) -> list[str]: + value = entry.metadata.get(key) if entry.metadata else None + if value is None: + return [] + if isinstance(value, list): + return [str(item) for item in value if str(item).strip()] + return [str(value)] if str(value).strip() else [] + + +def _entry_field_values(entry: TaxonomyEntry) -> Iterable[tuple[str, str]]: + yield "name", entry.name + for synonym in [*entry.synonym_ids, *_list_from_metadata(entry, "synonyms")]: + yield "synonyms", synonym + for signal in entry.signals: + yield "signals", signal + for trigger in entry.trigger_patterns: + yield "trigger_patterns", trigger + for definition in [entry.short_definition, entry.long_definition]: + if definition: + yield "definitions", definition + + +def _retrievable(entry: TaxonomyEntry) -> bool: + if entry.activation_status == ActivationStatus.deprecated.value or entry.academic_status == "deprecated": + return False + return bool(entry.enabled_for_retrieval or entry.enabled_for_classification or entry.activation_status == ActivationStatus.active.value) + + +def _is_healthy_suppressor(entry: TaxonomyEntry) -> bool: + return bool(entry.healthy_suppressor or entry.canonical_category == "healthy_reasoning_pattern") diff --git a/engine/argument_risk_engine/retrieval/lexical_retriever.py b/engine/argument_risk_engine/retrieval/lexical_retriever.py index 74710e3..49ee5d9 100644 --- a/engine/argument_risk_engine/retrieval/lexical_retriever.py +++ b/engine/argument_risk_engine/retrieval/lexical_retriever.py @@ -1,23 +1,210 @@ from __future__ import annotations -from argument_risk_engine.retrieval.inverted_index import tokenize +from dataclasses import asdict, dataclass, field +from typing import Any + +from argument_risk_engine.retrieval.inverted_index import ( + FIELD_WEIGHTS, + InvertedIndex, + normalize_phrase, + significant_terms, +) +from argument_risk_engine.retrieval.retrieval_diagnostics import RetrievalDiagnostics from argument_risk_engine.taxonomy.models import TaxonomyEntry, TaxonomyPack -def retrieve_candidates(claim: str, pack: TaxonomyPack, limit: int = 5) -> list[TaxonomyEntry]: - claim_text = claim.lower() - claim_tokens = set(tokenize(claim)) - scored: list[tuple[int, TaxonomyEntry]] = [] - for entry in pack.entries: - if not entry.active: +@dataclass(frozen=True) +class RetrievedTaxonomyEntry: + entry: TaxonomyEntry + retrieval_score: float + matched_terms: list[str] + matched_fields: list[str] + retrieval_reason: str + false_positive_risk: str = "medium" + healthy_pattern_matches: list[str] = field(default_factory=list) + diagnostics: dict[str, Any] = field(default_factory=dict) + + def __getattr__(self, name: str) -> Any: + return getattr(self.entry, name) + + def to_dict(self) -> dict[str, Any]: + data = asdict(self) + data["entry"] = self.entry.model_dump() if hasattr(self.entry, "model_dump") else self.entry.dict() + return data + + +def retrieve_candidates(claim: str, pack: TaxonomyPack, limit: int = 5) -> list[RetrievedTaxonomyEntry]: + """Retrieve deterministic taxonomy candidates for a claim. + + Retrieval is lexical and conservative: stopwords/generic taxonomy wording do not + activate entries, deprecated rows are ignored, and matched healthy-reasoning + patterns lower the scores of risky candidates. + """ + + index = _build_index(pack) + query = str(claim) + query_terms = significant_terms(query) + query_term_set = set(query_terms) + query_phrase = normalize_phrase(query) + + raw_ids = index.search(query) + raw_candidates: list[RetrievedTaxonomyEntry] = [] + healthy_matches: list[str] = [] + healthy_terms: set[str] = set() + + for doc_id in sorted(raw_ids): + indexed = index.get(doc_id) + if indexed is None: continue - score = 0 - for keyword in entry.keywords: - keyword_lower = keyword.lower() - if keyword_lower in claim_text: - score += 3 - score += len(set(tokenize(keyword_lower)) & claim_tokens) - if score: - scored.append((score, entry)) - scored.sort(key=lambda item: (-item[0], item[1].id)) - return [entry for _, entry in scored[:limit]] + score, matched_terms, matched_fields = _score_indexed_entry(indexed, query_phrase, query_term_set) + if score <= 0: + continue + if indexed.is_healthy_suppressor: + healthy_matches.append(indexed.entry.id) + healthy_terms.update(matched_terms) + continue + raw_candidates.append( + RetrievedTaxonomyEntry( + entry=indexed.entry, + retrieval_score=score, + matched_terms=matched_terms, + matched_fields=matched_fields, + retrieval_reason=_reason(matched_terms, matched_fields), + false_positive_risk=_base_false_positive_risk(indexed.entry), + diagnostics={}, + ) + ) + + adjusted: list[RetrievedTaxonomyEntry] = [] + suppressed_count = 0 + for candidate in raw_candidates: + overlap = sorted(set(candidate.matched_terms) & healthy_terms) + penalty = 0.0 + if healthy_matches: + penalty += 1.0 + if overlap: + penalty += 1.5 + score = max(0.0, candidate.retrieval_score - penalty) + if score <= 0: + suppressed_count += 1 + continue + risk = candidate.false_positive_risk + if penalty >= 1.5: + risk = "high" + elif penalty > 0 and risk == "low": + risk = "medium" + adjusted.append( + RetrievedTaxonomyEntry( + entry=candidate.entry, + retrieval_score=round(score, 4), + matched_terms=candidate.matched_terms, + matched_fields=candidate.matched_fields, + retrieval_reason=(candidate.retrieval_reason + "; reduced by healthy reasoning pattern" if penalty else candidate.retrieval_reason), + false_positive_risk=risk, + healthy_pattern_matches=healthy_matches, + diagnostics={}, + ) + ) + + adjusted.sort(key=lambda item: (-item.retrieval_score, item.entry.id)) + returned = adjusted[: max(limit, 0)] + diag = RetrievalDiagnostics( + query_terms=query_terms, + considered_entry_count=len(index.entries), + raw_candidate_count=len(raw_candidates), + returned_candidate_count=len(returned), + suppressed_candidate_count=suppressed_count, + healthy_suppressor_count=len(healthy_matches), + ignored_terms=sorted(set(normalize_phrase(query).split()) - set(query_terms)), + ).to_dict() + + return [ + RetrievedTaxonomyEntry( + entry=item.entry, + retrieval_score=item.retrieval_score, + matched_terms=item.matched_terms, + matched_fields=item.matched_fields, + retrieval_reason=item.retrieval_reason, + false_positive_risk=item.false_positive_risk, + healthy_pattern_matches=item.healthy_pattern_matches, + diagnostics=diag, + ) + for item in returned + ] + + +def _score_indexed_entry(indexed: Any, query_phrase: str, query_terms: set[str]) -> tuple[float, list[str], list[str]]: + score = 0.0 + matched_terms: set[str] = set() + matched_fields: set[str] = set() + + for indexed_field in indexed.fields: + field_weight = FIELD_WEIGHTS.get(indexed_field.field, 1.0) + field_terms = set(indexed_field.terms) + term_hits = field_terms & query_terms + if term_hits: + matched_terms.update(term_hits) + matched_fields.add(indexed_field.field) + score += field_weight * len(term_hits) + if indexed_field.phrase and " " in indexed_field.phrase and _phrase_in_query(indexed_field.phrase, query_phrase): + matched_fields.add(indexed_field.field) + matched_terms.update(indexed_field.terms) + score += field_weight * 2.5 + elif indexed_field.phrase and indexed_field.phrase in query_terms: + matched_fields.add(indexed_field.field) + matched_terms.update(indexed_field.terms) + score += field_weight + + # Require either a phrase match, a trigger/signal match, or two meaningful + # definition/name overlaps. This keeps neutral prose from retrieving many rows. + strong_field = bool({"signals", "trigger_patterns", "synonyms"} & matched_fields) + if not strong_field and len(matched_terms) < 2: + return 0.0, [], [] + + if indexed.is_candidate_only: + score *= 0.9 + + return score, sorted(matched_terms), sorted(matched_fields) + + +def _phrase_in_query(phrase: str, query_phrase: str) -> bool: + return f" {phrase} " in f" {query_phrase} " + + +def _reason(matched_terms: list[str], matched_fields: list[str]) -> str: + fields = ", ".join(matched_fields) if matched_fields else "taxonomy fields" + terms = ", ".join(matched_terms[:6]) if matched_terms else "phrase" + return f"Matched {terms} in {fields}." + + +def _base_false_positive_risk(entry: TaxonomyEntry) -> str: + sensitivity = str(entry.false_positive_sensitivity or "medium") + if entry.requires_context or entry.requires_human_judgment or sensitivity == "high": + return "high" + if sensitivity == "low": + return "low" + return "medium" + + +def _pack_cache_key(pack: TaxonomyPack) -> tuple[Any, ...]: + return ( + pack.name, + pack.version, + len(pack.entries), + tuple((entry.id, entry.activation_status, entry.enabled_for_retrieval, entry.enabled_for_classification) for entry in pack.entries), + ) + + +_INDEX_CACHE: dict[tuple[Any, ...], InvertedIndex] = {} + + +def _build_index(pack: TaxonomyPack) -> InvertedIndex: + key = _pack_cache_key(pack) + cached = _INDEX_CACHE.get(key) + if cached is not None: + return cached + index = InvertedIndex(pack) + if len(_INDEX_CACHE) > 16: + _INDEX_CACHE.clear() + _INDEX_CACHE[key] = index + return index diff --git a/engine/argument_risk_engine/retrieval/retrieval_diagnostics.py b/engine/argument_risk_engine/retrieval/retrieval_diagnostics.py index eb3325f..2c8d113 100644 --- a/engine/argument_risk_engine/retrieval/retrieval_diagnostics.py +++ b/engine/argument_risk_engine/retrieval/retrieval_diagnostics.py @@ -1,2 +1,23 @@ +from __future__ import annotations + +from dataclasses import asdict, dataclass, field +from typing import Any + + +@dataclass(frozen=True) +class RetrievalDiagnostics: + query_terms: list[str] = field(default_factory=list) + considered_entry_count: int = 0 + raw_candidate_count: int = 0 + returned_candidate_count: int = 0 + suppressed_candidate_count: int = 0 + healthy_suppressor_count: int = 0 + ignored_terms: list[str] = field(default_factory=list) + notes: list[str] = field(default_factory=list) + + def to_dict(self) -> dict[str, Any]: + return asdict(self) + + def diagnostics(matches: list[object]) -> dict[str, int]: return {"candidate_count": len(matches)} diff --git a/tests/test_claim_extractor.py b/tests/test_claim_extractor.py index 693b0c2..5aef62c 100644 --- a/tests/test_claim_extractor.py +++ b/tests/test_claim_extractor.py @@ -1,5 +1,37 @@ +from argument_risk_engine.explanation.evidence import find_evidence_spans from argument_risk_engine.extraction.claim_extractor import extract_claims def test_extract_claims_splits_sentences(): assert len(extract_claims("First claim. Second claim.")) == 2 + + +def test_extract_claims_preserves_offsets_and_types(): + text = "Intro. Prices will likely rise because supply is low. Should we wait?" + claims = extract_claims(text) + + assert [claim.text for claim in claims] == [ + "Prices will likely rise because supply is low.", + "Should we wait?", + ] + assert claims[0].claim_type == "causal_claim" + assert claims[1].claim_type == "question_claim" + for claim in claims: + assert text[claim.start_char : claim.end_char] == claim.text + + +def test_extract_claims_ignores_short_fragments_without_markers(): + claims = extract_claims("Wow. Too vague. This report should be reviewed.") + + assert [claim.text for claim in claims] == ["This report should be reviewed."] + assert claims[0].claim_type == "normative_claim" + + +def test_evidence_spans_are_exact_substrings_only(): + text = "The survey found that 62 percent agreed." + + spans = find_evidence_spans(text, "survey found") + assert len(spans) == 1 + assert spans[0].text == "survey found" + assert text[spans[0].start_char : spans[0].end_char] == spans[0].text + assert find_evidence_spans(text, "survey invented") == [] diff --git a/tests/test_retriever.py b/tests/test_retriever.py index 372ddab..3a4d081 100644 --- a/tests/test_retriever.py +++ b/tests/test_retriever.py @@ -1,7 +1,69 @@ +from argument_risk_engine.retrieval.candidate_filter import final_classification_candidates from argument_risk_engine.retrieval.lexical_retriever import retrieve_candidates -from argument_risk_engine.taxonomy.models import default_taxonomy_pack +from argument_risk_engine.taxonomy.models import TaxonomyEntry, TaxonomyPack, default_taxonomy_pack + + +def _entry(**overrides): + data = { + "id": "risk", + "name": "Risk", + "signals": ["always"], + "enabled_for_retrieval": True, + "enabled_for_classification": True, + "activation_status": "active", + } + data.update(overrides) + return TaxonomyEntry(**data) def test_retriever_finds_keyword_candidate(): matches = retrieve_candidates("Everyone always does this.", default_taxonomy_pack()) assert matches[0].id == "overgeneralization" + + +def test_retriever_ignores_deprecated_and_limits_neutral_text(): + pack = TaxonomyPack( + entries=[ + _entry(id=f"neutral_{idx}", name=f"Neutral {idx}", signals=[f"rareterm{idx}"]) + for idx in range(1000) + ] + + [_entry(id="deprecated", name="Deprecated", signals=["uniquehit"], activation_status="deprecated")] + ) + + assert retrieve_candidates("This is a neutral project update with ordinary wording.", pack, limit=20) == [] + assert retrieve_candidates("uniquehit appears here.", pack) == [] + + +def test_candidate_only_entries_are_retrieved_but_not_final_candidates(): + pack = TaxonomyPack( + entries=[ + _entry(id="candidate_only", name="Candidate only", signals=["specialmarker"], enabled_for_classification=False), + _entry(id="final", name="Final", signals=["specialmarker"]), + ] + ) + + matches = retrieve_candidates("The text uses specialmarker.", pack, limit=10) + ids = {match.id for match in matches} + assert {"candidate_only", "final"} <= ids + assert [match.id for match in final_classification_candidates(matches)] == ["final"] + + +def test_healthy_reasoning_patterns_reduce_risky_matches(): + risky = _entry(id="overgeneralization", name="Overgeneralization", signals=["always"]) + healthy = _entry( + id="qualified_reasoning", + name="Qualified reasoning", + canonical_category="healthy_reasoning_pattern", + signals=["usually", "always"], + healthy_suppressor=True, + ) + pack_without_suppressor = TaxonomyPack(entries=[risky]) + pack_with_suppressor = TaxonomyPack(entries=[risky, healthy]) + + baseline = retrieve_candidates("People always do this.", pack_without_suppressor)[0] + reduced = retrieve_candidates("People usually always do this.", pack_with_suppressor)[0] + + assert reduced.id == "overgeneralization" + assert reduced.retrieval_score < baseline.retrieval_score + assert reduced.healthy_pattern_matches == ["qualified_reasoning"] + assert reduced.diagnostics["healthy_suppressor_count"] == 1