From 4382be8fdc18b4432b5be6d8306d3d2f61807575 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Wed, 3 Jun 2026 21:17:14 -0700 Subject: [PATCH 1/8] Phase 10: back response DTOs with pyrit.models, soft-deprecate old wire fields Make backend response DTOs (ScoreView/MessagePieceView/MessageView/ AttackSummary) inherit their pyrit.models counterparts so they expose every canonical field, adding presentation data via computed fields / mappers. Drop the RetryEventResponse DTO and the old deprecation shims. Soft-deprecate the renamed wire fields instead of hard-breaking external API clients: re-expose score_id/scored_at/piece_id/pieces as deprecated read-only computed-field aliases mirroring id/timestamp/message_pieces. Pydantic flags them deprecated in the OpenAPI schema; slated for removal in 0.17.0. Frontend + CLI updated/verified; add test_response_contracts.py wire-shape and deprecated-alias guards. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../src/components/Chat/ChatWindow.test.tsx | 26 +- frontend/src/services/api.test.ts | 4 +- frontend/src/types/index.ts | 8 +- frontend/src/utils/messageMapper.test.ts | 140 +++---- frontend/src/utils/messageMapper.ts | 6 +- pyrit/backend/mappers/__init__.py | 2 - pyrit/backend/mappers/attack_mappers.py | 301 ++++----------- pyrit/backend/models/__init__.py | 12 +- pyrit/backend/models/_media.py | 84 ++++ pyrit/backend/models/attacks.py | 359 +++++++++++++----- tests/unit/backend/test_api_routes.py | 66 ++-- tests/unit/backend/test_attack_service.py | 49 ++- tests/unit/backend/test_mappers.py | 315 +++++++-------- tests/unit/backend/test_response_contracts.py | 227 +++++++++++ 14 files changed, 954 insertions(+), 645 deletions(-) create mode 100644 pyrit/backend/models/_media.py create mode 100644 tests/unit/backend/test_response_contracts.py diff --git a/frontend/src/components/Chat/ChatWindow.test.tsx b/frontend/src/components/Chat/ChatWindow.test.tsx index 357b15e832..4608f1cd16 100644 --- a/frontend/src/components/Chat/ChatWindow.test.tsx +++ b/frontend/src/components/Chat/ChatWindow.test.tsx @@ -76,9 +76,9 @@ function makeTextResponse(text: string) { { turn_number: 1, role: "assistant", - pieces: [ + message_pieces: [ { - piece_id: "p-resp", + id: "p-resp", original_value_data_type: "text", converted_value_data_type: "text", original_value: text, @@ -101,9 +101,9 @@ function makeImageResponse() { { turn_number: 1, role: "assistant", - pieces: [ + message_pieces: [ { - piece_id: "p-img", + id: "p-img", original_value_data_type: "text", converted_value_data_type: "image_path", original_value: "generated image", @@ -127,9 +127,9 @@ function makeAudioResponse() { { turn_number: 1, role: "assistant", - pieces: [ + message_pieces: [ { - piece_id: "p-aud", + id: "p-aud", original_value_data_type: "text", converted_value_data_type: "audio_path", original_value: "spoken text", @@ -153,9 +153,9 @@ function makeVideoResponse() { { turn_number: 1, role: "assistant", - pieces: [ + message_pieces: [ { - piece_id: "p-vid", + id: "p-vid", original_value_data_type: "text", converted_value_data_type: "video_path", original_value: "generated video", @@ -179,9 +179,9 @@ function makeMultiModalResponse() { { turn_number: 1, role: "assistant", - pieces: [ + message_pieces: [ { - piece_id: "p-text", + id: "p-text", original_value_data_type: "text", converted_value_data_type: "text", original_value: "Here is the result:", @@ -190,7 +190,7 @@ function makeMultiModalResponse() { response_error: "none", }, { - piece_id: "p-img2", + id: "p-img2", original_value_data_type: "text", converted_value_data_type: "image_path", original_value: "image content", @@ -214,9 +214,9 @@ function makeErrorResponse(errorType: string, description: string) { { turn_number: 1, role: "assistant", - pieces: [ + message_pieces: [ { - piece_id: "p-err", + id: "p-err", original_value_data_type: "text", converted_value_data_type: "text", original_value: "", diff --git a/frontend/src/services/api.test.ts b/frontend/src/services/api.test.ts index 7b1e6f645a..a2297c25a3 100644 --- a/frontend/src/services/api.test.ts +++ b/frontend/src/services/api.test.ts @@ -274,9 +274,9 @@ describe("api service", () => { { turn_number: 1, role: "user", - pieces: [ + message_pieces: [ { - piece_id: "p1", + id: "p1", converted_value: "Hello", converted_value_data_type: "text", }, diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index f3834dca67..9b9d30b414 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -172,17 +172,17 @@ export interface CreateAttackResponse { // --- Messages --- export interface BackendScore { - score_id: string + id: string scorer_type: string score_type: string score_value: string score_category?: string[] | null score_rationale?: string | null - scored_at: string + timestamp: string } export interface BackendMessagePiece { - piece_id: string + id: string original_value_data_type: string converted_value_data_type: string original_value?: string | null @@ -200,7 +200,7 @@ export interface BackendMessagePiece { export interface BackendMessage { turn_number: number role: string - pieces: BackendMessagePiece[] + message_pieces: BackendMessagePiece[] created_at: string } diff --git a/frontend/src/utils/messageMapper.test.ts b/frontend/src/utils/messageMapper.test.ts index 25ece07f1d..6e6121b675 100644 --- a/frontend/src/utils/messageMapper.test.ts +++ b/frontend/src/utils/messageMapper.test.ts @@ -99,9 +99,9 @@ describe("messageMapper", () => { const msg: BackendMessage = { turn_number: 1, role: "assistant", - pieces: [ + message_pieces: [ { - piece_id: "p1", + id: "p1", original_value_data_type: "text", converted_value_data_type: "text", original_value: "Hello there", @@ -125,9 +125,9 @@ describe("messageMapper", () => { const msg: BackendMessage = { turn_number: 1, role: "assistant", - pieces: [ + message_pieces: [ { - piece_id: "p1", + id: "p1", original_value_data_type: "text", converted_value_data_type: "image_path", original_value: "generate an image", @@ -153,9 +153,9 @@ describe("messageMapper", () => { const msg: BackendMessage = { turn_number: 1, role: "assistant", - pieces: [ + message_pieces: [ { - piece_id: "p1", + id: "p1", original_value_data_type: "text", converted_value_data_type: "audio_path", original_value: "speak this", @@ -179,9 +179,9 @@ describe("messageMapper", () => { const msg: BackendMessage = { turn_number: 1, role: "assistant", - pieces: [ + message_pieces: [ { - piece_id: "p1", + id: "p1", original_value_data_type: "text", converted_value_data_type: "video_path", original_value: "generate video", @@ -205,9 +205,9 @@ describe("messageMapper", () => { const msg: BackendMessage = { turn_number: 1, role: "assistant", - pieces: [ + message_pieces: [ { - piece_id: "vid-1", + id: "vid-1", original_value_data_type: "text", converted_value_data_type: "video_path", original_value: "generate video", @@ -231,9 +231,9 @@ describe("messageMapper", () => { const msg: BackendMessage = { turn_number: 1, role: "assistant", - pieces: [ + message_pieces: [ { - piece_id: "p1", + id: "p1", original_value_data_type: "text", converted_value_data_type: "binary_path", original_value: "convert this", @@ -260,9 +260,9 @@ describe("messageMapper", () => { const msg: BackendMessage = { turn_number: 1, role: "assistant", - pieces: [ + message_pieces: [ { - piece_id: "p1", + id: "p1", original_value_data_type: "text", converted_value_data_type: "text", converted_value: "", @@ -285,9 +285,9 @@ describe("messageMapper", () => { const msg: BackendMessage = { turn_number: 1, role: "assistant", - pieces: [ + message_pieces: [ { - piece_id: "p1", + id: "p1", original_value_data_type: "text", converted_value_data_type: "text", converted_value: "Here is the image:", @@ -295,7 +295,7 @@ describe("messageMapper", () => { response_error: "none", }, { - piece_id: "p2", + id: "p2", original_value_data_type: "text", converted_value_data_type: "image_path", converted_value: "aW1hZ2U=", @@ -318,9 +318,9 @@ describe("messageMapper", () => { const msg: BackendMessage = { turn_number: 0, role: "user", - pieces: [ + message_pieces: [ { - piece_id: "p1", + id: "p1", original_value_data_type: "text", converted_value_data_type: "text", converted_value: "test", @@ -338,9 +338,9 @@ describe("messageMapper", () => { const msg: BackendMessage = { turn_number: 0, role: "system", - pieces: [ + message_pieces: [ { - piece_id: "p1", + id: "p1", original_value_data_type: "text", converted_value_data_type: "text", converted_value: "You are helpful", @@ -358,9 +358,9 @@ describe("messageMapper", () => { const msg: BackendMessage = { turn_number: 0, role: "simulated_assistant", - pieces: [ + message_pieces: [ { - piece_id: "p1", + id: "p1", original_value_data_type: "text", converted_value_data_type: "text", converted_value: "injected", @@ -378,9 +378,9 @@ describe("messageMapper", () => { const msg: BackendMessage = { turn_number: 1, role: "assistant", - pieces: [ + message_pieces: [ { - piece_id: "p1", + id: "p1", original_value_data_type: "text", converted_value_data_type: "image_path", converted_value: "/api/media?path=output%2Fimg.png", @@ -407,9 +407,9 @@ describe("messageMapper", () => { return { turn_number: 1, role: "assistant", - pieces: [ + message_pieces: [ { - piece_id: "p1abcdef", + id: "p1abcdef", original_value_data_type: "text", converted_value_data_type, original_value: "prompt", @@ -485,9 +485,9 @@ describe("messageMapper", () => { const msg: BackendMessage = { turn_number: 1, role: "assistant", - pieces: [ + message_pieces: [ { - piece_id: "p1abcdef", + id: "p1abcdef", original_value_data_type: "image_path", converted_value_data_type: "image_path", original_value: url, @@ -519,9 +519,9 @@ describe("messageMapper", () => { { turn_number: 0, role: "user", - pieces: [ + message_pieces: [ { - piece_id: "p1", + id: "p1", original_value_data_type: "text", converted_value_data_type: "text", converted_value: "Hello", @@ -534,9 +534,9 @@ describe("messageMapper", () => { { turn_number: 1, role: "assistant", - pieces: [ + message_pieces: [ { - piece_id: "p2", + id: "p2", original_value_data_type: "text", converted_value_data_type: "text", converted_value: "Hi there!", @@ -783,9 +783,9 @@ describe("messageMapper", () => { const msg: BackendMessage = { turn_number: 1, role: "assistant", - pieces: [ + message_pieces: [ { - piece_id: "r1", + id: "r1", original_value_data_type: "reasoning", converted_value_data_type: "reasoning", converted_value: JSON.stringify({ @@ -796,7 +796,7 @@ describe("messageMapper", () => { response_error: "none", }, { - piece_id: "p1", + id: "p1", original_value_data_type: "text", converted_value_data_type: "text", converted_value: "Here is the answer.", @@ -817,9 +817,9 @@ describe("messageMapper", () => { const msg: BackendMessage = { turn_number: 1, role: "assistant", - pieces: [ + message_pieces: [ { - piece_id: "r1", + id: "r1", original_value_data_type: "reasoning", converted_value_data_type: "reasoning", converted_value: JSON.stringify({ @@ -833,7 +833,7 @@ describe("messageMapper", () => { response_error: "none", }, { - piece_id: "p1", + id: "p1", original_value_data_type: "text", converted_value_data_type: "text", converted_value: "Answer.", @@ -853,9 +853,9 @@ describe("messageMapper", () => { const msg: BackendMessage = { turn_number: 1, role: "assistant", - pieces: [ + message_pieces: [ { - piece_id: "r1", + id: "r1", original_value_data_type: "reasoning", converted_value_data_type: "reasoning", converted_value: JSON.stringify({ @@ -866,7 +866,7 @@ describe("messageMapper", () => { response_error: "none", }, { - piece_id: "p1", + id: "p1", original_value_data_type: "text", converted_value_data_type: "text", converted_value: "Just text.", @@ -889,9 +889,9 @@ describe("messageMapper", () => { const msg: BackendMessage = { turn_number: 1, role: "assistant", - pieces: [ + message_pieces: [ { - piece_id: "r1", + id: "r1", original_value_data_type: "reasoning", converted_value_data_type: "reasoning", converted_value: JSON.stringify({ @@ -916,9 +916,9 @@ describe("messageMapper", () => { const msg: BackendMessage = { turn_number: 1, role: "assistant", - pieces: [ + message_pieces: [ { - piece_id: "r1", + id: "r1", original_value_data_type: "reasoning", converted_value_data_type: "reasoning", converted_value: "plain text reasoning", @@ -941,9 +941,9 @@ describe("messageMapper", () => { const msg: BackendMessage = { turn_number: 1, role: "user", - pieces: [ + message_pieces: [ { - piece_id: "p1", + id: "p1", original_value_data_type: "text", converted_value_data_type: "text", original_value: "Tell me a joke", @@ -964,9 +964,9 @@ describe("messageMapper", () => { const msg: BackendMessage = { turn_number: 1, role: "user", - pieces: [ + message_pieces: [ { - piece_id: "p1", + id: "p1", original_value_data_type: "text", converted_value_data_type: "text", original_value: "Hello", @@ -987,9 +987,9 @@ describe("messageMapper", () => { const msg: BackendMessage = { turn_number: 1, role: "user", - pieces: [ + message_pieces: [ { - piece_id: "p1", + id: "p1", original_value_data_type: "text", converted_value_data_type: "text", original_value: "Hello", @@ -1009,9 +1009,9 @@ describe("messageMapper", () => { const msg: BackendMessage = { turn_number: 1, role: "user", - pieces: [ + message_pieces: [ { - piece_id: "p1", + id: "p1", original_value_data_type: "text", converted_value_data_type: "text", original_value: "Hello World", @@ -1032,9 +1032,9 @@ describe("messageMapper", () => { const msg: BackendMessage = { turn_number: 1, role: "assistant", - pieces: [ + message_pieces: [ { - piece_id: "p1", + id: "p1", original_value_data_type: "text", converted_value_data_type: "text", original_value: null, @@ -1055,9 +1055,9 @@ describe("messageMapper", () => { const msg: BackendMessage = { turn_number: 1, role: "assistant", - pieces: [ + message_pieces: [ { - piece_id: "p1", + id: "p1", original_value_data_type: "text", converted_value_data_type: "text", converted_value: "Some response", @@ -1077,9 +1077,9 @@ describe("messageMapper", () => { const msg: BackendMessage = { turn_number: 1, role: "assistant", - pieces: [ + message_pieces: [ { - piece_id: "p1", + id: "p1", original_value_data_type: "image_path", converted_value_data_type: "image_path", original_value: "originalImageData", @@ -1102,9 +1102,9 @@ describe("messageMapper", () => { const msg: BackendMessage = { turn_number: 1, role: "assistant", - pieces: [ + message_pieces: [ { - piece_id: "p1", + id: "p1", original_value_data_type: "image_path", converted_value_data_type: "image_path", original_value: "sameData", @@ -1129,9 +1129,9 @@ describe("messageMapper", () => { const msg: BackendMessage = { turn_number: 1, role: "assistant", - pieces: [ + message_pieces: [ { - piece_id: "p-audio", + id: "p-audio", original_value_data_type: "text", converted_value_data_type: "audio_path", converted_value: "audioBase64Data", @@ -1151,9 +1151,9 @@ describe("messageMapper", () => { const msg: BackendMessage = { turn_number: 1, role: "assistant", - pieces: [ + message_pieces: [ { - piece_id: "p-video", + id: "p-video", original_value_data_type: "text", converted_value_data_type: "video_path", converted_value: "videoBase64Data", @@ -1175,9 +1175,9 @@ describe("messageMapper", () => { const msg: BackendMessage = { turn_number: 1, role: "assistant", - pieces: [ + message_pieces: [ { - piece_id: "r1", + id: "r1", original_value_data_type: "reasoning", converted_value_data_type: "reasoning", converted_value: JSON.stringify({ @@ -1200,9 +1200,9 @@ describe("messageMapper", () => { const msg: BackendMessage = { turn_number: 1, role: "assistant", - pieces: [ + message_pieces: [ { - piece_id: "r1", + id: "r1", original_value_data_type: "reasoning", converted_value_data_type: "reasoning", converted_value: "", diff --git a/frontend/src/utils/messageMapper.ts b/frontend/src/utils/messageMapper.ts index 703aca0b4a..804d0525b1 100644 --- a/frontend/src/utils/messageMapper.ts +++ b/frontend/src/utils/messageMapper.ts @@ -142,7 +142,7 @@ function pieceToAttachment( const url = isBase64 ? buildDataUri(value, mime) : value const prefix = isOriginal ? 'original_' : '' const filename = isOriginal ? piece.original_filename : piece.converted_filename - const fallbackName = `${prefix}${dataType}_${piece.piece_id.slice(0, 8)}` + const fallbackName = `${prefix}${dataType}_${piece.id.slice(0, 8)}` // For base64-inlined media, derive the decoded byte count. For path / URL // values the string length is meaningless (e.g. /api/media?path=... is a @@ -155,7 +155,7 @@ function pieceToAttachment( url, mimeType: mime, size, - pieceId: piece.piece_id, + pieceId: piece.id, metadata: piece.prompt_metadata || undefined, } } @@ -184,7 +184,7 @@ export function backendMessageToFrontend(msg: BackendMessage): Message { const reasoningSummaries: string[] = [] let error: MessageError | undefined - for (const piece of msg.pieces) { + for (const piece of msg.message_pieces) { // Check for errors const pieceError = pieceToError(piece) if (pieceError && !error) { diff --git a/pyrit/backend/mappers/__init__.py b/pyrit/backend/mappers/__init__.py index 310b04e916..8b1892e4ad 100644 --- a/pyrit/backend/mappers/__init__.py +++ b/pyrit/backend/mappers/__init__.py @@ -12,7 +12,6 @@ from pyrit.backend.mappers.attack_mappers import ( attack_result_to_summary, pyrit_messages_to_dto_async, - pyrit_scores_to_dto, request_piece_to_pyrit_message_piece, request_to_pyrit_message, ) @@ -28,7 +27,6 @@ "converter_object_to_instance", "format_last_message_preview", "pyrit_messages_to_dto_async", - "pyrit_scores_to_dto", "request_piece_to_pyrit_message_piece", "request_to_pyrit_message", "target_object_to_instance", diff --git a/pyrit/backend/mappers/attack_mappers.py b/pyrit/backend/mappers/attack_mappers.py index 5807c27bef..dd1a4d9f87 100644 --- a/pyrit/backend/mappers/attack_mappers.py +++ b/pyrit/backend/mappers/attack_mappers.py @@ -12,7 +12,6 @@ from __future__ import annotations import logging -import mimetypes import time import uuid from datetime import datetime, timedelta, timezone @@ -28,24 +27,20 @@ from pyrit.backend.models.attacks import ( AddMessageRequest, AttackSummary, - Message, - MessagePiece, MessagePieceRequest, - RetryEventResponse, - Score, - TargetInfo, + MessagePieceView, + MessageView, + ScoreView, ) from pyrit.common.deprecation import print_deprecation_message from pyrit.models import MEDIA_PATH_DATA_TYPES, AttackResult, ChatMessageRole, PromptDataType from pyrit.models import Message as PyritMessage from pyrit.models import MessagePiece as PyritMessagePiece -from pyrit.models import Score as PyritScore logger = logging.getLogger(__name__) if TYPE_CHECKING: from pyrit.models.conversation_stats import ConversationStats - from pyrit.models.retry_event import RetryEvent # ============================================================================ # Domain → DTO (for API responses) @@ -181,273 +176,131 @@ def _resolve_media_url(*, value: Optional[str], data_type: str) -> Optional[str] return value -def retry_events_to_response(retry_events: list[RetryEvent] | None) -> list[RetryEventResponse] | None: - """ - Convert a list of RetryEvent domain objects to RetryEventResponse DTOs. - - Args: - retry_events: Domain retry events, or None. - - Returns: - List of RetryEventResponse DTOs, or None if the input is None or empty. - """ - if not retry_events: - return None - return [ - RetryEventResponse( - timestamp=evt.timestamp, - attempt_number=evt.attempt_number, - function_name=evt.function_name, - exception_type=evt.exception_type, - exception_message=evt.exception_message, - component_role=evt.component_role, - component_name=evt.component_name, - endpoint=evt.endpoint, - elapsed_seconds=evt.elapsed_seconds, - ) - for evt in retry_events - ] - - def attack_result_to_summary( ar: AttackResult, *, stats: ConversationStats, ) -> AttackSummary: """ - Build an AttackSummary DTO from an AttackResult. + Build an AttackSummary view from an AttackResult. + + Conversation-level stats (message count, preview, labels, timestamps) are + injected here; every other field is inherited from the AttackResult. The + summary's ``last_response`` media is resolved to a ``/api/media`` URL but not + SAS-signed — Azure Blob signing only happens on the async ``/messages`` path. Args: ar: The domain AttackResult. stats: Pre-aggregated conversation stats (from ``get_conversation_stats``). Returns: - AttackSummary DTO ready for the API response. + AttackSummary view ready for the API response. """ - message_count = stats.message_count - last_preview = format_last_message_preview( - value=stats.last_message_preview, - data_type=stats.last_message_data_type, - ) - - # Merge attack-result labels with conversation-level labels. - # Conversation labels take precedence on key collision. + # Merge attack-result labels with conversation-level labels; conversation + # labels take precedence on key collision. labels = dict(ar.labels) if ar.labels else {} labels.update(stats.labels or {}) - # Resolution order for created_at: explicit metadata override, then the - # persisted AttackResult.timestamp, and finally datetime.now() as a - # last-resort fallback for never-persisted results. - created_str = ar.metadata.get("created_at") - updated_str = ar.metadata.get("updated_at") - if created_str: - created_at = datetime.fromisoformat(created_str) - elif ar.timestamp is not None: - created_at = ar.timestamp - else: - created_at = datetime.now(timezone.utc) - updated_at = datetime.fromisoformat(updated_str) if updated_str else created_at - - aid = ar.get_attack_strategy_identifier() - - # Extract only frontend-relevant fields from ComponentIdentifier - target_id = aid.get_child("objective_target") if aid else None - converter_ids = aid.get_child_list("request_converters") if aid else [] - - target_info = ( - TargetInfo( - target_type=target_id.class_name, - endpoint=target_id.params.get("endpoint") or None, - model_name=target_id.params.get("model_name") or None, - ) - if target_id - else None - ) - - # Build retry event responses if available - retry_event_responses = retry_events_to_response(ar.retry_events) - - return AttackSummary( - attack_result_id=ar.attack_result_id, - conversation_id=ar.conversation_id, - attack_type=aid.class_name if aid else "Unknown", - attack_specific_params=(aid.params or None) if aid else None, - target=target_info, - converters=[c.class_name for c in converter_ids] if converter_ids else [], - outcome=ar.outcome.value, - last_message_preview=last_preview, - message_count=message_count, - related_conversation_ids=[ref.conversation_id for ref in ar.related_conversations], + created_at, updated_at = _resolve_summary_timestamps(ar) + return AttackSummary.from_domain( + ar, + last_response=_summary_last_response(ar.last_response), + last_score=ScoreView.from_domain(ar.last_score) if ar.last_score else None, + message_count=stats.message_count, + last_message_preview=format_last_message_preview( + value=stats.last_message_preview, + data_type=stats.last_message_data_type, + ), labels=labels, created_at=created_at, updated_at=updated_at, - error_message=ar.error_message, - error_type=ar.error_type, - error_traceback=ar.error_traceback, - total_retries=ar.total_retries, - retry_events=retry_event_responses, ) -def pyrit_scores_to_dto(scores: list[PyritScore]) -> list[Score]: +def _resolve_summary_timestamps(ar: AttackResult) -> tuple[datetime, datetime]: """ - Translate PyRIT score objects to backend Score DTOs. + Resolve ``created_at`` / ``updated_at`` for a summary. + + Resolution order for ``created_at``: explicit metadata override, then the + persisted ``AttackResult.timestamp``, and finally ``datetime.now`` as a + last-resort fallback for never-persisted results. Returns: - List of Score DTOs for the API. + A ``(created_at, updated_at)`` tuple. """ - return [ - Score( - score_id=str(score.id), - scorer_type=( - score.scorer_class_identifier.class_name or "Unknown" if score.scorer_class_identifier else "Unknown" - ), - score_type=score.score_type, - score_value=score.score_value, - score_category=score.score_category, - score_rationale=score.score_rationale, - scored_at=score.timestamp, - ) - for score in scores - ] + created_str = ar.metadata.get("created_at") + updated_str = ar.metadata.get("updated_at") + if created_str: + created_at = datetime.fromisoformat(created_str) + elif ar.timestamp is not None: + created_at = ar.timestamp + else: + created_at = datetime.now(timezone.utc) + updated_at = datetime.fromisoformat(updated_str) if updated_str else created_at + return created_at, updated_at -def _infer_mime_type(*, value: Optional[str], data_type: PromptDataType) -> Optional[str]: +def _summary_last_response(piece: Optional[PyritMessagePiece]) -> Optional[MessagePieceView]: """ - Infer MIME type from a value and its data type. - - For non-text data types, attempts to guess the MIME type from the value - treated as a file path (using the file extension). Returns ``None`` for - text content or when the type cannot be determined. - - Args: - value: The value (typically a file path for media content). - data_type: The prompt data type (e.g., 'text', 'image', 'audio'). + Build a ``MessagePieceView`` for a summary's last response (sync media resolution, no SAS). Returns: - MIME type string (e.g., 'image/png') or None. + A ``MessagePieceView`` for the piece, or ``None`` when no piece is given. """ - if not value or data_type == "text": + if piece is None: return None - mime_type, _ = mimetypes.guess_type(value) - return mime_type + return MessagePieceView.from_domain( + piece, + original_value=_resolve_media_url( + value=piece.original_value, data_type=piece.original_value_data_type or "text" + ), + converted_value=_resolve_media_url( + value=piece.converted_value or "", data_type=piece.converted_value_data_type or "text" + ) + or "", + ) -def _build_filename( - *, - data_type: str, - sha256: Optional[str], - value: Optional[str], -) -> Optional[str]: +async def _resolve_and_sign_media_async(*, value: Optional[str], data_type: str) -> Optional[str]: """ - Build a human-readable download filename from the data type and hash. - - Produces names like ``image_a1b2c3d4e5f6.png`` or ``audio_e5f6g7h8i9j0.wav``. - The hash is truncated to 12 characters for readability. - - Falls back to the file extension from *value* (path or URL) when the - MIME type cannot be determined from the data type alone. - - Returns ``None`` for text-like types that don't need a download filename. - - Args: - data_type: The prompt data type (e.g. ``image_path``, ``audio_path``). - sha256: The SHA256 hash of the content, if available. - value: The original value (path or URL) used to infer file extension. + Resolve a media value to a fetchable URL, signing Azure Blob URLs when present. Returns: - Optional[str]: A filename like ``image_a1b2c3d4e5f6.png``, or ``None`` for text-like types. + The resolved (and signed, if a blob) URL, or ``None`` when *value* is empty. """ - # Map data types to friendly prefixes - prefix_map = { - "image_path": "image", - "audio_path": "audio", - "video_path": "video", - "binary_path": "file", - } - prefix = prefix_map.get(data_type) - if not prefix: - return None - - short_hash = sha256[:12] if sha256 else uuid.uuid4().hex[:12] + resolved = _resolve_media_url(value=value, data_type=data_type) + if resolved and _is_azure_blob_url(resolved): + return await _sign_blob_url_async(blob_url=resolved) + return resolved - # Derive extension from the value (file path or URL) - ext = "" - if value and not value.startswith("data:"): - source = value - if source.startswith("http"): - source = urlparse(source).path - ext = Path(source).suffix # e.g. ".png" - if not ext: - # Fallback: guess from mime type based on data type prefix - default_ext = {"image": ".png", "audio": ".wav", "video": ".mp4", "file": ".bin"} - ext = default_ext.get(prefix, ".bin") - - return f"{prefix}_{short_hash}{ext}" - - -async def pyrit_messages_to_dto_async(pyrit_messages: list[PyritMessage]) -> list[Message]: +async def pyrit_messages_to_dto_async(pyrit_messages: list[PyritMessage]) -> list[MessageView]: """ - Translate PyRIT messages to backend Message DTOs. + Translate PyRIT messages to backend MessageView responses. Media file paths are converted to URLs the frontend can fetch directly: - - Local files → ``/api/media?path=...`` (served by the media endpoint) - - Azure Blob Storage files → signed URLs with SAS tokens + - Local files -> ``/api/media?path=...`` (served by the media endpoint) + - Azure Blob Storage files -> signed URLs with SAS tokens Returns: - List of Message DTOs for the API. + List of MessageView responses for the API. """ - messages = [] + messages: list[MessageView] = [] for msg in pyrit_messages: - pieces = [] + pieces: list[MessagePieceView] = [] for p in msg.message_pieces: - orig_dtype = p.original_value_data_type or "text" - conv_dtype = p.converted_value_data_type or "text" - - orig_val = _resolve_media_url(value=p.original_value, data_type=orig_dtype) - conv_val = _resolve_media_url(value=p.converted_value or "", data_type=conv_dtype) or "" - - # Sign Azure Blob Storage URLs so the frontend can fetch them directly - if orig_val and _is_azure_blob_url(orig_val): - orig_val = await _sign_blob_url_async(blob_url=orig_val) - if conv_val and _is_azure_blob_url(conv_val): - conv_val = await _sign_blob_url_async(blob_url=conv_val) - - pieces.append( - MessagePiece( - piece_id=str(p.id), - original_value_data_type=orig_dtype, - converted_value_data_type=conv_dtype, - original_value=orig_val, - original_value_mime_type=_infer_mime_type(value=p.original_value, data_type=orig_dtype), - converted_value=conv_val, - converted_value_mime_type=_infer_mime_type(value=p.converted_value, data_type=conv_dtype), - prompt_metadata=dict(p.prompt_metadata) if p.prompt_metadata else None, - scores=pyrit_scores_to_dto(p.scores) if p.scores else [], - response_error=p.response_error or "none", - original_filename=_build_filename( - data_type=orig_dtype, - sha256=p.original_value_sha256, - value=p.original_value, - ), - converted_filename=_build_filename( - data_type=conv_dtype, - sha256=p.converted_value_sha256, - value=p.converted_value, - ), + original_value = await _resolve_and_sign_media_async( + value=p.original_value, data_type=p.original_value_data_type or "text" + ) + converted_value = ( + await _resolve_and_sign_media_async( + value=p.converted_value or "", data_type=p.converted_value_data_type or "text" ) + or "" ) - - first = msg.message_pieces[0] if msg.message_pieces else None - messages.append( - Message( - turn_number=first.sequence if first else 0, - role=first.role if first else "user", - pieces=pieces, - created_at=first.timestamp if first else datetime.now(timezone.utc), + pieces.append( + MessagePieceView.from_domain(p, original_value=original_value, converted_value=converted_value) ) - ) - + messages.append(MessageView.from_domain(pieces=pieces)) return messages diff --git a/pyrit/backend/models/__init__.py b/pyrit/backend/models/__init__.py index 388076fcd5..3ab8571505 100644 --- a/pyrit/backend/models/__init__.py +++ b/pyrit/backend/models/__init__.py @@ -21,11 +21,11 @@ CreateAttackResponse, CreateConversationRequest, CreateConversationResponse, - Message, - MessagePiece, MessagePieceRequest, + MessagePieceView, + MessageView, PrependedMessageRequest, - Score, + ScoreView, TargetInfo, UpdateAttackRequest, UpdateMainConversationRequest, @@ -82,11 +82,11 @@ "CreateAttackResponse", "CreateConversationRequest", "CreateConversationResponse", - "Message", - "MessagePiece", "MessagePieceRequest", + "MessagePieceView", + "MessageView", "PrependedMessageRequest", - "Score", + "ScoreView", "TargetInfo", "UpdateAttackRequest", # Common diff --git a/pyrit/backend/models/_media.py b/pyrit/backend/models/_media.py new file mode 100644 index 0000000000..0dd5a5eeb6 --- /dev/null +++ b/pyrit/backend/models/_media.py @@ -0,0 +1,84 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Pure media-presentation helpers shared by the attack response models. + +These derive download filenames and MIME types from a message piece's stored +value. They live here (rather than in the mapper) so the response models can +import them without pulling in the mapper's Azure / I/O dependencies, avoiding a +``models`` ↔ ``mappers`` import cycle. +""" + +from __future__ import annotations + +import mimetypes +import uuid +from pathlib import Path +from typing import TYPE_CHECKING, Optional +from urllib.parse import urlparse + +if TYPE_CHECKING: + from pyrit.models import PromptDataType + +# Friendly download-filename prefixes per media data type. +_FILENAME_PREFIXES = { + "image_path": "image", + "audio_path": "audio", + "video_path": "video", + "binary_path": "file", +} + +# Fallback extension per prefix when the value carries no usable suffix. +_DEFAULT_EXTENSIONS = {"image": ".png", "audio": ".wav", "video": ".mp4", "file": ".bin"} + + +def infer_mime_type(*, value: Optional[str], data_type: PromptDataType) -> Optional[str]: + """ + Infer a MIME type from a value and its data type. + + Args: + value: The value (typically a file path for media content). + data_type: The prompt data type (e.g., ``text``, ``image_path``). + + Returns: + A MIME type string (e.g., ``image/png``), or ``None`` for text content or + when the type cannot be determined. + """ + if not value or data_type == "text": + return None + mime_type, _ = mimetypes.guess_type(value) + return mime_type + + +def build_filename(*, data_type: str, sha256: Optional[str], value: Optional[str]) -> Optional[str]: + """ + Build a human-readable download filename from the data type and hash. + + Produces names like ``image_a1b2c3d4e5f6.png``. The hash is truncated to 12 + characters for readability and falls back to the file extension from *value* + when one is available. + + Args: + data_type: The prompt data type (e.g. ``image_path``). + sha256: The SHA256 hash of the content, if available. + value: The original value (path or URL) used to infer the file extension. + + Returns: + A filename like ``image_a1b2c3d4e5f6.png``, or ``None`` for text-like types. + """ + prefix = _FILENAME_PREFIXES.get(data_type) + if not prefix: + return None + + short_hash = sha256[:12] if sha256 else uuid.uuid4().hex[:12] + + ext = "" + if value and not value.startswith("data:"): + source = urlparse(value).path if value.startswith("http") else value + ext = Path(source).suffix + + if not ext: + ext = _DEFAULT_EXTENSIONS.get(prefix, ".bin") + + return f"{prefix}_{short_hash}{ext}" diff --git a/pyrit/backend/models/attacks.py b/pyrit/backend/models/attacks.py index 2f98f78b7e..6dea5b57e4 100644 --- a/pyrit/backend/models/attacks.py +++ b/pyrit/backend/models/attacks.py @@ -8,143 +8,294 @@ This is the attack-centric API design where every user interaction targets a model. """ -from datetime import datetime +from datetime import datetime, timezone from typing import Any, Literal, Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, computed_field, field_serializer +from pyrit.backend.models._media import build_filename, infer_mime_type from pyrit.backend.models.common import PaginationInfo -from pyrit.models import ChatMessageRole, PromptResponseError +from pyrit.models import ( + AttackResult, + ChatMessageRole, + ConversationReference, + Message, + MessagePiece, + Score, +) -class Score(BaseModel): - """A score associated with a message piece.""" +class TargetInfo(BaseModel): + """Target information extracted from the stored attack-strategy identifier.""" - score_id: str = Field(..., description="Unique score identifier") - scorer_type: str = Field(..., description="Type of scorer (e.g., 'bias', 'toxicity')") - score_type: str = Field(..., description="Score type: 'true_false', 'float_scale', or 'unknown'") - score_value: str = Field( - ..., description="Score value ('true'/'false' for true_false, '0.0'-'1.0' for float_scale)" - ) - score_category: Optional[list[str]] = Field(None, description="Harm categories (e.g., ['hate', 'violence'])") - score_rationale: Optional[str] = Field(None, description="Explanation for the score") - scored_at: datetime = Field(..., description="When the score was generated") + target_type: str = Field(..., description="Target class name (e.g., 'OpenAIChatTarget')") + endpoint: Optional[str] = Field(None, description="Target endpoint URL") + model_name: Optional[str] = Field(None, description="Model or deployment name") -class MessagePiece(BaseModel): +class ScoreView(Score): """ - A piece of a message (text, image, audio, etc.). + API view of a ``pyrit.models.Score``. - Supports multimodal content with original/converted values and embedded scores. - Media content is base64-encoded since frontend can't access server file paths. + Exposes every canonical score field and adds a flattened ``scorer_type`` so + clients don't have to dig into ``scorer_class_identifier``. """ - piece_id: str = Field(..., description="Unique piece identifier") - original_value_data_type: str = Field( - default="text", description="Data type of the original value: 'text', 'image', 'audio', etc." - ) - converted_value_data_type: str = Field( - default="text", description="Data type of the converted value: 'text', 'image', 'audio', etc." - ) - original_value: Optional[str] = Field(default=None, description="Original value before conversion") - original_value_mime_type: Optional[str] = Field(default=None, description="MIME type of original value") - converted_value: str = Field(..., description="Converted value (text or base64 for media)") - converted_value_mime_type: Optional[str] = Field(default=None, description="MIME type of converted value") - scores: list[Score] = Field(default_factory=list, description="Scores embedded in this piece") - response_error: PromptResponseError = Field( - default="none", description="Error status: none, processing, blocked, empty, unknown" + @computed_field # type: ignore[prop-decorator] + @property + def scorer_type(self) -> str: + """Return the scorer class name, or ``"Unknown"`` when unavailable.""" + identifier = self.scorer_class_identifier + if identifier and identifier.class_name: + return identifier.class_name + return "Unknown" + + @computed_field(deprecated="Use 'id' instead; 'score_id' is removed in 0.17.0.") # type: ignore[prop-decorator] + @property + def score_id(self) -> str: + """Deprecated alias for ``id``.""" + return str(self.id) + + @computed_field( # type: ignore[prop-decorator] + deprecated="Use 'timestamp' instead; 'scored_at' is removed in 0.17.0." ) + @property + def scored_at(self) -> Optional[datetime]: + """Deprecated alias for ``timestamp``.""" + return self.timestamp + + @classmethod + def from_domain(cls, score: Score) -> "ScoreView": + """ + Build a ``ScoreView`` from a domain ``Score`` without re-validating. + + Uses ``model_construct`` to bypass the domain validators (the score is + already valid) and copies fields by reference to preserve UUIDs, + datetimes, and identifier objects. + + Returns: + A ``ScoreView`` mirroring the domain score's fields. + """ + return cls.model_construct(**{name: getattr(score, name) for name in Score.model_fields}) + + +class MessagePieceView(MessagePiece): + """ + API view of a ``pyrit.models.MessagePiece``. + + ``original_value`` / ``converted_value`` carry frontend-fetchable URLs for + media pieces (the raw on-disk path is never exposed); text pieces keep their + literal values. MIME types and download filenames are derived from the raw + values at map time. + """ + + scores: list[ScoreView] = Field(default_factory=list) + original_value_mime_type: Optional[str] = Field(default=None, description="MIME type of the original value") + converted_value_mime_type: Optional[str] = Field(default=None, description="MIME type of the converted value") + original_filename: Optional[str] = Field(default=None, description="Download filename for the original value") + converted_filename: Optional[str] = Field(default=None, description="Download filename for the converted value") response_error_description: Optional[str] = Field( default=None, description="Description of the error if response_error is not 'none'" ) - original_filename: Optional[str] = Field( - default=None, description="Original filename extracted from file path or blob URL" - ) - converted_filename: Optional[str] = Field( - default=None, description="Converted filename extracted from file path or blob URL" - ) - prompt_metadata: Optional[dict[str, Any]] = Field( - default=None, description="Metadata associated with the piece (e.g., video_id for remix mode)" - ) + @computed_field(deprecated="Use 'id' instead; 'piece_id' is removed in 0.17.0.") # type: ignore[prop-decorator] + @property + def piece_id(self) -> str: + """Deprecated alias for ``id``.""" + return str(self.id) + + @classmethod + def from_domain( + cls, + piece: MessagePiece, + *, + original_value: Optional[str], + converted_value: str, + ) -> "MessagePieceView": + """ + Build a ``MessagePieceView`` from a domain piece without re-validating. + + Args: + piece: The domain message piece. + original_value: Resolved/fetchable original value (a URL for media). + converted_value: Resolved/fetchable converted value (a URL for media). + + Returns: + A ``MessagePieceView`` with derived MIME types, filenames, and views. + """ + data = {name: getattr(piece, name) for name in MessagePiece.model_fields} + orig_dtype = piece.original_value_data_type or "text" + conv_dtype = piece.converted_value_data_type or "text" + data.update( + original_value=original_value, + converted_value=converted_value, + scores=[ScoreView.from_domain(score) for score in piece.scores], + original_value_mime_type=infer_mime_type(value=piece.original_value, data_type=orig_dtype), + converted_value_mime_type=infer_mime_type(value=piece.converted_value, data_type=conv_dtype), + original_filename=build_filename( + data_type=orig_dtype, sha256=piece.original_value_sha256, value=piece.original_value + ), + converted_filename=build_filename( + data_type=conv_dtype, sha256=piece.converted_value_sha256, value=piece.converted_value + ), + ) + return cls.model_construct(**data) + + +class MessageView(Message): + """ + API view of a ``pyrit.models.Message``. -class Message(BaseModel): - """A message within a conversation.""" + Adds turn-level metadata (``turn_number``, ``role``, ``created_at``) derived + from the first piece, and narrows ``message_pieces`` to ``MessagePieceView``. + """ - turn_number: int = Field(..., description="Turn number in the conversation (1-indexed)") - role: ChatMessageRole = Field(..., description="Message role") - pieces: list[MessagePiece] = Field(..., description="Message pieces (multimodal support)") - created_at: datetime = Field(..., description="Message creation timestamp") + message_pieces: list[MessagePieceView] = Field(default_factory=list) + @computed_field # type: ignore[prop-decorator] + @property + def turn_number(self) -> int: + """Return the sequence of the first piece (the conversation turn).""" + return self.message_pieces[0].sequence if self.message_pieces else 0 -# ============================================================================ -# Attack Summary (List View) -# ============================================================================ + @computed_field # type: ignore[prop-decorator] + @property + def role(self) -> ChatMessageRole: + """Return the role of the first piece.""" + return self.message_pieces[0].role if self.message_pieces else "user" + @computed_field # type: ignore[prop-decorator] + @property + def created_at(self) -> datetime: + """Return the timestamp of the first piece.""" + return self.message_pieces[0].timestamp if self.message_pieces else datetime.now(timezone.utc) -class TargetInfo(BaseModel): - """Target information extracted from the stored TargetIdentifier.""" + @computed_field( # type: ignore[prop-decorator] + deprecated="Use 'message_pieces' instead; 'pieces' is removed in 0.17.0." + ) + @property + def pieces(self) -> list[MessagePieceView]: + """Deprecated alias for ``message_pieces``.""" + return self.message_pieces - target_type: str = Field(..., description="Target class name (e.g., 'OpenAIChatTarget')") - endpoint: Optional[str] = Field(None, description="Target endpoint URL") - model_name: Optional[str] = Field(None, description="Model or deployment name") + @classmethod + def from_domain(cls, *, pieces: list[MessagePieceView]) -> "MessageView": + """ + Build a ``MessageView`` from already-mapped piece views. + Returns: + A ``MessageView`` wrapping the provided piece views. + """ + return cls.model_construct(message_pieces=pieces) -class RetryEventResponse(BaseModel): - """A single retry attempt captured during execution.""" - timestamp: datetime = Field(..., description="When the retry occurred") - attempt_number: int = Field(..., ge=1, description="Tenacity attempt number (1-based)") - function_name: str = Field(..., description="The retried function name") - exception_type: str = Field("", description="Exception class name") - exception_message: str = Field("", description="Exception message") - component_role: str = Field("", description="Component role from ExecutionContext") - component_name: str | None = Field(None, description="Component class name") - endpoint: str | None = Field(None, description="Target endpoint URL") - elapsed_seconds: float = Field(0.0, ge=0, description="Time since first attempt in seconds") +class AttackSummary(AttackResult): + """ + API view of a ``pyrit.models.AttackResult``. + Inherits every canonical attack-result field (including ``last_response``, + ``last_score`` and ``retry_events``) and adds presentation data: computed + projections of the strategy identifier plus mapper-populated conversation + stats. ``last_response`` / ``last_score`` are narrowed to their view types so + their presentation fields serialize. + """ -class AttackSummary(BaseModel): - """Summary view of an attack (for list views, omits full message content).""" + last_response: Optional[MessagePieceView] = None + last_score: Optional[ScoreView] = None - attack_result_id: str = Field(..., description="Database-assigned unique ID for this AttackResult") - conversation_id: str = Field(..., description="Primary conversation of this attack result") - attack_type: str = Field("", description="Attack class name (e.g., 'CrescendoAttack', 'ManualAttack')") - attack_specific_params: Optional[dict[str, Any]] = Field(None, description="Additional attack-specific parameters") - target: Optional[TargetInfo] = Field(None, description="Target information from the stored identifier") - converters: list[str] = Field( - default_factory=list, description="Request converter class names applied in this attack" - ) - objective: str = Field("", description="Natural-language description of the attacker's objective") - outcome: Optional[Literal["undetermined", "success", "failure", "error"]] = Field( - None, description="Attack outcome (null if not yet determined)" + # Mapper-populated presentation fields (need external stats / metadata). + message_count: int = Field(default=0, description="Total number of messages in the attack") + last_message_preview: Optional[str] = Field(default=None, description="Preview of the last message") + created_at: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc), description="Attack creation timestamp" ) - outcome_reason: str | None = Field(None, description="Reason for the outcome") - last_response: str | None = Field(None, description="Model response from the final turn") - last_message_preview: Optional[str] = Field( - None, description="Preview of the last message (truncated to ~100 chars)" + updated_at: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc), description="Last update timestamp" ) - score_value: str | None = Field(None, description="Score value from the objective scorer") - executed_turns: int = Field(0, ge=0, description="Number of turns executed") - execution_time_ms: int = Field(0, ge=0, description="Execution time in milliseconds") - message_count: int = Field(0, description="Total number of messages in the attack") - related_conversation_ids: list[str] = Field( - default_factory=list, description="IDs of related conversations within this attack" - ) - labels: dict[str, str] = Field(default_factory=dict, description="User-defined labels for filtering") - created_at: datetime = Field(..., description="Attack creation timestamp") - updated_at: datetime = Field(..., description="Last update timestamp") - # Error information - error_message: str | None = Field(None, description="Error message if the attack failed with an exception") - error_type: str | None = Field(None, description="Exception class name (e.g., 'RateLimitError')") - error_traceback: str | None = Field(None, description="Formatted traceback string") - - # Retry information - total_retries: int = Field(0, ge=0, description="Total number of retries during this attack") - retry_events: list[RetryEventResponse] | None = Field( - None, description="Detailed retry events (omitted in list views unless requested)" - ) + @field_serializer("related_conversations") + def _serialize_related_conversations(self, conversations: set[ConversationReference]) -> list[Any]: + """ + Serialize related conversations in a stable (sorted) order for deterministic output. + + Returns: + A list of serialized conversation references ordered by ``conversation_id``. + """ + ordered = sorted(conversations, key=lambda ref: ref.conversation_id) + return [ref.model_dump() for ref in ordered] + + @computed_field # type: ignore[prop-decorator] + @property + def attack_type(self) -> str: + """Return the attack strategy class name, or ``"Unknown"``.""" + identifier = self.get_attack_strategy_identifier() + return identifier.class_name if identifier else "Unknown" + + @computed_field # type: ignore[prop-decorator] + @property + def attack_specific_params(self) -> Optional[dict[str, Any]]: + """Return the attack strategy params, or ``None``.""" + identifier = self.get_attack_strategy_identifier() + return (identifier.params or None) if identifier else None + + @computed_field # type: ignore[prop-decorator] + @property + def target(self) -> Optional[TargetInfo]: + """Return the objective target info extracted from the identifier.""" + identifier = self.get_attack_strategy_identifier() + target_id = identifier.get_child("objective_target") if identifier else None + if not target_id: + return None + return TargetInfo( + target_type=target_id.class_name, + endpoint=target_id.params.get("endpoint") or None, + model_name=target_id.params.get("model_name") or None, + ) + + @computed_field # type: ignore[prop-decorator] + @property + def converters(self) -> list[str]: + """Return the request-converter class names applied in this attack.""" + identifier = self.get_attack_strategy_identifier() + converter_ids = identifier.get_child_list("request_converters") if identifier else [] + return [c.class_name for c in converter_ids] + + @computed_field # type: ignore[prop-decorator] + @property + def related_conversation_ids(self) -> list[str]: + """Return the IDs of related conversations, sorted for stable output.""" + return sorted(ref.conversation_id for ref in self.related_conversations) + + @classmethod + def from_domain( + cls, + attack_result: AttackResult, + *, + last_response: Optional[MessagePieceView], + last_score: Optional[ScoreView], + message_count: int, + last_message_preview: Optional[str], + labels: dict[str, str], + created_at: datetime, + updated_at: datetime, + ) -> "AttackSummary": + """ + Build an ``AttackSummary`` from a domain ``AttackResult`` and mapper-derived stats. + + Returns: + An ``AttackSummary`` combining the attack result with presentation stats. + """ + data = {name: getattr(attack_result, name) for name in AttackResult.model_fields} + data.update( + last_response=last_response, + last_score=last_score, + labels=labels, + message_count=message_count, + last_message_preview=last_message_preview, + created_at=created_at, + updated_at=updated_at, + ) + return cls.model_construct(**data) # ============================================================================ @@ -156,7 +307,7 @@ class ConversationMessagesResponse(BaseModel): """Response containing all messages for a conversation.""" conversation_id: str = Field(..., description="Conversation identifier") - messages: list[Message] = Field(default_factory=list, description="All messages in order") + messages: list[MessageView] = Field(default_factory=list, description="All messages in order") # ============================================================================ diff --git a/tests/unit/backend/test_api_routes.py b/tests/unit/backend/test_api_routes.py index 59bf407382..a0ae81ca5b 100644 --- a/tests/unit/backend/test_api_routes.py +++ b/tests/unit/backend/test_api_routes.py @@ -22,8 +22,8 @@ AttackSummary, ConversationMessagesResponse, CreateAttackResponse, - Message, - MessagePiece, + MessagePieceView, + MessageView, ) from pyrit.backend.models.common import PaginationInfo from pyrit.backend.models.converters import ( @@ -40,6 +40,24 @@ TargetListResponse, ) from pyrit.backend.routes.labels import get_label_options +from pyrit.models import MessagePiece as PyritMessagePiece + + +def _make_message_view(*, role: str = "user", value: str = "hello", sequence: int = 1) -> MessageView: + """Build a ``MessageView`` from a single text piece for route tests.""" + piece = PyritMessagePiece( + role=role, + original_value=value, + converted_value=value, + original_value_data_type="text", + converted_value_data_type="text", + conversation_id="attack-1", + sequence=sequence, + ) + piece_view = MessagePieceView.from_domain( + piece, original_value=piece.original_value, converted_value=piece.converted_value + ) + return MessageView.from_domain(pieces=[piece_view]) @pytest.fixture @@ -210,8 +228,7 @@ def test_get_attack_success(self, client: TestClient) -> None: return_value=AttackSummary( attack_result_id="ar-attack-1", conversation_id="attack-1", - attack_type="TestAttack", - outcome=None, + objective="test objective", last_message_preview=None, message_count=0, created_at=now, @@ -247,7 +264,7 @@ def test_update_attack_success(self, client: TestClient) -> None: return_value=AttackSummary( attack_result_id="ar-attack-1", conversation_id="attack-1", - attack_type="TestAttack", + objective="test objective", outcome="success", last_message_preview=None, message_count=0, @@ -273,8 +290,7 @@ def test_add_message_success(self, client: TestClient) -> None: attack_summary = AttackSummary( attack_result_id="ar-attack-1", conversation_id="attack-1", - attack_type="TestAttack", - outcome=None, + objective="test objective", last_message_preview=None, message_count=2, created_at=now, @@ -284,28 +300,8 @@ def test_add_message_success(self, client: TestClient) -> None: attack_messages = ConversationMessagesResponse( conversation_id="attack-1", messages=[ - Message( - turn_number=1, - role="user", - pieces=[ - MessagePiece( - piece_id="piece-1", - converted_value="Hello", - ) - ], - created_at=now, - ), - Message( - turn_number=2, - role="assistant", - pieces=[ - MessagePiece( - piece_id="piece-2", - converted_value="Hi there!", - ) - ], - created_at=now, - ), + _make_message_view(role="user", value="Hello", sequence=1), + _make_message_view(role="assistant", value="Hi there!", sequence=2), ], ) @@ -400,20 +396,13 @@ def test_add_message_internal_error(self, client: TestClient) -> None: def test_get_conversation_messages_success(self, client: TestClient) -> None: """Test getting attack messages.""" - now = datetime.now(timezone.utc) - with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: mock_service = MagicMock() mock_service.get_conversation_messages_async = AsyncMock( return_value=ConversationMessagesResponse( conversation_id="attack-1", messages=[ - Message( - turn_number=1, - role="user", - pieces=[MessagePiece(piece_id="p1", converted_value="Hello")], - created_at=now, - ) + _make_message_view(role="user", value="Hello", sequence=1), ], ) ) @@ -462,8 +451,7 @@ def test_list_attacks_with_labels(self, client: TestClient) -> None: AttackSummary( attack_result_id="ar-attack-1", conversation_id="attack-1", - attack_type="TestAttack", - outcome=None, + objective="test objective", last_message_preview=None, message_count=0, labels={"env": "prod"}, diff --git a/tests/unit/backend/test_attack_service.py b/tests/unit/backend/test_attack_service.py index f02367f6f8..d4ea3b3dbe 100644 --- a/tests/unit/backend/test_attack_service.py +++ b/tests/unit/backend/test_attack_service.py @@ -26,7 +26,14 @@ AttackService, get_attack_service, ) -from pyrit.models import AttackOutcome, AttackResult, ComponentIdentifier, build_atomic_attack_identifier +from pyrit.models import ( + AttackOutcome, + AttackResult, + ComponentIdentifier, + Message, + MessagePiece, + build_atomic_attack_identifier, +) from pyrit.models.conversation_stats import ConversationStats @@ -1436,23 +1443,18 @@ async def test_get_attack_with_messages_translates_correctly(self, attack_servic ar = make_attack_result(conversation_id="test-id") mock_memory.get_attack_results.return_value = [ar] - # Create mock message with pieces - mock_piece = MagicMock() - mock_piece.id = "piece-1" - mock_piece.converted_value_data_type = "text" - mock_piece.original_value_data_type = "text" - mock_piece.original_value = "Hello" - mock_piece.converted_value = "Hello" - mock_piece.response_error = None - mock_piece.sequence = 0 - mock_piece.role = "user" - mock_piece.timestamp = datetime.now(timezone.utc) - mock_piece.scores = None - - mock_msg = MagicMock() - mock_msg.message_pieces = [mock_piece] + piece = MessagePiece( + role="user", + original_value="Hello", + converted_value="Hello", + original_value_data_type="text", + converted_value_data_type="text", + conversation_id="test-id", + sequence=0, + ) + msg = Message(message_pieces=[piece]) - mock_memory.get_conversation.return_value = [mock_msg] + mock_memory.get_conversation.return_value = [msg] result = await attack_service.get_conversation_messages_async( attack_result_id="test-id", conversation_id="test-id" @@ -1461,8 +1463,8 @@ async def test_get_attack_with_messages_translates_correctly(self, attack_servic assert result is not None assert len(result.messages) == 1 assert result.messages[0].role == "user" - assert len(result.messages[0].pieces) == 1 - assert result.messages[0].pieces[0].original_value == "Hello" + assert len(result.messages[0].message_pieces) == 1 + assert result.messages[0].message_pieces[0].original_value == "Hello" # ============================================================================ @@ -1965,8 +1967,7 @@ async def test_stores_message_in_target_conversation(self, attack_service, mock_ mock_summary = AttackSummary( attack_result_id="ar-attack-1", conversation_id="attack-1", - attack_type="ManualAttack", - converters=[], + objective="test objective", message_count=1, labels={}, created_at=now, @@ -2243,8 +2244,7 @@ async def test_add_message_merges_converter_identifiers_without_duplicates(self, return_value=AttackSummary( attack_result_id="ar-attack-1", conversation_id="attack-1", - attack_type="ManualAttack", - converters=[], + objective="test objective", message_count=0, labels={}, created_at=datetime.now(timezone.utc), @@ -2322,8 +2322,7 @@ async def test_converter_merge_with_flat_atomic_identifier(self, attack_service, return_value=AttackSummary( attack_result_id="ar-flat-1", conversation_id="flat-1", - attack_type="ManualAttack", - converters=[], + objective="test objective", message_count=0, labels={}, created_at=datetime.now(timezone.utc), diff --git a/tests/unit/backend/test_mappers.py b/tests/unit/backend/test_mappers.py index 2d6e4da37c..ff7d912c84 100644 --- a/tests/unit/backend/test_mappers.py +++ b/tests/unit/backend/test_mappers.py @@ -12,25 +12,28 @@ import tempfile import uuid from datetime import datetime, timezone +from typing import Optional from unittest.mock import AsyncMock, MagicMock, patch import pytest from pyrit.backend.mappers.attack_mappers import ( - _build_filename, - _infer_mime_type, _is_azure_blob_url, _resolve_media_url, _sign_blob_url_async, attack_result_to_summary, pyrit_messages_to_dto_async, - pyrit_scores_to_dto, request_piece_to_pyrit_message_piece, request_to_pyrit_message, ) from pyrit.backend.mappers.converter_mappers import converter_object_to_instance from pyrit.backend.mappers.target_mappers import target_object_to_instance +from pyrit.backend.models._media import build_filename, infer_mime_type +from pyrit.backend.models.attacks import ScoreView from pyrit.models import AttackOutcome, AttackResult, ComponentIdentifier +from pyrit.models import Message as PyritMessage +from pyrit.models import MessagePiece as PyritMessagePiece +from pyrit.models import Score as PyritScore from pyrit.models.conversation_stats import ConversationStats from pyrit.prompt_target import PromptTarget, TargetCapabilities @@ -83,42 +86,46 @@ def _make_attack_result( ) -def _make_mock_piece( +def _make_piece( *, sequence: int = 0, converted_value: str = "hello", original_value: str = "hello", -): - """Create a mock message piece for mapper tests.""" - p = MagicMock() - p.id = "piece-1" - p.sequence = sequence - p.converted_value = converted_value - p.original_value = original_value - p.converted_value_data_type = "text" - p.original_value_data_type = "text" - p.response_error = "none" - p.role = "user" - p.timestamp = datetime.now(timezone.utc) - p.scores = [] - return p - - -def _make_mock_score(): - """Create a mock score for mapper tests.""" - s = MagicMock() - s.id = "score-1" - s.scorer_class_identifier = ComponentIdentifier( - class_name="TrueFalseScorer", - class_module="pyrit.score", - params={"scorer_type": "true_false"}, + original_value_data_type: str = "text", + converted_value_data_type: str = "text", + role: str = "user", +) -> PyritMessagePiece: + """Create a real domain message piece for mapper tests.""" + return PyritMessagePiece( + role=role, + original_value=original_value, + converted_value=converted_value, + original_value_data_type=original_value_data_type, + converted_value_data_type=converted_value_data_type, + conversation_id="conv-1", + sequence=sequence, + ) + + +def _make_score( + *, + score_value: str = "1.0", + score_type: str = "float_scale", + score_category: Optional[list[str]] = None, + scorer_name: str = "TrueFalseScorer", +) -> PyritScore: + """Create a real domain score for mapper tests.""" + return PyritScore( + score_value=score_value, + score_type=score_type, + score_category=score_category, + score_rationale="Looks correct", + message_piece_id=str(uuid.uuid4()), + scorer_class_identifier=ComponentIdentifier( + class_name=scorer_name, + class_module="pyrit.score", + ), ) - s.score_value = "1.0" - s.score_type = "float_scale" - s.score_category = None - s.score_rationale = "Looks correct" - s.timestamp = datetime.now(timezone.utc) - return s # ============================================================================ @@ -388,7 +395,7 @@ def test_created_at_falls_back_to_now_when_both_absent(self) -> None: assert before <= summary.created_at <= after def test_retry_events_mapped_to_response(self) -> None: - """Test that retry events on an AttackResult are mapped to RetryEventResponse DTOs.""" + """Test that retry events on an AttackResult are inherited by the AttackSummary.""" from pyrit.models.retry_event import RetryEvent now = datetime.now(timezone.utc) @@ -421,36 +428,34 @@ def test_retry_events_mapped_to_response(self) -> None: assert evt.elapsed_seconds == 1.5 assert summary.total_retries == 1 - """Tests for pyrit_scores_to_dto function.""" + """Tests for retry-event passthrough on AttackSummary.""" def test_maps_scores(self) -> None: - """Test that scores are correctly translated.""" - mock_score = _make_mock_score() + """Test that a domain score is exposed as a ScoreView with a flattened scorer_type.""" + score = _make_score() - result = pyrit_scores_to_dto([mock_score]) + view = ScoreView.from_domain(score) - assert len(result) == 1 - assert result[0].score_id == "score-1" - assert result[0].scorer_type == "TrueFalseScorer" - assert result[0].score_value == "1.0" - assert result[0].score_type == "float_scale" - assert result[0].score_rationale == "Looks correct" + assert view.id == score.id + assert view.scorer_type == "TrueFalseScorer" + assert view.score_value == "1.0" + assert view.score_type == "float_scale" + assert view.score_rationale == "Looks correct" - def test_empty_scores(self) -> None: - """Test mapping empty scores list.""" - result = pyrit_scores_to_dto([]) - assert result == [] + def test_scorer_type_unknown_without_identifier(self) -> None: + """Test that scorer_type falls back to 'Unknown' when no identifier is set.""" + score = PyritScore(score_value="0.5", score_type="float_scale", message_piece_id=str(uuid.uuid4())) + + view = ScoreView.from_domain(score) + + assert view.scorer_type == "Unknown" def test_true_false_scores_are_included(self) -> None: - """Test that true_false score values are mapped correctly.""" - float_score = _make_mock_score() - bool_score = _make_mock_score() - bool_score.id = "score-bool" - bool_score.score_value = "false" - bool_score.score_type = "true_false" - bool_score.score_category = ["hate"] + """Test that true_false score values and categories are preserved.""" + float_score = _make_score() + bool_score = _make_score(score_value="false", score_type="true_false", score_category=["hate"]) - result = pyrit_scores_to_dto([float_score, bool_score]) + result = [ScoreView.from_domain(float_score), ScoreView.from_domain(bool_score)] assert len(result) == 2 assert result[0].score_value == "1.0" @@ -464,30 +469,31 @@ class TestPyritMessagesToDto: async def test_maps_single_message(self) -> None: """Test mapping a single message with one piece.""" - piece = _make_mock_piece(original_value="hi", converted_value="hi") - msg = MagicMock() - msg.message_pieces = [piece] + piece = _make_piece(original_value="hi", converted_value="hi") + msg = PyritMessage(message_pieces=[piece]) result = await pyrit_messages_to_dto_async([msg]) assert len(result) == 1 assert result[0].role == "user" - assert len(result[0].pieces) == 1 - assert result[0].pieces[0].original_value == "hi" - assert result[0].pieces[0].converted_value == "hi" + assert len(result[0].message_pieces) == 1 + assert result[0].message_pieces[0].original_value == "hi" + assert result[0].message_pieces[0].converted_value == "hi" async def test_maps_data_types_separately(self) -> None: """Test that original and converted data types are mapped independently.""" - piece = _make_mock_piece(original_value="describe this", converted_value="base64data") - piece.original_value_data_type = "text" - piece.converted_value_data_type = "image" - msg = MagicMock() - msg.message_pieces = [piece] + piece = _make_piece( + original_value="describe this", + converted_value="base64data", + original_value_data_type="text", + converted_value_data_type="image_path", + ) + msg = PyritMessage(message_pieces=[piece]) result = await pyrit_messages_to_dto_async([msg]) - assert result[0].pieces[0].original_value_data_type == "text" - assert result[0].pieces[0].converted_value_data_type == "image" + assert result[0].message_pieces[0].original_value_data_type == "text" + assert result[0].message_pieces[0].converted_value_data_type == "image_path" async def test_maps_empty_list(self) -> None: """Test mapping an empty messages list.""" @@ -496,41 +502,44 @@ async def test_maps_empty_list(self) -> None: async def test_populates_mime_type_for_image(self) -> None: """Test that MIME types are inferred for image pieces.""" - piece = _make_mock_piece(original_value="/path/to/photo.png", converted_value="/path/to/photo.jpg") - piece.original_value_data_type = "image" - piece.converted_value_data_type = "image" - msg = MagicMock() - msg.message_pieces = [piece] + piece = _make_piece( + original_value="/path/to/photo.png", + converted_value="/path/to/photo.jpg", + original_value_data_type="image_path", + converted_value_data_type="image_path", + ) + msg = PyritMessage(message_pieces=[piece]) result = await pyrit_messages_to_dto_async([msg]) - assert result[0].pieces[0].original_value_mime_type == "image/png" - assert result[0].pieces[0].converted_value_mime_type == "image/jpeg" + assert result[0].message_pieces[0].original_value_mime_type == "image/png" + assert result[0].message_pieces[0].converted_value_mime_type == "image/jpeg" async def test_mime_type_none_for_text(self) -> None: """Test that MIME type is None for text pieces.""" - piece = _make_mock_piece(original_value="hello", converted_value="hello") - msg = MagicMock() - msg.message_pieces = [piece] + piece = _make_piece(original_value="hello", converted_value="hello") + msg = PyritMessage(message_pieces=[piece]) result = await pyrit_messages_to_dto_async([msg]) - assert result[0].pieces[0].original_value_mime_type is None - assert result[0].pieces[0].converted_value_mime_type is None + assert result[0].message_pieces[0].original_value_mime_type is None + assert result[0].message_pieces[0].converted_value_mime_type is None async def test_mime_type_for_audio(self) -> None: """Test that MIME types are inferred for audio pieces.""" - piece = _make_mock_piece(original_value="/tmp/speech.wav", converted_value="/tmp/speech.mp3") - piece.original_value_data_type = "audio" - piece.converted_value_data_type = "audio" - msg = MagicMock() - msg.message_pieces = [piece] + piece = _make_piece( + original_value="/tmp/speech.wav", + converted_value="/tmp/speech.mp3", + original_value_data_type="audio_path", + converted_value_data_type="audio_path", + ) + msg = PyritMessage(message_pieces=[piece]) result = await pyrit_messages_to_dto_async([msg]) # Python 3.10 returns "audio/wav", 3.11+ returns "audio/x-wav" - assert result[0].pieces[0].original_value_mime_type in ("audio/wav", "audio/x-wav") - assert result[0].pieces[0].converted_value_mime_type == "audio/mpeg" + assert result[0].message_pieces[0].original_value_mime_type in ("audio/wav", "audio/x-wav") + assert result[0].message_pieces[0].converted_value_mime_type == "audio/mpeg" async def test_local_media_file_returns_media_url(self) -> None: """Test that local media files are converted to /api/media URLs.""" @@ -539,64 +548,63 @@ async def test_local_media_file_returns_media_url(self) -> None: tmp_path = tmp.name try: - piece = _make_mock_piece(original_value=tmp_path, converted_value=tmp_path) - piece.original_value_data_type = "image_path" - piece.converted_value_data_type = "image_path" - msg = MagicMock() - msg.message_pieces = [piece] + piece = _make_piece( + original_value=tmp_path, + converted_value=tmp_path, + original_value_data_type="image_path", + converted_value_data_type="image_path", + ) + msg = PyritMessage(message_pieces=[piece]) result = await pyrit_messages_to_dto_async([msg]) - assert result[0].pieces[0].original_value is not None - assert result[0].pieces[0].original_value.startswith("/api/media?path=") - assert result[0].pieces[0].converted_value.startswith("/api/media?path=") + assert result[0].message_pieces[0].original_value is not None + assert result[0].message_pieces[0].original_value.startswith("/api/media?path=") + assert result[0].message_pieces[0].converted_value.startswith("/api/media?path=") finally: os.unlink(tmp_path) async def test_data_uri_passthrough(self) -> None: """Test that pre-encoded data URIs are not re-encoded.""" - piece = _make_mock_piece( + piece = _make_piece( original_value="data:image/png;base64,AAAA", converted_value="data:image/jpeg;base64,BBBB", + original_value_data_type="image_path", + converted_value_data_type="image_path", ) - piece.original_value_data_type = "image_path" - piece.converted_value_data_type = "image_path" - msg = MagicMock() - msg.message_pieces = [piece] + msg = PyritMessage(message_pieces=[piece]) result = await pyrit_messages_to_dto_async([msg]) - assert result[0].pieces[0].original_value == "data:image/png;base64,AAAA" - assert result[0].pieces[0].converted_value == "data:image/jpeg;base64,BBBB" + assert result[0].message_pieces[0].original_value == "data:image/png;base64,AAAA" + assert result[0].message_pieces[0].converted_value == "data:image/jpeg;base64,BBBB" async def test_non_blob_http_url_passthrough(self) -> None: """Test that non-Azure-Blob HTTP URLs are passed through as-is.""" - piece = _make_mock_piece( + piece = _make_piece( original_value="http://example.com/image.png", converted_value="http://example.com/image.png", + original_value_data_type="image_path", + converted_value_data_type="image_path", ) - piece.original_value_data_type = "image_path" - piece.converted_value_data_type = "image_path" - msg = MagicMock() - msg.message_pieces = [piece] + msg = PyritMessage(message_pieces=[piece]) result = await pyrit_messages_to_dto_async([msg]) - assert result[0].pieces[0].original_value == "http://example.com/image.png" - assert result[0].pieces[0].converted_value == "http://example.com/image.png" + assert result[0].message_pieces[0].original_value == "http://example.com/image.png" + assert result[0].message_pieces[0].converted_value == "http://example.com/image.png" async def test_azure_blob_url_is_signed(self) -> None: """Test that Azure Blob Storage URLs are signed with SAS tokens.""" blob_url = "https://myaccount.blob.core.windows.net/dbdata/prompt-memory-entries/images/test.png" signed_url = blob_url + "?sig=abc123" - piece = _make_mock_piece( + piece = _make_piece( original_value=blob_url, converted_value=blob_url, + original_value_data_type="image_path", + converted_value_data_type="image_path", ) - piece.original_value_data_type = "image_path" - piece.converted_value_data_type = "image_path" - msg = MagicMock() - msg.message_pieces = [piece] + msg = PyritMessage(message_pieces=[piece]) with patch( "pyrit.backend.mappers.attack_mappers._sign_blob_url_async", @@ -605,20 +613,19 @@ async def test_azure_blob_url_is_signed(self) -> None: ): result = await pyrit_messages_to_dto_async([msg]) - assert result[0].pieces[0].original_value == signed_url - assert result[0].pieces[0].converted_value == signed_url + assert result[0].message_pieces[0].original_value == signed_url + assert result[0].message_pieces[0].converted_value == signed_url async def test_azure_blob_url_sign_failure_returns_raw_url(self) -> None: """Test that blob sign failure falls back to the raw blob URL.""" blob_url = "https://myaccount.blob.core.windows.net/dbdata/images/test.png" - piece = _make_mock_piece( + piece = _make_piece( original_value=blob_url, converted_value=blob_url, + original_value_data_type="image_path", + converted_value_data_type="image_path", ) - piece.original_value_data_type = "image_path" - piece.converted_value_data_type = "image_path" - msg = MagicMock() - msg.message_pieces = [piece] + msg = PyritMessage(message_pieces=[piece]) with patch( "pyrit.backend.mappers.attack_mappers._sign_blob_url_async", @@ -627,21 +634,23 @@ async def test_azure_blob_url_sign_failure_returns_raw_url(self) -> None: ): result = await pyrit_messages_to_dto_async([msg]) - assert result[0].pieces[0].original_value == blob_url - assert result[0].pieces[0].converted_value == blob_url + assert result[0].message_pieces[0].original_value == blob_url + assert result[0].message_pieces[0].converted_value == blob_url async def test_nonexistent_media_file_returns_raw_path(self) -> None: """Test that non-existent local media files fall back to raw path values.""" - piece = _make_mock_piece(original_value="/tmp/nonexistent.png", converted_value="/tmp/nonexistent.png") - piece.original_value_data_type = "image_path" - piece.converted_value_data_type = "image_path" - msg = MagicMock() - msg.message_pieces = [piece] + piece = _make_piece( + original_value="/tmp/nonexistent.png", + converted_value="/tmp/nonexistent.png", + original_value_data_type="image_path", + converted_value_data_type="image_path", + ) + msg = PyritMessage(message_pieces=[piece]) result = await pyrit_messages_to_dto_async([msg]) - assert result[0].pieces[0].original_value == "/tmp/nonexistent.png" - assert result[0].pieces[0].converted_value == "/tmp/nonexistent.png" + assert result[0].message_pieces[0].original_value == "/tmp/nonexistent.png" + assert result[0].message_pieces[0].converted_value == "/tmp/nonexistent.png" class TestIsAzureBlobUrl: @@ -1016,79 +1025,79 @@ def test_original_prompt_id_defaults_to_self_when_absent(self) -> None: class TestInferMimeType: - """Tests for _infer_mime_type helper function.""" + """Tests for infer_mime_type helper function.""" def test_returns_none_for_text(self) -> None: """Text data type should always return None.""" - assert _infer_mime_type(value="/path/to/file.png", data_type="text") is None + assert infer_mime_type(value="/path/to/file.png", data_type="text") is None def test_returns_none_for_empty_value(self) -> None: """Empty or None value should return None.""" - assert _infer_mime_type(value=None, data_type="image") is None - assert _infer_mime_type(value="", data_type="image") is None + assert infer_mime_type(value=None, data_type="image_path") is None + assert infer_mime_type(value="", data_type="image_path") is None def test_infers_png(self) -> None: """Test MIME type inference for PNG files.""" - assert _infer_mime_type(value="/tmp/photo.png", data_type="image") == "image/png" + assert infer_mime_type(value="/tmp/photo.png", data_type="image_path") == "image/png" def test_infers_jpeg(self) -> None: """Test MIME type inference for JPEG files.""" - assert _infer_mime_type(value="/tmp/photo.jpg", data_type="image") == "image/jpeg" + assert infer_mime_type(value="/tmp/photo.jpg", data_type="image_path") == "image/jpeg" def test_infers_wav(self) -> None: """Test MIME type inference for WAV files.""" - result = _infer_mime_type(value="/tmp/audio.wav", data_type="audio") + result = infer_mime_type(value="/tmp/audio.wav", data_type="audio_path") assert result is not None assert "wav" in result def test_infers_mp3(self) -> None: """Test MIME type inference for MP3 files.""" - assert _infer_mime_type(value="/tmp/audio.mp3", data_type="audio") == "audio/mpeg" + assert infer_mime_type(value="/tmp/audio.mp3", data_type="audio_path") == "audio/mpeg" def test_returns_none_for_unknown_extension(self) -> None: """Test that unrecognized extensions return None.""" - assert _infer_mime_type(value="/tmp/data.xyz123", data_type="image") is None + assert infer_mime_type(value="/tmp/data.xyz123", data_type="image_path") is None def test_infers_mp4(self) -> None: """Test MIME type inference for MP4 video files.""" - assert _infer_mime_type(value="/tmp/video.mp4", data_type="video") == "video/mp4" + assert infer_mime_type(value="/tmp/video.mp4", data_type="video_path") == "video/mp4" class TestBuildFilename: - """Tests for _build_filename helper function.""" + """Tests for build_filename helper function.""" def test_image_path_with_hash(self) -> None: - result = _build_filename(data_type="image_path", sha256="abcdef1234567890", value="/tmp/photo.png") + result = build_filename(data_type="image_path", sha256="abcdef1234567890", value="/tmp/photo.png") assert result == "image_abcdef123456.png" def test_audio_path_with_hash(self) -> None: - result = _build_filename(data_type="audio_path", sha256="1234abcd5678efgh", value="/tmp/speech.wav") + result = build_filename(data_type="audio_path", sha256="1234abcd5678efgh", value="/tmp/speech.wav") assert result == "audio_1234abcd5678.wav" def test_video_path_with_hash(self) -> None: - result = _build_filename(data_type="video_path", sha256="deadbeef00000000", value="/tmp/clip.mp4") + result = build_filename(data_type="video_path", sha256="deadbeef00000000", value="/tmp/clip.mp4") assert result == "video_deadbeef0000.mp4" def test_binary_path_with_hash(self) -> None: - result = _build_filename(data_type="binary_path", sha256="cafe0123babe4567", value="/tmp/doc.pdf") + result = build_filename(data_type="binary_path", sha256="cafe0123babe4567", value="/tmp/doc.pdf") assert result == "file_cafe0123babe.pdf" def test_returns_none_for_text(self) -> None: - assert _build_filename(data_type="text", sha256="abc123", value="hello") is None + assert build_filename(data_type="text", sha256="abc123", value="hello") is None def test_returns_none_for_reasoning(self) -> None: - assert _build_filename(data_type="reasoning", sha256="abc123", value="thinking") is None + assert build_filename(data_type="reasoning", sha256="abc123", value="thinking") is None def test_fallback_ext_when_no_value(self) -> None: - result = _build_filename(data_type="image_path", sha256="abcdef1234567890", value=None) + result = build_filename(data_type="image_path", sha256="abcdef1234567890", value=None) assert result == "image_abcdef123456.png" def test_fallback_ext_for_data_uri(self) -> None: - result = _build_filename(data_type="audio_path", sha256="abcdef1234567890", value="data:audio/wav;base64,AAA=") + result = build_filename(data_type="audio_path", sha256="abcdef1234567890", value="data:audio/wav;base64,AAA=") assert result == "audio_abcdef123456.wav" def test_random_hash_when_no_sha256(self) -> None: - result = _build_filename(data_type="image_path", sha256=None, value="/tmp/photo.png") + result = build_filename(data_type="image_path", sha256=None, value="/tmp/photo.png") assert result is not None assert result.startswith("image_") assert result.endswith(".png") @@ -1096,7 +1105,7 @@ def test_random_hash_when_no_sha256(self) -> None: def test_blob_url_extension(self) -> None: url = "https://account.blob.core.windows.net/container/images/photo.jpg" - result = _build_filename(data_type="image_path", sha256="abcdef1234567890", value=url) + result = build_filename(data_type="image_path", sha256="abcdef1234567890", value=url) assert result == "image_abcdef123456.jpg" diff --git a/tests/unit/backend/test_response_contracts.py b/tests/unit/backend/test_response_contracts.py new file mode 100644 index 0000000000..45ca0643a9 --- /dev/null +++ b/tests/unit/backend/test_response_contracts.py @@ -0,0 +1,227 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +JSON/schema contract tests for the backend response views. + +These guard the serialized wire shape of the canonical-model-backed response +DTOs (``ScoreView``/``MessagePieceView``/``MessageView``/``AttackSummary``): +canonical fields plus presentation computed fields must appear in +``model_dump(mode="json")``, ``related_conversations`` must serialize in a +stable (sorted) order, and the deprecated wire aliases (``score_id``, +``scored_at``, ``piece_id``, ``pieces``) must stay populated for back-compat. +""" + +import uuid +from datetime import datetime, timezone + +from pyrit.backend.models.attacks import ( + AttackSummary, + MessagePieceView, + MessageView, + ScoreView, +) +from pyrit.models import ( + AttackResult, + ComponentIdentifier, + MessagePiece, + RetryEvent, + Score, + build_atomic_attack_identifier, +) +from pyrit.models.conversation_reference import ConversationReference, ConversationType + + +def _make_score() -> Score: + return Score( + score_value="0.5", + score_type="float_scale", + score_rationale="because", + message_piece_id=str(uuid.uuid4()), + scorer_class_identifier=ComponentIdentifier(class_name="FloatScaleScorer", class_module="pyrit.score"), + ) + + +def _make_piece(*, sequence: int = 0, role: str = "user") -> MessagePiece: + return MessagePiece( + role=role, + original_value="hello", + converted_value="hello", + original_value_data_type="text", + converted_value_data_type="text", + conversation_id="conv-1", + sequence=sequence, + ) + + +def _make_attack_result(*, name: str = "CrescendoAttack") -> AttackResult: + target = ComponentIdentifier(class_name="OpenAIChatTarget", class_module="pyrit.prompt_target") + return AttackResult( + conversation_id="attack-1", + objective="test objective", + attack_result_id="ar-attack-1", + atomic_attack_identifier=build_atomic_attack_identifier( + attack_identifier=ComponentIdentifier( + class_name=name, + class_module="pyrit.attacks", + children={"objective_target": target}, + ), + ), + ) + + +class TestScoreViewContract: + """JSON contract for ScoreView.""" + + def test_dump_has_canonical_and_computed_fields(self) -> None: + """Test that the serialized score exposes canonical fields plus scorer_type.""" + view = ScoreView.from_domain(_make_score()) + dumped = view.model_dump(mode="json") + + assert dumped["score_value"] == "0.5" + assert dumped["score_type"] == "float_scale" + assert dumped["scorer_type"] == "FloatScaleScorer" + assert "scorer_class_identifier" in dumped + + def test_schema_builds(self) -> None: + """Test that ScoreView's serialization schema includes the computed field.""" + assert "scorer_type" in ScoreView.model_json_schema(mode="serialization")["properties"] + + +class TestMessagePieceViewContract: + """JSON contract for MessagePieceView.""" + + def test_dump_has_canonical_and_presentation_fields(self) -> None: + """Test that the serialized piece exposes canonical and derived presentation fields.""" + piece = _make_piece() + view = MessagePieceView.from_domain(piece, original_value="hello", converted_value="hello") + dumped = view.model_dump(mode="json") + + assert dumped["role"] == "user" + assert dumped["original_value"] == "hello" + assert "original_value_mime_type" in dumped + assert "converted_value_mime_type" in dumped + assert "original_filename" in dumped + assert "converted_filename" in dumped + assert "response_error_description" in dumped + assert dumped["scores"] == [] + + def test_scores_are_score_views(self) -> None: + """Test that nested scores serialize with the ScoreView computed field.""" + piece = _make_piece() + piece.scores = [_make_score()] + view = MessagePieceView.from_domain(piece, original_value="hello", converted_value="hello") + dumped = view.model_dump(mode="json") + + assert dumped["scores"][0]["scorer_type"] == "FloatScaleScorer" + + +class TestMessageViewContract: + """JSON contract for MessageView.""" + + def test_dump_has_turn_metadata_and_pieces(self) -> None: + """Test that the serialized message exposes turn metadata and piece views.""" + piece = MessagePieceView.from_domain( + _make_piece(sequence=3, role="assistant"), original_value="hello", converted_value="hello" + ) + view = MessageView.from_domain(pieces=[piece]) + dumped = view.model_dump(mode="json") + + assert dumped["turn_number"] == 3 + assert dumped["role"] == "assistant" + assert "created_at" in dumped + assert len(dumped["message_pieces"]) == 1 + assert dumped["message_pieces"][0]["role"] == "assistant" + + +class TestAttackSummaryContract: + """JSON contract for AttackSummary, including set-ordering (R1).""" + + def _summary(self, ar: AttackResult) -> AttackSummary: + now = datetime.now(timezone.utc) + return AttackSummary.from_domain( + ar, + last_response=None, + last_score=None, + message_count=2, + last_message_preview="hi", + labels={"env": "prod"}, + created_at=now, + updated_at=now, + ) + + def test_dump_has_canonical_computed_and_stats_fields(self) -> None: + """Test that the serialized summary exposes canonical, computed, and stats fields.""" + dumped = self._summary(_make_attack_result()).model_dump(mode="json") + + assert dumped["conversation_id"] == "attack-1" + assert dumped["objective"] == "test objective" + assert dumped["attack_type"] == "CrescendoAttack" + assert dumped["target"]["target_type"] == "OpenAIChatTarget" + assert dumped["converters"] == [] + assert dumped["message_count"] == 2 + assert dumped["last_message_preview"] == "hi" + assert dumped["labels"] == {"env": "prod"} + assert "retry_events" in dumped + + def test_related_conversations_serialize_sorted(self) -> None: + """Test that related_conversations serialize in a stable, sorted order (R1).""" + ar = _make_attack_result() + ar.related_conversations = { + ConversationReference(conversation_id="zeta", conversation_type=ConversationType.PRUNED), + ConversationReference(conversation_id="alpha", conversation_type=ConversationType.ADVERSARIAL), + ConversationReference(conversation_id="mid", conversation_type=ConversationType.PRUNED), + } + + dumped = self._summary(ar).model_dump(mode="json") + + ordered_ids = [ref["conversation_id"] for ref in dumped["related_conversations"]] + assert ordered_ids == ["alpha", "mid", "zeta"] + assert dumped["related_conversation_ids"] == ["alpha", "mid", "zeta"] + + def test_retry_events_round_trip(self) -> None: + """Test that inherited retry_events serialize with their canonical payload.""" + ar = _make_attack_result() + ar.retry_events = [RetryEvent(attempt_number=1, exception_type="RateLimitError")] + + dumped = self._summary(ar).model_dump(mode="json") + + assert dumped["retry_events"][0]["attempt_number"] == 1 + assert dumped["retry_events"][0]["exception_type"] == "RateLimitError" + + +class TestDeprecatedWireAliases: + """Old wire field names stay populated (as deprecated aliases) for backward compat.""" + + def test_score_view_emits_deprecated_aliases(self) -> None: + """Test that ScoreView still emits score_id/scored_at mirroring id/timestamp.""" + view = ScoreView.from_domain(_make_score()) + dumped = view.model_dump(mode="json") + + assert dumped["score_id"] == str(view.id) + assert dumped["scored_at"] == dumped["timestamp"] + + def test_message_piece_view_emits_deprecated_alias(self) -> None: + """Test that MessagePieceView still emits piece_id mirroring id.""" + view = MessagePieceView.from_domain(_make_piece(), original_value="hello", converted_value="hello") + dumped = view.model_dump(mode="json") + + assert dumped["piece_id"] == str(view.id) + + def test_message_view_emits_deprecated_alias(self) -> None: + """Test that MessageView still emits pieces mirroring message_pieces.""" + piece = MessagePieceView.from_domain(_make_piece(), original_value="hello", converted_value="hello") + dumped = MessageView.from_domain(pieces=[piece]).model_dump(mode="json") + + assert dumped["pieces"] == dumped["message_pieces"] + + def test_aliases_marked_deprecated_in_schema(self) -> None: + """Test that the deprecated aliases are flagged deprecated in the OpenAPI schema.""" + score_props = ScoreView.model_json_schema(mode="serialization")["properties"] + piece_props = MessagePieceView.model_json_schema(mode="serialization")["properties"] + message_props = MessageView.model_json_schema(mode="serialization")["properties"] + + assert score_props["score_id"]["deprecated"] is True + assert score_props["scored_at"]["deprecated"] is True + assert piece_props["piece_id"]["deprecated"] is True + assert message_props["pieces"]["deprecated"] is True From 9d6743c95d518eab3845fd503796f653ff42486e Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 4 Jun 2026 14:05:56 -0700 Subject: [PATCH 2/8] Phase 10: drop remaining 'as PyritX' alias dance in mappers and tests Following the cleanup of attacks.py, remove the legacy 'from pyrit.models import X as PyritX' aliasing in attack_mappers.py, test_mappers.py, and test_api_routes.py. There's no name collision in those files, so import Message/MessagePiece/Score plainly. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/backend/mappers/attack_mappers.py | 20 ++++++------- tests/unit/backend/test_api_routes.py | 4 +-- tests/unit/backend/test_mappers.py | 37 ++++++++++++------------- 3 files changed, 28 insertions(+), 33 deletions(-) diff --git a/pyrit/backend/mappers/attack_mappers.py b/pyrit/backend/mappers/attack_mappers.py index dd1a4d9f87..1e5bd28faf 100644 --- a/pyrit/backend/mappers/attack_mappers.py +++ b/pyrit/backend/mappers/attack_mappers.py @@ -33,9 +33,7 @@ ScoreView, ) from pyrit.common.deprecation import print_deprecation_message -from pyrit.models import MEDIA_PATH_DATA_TYPES, AttackResult, ChatMessageRole, PromptDataType -from pyrit.models import Message as PyritMessage -from pyrit.models import MessagePiece as PyritMessagePiece +from pyrit.models import MEDIA_PATH_DATA_TYPES, AttackResult, ChatMessageRole, Message, MessagePiece, PromptDataType logger = logging.getLogger(__name__) @@ -239,7 +237,7 @@ def _resolve_summary_timestamps(ar: AttackResult) -> tuple[datetime, datetime]: return created_at, updated_at -def _summary_last_response(piece: Optional[PyritMessagePiece]) -> Optional[MessagePieceView]: +def _summary_last_response(piece: Optional[MessagePiece]) -> Optional[MessagePieceView]: """ Build a ``MessagePieceView`` for a summary's last response (sync media resolution, no SAS). @@ -273,7 +271,7 @@ async def _resolve_and_sign_media_async(*, value: Optional[str], data_type: str) return resolved -async def pyrit_messages_to_dto_async(pyrit_messages: list[PyritMessage]) -> list[MessageView]: +async def pyrit_messages_to_dto_async(pyrit_messages: list[Message]) -> list[MessageView]: """ Translate PyRIT messages to backend MessageView responses. @@ -316,7 +314,7 @@ def request_piece_to_pyrit_message_piece( conversation_id: str, sequence: int, labels: Optional[dict[str, str]] = None, # deprecated -) -> PyritMessagePiece: +) -> MessagePiece: """ Convert a single request piece DTO to a PyRIT MessagePiece domain object. @@ -329,7 +327,7 @@ def request_piece_to_pyrit_message_piece( Deprecated: This parameter will be removed in a release 0.16.0. Returns: - PyritMessagePiece domain object. + MessagePiece domain object. """ if labels is not None: print_deprecation_message( @@ -343,7 +341,7 @@ def request_piece_to_pyrit_message_piece( elif piece.mime_type: metadata = {"mime_type": piece.mime_type} original_prompt_id = uuid.UUID(piece.original_prompt_id) if piece.original_prompt_id else None - return PyritMessagePiece( + return MessagePiece( role=role, original_value=piece.original_value, original_value_data_type=cast("PromptDataType", piece.data_type), @@ -363,7 +361,7 @@ def request_to_pyrit_message( conversation_id: str, sequence: int, labels: Optional[dict[str, str]] = None, # deprecated -) -> PyritMessage: +) -> Message: """ Build a PyRIT Message from an AddMessageRequest DTO. @@ -375,7 +373,7 @@ def request_to_pyrit_message( Deprecated: This parameter will be removed in a release 0.16.0. Returns: - PyritMessage ready to send to the target. + Message ready to send to the target. """ if labels is not None: print_deprecation_message( @@ -393,7 +391,7 @@ def request_to_pyrit_message( ) for p in request.pieces ] - return PyritMessage(message_pieces=pieces) + return Message(message_pieces=pieces) # ============================================================================ diff --git a/tests/unit/backend/test_api_routes.py b/tests/unit/backend/test_api_routes.py index a0ae81ca5b..5c52e4a007 100644 --- a/tests/unit/backend/test_api_routes.py +++ b/tests/unit/backend/test_api_routes.py @@ -40,12 +40,12 @@ TargetListResponse, ) from pyrit.backend.routes.labels import get_label_options -from pyrit.models import MessagePiece as PyritMessagePiece +from pyrit.models import MessagePiece def _make_message_view(*, role: str = "user", value: str = "hello", sequence: int = 1) -> MessageView: """Build a ``MessageView`` from a single text piece for route tests.""" - piece = PyritMessagePiece( + piece = MessagePiece( role=role, original_value=value, converted_value=value, diff --git a/tests/unit/backend/test_mappers.py b/tests/unit/backend/test_mappers.py index ff7d912c84..b40ce0314a 100644 --- a/tests/unit/backend/test_mappers.py +++ b/tests/unit/backend/test_mappers.py @@ -30,10 +30,7 @@ from pyrit.backend.mappers.target_mappers import target_object_to_instance from pyrit.backend.models._media import build_filename, infer_mime_type from pyrit.backend.models.attacks import ScoreView -from pyrit.models import AttackOutcome, AttackResult, ComponentIdentifier -from pyrit.models import Message as PyritMessage -from pyrit.models import MessagePiece as PyritMessagePiece -from pyrit.models import Score as PyritScore +from pyrit.models import AttackOutcome, AttackResult, ComponentIdentifier, Message, MessagePiece, Score from pyrit.models.conversation_stats import ConversationStats from pyrit.prompt_target import PromptTarget, TargetCapabilities @@ -94,9 +91,9 @@ def _make_piece( original_value_data_type: str = "text", converted_value_data_type: str = "text", role: str = "user", -) -> PyritMessagePiece: +) -> MessagePiece: """Create a real domain message piece for mapper tests.""" - return PyritMessagePiece( + return MessagePiece( role=role, original_value=original_value, converted_value=converted_value, @@ -113,9 +110,9 @@ def _make_score( score_type: str = "float_scale", score_category: Optional[list[str]] = None, scorer_name: str = "TrueFalseScorer", -) -> PyritScore: +) -> Score: """Create a real domain score for mapper tests.""" - return PyritScore( + return Score( score_value=score_value, score_type=score_type, score_category=score_category, @@ -444,7 +441,7 @@ def test_maps_scores(self) -> None: def test_scorer_type_unknown_without_identifier(self) -> None: """Test that scorer_type falls back to 'Unknown' when no identifier is set.""" - score = PyritScore(score_value="0.5", score_type="float_scale", message_piece_id=str(uuid.uuid4())) + score = Score(score_value="0.5", score_type="float_scale", message_piece_id=str(uuid.uuid4())) view = ScoreView.from_domain(score) @@ -470,7 +467,7 @@ class TestPyritMessagesToDto: async def test_maps_single_message(self) -> None: """Test mapping a single message with one piece.""" piece = _make_piece(original_value="hi", converted_value="hi") - msg = PyritMessage(message_pieces=[piece]) + msg = Message(message_pieces=[piece]) result = await pyrit_messages_to_dto_async([msg]) @@ -488,7 +485,7 @@ async def test_maps_data_types_separately(self) -> None: original_value_data_type="text", converted_value_data_type="image_path", ) - msg = PyritMessage(message_pieces=[piece]) + msg = Message(message_pieces=[piece]) result = await pyrit_messages_to_dto_async([msg]) @@ -508,7 +505,7 @@ async def test_populates_mime_type_for_image(self) -> None: original_value_data_type="image_path", converted_value_data_type="image_path", ) - msg = PyritMessage(message_pieces=[piece]) + msg = Message(message_pieces=[piece]) result = await pyrit_messages_to_dto_async([msg]) @@ -518,7 +515,7 @@ async def test_populates_mime_type_for_image(self) -> None: async def test_mime_type_none_for_text(self) -> None: """Test that MIME type is None for text pieces.""" piece = _make_piece(original_value="hello", converted_value="hello") - msg = PyritMessage(message_pieces=[piece]) + msg = Message(message_pieces=[piece]) result = await pyrit_messages_to_dto_async([msg]) @@ -533,7 +530,7 @@ async def test_mime_type_for_audio(self) -> None: original_value_data_type="audio_path", converted_value_data_type="audio_path", ) - msg = PyritMessage(message_pieces=[piece]) + msg = Message(message_pieces=[piece]) result = await pyrit_messages_to_dto_async([msg]) @@ -554,7 +551,7 @@ async def test_local_media_file_returns_media_url(self) -> None: original_value_data_type="image_path", converted_value_data_type="image_path", ) - msg = PyritMessage(message_pieces=[piece]) + msg = Message(message_pieces=[piece]) result = await pyrit_messages_to_dto_async([msg]) @@ -572,7 +569,7 @@ async def test_data_uri_passthrough(self) -> None: original_value_data_type="image_path", converted_value_data_type="image_path", ) - msg = PyritMessage(message_pieces=[piece]) + msg = Message(message_pieces=[piece]) result = await pyrit_messages_to_dto_async([msg]) @@ -587,7 +584,7 @@ async def test_non_blob_http_url_passthrough(self) -> None: original_value_data_type="image_path", converted_value_data_type="image_path", ) - msg = PyritMessage(message_pieces=[piece]) + msg = Message(message_pieces=[piece]) result = await pyrit_messages_to_dto_async([msg]) @@ -604,7 +601,7 @@ async def test_azure_blob_url_is_signed(self) -> None: original_value_data_type="image_path", converted_value_data_type="image_path", ) - msg = PyritMessage(message_pieces=[piece]) + msg = Message(message_pieces=[piece]) with patch( "pyrit.backend.mappers.attack_mappers._sign_blob_url_async", @@ -625,7 +622,7 @@ async def test_azure_blob_url_sign_failure_returns_raw_url(self) -> None: original_value_data_type="image_path", converted_value_data_type="image_path", ) - msg = PyritMessage(message_pieces=[piece]) + msg = Message(message_pieces=[piece]) with patch( "pyrit.backend.mappers.attack_mappers._sign_blob_url_async", @@ -645,7 +642,7 @@ async def test_nonexistent_media_file_returns_raw_path(self) -> None: original_value_data_type="image_path", converted_value_data_type="image_path", ) - msg = PyritMessage(message_pieces=[piece]) + msg = Message(message_pieces=[piece]) result = await pyrit_messages_to_dto_async([msg]) From c852ee281edff78a4ddc92f183346347c4902480 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 4 Jun 2026 15:29:34 -0700 Subject: [PATCH 3/8] Phase 10: split MessagePieceView storage refs from client URLs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously MessagePieceView.from_domain re-used the inherited original_value / converted_value fields to carry the client-fetchable URL (a /api/media path or a SAS-signed blob URL). Two different concepts shared one slot, and the override-style signature gave readers no signal which kwarg was a narrow vs. a transform of the inherited field. Make the two concepts independent: - original_value / converted_value (inherited) keep the raw stored value the database has — text, a file path, a blob URL, or a data URI. - original_value_url / converted_value_url are new optional presentation fields populated by the mapper for media pieces, and None for plain text. _resolve_media_url now returns None for non-media data types (it's a media resolver — text has no fetchable URL). The mapper gates URL population on the piece data type so text never carries a stale URL string. Frontend: BackendMessagePiece gains the two *_value_url fields. pieceToAttachment prefers *_value_url when present and falls back to the existing base64 / raw-value detection so older payload shapes keep working. A new test exercises the URL-precedence path. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- frontend/src/types/index.ts | 2 + frontend/src/utils/messageMapper.test.ts | 32 +++++++ frontend/src/utils/messageMapper.ts | 14 ++- pyrit/backend/mappers/attack_mappers.py | 59 ++++++++----- pyrit/backend/models/attacks.py | 87 ++++++++++--------- tests/unit/backend/test_api_routes.py | 6 +- tests/unit/backend/test_mappers.py | 84 +++++++++++++----- tests/unit/backend/test_response_contracts.py | 25 +++--- 8 files changed, 205 insertions(+), 104 deletions(-) diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 9b9d30b414..012fcf76e1 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -186,8 +186,10 @@ export interface BackendMessagePiece { original_value_data_type: string converted_value_data_type: string original_value?: string | null + original_value_url?: string | null original_value_mime_type?: string | null converted_value: string + converted_value_url?: string | null converted_value_mime_type?: string | null original_filename?: string | null converted_filename?: string | null diff --git a/frontend/src/utils/messageMapper.test.ts b/frontend/src/utils/messageMapper.test.ts index 6e6121b675..93d480ecba 100644 --- a/frontend/src/utils/messageMapper.test.ts +++ b/frontend/src/utils/messageMapper.test.ts @@ -149,6 +149,38 @@ describe("messageMapper", () => { ); }); + it("prefers converted_value_url over the raw converted_value for media", () => { + // New shape (Phase 10+): raw storage path stays in converted_value; + // the client-fetchable URL the mapper resolves goes to converted_value_url. + const msg: BackendMessage = { + turn_number: 1, + role: "assistant", + message_pieces: [ + { + id: "p1", + original_value_data_type: "text", + converted_value_data_type: "image_path", + original_value: "generate an image", + converted_value: "C:\\dbdata\\prompt-memory-entries\\images\\image.png", + converted_value_url: "/api/media?path=C%3A%5Cdbdata%5Cimages%5Cimage.png", + converted_value_mime_type: "image/png", + scores: [], + response_error: "none", + }, + ], + created_at: "2026-02-15T00:00:00Z", + }; + + const result = backendMessageToFrontend(msg); + + expect(result.attachments).toHaveLength(1); + expect(result.attachments![0].url).toBe( + "/api/media?path=C%3A%5Cdbdata%5Cimages%5Cimage.png" + ); + // Path-style URLs don't have a known payload size to display. + expect(result.attachments![0].size).toBeUndefined(); + }); + it("should convert an audio response", () => { const msg: BackendMessage = { turn_number: 1, diff --git a/frontend/src/utils/messageMapper.ts b/frontend/src/utils/messageMapper.ts index 804d0525b1..1868ca4ad0 100644 --- a/frontend/src/utils/messageMapper.ts +++ b/frontend/src/utils/messageMapper.ts @@ -119,6 +119,13 @@ function decodedBase64ByteCount(value: string): number { * * When `source` is `'converted'` (the default), uses `converted_value*` fields. * When `source` is `'original'`, uses `original_value*` fields instead. + * + * The backend exposes media in two forms: the raw stored value + * (`*_value`, which may be a file path, blob URL, base64 string, etc.) and a + * client-fetchable URL (`*_value_url`, populated by the mapper as a + * `/api/media?path=...` link or a SAS-signed blob URL). When `*_value_url` is + * present we use it directly; otherwise we fall back to the legacy detection + * logic so older shaped payloads still render. */ function pieceToAttachment( piece: BackendMessagePiece, @@ -127,6 +134,7 @@ function pieceToAttachment( const isOriginal = source === 'original' const dataType = isOriginal ? piece.original_value_data_type : piece.converted_value_data_type const value = isOriginal ? piece.original_value : piece.converted_value + const valueUrl = isOriginal ? piece.original_value_url : piece.converted_value_url const mimeField = isOriginal ? piece.original_value_mime_type : piece.converted_value_mime_type if (!isMediaDataType(dataType) || !value) return null @@ -139,7 +147,9 @@ function pieceToAttachment( /^[a-z][a-z0-9+.-]*:/i.test(value) // URI scheme (file:, blob:, etc.) const isBase64 = !looksLikePathOrScheme && value.length >= 16 && /^[A-Za-z0-9+/=\n]+$/.test(value) - const url = isBase64 ? buildDataUri(value, mime) : value + // Prefer the mapper-resolved URL when present; fall back to existing logic + // (base64 inline data URI or raw value-as-URL) for compatibility. + const url = valueUrl || (isBase64 ? buildDataUri(value, mime) : value) const prefix = isOriginal ? 'original_' : '' const filename = isOriginal ? piece.original_filename : piece.converted_filename const fallbackName = `${prefix}${dataType}_${piece.id.slice(0, 8)}` @@ -147,7 +157,7 @@ function pieceToAttachment( // For base64-inlined media, derive the decoded byte count. For path / URL // values the string length is meaningless (e.g. /api/media?path=... is a // reference, not the payload), so size is omitted and the UI must hide it. - const size = isBase64 ? decodedBase64ByteCount(value) : undefined + const size = isBase64 && !valueUrl ? decodedBase64ByteCount(value) : undefined return { type: dataTypeToAttachmentType(dataType), diff --git a/pyrit/backend/mappers/attack_mappers.py b/pyrit/backend/mappers/attack_mappers.py index 1e5bd28faf..0eefb8a514 100644 --- a/pyrit/backend/mappers/attack_mappers.py +++ b/pyrit/backend/mappers/attack_mappers.py @@ -150,21 +150,25 @@ async def _sign_blob_url_async(*, blob_url: str) -> str: def _resolve_media_url(*, value: Optional[str], data_type: str) -> Optional[str]: """ - For media path types, convert a local file path to a ``/api/media`` URL. + Resolve a media value to a client-fetchable URL. - Non-media types and Azure Blob URLs are returned as-is (blob URLs are - signed later in ``pyrit_messages_to_dto_async``). + Returns ``None`` for non-media data types or empty values — there's no URL + to expose for plain text. For media values: + + - Local file paths -> ``/api/media?path=...`` + - data URIs and http(s) URLs -> passed through as-is (blob URLs are + signed later in ``pyrit_messages_to_dto_async``) + - Anything else (e.g. nonexistent paths) -> passed through unchanged Args: value: The stored value (file path, blob URL, data URI, or text). data_type: The prompt data type (e.g. ``image_path``, ``text``). Returns: - The value unchanged for non-media types, a ``/api/media?path=...`` - URL for local file paths, or the original value for blob URLs / data URIs. + A client-fetchable URL for media, or ``None`` for text / empty values. """ if not value or data_type not in MEDIA_PATH_DATA_TYPES: - return value + return None # Already a URL or data URI — pass through if value.startswith(("http://", "https://", "data:")): return value @@ -199,19 +203,27 @@ def attack_result_to_summary( labels = dict(ar.labels) if ar.labels else {} labels.update(stats.labels or {}) created_at, updated_at = _resolve_summary_timestamps(ar) - return AttackSummary.from_domain( - ar, + + # Start with every canonical AttackResult field, then overlay view-narrowed + # values (last_response/last_score) and presentation extras (message_count, + # previews, timestamps, merged labels). ``model_construct`` skips validation + # because the source AttackResult is already valid. + data = {name: getattr(ar, name) for name in AttackResult.model_fields} + data.update( + # Overlays — narrow domain types to view types. last_response=_summary_last_response(ar.last_response), last_score=ScoreView.from_domain(ar.last_score) if ar.last_score else None, + labels=labels, + # Presentation extras — not on AttackResult. message_count=stats.message_count, last_message_preview=format_last_message_preview( value=stats.last_message_preview, data_type=stats.last_message_data_type, ), - labels=labels, created_at=created_at, updated_at=updated_at, ) + return AttackSummary.model_construct(**data) def _resolve_summary_timestamps(ar: AttackResult) -> tuple[datetime, datetime]: @@ -248,13 +260,12 @@ def _summary_last_response(piece: Optional[MessagePiece]) -> Optional[MessagePie return None return MessagePieceView.from_domain( piece, - original_value=_resolve_media_url( + original_value_url=_resolve_media_url( value=piece.original_value, data_type=piece.original_value_data_type or "text" ), - converted_value=_resolve_media_url( - value=piece.converted_value or "", data_type=piece.converted_value_data_type or "text" - ) - or "", + converted_value_url=_resolve_media_url( + value=piece.converted_value, data_type=piece.converted_value_data_type or "text" + ), ) @@ -275,7 +286,10 @@ async def pyrit_messages_to_dto_async(pyrit_messages: list[Message]) -> list[Mes """ Translate PyRIT messages to backend MessageView responses. - Media file paths are converted to URLs the frontend can fetch directly: + The raw stored ``original_value`` / ``converted_value`` are passed through + unchanged. Media file paths are additionally resolved into client-fetchable + URLs and exposed via ``original_value_url`` / ``converted_value_url``: + - Local files -> ``/api/media?path=...`` (served by the media endpoint) - Azure Blob Storage files -> signed URLs with SAS tokens @@ -286,19 +300,18 @@ async def pyrit_messages_to_dto_async(pyrit_messages: list[Message]) -> list[Mes for msg in pyrit_messages: pieces: list[MessagePieceView] = [] for p in msg.message_pieces: - original_value = await _resolve_and_sign_media_async( + original_value_url = await _resolve_and_sign_media_async( value=p.original_value, data_type=p.original_value_data_type or "text" ) - converted_value = ( - await _resolve_and_sign_media_async( - value=p.converted_value or "", data_type=p.converted_value_data_type or "text" - ) - or "" + converted_value_url = await _resolve_and_sign_media_async( + value=p.converted_value, data_type=p.converted_value_data_type or "text" ) pieces.append( - MessagePieceView.from_domain(p, original_value=original_value, converted_value=converted_value) + MessagePieceView.from_domain( + p, original_value_url=original_value_url, converted_value_url=converted_value_url + ) ) - messages.append(MessageView.from_domain(pieces=pieces)) + messages.append(MessageView.from_domain(message_pieces=pieces)) return messages diff --git a/pyrit/backend/models/attacks.py b/pyrit/backend/models/attacks.py index 6dea5b57e4..cc5f1aa699 100644 --- a/pyrit/backend/models/attacks.py +++ b/pyrit/backend/models/attacks.py @@ -83,13 +83,35 @@ class MessagePieceView(MessagePiece): """ API view of a ``pyrit.models.MessagePiece``. - ``original_value`` / ``converted_value`` carry frontend-fetchable URLs for - media pieces (the raw on-disk path is never exposed); text pieces keep their - literal values. MIME types and download filenames are derived from the raw - values at map time. + Inherits the canonical piece fields unchanged: ``original_value`` / + ``converted_value`` carry the raw stored content the server holds (text, a + local file path, a blob URL, or a data URI — whatever the database has). + + Adds presentation-only fields the client needs: + + - ``original_value_url`` / ``converted_value_url`` — client-fetchable URLs + populated by the mapper for media pieces (``/api/media?path=...`` for + local files; SAS-signed URLs for Azure Blob; pass-through for data URIs + and existing http(s) URLs). ``None`` for plain text and empty values. + - MIME types, download filenames and the response-error description are + derived from the raw values at map time. """ scores: list[ScoreView] = Field(default_factory=list) + original_value_url: Optional[str] = Field( + default=None, + description=( + "Client-fetchable URL for the original media value (e.g. " + "/api/media?path=... or a SAS-signed blob URL). None for text pieces." + ), + ) + converted_value_url: Optional[str] = Field( + default=None, + description=( + "Client-fetchable URL for the converted media value (e.g. " + "/api/media?path=... or a SAS-signed blob URL). None for text pieces." + ), + ) original_value_mime_type: Optional[str] = Field(default=None, description="MIME type of the original value") converted_value_mime_type: Optional[str] = Field(default=None, description="MIME type of the converted value") original_filename: Optional[str] = Field(default=None, description="Download filename for the original value") @@ -109,16 +131,23 @@ def from_domain( cls, piece: MessagePiece, *, - original_value: Optional[str], - converted_value: str, + original_value_url: Optional[str] = None, + converted_value_url: Optional[str] = None, ) -> "MessagePieceView": """ Build a ``MessagePieceView`` from a domain piece without re-validating. + The canonical piece fields (``original_value``, ``converted_value``, + sha256s, role, ids, etc.) are copied through unchanged. The optional URL + kwargs are purely additive — they populate the presentation-only + ``*_value_url`` fields the client uses to fetch media. + Args: piece: The domain message piece. - original_value: Resolved/fetchable original value (a URL for media). - converted_value: Resolved/fetchable converted value (a URL for media). + original_value_url: Client-fetchable URL for ``piece.original_value`` + when it's media; ``None`` for text. + converted_value_url: Client-fetchable URL for ``piece.converted_value`` + when it's media; ``None`` for text. Returns: A ``MessagePieceView`` with derived MIME types, filenames, and views. @@ -127,9 +156,9 @@ def from_domain( orig_dtype = piece.original_value_data_type or "text" conv_dtype = piece.converted_value_data_type or "text" data.update( - original_value=original_value, - converted_value=converted_value, scores=[ScoreView.from_domain(score) for score in piece.scores], + original_value_url=original_value_url, + converted_value_url=converted_value_url, original_value_mime_type=infer_mime_type(value=piece.original_value, data_type=orig_dtype), converted_value_mime_type=infer_mime_type(value=piece.converted_value, data_type=conv_dtype), original_filename=build_filename( @@ -179,14 +208,14 @@ def pieces(self) -> list[MessagePieceView]: return self.message_pieces @classmethod - def from_domain(cls, *, pieces: list[MessagePieceView]) -> "MessageView": + def from_domain(cls, *, message_pieces: list[MessagePieceView]) -> "MessageView": """ Build a ``MessageView`` from already-mapped piece views. Returns: A ``MessageView`` wrapping the provided piece views. """ - return cls.model_construct(message_pieces=pieces) + return cls.model_construct(message_pieces=message_pieces) class AttackSummary(AttackResult): @@ -266,36 +295,12 @@ def related_conversation_ids(self) -> list[str]: """Return the IDs of related conversations, sorted for stable output.""" return sorted(ref.conversation_id for ref in self.related_conversations) - @classmethod - def from_domain( - cls, - attack_result: AttackResult, - *, - last_response: Optional[MessagePieceView], - last_score: Optional[ScoreView], - message_count: int, - last_message_preview: Optional[str], - labels: dict[str, str], - created_at: datetime, - updated_at: datetime, - ) -> "AttackSummary": - """ - Build an ``AttackSummary`` from a domain ``AttackResult`` and mapper-derived stats. - Returns: - An ``AttackSummary`` combining the attack result with presentation stats. - """ - data = {name: getattr(attack_result, name) for name in AttackResult.model_fields} - data.update( - last_response=last_response, - last_score=last_score, - labels=labels, - message_count=message_count, - last_message_preview=last_message_preview, - created_at=created_at, - updated_at=updated_at, - ) - return cls.model_construct(**data) +# Note: no ``from_domain`` classmethod here. The mapper assembles ``AttackSummary`` +# directly with ``model_construct`` because the construction overlays view-narrowed +# values (``last_response``, ``last_score``, merged ``labels``) on top of the +# canonical ``AttackResult`` fields — a smell that a ``from_domain`` signature +# couldn't express without competing parameters. # ============================================================================ diff --git a/tests/unit/backend/test_api_routes.py b/tests/unit/backend/test_api_routes.py index 5c52e4a007..c5c3f686aa 100644 --- a/tests/unit/backend/test_api_routes.py +++ b/tests/unit/backend/test_api_routes.py @@ -54,10 +54,8 @@ def _make_message_view(*, role: str = "user", value: str = "hello", sequence: in conversation_id="attack-1", sequence=sequence, ) - piece_view = MessagePieceView.from_domain( - piece, original_value=piece.original_value, converted_value=piece.converted_value - ) - return MessageView.from_domain(pieces=[piece_view]) + piece_view = MessagePieceView.from_domain(piece) + return MessageView.from_domain(message_pieces=[piece_view]) @pytest.fixture diff --git a/tests/unit/backend/test_mappers.py b/tests/unit/backend/test_mappers.py index b40ce0314a..257f52d725 100644 --- a/tests/unit/backend/test_mappers.py +++ b/tests/unit/backend/test_mappers.py @@ -539,7 +539,7 @@ async def test_mime_type_for_audio(self) -> None: assert result[0].message_pieces[0].converted_value_mime_type == "audio/mpeg" async def test_local_media_file_returns_media_url(self) -> None: - """Test that local media files are converted to /api/media URLs.""" + """Local media files surface a /api/media URL via *_value_url; raw value stays unchanged.""" with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: tmp.write(b"PNGDATA") tmp_path = tmp.name @@ -555,14 +555,20 @@ async def test_local_media_file_returns_media_url(self) -> None: result = await pyrit_messages_to_dto_async([msg]) - assert result[0].message_pieces[0].original_value is not None - assert result[0].message_pieces[0].original_value.startswith("/api/media?path=") - assert result[0].message_pieces[0].converted_value.startswith("/api/media?path=") + view = result[0].message_pieces[0] + # Raw stored value (inherited from MessagePiece) — unchanged + assert view.original_value == tmp_path + assert view.converted_value == tmp_path + # Client-fetchable URL — populated by the mapper + assert view.original_value_url is not None + assert view.original_value_url.startswith("/api/media?path=") + assert view.converted_value_url is not None + assert view.converted_value_url.startswith("/api/media?path=") finally: os.unlink(tmp_path) async def test_data_uri_passthrough(self) -> None: - """Test that pre-encoded data URIs are not re-encoded.""" + """Pre-encoded data URIs surface as both the raw value and the URL.""" piece = _make_piece( original_value="data:image/png;base64,AAAA", converted_value="data:image/jpeg;base64,BBBB", @@ -573,11 +579,14 @@ async def test_data_uri_passthrough(self) -> None: result = await pyrit_messages_to_dto_async([msg]) - assert result[0].message_pieces[0].original_value == "data:image/png;base64,AAAA" - assert result[0].message_pieces[0].converted_value == "data:image/jpeg;base64,BBBB" + view = result[0].message_pieces[0] + assert view.original_value == "data:image/png;base64,AAAA" + assert view.converted_value == "data:image/jpeg;base64,BBBB" + assert view.original_value_url == "data:image/png;base64,AAAA" + assert view.converted_value_url == "data:image/jpeg;base64,BBBB" async def test_non_blob_http_url_passthrough(self) -> None: - """Test that non-Azure-Blob HTTP URLs are passed through as-is.""" + """Non-Azure-Blob HTTP URLs surface as both the raw value and the URL.""" piece = _make_piece( original_value="http://example.com/image.png", converted_value="http://example.com/image.png", @@ -588,11 +597,14 @@ async def test_non_blob_http_url_passthrough(self) -> None: result = await pyrit_messages_to_dto_async([msg]) - assert result[0].message_pieces[0].original_value == "http://example.com/image.png" - assert result[0].message_pieces[0].converted_value == "http://example.com/image.png" + view = result[0].message_pieces[0] + assert view.original_value == "http://example.com/image.png" + assert view.converted_value == "http://example.com/image.png" + assert view.original_value_url == "http://example.com/image.png" + assert view.converted_value_url == "http://example.com/image.png" async def test_azure_blob_url_is_signed(self) -> None: - """Test that Azure Blob Storage URLs are signed with SAS tokens.""" + """Azure Blob URLs are signed into *_value_url; raw value keeps the unsigned URL.""" blob_url = "https://myaccount.blob.core.windows.net/dbdata/prompt-memory-entries/images/test.png" signed_url = blob_url + "?sig=abc123" piece = _make_piece( @@ -610,11 +622,16 @@ async def test_azure_blob_url_is_signed(self) -> None: ): result = await pyrit_messages_to_dto_async([msg]) - assert result[0].message_pieces[0].original_value == signed_url - assert result[0].message_pieces[0].converted_value == signed_url + view = result[0].message_pieces[0] + # Raw blob URL — unsigned, as stored + assert view.original_value == blob_url + assert view.converted_value == blob_url + # Signed URL — what the client should fetch + assert view.original_value_url == signed_url + assert view.converted_value_url == signed_url async def test_azure_blob_url_sign_failure_returns_raw_url(self) -> None: - """Test that blob sign failure falls back to the raw blob URL.""" + """Sign failure falls back to the unsigned blob URL on both raw and *_value_url.""" blob_url = "https://myaccount.blob.core.windows.net/dbdata/images/test.png" piece = _make_piece( original_value=blob_url, @@ -631,11 +648,14 @@ async def test_azure_blob_url_sign_failure_returns_raw_url(self) -> None: ): result = await pyrit_messages_to_dto_async([msg]) - assert result[0].message_pieces[0].original_value == blob_url - assert result[0].message_pieces[0].converted_value == blob_url + view = result[0].message_pieces[0] + assert view.original_value == blob_url + assert view.converted_value == blob_url + assert view.original_value_url == blob_url + assert view.converted_value_url == blob_url async def test_nonexistent_media_file_returns_raw_path(self) -> None: - """Test that non-existent local media files fall back to raw path values.""" + """Non-existent local media paths fall back to the raw path on both fields.""" piece = _make_piece( original_value="/tmp/nonexistent.png", converted_value="/tmp/nonexistent.png", @@ -646,8 +666,24 @@ async def test_nonexistent_media_file_returns_raw_path(self) -> None: result = await pyrit_messages_to_dto_async([msg]) - assert result[0].message_pieces[0].original_value == "/tmp/nonexistent.png" - assert result[0].message_pieces[0].converted_value == "/tmp/nonexistent.png" + view = result[0].message_pieces[0] + assert view.original_value == "/tmp/nonexistent.png" + assert view.converted_value == "/tmp/nonexistent.png" + assert view.original_value_url == "/tmp/nonexistent.png" + assert view.converted_value_url == "/tmp/nonexistent.png" + + async def test_text_piece_url_fields_are_none(self) -> None: + """Text pieces don't have a fetchable URL — *_value_url is None.""" + piece = _make_piece(original_value="hello world", converted_value="hello world") + msg = Message(message_pieces=[piece]) + + result = await pyrit_messages_to_dto_async([msg]) + + view = result[0].message_pieces[0] + assert view.original_value == "hello world" + assert view.converted_value == "hello world" + assert view.original_value_url is None + assert view.converted_value_url is None class TestIsAzureBlobUrl: @@ -724,9 +760,13 @@ async def test_empty_path_returns_original(self) -> None: class TestResolveMediaUrl: """Tests for _resolve_media_url helper.""" - def test_text_value_passes_through(self) -> None: - """Non-media types are returned as-is.""" - assert _resolve_media_url(value="hello world", data_type="text") == "hello world" + def test_text_value_returns_none(self) -> None: + """Non-media types have no fetchable URL — return None.""" + assert _resolve_media_url(value="hello world", data_type="text") is None + + def test_text_empty_value_returns_none(self) -> None: + """Empty values return None even for media data types.""" + assert _resolve_media_url(value="", data_type="image_path") is None def test_data_uri_passes_through(self) -> None: """Pre-encoded data URIs are returned as-is.""" diff --git a/tests/unit/backend/test_response_contracts.py b/tests/unit/backend/test_response_contracts.py index 45ca0643a9..d51bee39ff 100644 --- a/tests/unit/backend/test_response_contracts.py +++ b/tests/unit/backend/test_response_contracts.py @@ -94,11 +94,13 @@ class TestMessagePieceViewContract: def test_dump_has_canonical_and_presentation_fields(self) -> None: """Test that the serialized piece exposes canonical and derived presentation fields.""" piece = _make_piece() - view = MessagePieceView.from_domain(piece, original_value="hello", converted_value="hello") + view = MessagePieceView.from_domain(piece) dumped = view.model_dump(mode="json") assert dumped["role"] == "user" assert dumped["original_value"] == "hello" + assert "original_value_url" in dumped + assert "converted_value_url" in dumped assert "original_value_mime_type" in dumped assert "converted_value_mime_type" in dumped assert "original_filename" in dumped @@ -110,7 +112,7 @@ def test_scores_are_score_views(self) -> None: """Test that nested scores serialize with the ScoreView computed field.""" piece = _make_piece() piece.scores = [_make_score()] - view = MessagePieceView.from_domain(piece, original_value="hello", converted_value="hello") + view = MessagePieceView.from_domain(piece) dumped = view.model_dump(mode="json") assert dumped["scores"][0]["scorer_type"] == "FloatScaleScorer" @@ -121,10 +123,8 @@ class TestMessageViewContract: def test_dump_has_turn_metadata_and_pieces(self) -> None: """Test that the serialized message exposes turn metadata and piece views.""" - piece = MessagePieceView.from_domain( - _make_piece(sequence=3, role="assistant"), original_value="hello", converted_value="hello" - ) - view = MessageView.from_domain(pieces=[piece]) + piece = MessagePieceView.from_domain(_make_piece(sequence=3, role="assistant")) + view = MessageView.from_domain(message_pieces=[piece]) dumped = view.model_dump(mode="json") assert dumped["turn_number"] == 3 @@ -139,16 +139,17 @@ class TestAttackSummaryContract: def _summary(self, ar: AttackResult) -> AttackSummary: now = datetime.now(timezone.utc) - return AttackSummary.from_domain( - ar, + data = {name: getattr(ar, name) for name in AttackResult.model_fields} + data.update( last_response=None, last_score=None, + labels={"env": "prod"}, message_count=2, last_message_preview="hi", - labels={"env": "prod"}, created_at=now, updated_at=now, ) + return AttackSummary.model_construct(**data) def test_dump_has_canonical_computed_and_stats_fields(self) -> None: """Test that the serialized summary exposes canonical, computed, and stats fields.""" @@ -203,15 +204,15 @@ def test_score_view_emits_deprecated_aliases(self) -> None: def test_message_piece_view_emits_deprecated_alias(self) -> None: """Test that MessagePieceView still emits piece_id mirroring id.""" - view = MessagePieceView.from_domain(_make_piece(), original_value="hello", converted_value="hello") + view = MessagePieceView.from_domain(_make_piece()) dumped = view.model_dump(mode="json") assert dumped["piece_id"] == str(view.id) def test_message_view_emits_deprecated_alias(self) -> None: """Test that MessageView still emits pieces mirroring message_pieces.""" - piece = MessagePieceView.from_domain(_make_piece(), original_value="hello", converted_value="hello") - dumped = MessageView.from_domain(pieces=[piece]).model_dump(mode="json") + piece = MessagePieceView.from_domain(_make_piece()) + dumped = MessageView.from_domain(message_pieces=[piece]).model_dump(mode="json") assert dumped["pieces"] == dumped["message_pieces"] From daf2854db82514436042093e39dce0c3eab5a373 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 4 Jun 2026 16:06:50 -0700 Subject: [PATCH 4/8] Phase 10: drop chatty refactor-narrating comments Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/backend/mappers/attack_mappers.py | 8 -------- pyrit/backend/models/attacks.py | 7 ------- 2 files changed, 15 deletions(-) diff --git a/pyrit/backend/mappers/attack_mappers.py b/pyrit/backend/mappers/attack_mappers.py index 0eefb8a514..c16bbb54f2 100644 --- a/pyrit/backend/mappers/attack_mappers.py +++ b/pyrit/backend/mappers/attack_mappers.py @@ -198,23 +198,15 @@ def attack_result_to_summary( Returns: AttackSummary view ready for the API response. """ - # Merge attack-result labels with conversation-level labels; conversation - # labels take precedence on key collision. labels = dict(ar.labels) if ar.labels else {} labels.update(stats.labels or {}) created_at, updated_at = _resolve_summary_timestamps(ar) - # Start with every canonical AttackResult field, then overlay view-narrowed - # values (last_response/last_score) and presentation extras (message_count, - # previews, timestamps, merged labels). ``model_construct`` skips validation - # because the source AttackResult is already valid. data = {name: getattr(ar, name) for name in AttackResult.model_fields} data.update( - # Overlays — narrow domain types to view types. last_response=_summary_last_response(ar.last_response), last_score=ScoreView.from_domain(ar.last_score) if ar.last_score else None, labels=labels, - # Presentation extras — not on AttackResult. message_count=stats.message_count, last_message_preview=format_last_message_preview( value=stats.last_message_preview, diff --git a/pyrit/backend/models/attacks.py b/pyrit/backend/models/attacks.py index cc5f1aa699..c38ff68e7e 100644 --- a/pyrit/backend/models/attacks.py +++ b/pyrit/backend/models/attacks.py @@ -296,13 +296,6 @@ def related_conversation_ids(self) -> list[str]: return sorted(ref.conversation_id for ref in self.related_conversations) -# Note: no ``from_domain`` classmethod here. The mapper assembles ``AttackSummary`` -# directly with ``model_construct`` because the construction overlays view-narrowed -# values (``last_response``, ``last_score``, merged ``labels``) on top of the -# canonical ``AttackResult`` fields — a smell that a ``from_domain`` signature -# couldn't express without competing parameters. - - # ============================================================================ # Conversation Messages Response # ============================================================================ From bf693798bd9123a8028ab02774ff5116b52f4a55 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 4 Jun 2026 16:06:50 -0700 Subject: [PATCH 5/8] Phase 10: update e2e mock fixtures to canonical wire field names Rename pieces -> message_pieces and piece_id -> id in mock response payloads. The frontend types only read the canonical names; the old aliases are deprecated and emitted by the backend serializer for back-compat but unused by the client. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- frontend/e2e/chat.spec.ts | 26 +++++++++++++------------- frontend/e2e/converters.spec.ts | 8 ++++---- frontend/e2e/errors.spec.ts | 8 ++++---- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/frontend/e2e/chat.spec.ts b/frontend/e2e/chat.spec.ts index 6ad9005c77..a58245d757 100644 --- a/frontend/e2e/chat.spec.ts +++ b/frontend/e2e/chat.spec.ts @@ -53,9 +53,9 @@ async function mockBackendAPIs(page: Page) { turn_number: turnNumber, role: "user", created_at: new Date().toISOString(), - pieces: [ + message_pieces: [ { - piece_id: `piece-u-${turnNumber}`, + id: `piece-u-${turnNumber}`, original_value_data_type: "text", converted_value_data_type: "text", original_value: userText, @@ -69,9 +69,9 @@ async function mockBackendAPIs(page: Page) { turn_number: turnNumber, role: "assistant", created_at: new Date().toISOString(), - pieces: [ + message_pieces: [ { - piece_id: `piece-a-${turnNumber}`, + id: `piece-a-${turnNumber}`, original_value_data_type: "text", converted_value_data_type: "text", original_value: `Mock response for: ${userText}`, @@ -362,9 +362,9 @@ function buildModalityMock( turn_number: 0, role: "user", created_at: new Date().toISOString(), - pieces: [ + message_pieces: [ { - piece_id: "u1", + id: "u1", original_value_data_type: "text", converted_value_data_type: "text", original_value: userText, @@ -378,7 +378,7 @@ function buildModalityMock( turn_number: 1, role: "assistant", created_at: new Date().toISOString(), - pieces: assistantPieces, + message_pieces: assistantPieces, }, ]; postSeen = true; @@ -423,7 +423,7 @@ function buildModalityMock( test.describe("Multi-modal: Image response", () => { const setupImageMock = buildModalityMock([ { - piece_id: "img-1", + id: "img-1", original_value_data_type: "text", converted_value_data_type: "image_path", original_value: "generated image", @@ -457,7 +457,7 @@ test.describe("Multi-modal: Image response", () => { test.describe("Multi-modal: Audio response", () => { const setupAudioMock = buildModalityMock([ { - piece_id: "aud-1", + id: "aud-1", original_value_data_type: "text", converted_value_data_type: "audio_path", original_value: "spoken text", @@ -488,7 +488,7 @@ test.describe("Multi-modal: Audio response", () => { test.describe("Multi-modal: Video response", () => { const setupVideoMock = buildModalityMock([ { - piece_id: "vid-1", + id: "vid-1", original_value_data_type: "text", converted_value_data_type: "video_path", original_value: "generated video", @@ -520,7 +520,7 @@ test.describe("Multi-modal: Video response", () => { test.describe("Multi-modal: Mixed text + image response", () => { const setupMixedMock = buildModalityMock([ { - piece_id: "txt-1", + id: "txt-1", original_value_data_type: "text", converted_value_data_type: "text", original_value: "Here is the analysis:", @@ -529,7 +529,7 @@ test.describe("Multi-modal: Mixed text + image response", () => { response_error: "none", }, { - piece_id: "img-2", + id: "img-2", original_value_data_type: "text", converted_value_data_type: "image_path", original_value: "chart image", @@ -559,7 +559,7 @@ test.describe("Multi-modal: Mixed text + image response", () => { test.describe("Multi-modal: Error response from target", () => { const setupErrorMock = buildModalityMock([ { - piece_id: "err-1", + id: "err-1", original_value_data_type: "text", converted_value_data_type: "text", original_value: "", diff --git a/frontend/e2e/converters.spec.ts b/frontend/e2e/converters.spec.ts index c149c335e4..74dfb91148 100644 --- a/frontend/e2e/converters.spec.ts +++ b/frontend/e2e/converters.spec.ts @@ -242,9 +242,9 @@ async function mockBackendAPIs(page: Page) { turn_number: turnNumber, role: "user", created_at: new Date().toISOString(), - pieces: [ + message_pieces: [ { - piece_id: `piece-u-${turnNumber}`, + id: `piece-u-${turnNumber}`, original_value_data_type: "text", converted_value_data_type: "text", original_value: userText, @@ -258,9 +258,9 @@ async function mockBackendAPIs(page: Page) { turn_number: turnNumber, role: "assistant", created_at: new Date().toISOString(), - pieces: [ + message_pieces: [ { - piece_id: `piece-a-${turnNumber}`, + id: `piece-a-${turnNumber}`, original_value_data_type: "text", converted_value_data_type: "text", original_value: `Mock response for: ${displayText}`, diff --git a/frontend/e2e/errors.spec.ts b/frontend/e2e/errors.spec.ts index 8fe79d1eb7..6a08676951 100644 --- a/frontend/e2e/errors.spec.ts +++ b/frontend/e2e/errors.spec.ts @@ -15,9 +15,9 @@ function buildSuccessMessageMock(userText: string) { turn_number: 1, role: "user", created_at: new Date().toISOString(), - pieces: [ + message_pieces: [ { - piece_id: "p-u", + id: "p-u", original_value_data_type: "text", converted_value_data_type: "text", original_value: userText, @@ -31,9 +31,9 @@ function buildSuccessMessageMock(userText: string) { turn_number: 1, role: "assistant", created_at: new Date().toISOString(), - pieces: [ + message_pieces: [ { - piece_id: "p-a", + id: "p-a", original_value_data_type: "text", converted_value_data_type: "text", original_value: `Reply to: ${userText}`, From f25f28ec86a4a5bcb7a28156378b503e0ba6e459 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Fri, 5 Jun 2026 16:36:06 -0700 Subject: [PATCH 6/8] Phase 10: address review comments (drop pieces alias, async summary signing, inline MessageView) - Drop `MessageView.from_domain` classmethod: callers (mapper + 2 tests) now use `MessageView.model_construct(message_pieces=...)` directly. Removes the signature ambiguity Roman flagged where `from_domain` took pre-mapped views instead of a domain Message. (#1941 comment 3362992954) - Promote `attack_result_to_summary` to async and SAS-sign blob URLs in the summary's `last_response` (matches the messages path). The sync helper silently returned unsigned blob URLs from list/detail endpoints, which would 403 on Azure-backed deployments. Both call sites in attack_service are already async, so the await is a one-line change. (#1941 comment 3362992959) - Drop the deprecated `pieces` computed_field on `MessageView`: it duplicated the heavy `message_pieces` array per piece (~24% extra wire size on a 2-piece conversation, per Roman's measurement). The wire-shape simplicity wins over the soft-alias break; cheap scalar aliases (`score_id`, `piece_id`, `scored_at`) are kept. Flips the contract test to assert `pieces` is absent. (#1941 comment 3362992961) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/backend/mappers/__init__.py | 4 +- pyrit/backend/mappers/attack_mappers.py | 18 ++--- pyrit/backend/models/attacks.py | 18 ----- pyrit/backend/services/attack_service.py | 6 +- tests/unit/backend/test_api_routes.py | 2 +- tests/unit/backend/test_mappers.py | 80 +++++++++---------- tests/unit/backend/test_response_contracts.py | 13 +-- uv.lock | 2 +- 8 files changed, 63 insertions(+), 80 deletions(-) diff --git a/pyrit/backend/mappers/__init__.py b/pyrit/backend/mappers/__init__.py index 8b1892e4ad..15bb298838 100644 --- a/pyrit/backend/mappers/__init__.py +++ b/pyrit/backend/mappers/__init__.py @@ -10,7 +10,7 @@ from pyrit.backend.mappers._preview import format_last_message_preview from pyrit.backend.mappers.attack_mappers import ( - attack_result_to_summary, + attack_result_to_summary_async, pyrit_messages_to_dto_async, request_piece_to_pyrit_message_piece, request_to_pyrit_message, @@ -23,7 +23,7 @@ ) __all__ = [ - "attack_result_to_summary", + "attack_result_to_summary_async", "converter_object_to_instance", "format_last_message_preview", "pyrit_messages_to_dto_async", diff --git a/pyrit/backend/mappers/attack_mappers.py b/pyrit/backend/mappers/attack_mappers.py index 21a4de45a3..4edb377d9e 100644 --- a/pyrit/backend/mappers/attack_mappers.py +++ b/pyrit/backend/mappers/attack_mappers.py @@ -178,7 +178,7 @@ def _resolve_media_url(*, value: str | None, data_type: str) -> str | None: return value -def attack_result_to_summary( +async def attack_result_to_summary_async( ar: AttackResult, *, stats: ConversationStats, @@ -188,8 +188,8 @@ def attack_result_to_summary( Conversation-level stats (message count, preview, labels, timestamps) are injected here; every other field is inherited from the AttackResult. The - summary's ``last_response`` media is resolved to a ``/api/media`` URL but not - SAS-signed — Azure Blob signing only happens on the async ``/messages`` path. + summary's ``last_response`` media is resolved to a ``/api/media`` URL and + Azure Blob URLs are SAS-signed so they're directly fetchable by the client. Args: ar: The domain AttackResult. @@ -204,7 +204,7 @@ def attack_result_to_summary( data = {name: getattr(ar, name) for name in AttackResult.model_fields} data.update( - last_response=_summary_last_response(ar.last_response), + last_response=await _summary_last_response_async(ar.last_response), last_score=ScoreView.from_domain(ar.last_score) if ar.last_score else None, labels=labels, message_count=stats.message_count, @@ -241,9 +241,9 @@ def _resolve_summary_timestamps(ar: AttackResult) -> tuple[datetime, datetime]: return created_at, updated_at -def _summary_last_response(piece: MessagePiece | None) -> MessagePieceView | None: +async def _summary_last_response_async(piece: MessagePiece | None) -> MessagePieceView | None: """ - Build a ``MessagePieceView`` for a summary's last response (sync media resolution, no SAS). + Build a ``MessagePieceView`` for a summary's last response with signed media URLs. Returns: A ``MessagePieceView`` for the piece, or ``None`` when no piece is given. @@ -252,10 +252,10 @@ def _summary_last_response(piece: MessagePiece | None) -> MessagePieceView | Non return None return MessagePieceView.from_domain( piece, - original_value_url=_resolve_media_url( + original_value_url=await _resolve_and_sign_media_async( value=piece.original_value, data_type=piece.original_value_data_type or "text" ), - converted_value_url=_resolve_media_url( + converted_value_url=await _resolve_and_sign_media_async( value=piece.converted_value, data_type=piece.converted_value_data_type or "text" ), ) @@ -303,7 +303,7 @@ async def pyrit_messages_to_dto_async(pyrit_messages: list[Message]) -> list[Mes p, original_value_url=original_value_url, converted_value_url=converted_value_url ) ) - messages.append(MessageView.from_domain(message_pieces=pieces)) + messages.append(MessageView.model_construct(message_pieces=pieces)) return messages diff --git a/pyrit/backend/models/attacks.py b/pyrit/backend/models/attacks.py index a919237c97..be22b9714d 100644 --- a/pyrit/backend/models/attacks.py +++ b/pyrit/backend/models/attacks.py @@ -199,24 +199,6 @@ def created_at(self) -> datetime: """Return the timestamp of the first piece.""" return self.message_pieces[0].timestamp if self.message_pieces else datetime.now(timezone.utc) - @computed_field( # type: ignore[prop-decorator] - deprecated="Use 'message_pieces' instead; 'pieces' is removed in 0.17.0." - ) - @property - def pieces(self) -> list[MessagePieceView]: - """Deprecated alias for ``message_pieces``.""" - return self.message_pieces - - @classmethod - def from_domain(cls, *, message_pieces: list[MessagePieceView]) -> "MessageView": - """ - Build a ``MessageView`` from already-mapped piece views. - - Returns: - A ``MessageView`` wrapping the provided piece views. - """ - return cls.model_construct(message_pieces=message_pieces) - class AttackSummary(AttackResult): """ diff --git a/pyrit/backend/services/attack_service.py b/pyrit/backend/services/attack_service.py index 836a783fba..6713cc884b 100644 --- a/pyrit/backend/services/attack_service.py +++ b/pyrit/backend/services/attack_service.py @@ -25,7 +25,7 @@ from urllib.parse import parse_qs, urlparse from pyrit.backend.mappers import ( - attack_result_to_summary, + attack_result_to_summary_async, format_last_message_preview, pyrit_messages_to_dto_async, request_piece_to_pyrit_message_piece, @@ -188,7 +188,7 @@ async def list_attacks_async( labels=conv_labels, ) - page.append(attack_result_to_summary(ar, stats=merged)) + page.append(await attack_result_to_summary_async(ar, stats=merged)) return AttackListResponse( items=page, @@ -235,7 +235,7 @@ async def get_attack_async(self, *, attack_result_id: str) -> AttackSummary | No ar = results[0] stats_map = self._memory.get_conversation_stats(conversation_ids=[ar.conversation_id]) stats = stats_map.get(ar.conversation_id, ConversationStats(message_count=0)) - return attack_result_to_summary(ar, stats=stats) + return await attack_result_to_summary_async(ar, stats=stats) async def get_conversation_messages_async( self, diff --git a/tests/unit/backend/test_api_routes.py b/tests/unit/backend/test_api_routes.py index c5c3f686aa..29a48bd656 100644 --- a/tests/unit/backend/test_api_routes.py +++ b/tests/unit/backend/test_api_routes.py @@ -55,7 +55,7 @@ def _make_message_view(*, role: str = "user", value: str = "hello", sequence: in sequence=sequence, ) piece_view = MessagePieceView.from_domain(piece) - return MessageView.from_domain(message_pieces=[piece_view]) + return MessageView.model_construct(message_pieces=[piece_view]) @pytest.fixture diff --git a/tests/unit/backend/test_mappers.py b/tests/unit/backend/test_mappers.py index f2775503f7..890ffe27e8 100644 --- a/tests/unit/backend/test_mappers.py +++ b/tests/unit/backend/test_mappers.py @@ -20,7 +20,7 @@ _is_azure_blob_url, _resolve_media_url, _sign_blob_url_async, - attack_result_to_summary, + attack_result_to_summary_async, pyrit_messages_to_dto_async, request_piece_to_pyrit_message_piece, request_to_pyrit_message, @@ -130,14 +130,14 @@ def _make_score( class TestAttackResultToSummary: - """Tests for attack_result_to_summary function.""" + """Tests for attack_result_to_summary_async function.""" - def test_basic_mapping(self) -> None: + async def test_basic_mapping(self) -> None: """Test that all fields are mapped correctly.""" ar = _make_attack_result(name="My Attack") stats = ConversationStats(message_count=2) - summary = attack_result_to_summary(ar, stats=stats) + summary = await attack_result_to_summary_async(ar, stats=stats) assert summary.conversation_id == ar.conversation_id assert summary.outcome == "undetermined" @@ -147,23 +147,23 @@ def test_basic_mapping(self) -> None: assert summary.target is not None assert summary.target.target_type == "TextTarget" - def test_empty_pieces_gives_zero_messages(self) -> None: + async def test_empty_pieces_gives_zero_messages(self) -> None: """Test mapping with no message pieces.""" ar = _make_attack_result() stats = ConversationStats(message_count=0) - summary = attack_result_to_summary(ar, stats=stats) + summary = await attack_result_to_summary_async(ar, stats=stats) assert summary.message_count == 0 assert summary.last_message_preview is None - def test_last_message_preview_truncates_long_raw_text(self) -> None: + async def test_last_message_preview_truncates_long_raw_text(self) -> None: """The mapper applies the preview formatter, which truncates long raw text.""" ar = _make_attack_result() long_text = "x" * 200 stats = ConversationStats(message_count=1, last_message_preview=long_text, last_message_data_type="text") - summary = attack_result_to_summary(ar, stats=stats) + summary = await attack_result_to_summary_async(ar, stats=stats) assert summary.last_message_preview is not None assert len(summary.last_message_preview) == 103 # 100 + "..." @@ -178,7 +178,7 @@ def test_last_message_preview_truncates_long_raw_text(self) -> None: ("binary_path", "[File: 1780010098266691.png]"), ], ) - def test_media_last_message_preview_hides_absolute_path(self, data_type: str, expected: str) -> None: + async def test_media_last_message_preview_hides_absolute_path(self, data_type: str, expected: str) -> None: """The mapper renders media-type previews as friendly labels rather than leaking the raw on-disk path it receives from memory.""" ar = _make_attack_result() @@ -189,21 +189,21 @@ def test_media_last_message_preview_hides_absolute_path(self, data_type: str, ex last_message_data_type=data_type, ) - summary = attack_result_to_summary(ar, stats=stats) + summary = await attack_result_to_summary_async(ar, stats=stats) assert summary.last_message_preview == expected assert "C:\\" not in (summary.last_message_preview or "") - def test_labels_are_mapped(self) -> None: + async def test_labels_are_mapped(self) -> None: """Test that labels are derived from stats.""" ar = _make_attack_result() stats = ConversationStats(message_count=1, labels={"env": "prod", "team": "red"}) - summary = attack_result_to_summary(ar, stats=stats) + summary = await attack_result_to_summary_async(ar, stats=stats) assert summary.labels == {"env": "prod", "team": "red", "test_ar_label": "test_ar_value"} - def test_labels_passed_through_without_normalization(self) -> None: + async def test_labels_passed_through_without_normalization(self) -> None: """Test that labels are passed through as-is (DB stores canonical keys after migration).""" ar = _make_attack_result() stats = ConversationStats( @@ -211,7 +211,7 @@ def test_labels_passed_through_without_normalization(self) -> None: labels={"operator": "alice", "operation": "op_red", "env": "prod"}, ) - summary = attack_result_to_summary(ar, stats=stats) + summary = await attack_result_to_summary_async(ar, stats=stats) assert summary.labels == { "operator": "alice", @@ -220,7 +220,7 @@ def test_labels_passed_through_without_normalization(self) -> None: "test_ar_label": "test_ar_value", } - def test_conversation_labels_take_precedence_on_collision(self) -> None: + async def test_conversation_labels_take_precedence_on_collision(self) -> None: """Test that conversation-level labels override attack-result labels on key collision.""" ar = _make_attack_result() stats = ConversationStats( @@ -228,38 +228,38 @@ def test_conversation_labels_take_precedence_on_collision(self) -> None: labels={"test_ar_label": "conversation_wins"}, ) - summary = attack_result_to_summary(ar, stats=stats) + summary = await attack_result_to_summary_async(ar, stats=stats) assert summary.labels["test_ar_label"] == "conversation_wins" - def test_outcome_success(self) -> None: + async def test_outcome_success(self) -> None: """Test that success outcome is mapped.""" ar = _make_attack_result(outcome=AttackOutcome.SUCCESS) stats = ConversationStats(message_count=0) - summary = attack_result_to_summary(ar, stats=stats) + summary = await attack_result_to_summary_async(ar, stats=stats) assert summary.outcome == "success" - def test_no_target_returns_none_fields(self) -> None: + async def test_no_target_returns_none_fields(self) -> None: """Test that target fields are None when no target identifier exists.""" ar = _make_attack_result(has_target=False) stats = ConversationStats(message_count=0) - summary = attack_result_to_summary(ar, stats=stats) + summary = await attack_result_to_summary_async(ar, stats=stats) assert summary.target is None - def test_attack_specific_params_passed_through(self) -> None: + async def test_attack_specific_params_passed_through(self) -> None: """Test that attack_specific_params are extracted from identifier.""" ar = _make_attack_result() stats = ConversationStats(message_count=0) - summary = attack_result_to_summary(ar, stats=stats) + summary = await attack_result_to_summary_async(ar, stats=stats) assert summary.attack_specific_params == {"source": "gui"} - def test_converters_extracted_from_identifier(self) -> None: + async def test_converters_extracted_from_identifier(self) -> None: """Test that converter class names are extracted into converters list.""" now = datetime.now(timezone.utc) ar = AttackResult( @@ -295,20 +295,20 @@ def test_converters_extracted_from_identifier(self) -> None: labels={"test_label": "test_value"}, ) - summary = attack_result_to_summary(ar, stats=ConversationStats(message_count=0)) + summary = await attack_result_to_summary_async(ar, stats=ConversationStats(message_count=0)) assert summary.converters == ["Base64Converter", "ROT13Converter"] - def test_no_converters_returns_empty_list(self) -> None: + async def test_no_converters_returns_empty_list(self) -> None: """Test that converters is empty list when no converters in identifier.""" ar = _make_attack_result() stats = ConversationStats(message_count=0) - summary = attack_result_to_summary(ar, stats=stats) + summary = await attack_result_to_summary_async(ar, stats=stats) assert summary.converters == [] - def test_related_conversation_ids_from_related_conversations(self) -> None: + async def test_related_conversation_ids_from_related_conversations(self) -> None: """Test that related_conversation_ids includes all related conversation IDs.""" from pyrit.models.conversation_reference import ConversationReference, ConversationType @@ -324,29 +324,29 @@ def test_related_conversation_ids_from_related_conversations(self) -> None: ), } - summary = attack_result_to_summary(ar, stats=ConversationStats(message_count=0)) + summary = await attack_result_to_summary_async(ar, stats=ConversationStats(message_count=0)) assert sorted(summary.related_conversation_ids) == ["branch-1", "pruned-1"] - def test_related_conversation_ids_empty_when_no_related(self) -> None: + async def test_related_conversation_ids_empty_when_no_related(self) -> None: """Test that related_conversation_ids is empty when no related conversations exist.""" ar = _make_attack_result() stats = ConversationStats(message_count=0) - summary = attack_result_to_summary(ar, stats=stats) + summary = await attack_result_to_summary_async(ar, stats=stats) assert summary.related_conversation_ids == [] - def test_message_count_from_stats(self) -> None: + async def test_message_count_from_stats(self) -> None: """Test that message_count comes from stats.""" ar = _make_attack_result() stats = ConversationStats(message_count=5) - summary = attack_result_to_summary(ar, stats=stats) + summary = await attack_result_to_summary_async(ar, stats=stats) assert summary.message_count == 5 - def test_created_at_prefers_ar_timestamp_when_metadata_absent(self) -> None: + async def test_created_at_prefers_ar_timestamp_when_metadata_absent(self) -> None: """When metadata['created_at'] is absent but ar.timestamp is set, use ar.timestamp.""" persisted_ts = datetime(2026, 4, 17, 12, 0, 0, tzinfo=timezone.utc) ar = AttackResult( @@ -355,12 +355,12 @@ def test_created_at_prefers_ar_timestamp_when_metadata_absent(self) -> None: outcome=AttackOutcome.SUCCESS, timestamp=persisted_ts, ) - summary = attack_result_to_summary(ar, stats=ConversationStats(message_count=0)) + summary = await attack_result_to_summary_async(ar, stats=ConversationStats(message_count=0)) assert summary.created_at == persisted_ts assert summary.updated_at == persisted_ts - def test_created_at_metadata_still_wins_over_ar_timestamp(self) -> None: + async def test_created_at_metadata_still_wins_over_ar_timestamp(self) -> None: """When both metadata['created_at'] and ar.timestamp are set, metadata wins (backward compat).""" metadata_ts = datetime(2026, 1, 1, 0, 0, 0, tzinfo=timezone.utc) ar_ts = datetime(2026, 4, 17, 12, 0, 0, tzinfo=timezone.utc) @@ -371,11 +371,11 @@ def test_created_at_metadata_still_wins_over_ar_timestamp(self) -> None: timestamp=ar_ts, metadata={"created_at": metadata_ts.isoformat()}, ) - summary = attack_result_to_summary(ar, stats=ConversationStats(message_count=0)) + summary = await attack_result_to_summary_async(ar, stats=ConversationStats(message_count=0)) assert summary.created_at == metadata_ts - def test_created_at_falls_back_to_now_when_both_absent(self) -> None: + async def test_created_at_falls_back_to_now_when_both_absent(self) -> None: """When neither metadata nor ar.timestamp is set, fall back to datetime.now().""" ar = AttackResult( conversation_id="attack-1", @@ -385,12 +385,12 @@ def test_created_at_falls_back_to_now_when_both_absent(self) -> None: ar.timestamp = None # type: ignore[assignment] before = datetime.now(timezone.utc) - summary = attack_result_to_summary(ar, stats=ConversationStats(message_count=0)) + summary = await attack_result_to_summary_async(ar, stats=ConversationStats(message_count=0)) after = datetime.now(timezone.utc) assert before <= summary.created_at <= after - def test_retry_events_mapped_to_response(self) -> None: + async def test_retry_events_mapped_to_response(self) -> None: """Test that retry events on an AttackResult are inherited by the AttackSummary.""" from pyrit.models.retry_event import RetryEvent @@ -412,7 +412,7 @@ def test_retry_events_mapped_to_response(self) -> None: ar.total_retries = 1 stats = ConversationStats(message_count=0) - summary = attack_result_to_summary(ar, stats=stats) + summary = await attack_result_to_summary_async(ar, stats=stats) assert summary.retry_events is not None assert len(summary.retry_events) == 1 diff --git a/tests/unit/backend/test_response_contracts.py b/tests/unit/backend/test_response_contracts.py index d51bee39ff..6530204cc8 100644 --- a/tests/unit/backend/test_response_contracts.py +++ b/tests/unit/backend/test_response_contracts.py @@ -124,7 +124,7 @@ class TestMessageViewContract: def test_dump_has_turn_metadata_and_pieces(self) -> None: """Test that the serialized message exposes turn metadata and piece views.""" piece = MessagePieceView.from_domain(_make_piece(sequence=3, role="assistant")) - view = MessageView.from_domain(message_pieces=[piece]) + view = MessageView.model_construct(message_pieces=[piece]) dumped = view.model_dump(mode="json") assert dumped["turn_number"] == 3 @@ -209,12 +209,13 @@ def test_message_piece_view_emits_deprecated_alias(self) -> None: assert dumped["piece_id"] == str(view.id) - def test_message_view_emits_deprecated_alias(self) -> None: - """Test that MessageView still emits pieces mirroring message_pieces.""" + def test_message_view_does_not_emit_pieces_alias(self) -> None: + """The deprecated ``pieces`` alias was dropped; only ``message_pieces`` is emitted.""" piece = MessagePieceView.from_domain(_make_piece()) - dumped = MessageView.from_domain(message_pieces=[piece]).model_dump(mode="json") + dumped = MessageView.model_construct(message_pieces=[piece]).model_dump(mode="json") - assert dumped["pieces"] == dumped["message_pieces"] + assert "pieces" not in dumped + assert "message_pieces" in dumped def test_aliases_marked_deprecated_in_schema(self) -> None: """Test that the deprecated aliases are flagged deprecated in the OpenAPI schema.""" @@ -225,4 +226,4 @@ def test_aliases_marked_deprecated_in_schema(self) -> None: assert score_props["score_id"]["deprecated"] is True assert score_props["scored_at"]["deprecated"] is True assert piece_props["piece_id"]["deprecated"] is True - assert message_props["pieces"]["deprecated"] is True + assert "pieces" not in message_props diff --git a/uv.lock b/uv.lock index 31454866e2..e4ecfb9887 100644 --- a/uv.lock +++ b/uv.lock @@ -5150,7 +5150,7 @@ wheels = [ [[package]] name = "pyrit" -version = "0.14.0.dev0" +version = "0.15.0.dev0" source = { editable = "." } dependencies = [ { name = "aiofiles" }, From fb9bac2471381fcd62756cd51e1e54bb30d25457 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 11 Jun 2026 16:59:58 -0700 Subject: [PATCH 7/8] MAINT: Drop redundant memory kwarg from pyrit_messages_to_dto_async Memory is a singleton everywhere else in the backend (services grab it in __init__, routes call get_memory_instance() inline). The mapper's optional memory kwarg was redundant: the only production caller passed self._memory (the singleton) and the real-object tests passed sqlite_instance (also the singleton, since the fixture calls set_memory_instance). The mapper now grabs CentralMemory.get_memory_instance() inline like the routes do. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/backend/mappers/attack_mappers.py | 17 +++++------------ pyrit/backend/services/attack_service.py | 2 +- tests/unit/backend/test_mappers.py | 6 +++--- 3 files changed, 9 insertions(+), 16 deletions(-) diff --git a/pyrit/backend/mappers/attack_mappers.py b/pyrit/backend/mappers/attack_mappers.py index 628106e30b..ffcb019aae 100644 --- a/pyrit/backend/mappers/attack_mappers.py +++ b/pyrit/backend/mappers/attack_mappers.py @@ -48,7 +48,6 @@ logger = logging.getLogger(__name__) if TYPE_CHECKING: - from pyrit.memory import MemoryInterface from pyrit.models.conversation_stats import ConversationStats # ============================================================================ @@ -302,7 +301,6 @@ def _score_lookup_key(*, piece: MessagePiece) -> str: async def _fetch_scores_by_piece_async( *, pyrit_messages: list[Message], - memory: MemoryInterface | None, ) -> dict[str, list[Score]]: """ Batch-fetch scores for every piece in ``pyrit_messages`` and group by piece id. @@ -317,9 +315,7 @@ async def _fetch_scores_by_piece_async( if not score_lookup_ids: return {} - if memory is None: - memory = CentralMemory.get_memory_instance() - + memory = CentralMemory.get_memory_instance() fetched = await asyncio.to_thread(memory.get_prompt_scores, prompt_ids=score_lookup_ids) grouped: dict[str, list[Score]] = {} @@ -330,8 +326,6 @@ async def _fetch_scores_by_piece_async( async def pyrit_messages_to_dto_async( pyrit_messages: list[Message], - *, - memory: MemoryInterface | None = None, ) -> list[MessageView]: """ Translate PyRIT messages to backend MessageView responses. @@ -343,15 +337,14 @@ async def pyrit_messages_to_dto_async( - Local files -> ``/api/media?path=...`` (served by the media endpoint) - Azure Blob Storage files -> signed URLs with SAS tokens - Scores are fetched from memory (``MessagePiece`` no longer carries them) - via a single batched ``get_prompt_scores`` call and attached to their - originating piece. Pass ``memory`` explicitly to avoid the ``CentralMemory`` - singleton lookup or to inject a fake in tests. + Scores are fetched from ``CentralMemory`` (``MessagePiece`` no longer carries + them) via a single batched ``get_prompt_scores`` call and attached to their + originating piece. Returns: List of MessageView responses for the API. """ - scores_by_piece = await _fetch_scores_by_piece_async(pyrit_messages=pyrit_messages, memory=memory) + scores_by_piece = await _fetch_scores_by_piece_async(pyrit_messages=pyrit_messages) messages: list[MessageView] = [] for msg in pyrit_messages: diff --git a/pyrit/backend/services/attack_service.py b/pyrit/backend/services/attack_service.py index eaf9ede368..866a9df364 100644 --- a/pyrit/backend/services/attack_service.py +++ b/pyrit/backend/services/attack_service.py @@ -268,7 +268,7 @@ async def get_conversation_messages_async( # Get messages for this conversation pyrit_messages = self._memory.get_conversation_messages(conversation_id=conversation_id) - backend_messages = await pyrit_messages_to_dto_async(list(pyrit_messages), memory=self._memory) + backend_messages = await pyrit_messages_to_dto_async(list(pyrit_messages)) return ConversationMessagesResponse( conversation_id=conversation_id, diff --git a/tests/unit/backend/test_mappers.py b/tests/unit/backend/test_mappers.py index 0c445d814f..9170f031df 100644 --- a/tests/unit/backend/test_mappers.py +++ b/tests/unit/backend/test_mappers.py @@ -738,7 +738,7 @@ async def test_scores_are_fetched_from_memory_and_attached(self, sqlite_instance sqlite_instance.add_scores_to_memory(scores=[score]) reloaded = sqlite_instance.get_conversation_messages(conversation_id=piece.conversation_id) - result = await pyrit_messages_to_dto_async(list(reloaded), memory=sqlite_instance) + result = await pyrit_messages_to_dto_async(list(reloaded)) assert len(result) == 1 dto_pieces = result[0].message_pieces @@ -759,7 +759,7 @@ async def test_empty_scores_when_none_recorded(self, sqlite_instance) -> None: sqlite_instance.add_message_to_memory(request=RealPyritMessage(message_pieces=[piece])) reloaded = sqlite_instance.get_conversation_messages(conversation_id=piece.conversation_id) - result = await pyrit_messages_to_dto_async(list(reloaded), memory=sqlite_instance) + result = await pyrit_messages_to_dto_async(list(reloaded)) assert result[0].message_pieces[0].scores == [] @@ -793,7 +793,7 @@ async def test_scores_are_grouped_per_piece_across_multiple_pieces(self, sqlite_ ) reloaded = sqlite_instance.get_conversation_messages(conversation_id=conv_id) - result = await pyrit_messages_to_dto_async(list(reloaded), memory=sqlite_instance) + result = await pyrit_messages_to_dto_async(list(reloaded)) by_role = {msg.role: msg for msg in result} assert by_role["user"].message_pieces[0].scores == [] From 340c3456611bf54a883bdbba46a28be61f558616 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 11 Jun 2026 17:26:08 -0700 Subject: [PATCH 8/8] MAINT: Avoid self-inflicted deprecation warnings + fix MessagePieceView docstring Two review fixes on the canonical-model-backed response views: - The deprecated wire aliases (score_id, scored_at, piece_id) used @computed_field(deprecated="..."), which emits a DeprecationWarning every time the field is serialized. Since computed fields are dumped on every response, the server warned against its own serialization on the hot path (consumers read JSON, not server logs). Switched to json_schema_extra={"deprecated": True}, which keeps the OpenAPI deprecated flag (and the removal note in the description) without the runtime warning. - MessagePieceView's docstring claimed response_error_description is "derived from the raw values at map time"; it is never populated (defaults to None, matching prior behavior). Corrected the docstring to say so. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/backend/models/attacks.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/pyrit/backend/models/attacks.py b/pyrit/backend/models/attacks.py index 8d38f907cf..45e674f8ea 100644 --- a/pyrit/backend/models/attacks.py +++ b/pyrit/backend/models/attacks.py @@ -50,18 +50,16 @@ def scorer_type(self) -> str: return identifier.class_name return "Unknown" - @computed_field(deprecated="Use 'id' instead; 'score_id' is removed in 0.17.0.") # type: ignore[prop-decorator] + @computed_field(json_schema_extra={"deprecated": True}) # type: ignore[prop-decorator] @property def score_id(self) -> str: - """Deprecated alias for ``id``.""" + """Deprecated alias for ``id``; use ``id`` instead (removed in 0.17.0).""" return str(self.id) - @computed_field( # type: ignore[prop-decorator] - deprecated="Use 'timestamp' instead; 'scored_at' is removed in 0.17.0." - ) + @computed_field(json_schema_extra={"deprecated": True}) # type: ignore[prop-decorator] @property def scored_at(self) -> datetime | None: - """Deprecated alias for ``timestamp``.""" + """Deprecated alias for ``timestamp``; use ``timestamp`` instead (removed in 0.17.0).""" return self.timestamp @classmethod @@ -93,8 +91,11 @@ class MessagePieceView(MessagePiece): populated by the mapper for media pieces (``/api/media?path=...`` for local files; SAS-signed URLs for Azure Blob; pass-through for data URIs and existing http(s) URLs). ``None`` for plain text and empty values. - - MIME types, download filenames and the response-error description are + - ``*_mime_type`` / ``*_filename`` — MIME types and download filenames derived from the raw values at map time. + + ``response_error_description`` is an optional error detail that defaults to + ``None``; the canonical piece carries no separate description. """ scores: list[ScoreView] = Field(default_factory=list) @@ -120,10 +121,10 @@ class MessagePieceView(MessagePiece): default=None, description="Description of the error if response_error is not 'none'" ) - @computed_field(deprecated="Use 'id' instead; 'piece_id' is removed in 0.17.0.") # type: ignore[prop-decorator] + @computed_field(json_schema_extra={"deprecated": True}) # type: ignore[prop-decorator] @property def piece_id(self) -> str: - """Deprecated alias for ``id``.""" + """Deprecated alias for ``id``; use ``id`` instead (removed in 0.17.0).""" return str(self.id) @classmethod