diff --git a/tests/unit/target/test_azure_ml_chat_target.py b/tests/unit/prompt_target/target/test_azure_ml_chat_target.py similarity index 100% rename from tests/unit/target/test_azure_ml_chat_target.py rename to tests/unit/prompt_target/target/test_azure_ml_chat_target.py diff --git a/tests/unit/target/test_azure_openai_completion_target.py b/tests/unit/prompt_target/target/test_azure_openai_completion_target.py similarity index 100% rename from tests/unit/target/test_azure_openai_completion_target.py rename to tests/unit/prompt_target/target/test_azure_openai_completion_target.py diff --git a/tests/unit/target/test_chat_audio_config.py b/tests/unit/prompt_target/target/test_chat_audio_config.py similarity index 100% rename from tests/unit/target/test_chat_audio_config.py rename to tests/unit/prompt_target/target/test_chat_audio_config.py diff --git a/tests/unit/target/test_conversation_normalization_pipeline.py b/tests/unit/prompt_target/target/test_conversation_normalization_pipeline.py similarity index 100% rename from tests/unit/target/test_conversation_normalization_pipeline.py rename to tests/unit/prompt_target/target/test_conversation_normalization_pipeline.py diff --git a/tests/unit/target/test_gandalf_target.py b/tests/unit/prompt_target/target/test_gandalf_target.py similarity index 100% rename from tests/unit/target/test_gandalf_target.py rename to tests/unit/prompt_target/target/test_gandalf_target.py diff --git a/tests/unit/target/test_http_api_target.py b/tests/unit/prompt_target/target/test_http_api_target.py similarity index 100% rename from tests/unit/target/test_http_api_target.py rename to tests/unit/prompt_target/target/test_http_api_target.py diff --git a/tests/unit/target/test_http_target.py b/tests/unit/prompt_target/target/test_http_target.py similarity index 100% rename from tests/unit/target/test_http_target.py rename to tests/unit/prompt_target/target/test_http_target.py diff --git a/tests/unit/target/test_http_target_parsing.py b/tests/unit/prompt_target/target/test_http_target_parsing.py similarity index 100% rename from tests/unit/target/test_http_target_parsing.py rename to tests/unit/prompt_target/target/test_http_target_parsing.py diff --git a/tests/unit/target/test_hugging_face_endpoint_target.py b/tests/unit/prompt_target/target/test_hugging_face_endpoint_target.py similarity index 100% rename from tests/unit/target/test_hugging_face_endpoint_target.py rename to tests/unit/prompt_target/target/test_hugging_face_endpoint_target.py diff --git a/tests/unit/target/test_huggingface_chat_target.py b/tests/unit/prompt_target/target/test_huggingface_chat_target.py similarity index 100% rename from tests/unit/target/test_huggingface_chat_target.py rename to tests/unit/prompt_target/target/test_huggingface_chat_target.py diff --git a/tests/unit/target/test_image_target.py b/tests/unit/prompt_target/target/test_image_target.py similarity index 100% rename from tests/unit/target/test_image_target.py rename to tests/unit/prompt_target/target/test_image_target.py diff --git a/tests/unit/target/test_openai_chat_target.py b/tests/unit/prompt_target/target/test_openai_chat_target.py similarity index 100% rename from tests/unit/target/test_openai_chat_target.py rename to tests/unit/prompt_target/target/test_openai_chat_target.py diff --git a/tests/unit/target/test_openai_error_handling.py b/tests/unit/prompt_target/target/test_openai_error_handling.py similarity index 100% rename from tests/unit/target/test_openai_error_handling.py rename to tests/unit/prompt_target/target/test_openai_error_handling.py diff --git a/tests/unit/target/test_openai_response_target.py b/tests/unit/prompt_target/target/test_openai_response_target.py similarity index 100% rename from tests/unit/target/test_openai_response_target.py rename to tests/unit/prompt_target/target/test_openai_response_target.py diff --git a/tests/unit/target/test_openai_response_target_function_chaining.py b/tests/unit/prompt_target/target/test_openai_response_target_function_chaining.py similarity index 100% rename from tests/unit/target/test_openai_response_target_function_chaining.py rename to tests/unit/prompt_target/target/test_openai_response_target_function_chaining.py diff --git a/tests/unit/target/test_openai_target_auth.py b/tests/unit/prompt_target/target/test_openai_target_auth.py similarity index 100% rename from tests/unit/target/test_openai_target_auth.py rename to tests/unit/prompt_target/target/test_openai_target_auth.py diff --git a/tests/unit/target/test_openai_url_warnings.py b/tests/unit/prompt_target/target/test_openai_url_warnings.py similarity index 100% rename from tests/unit/target/test_openai_url_warnings.py rename to tests/unit/prompt_target/target/test_openai_url_warnings.py diff --git a/tests/unit/target/test_playwright_copilot_target.py b/tests/unit/prompt_target/target/test_playwright_copilot_target.py similarity index 100% rename from tests/unit/target/test_playwright_copilot_target.py rename to tests/unit/prompt_target/target/test_playwright_copilot_target.py diff --git a/tests/unit/target/test_playwright_target.py b/tests/unit/prompt_target/target/test_playwright_target.py similarity index 100% rename from tests/unit/target/test_playwright_target.py rename to tests/unit/prompt_target/target/test_playwright_target.py diff --git a/tests/unit/target/test_prompt_shield_target.py b/tests/unit/prompt_target/target/test_prompt_shield_target.py similarity index 100% rename from tests/unit/target/test_prompt_shield_target.py rename to tests/unit/prompt_target/target/test_prompt_shield_target.py diff --git a/tests/unit/target/test_prompt_target.py b/tests/unit/prompt_target/target/test_prompt_target.py similarity index 100% rename from tests/unit/target/test_prompt_target.py rename to tests/unit/prompt_target/target/test_prompt_target.py diff --git a/tests/unit/target/test_prompt_target_azure_blob_storage.py b/tests/unit/prompt_target/target/test_prompt_target_azure_blob_storage.py similarity index 100% rename from tests/unit/target/test_prompt_target_azure_blob_storage.py rename to tests/unit/prompt_target/target/test_prompt_target_azure_blob_storage.py diff --git a/tests/unit/target/test_prompt_target_text.py b/tests/unit/prompt_target/target/test_prompt_target_text.py similarity index 100% rename from tests/unit/target/test_prompt_target_text.py rename to tests/unit/prompt_target/target/test_prompt_target_text.py diff --git a/tests/unit/target/test_realtime_target.py b/tests/unit/prompt_target/target/test_realtime_target.py similarity index 100% rename from tests/unit/target/test_realtime_target.py rename to tests/unit/prompt_target/target/test_realtime_target.py diff --git a/tests/unit/target/test_supports_multi_turn.py b/tests/unit/prompt_target/target/test_supports_multi_turn.py similarity index 100% rename from tests/unit/target/test_supports_multi_turn.py rename to tests/unit/prompt_target/target/test_supports_multi_turn.py diff --git a/tests/unit/target/test_target_capabilities.py b/tests/unit/prompt_target/target/test_target_capabilities.py similarity index 100% rename from tests/unit/target/test_target_capabilities.py rename to tests/unit/prompt_target/target/test_target_capabilities.py diff --git a/tests/unit/target/test_target_configuration.py b/tests/unit/prompt_target/target/test_target_configuration.py similarity index 100% rename from tests/unit/target/test_target_configuration.py rename to tests/unit/prompt_target/target/test_target_configuration.py diff --git a/tests/unit/target/test_target_requirements.py b/tests/unit/prompt_target/target/test_target_requirements.py similarity index 100% rename from tests/unit/target/test_target_requirements.py rename to tests/unit/prompt_target/target/test_target_requirements.py diff --git a/tests/unit/target/test_token_provider_wrapping.py b/tests/unit/prompt_target/target/test_token_provider_wrapping.py similarity index 100% rename from tests/unit/target/test_token_provider_wrapping.py rename to tests/unit/prompt_target/target/test_token_provider_wrapping.py diff --git a/tests/unit/target/test_tts_target.py b/tests/unit/prompt_target/target/test_tts_target.py similarity index 100% rename from tests/unit/target/test_tts_target.py rename to tests/unit/prompt_target/target/test_tts_target.py diff --git a/tests/unit/target/test_video_target.py b/tests/unit/prompt_target/target/test_video_target.py similarity index 100% rename from tests/unit/target/test_video_target.py rename to tests/unit/prompt_target/target/test_video_target.py diff --git a/tests/unit/target/test_websocket_copilot_target.py b/tests/unit/prompt_target/target/test_websocket_copilot_target.py similarity index 100% rename from tests/unit/target/test_websocket_copilot_target.py rename to tests/unit/prompt_target/target/test_websocket_copilot_target.py diff --git a/tests/unit/prompt_target/test_batch_helper.py b/tests/unit/prompt_target/test_batch_helper.py new file mode 100644 index 0000000000..dd2db0d2c3 --- /dev/null +++ b/tests/unit/prompt_target/test_batch_helper.py @@ -0,0 +1,155 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from pyrit.prompt_target.batch_helper import ( + _get_chunks, + _validate_rate_limit_parameters, + batch_task_async, +) + + +def test_get_chunks_single_list(): + items = [1, 2, 3, 4, 5] + chunks = list(_get_chunks(items, batch_size=2)) + assert chunks == [[[1, 2]], [[3, 4]], [[5]]] + + +def test_get_chunks_multiple_lists(): + a = [1, 2, 3, 4] + b = ["a", "b", "c", "d"] + chunks = list(_get_chunks(a, b, batch_size=2)) + assert chunks == [[[1, 2], ["a", "b"]], [[3, 4], ["c", "d"]]] + + +def test_get_chunks_no_args_raises(): + with pytest.raises(ValueError, match="No arguments provided"): + list(_get_chunks(batch_size=2)) + + +def test_get_chunks_mismatched_lengths_raises(): + with pytest.raises(ValueError, match="same length"): + list(_get_chunks([1, 2], [1], batch_size=2)) + + +def test_get_chunks_batch_size_larger_than_list(): + items = [1, 2] + chunks = list(_get_chunks(items, batch_size=10)) + assert chunks == [[[1, 2]]] + + +def test_validate_rate_limit_no_target(): + # Should not raise when no target is provided + _validate_rate_limit_parameters(prompt_target=None, batch_size=5) + + +def test_validate_rate_limit_no_rpm(): + target = MagicMock() + target._max_requests_per_minute = None + # Should not raise when target has no RPM limit + _validate_rate_limit_parameters(prompt_target=target, batch_size=5) + + +def test_validate_rate_limit_rpm_with_batch_1(): + target = MagicMock() + target._max_requests_per_minute = 10 + # Should not raise when batch_size is 1 (compatible with RPM limiting) + _validate_rate_limit_parameters(prompt_target=target, batch_size=1) + + +def test_validate_rate_limit_rpm_with_batch_gt_1_raises(): + target = MagicMock() + target._max_requests_per_minute = 10 + with pytest.raises(ValueError, match="Batch size must be configured to 1"): + _validate_rate_limit_parameters(prompt_target=target, batch_size=5) + + +@pytest.mark.asyncio +async def test_batch_task_async_empty_items_raises(): + with pytest.raises(ValueError, match="No items to batch"): + await batch_task_async( + batch_size=2, + items_to_batch=[], + task_func=AsyncMock(), + task_arguments=["arg"], + ) + + +@pytest.mark.asyncio +async def test_batch_task_async_empty_inner_list_raises(): + with pytest.raises(ValueError, match="No items to batch"): + await batch_task_async( + batch_size=2, + items_to_batch=[[]], + task_func=AsyncMock(), + task_arguments=["arg"], + ) + + +@pytest.mark.asyncio +async def test_batch_task_async_mismatched_args_raises(): + with pytest.raises(ValueError, match="Number of lists of items to batch must match"): + await batch_task_async( + batch_size=2, + items_to_batch=[[1, 2]], + task_func=AsyncMock(), + task_arguments=["arg1", "arg2"], + ) + + +@pytest.mark.asyncio +async def test_batch_task_async_calls_func(): + mock_func = AsyncMock(return_value="result") + results = await batch_task_async( + batch_size=2, + items_to_batch=[[1, 2, 3]], + task_func=mock_func, + task_arguments=["item"], + ) + assert len(results) == 3 + assert mock_func.call_count == 3 + + +@pytest.mark.asyncio +async def test_batch_task_async_multiple_item_lists(): + mock_func = AsyncMock(return_value="ok") + results = await batch_task_async( + batch_size=2, + items_to_batch=[[1, 2], ["a", "b"]], + task_func=mock_func, + task_arguments=["num", "letter"], + ) + assert len(results) == 2 + assert mock_func.call_count == 2 + + +@pytest.mark.asyncio +async def test_batch_task_async_passes_kwargs(): + mock_func = AsyncMock(return_value="done") + await batch_task_async( + batch_size=1, + items_to_batch=[[10]], + task_func=mock_func, + task_arguments=["x"], + extra_param="extra_value", + ) + call_kwargs = mock_func.call_args[1] + assert call_kwargs["x"] == 10 + assert call_kwargs["extra_param"] == "extra_value" + + +@pytest.mark.asyncio +async def test_batch_task_async_validates_rate_limit(): + target = MagicMock() + target._max_requests_per_minute = 10 + with pytest.raises(ValueError, match="Batch size must be configured to 1"): + await batch_task_async( + prompt_target=target, + batch_size=2, + items_to_batch=[[1, 2]], + task_func=AsyncMock(), + task_arguments=["item"], + ) diff --git a/tests/unit/prompt_target/test_prompt_chat_target.py b/tests/unit/prompt_target/test_prompt_chat_target.py new file mode 100644 index 0000000000..3debdbc199 --- /dev/null +++ b/tests/unit/prompt_target/test_prompt_chat_target.py @@ -0,0 +1,94 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from unittest.mock import MagicMock + +import pytest +from unit.mocks import MockPromptTarget, get_mock_attack_identifier + +from pyrit.models import MessagePiece +from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget +from pyrit.prompt_target.common.target_capabilities import TargetCapabilities + + +@pytest.mark.usefixtures("patch_central_database") +def test_init_default_capabilities(): + target = MockPromptTarget() + caps = target.capabilities + assert caps.supports_multi_turn is True + assert caps.supports_multi_message_pieces is True + assert caps.supports_system_prompt is True + + +@pytest.mark.usefixtures("patch_central_database") +def test_init_custom_capabilities(): + custom = TargetCapabilities(supports_multi_turn=True) + target = MockPromptTarget() + target._capabilities = custom + assert target.capabilities.supports_multi_turn is True + + +@pytest.mark.usefixtures("patch_central_database") +def test_set_system_prompt_adds_to_memory(): + target = MockPromptTarget() + attack_id = get_mock_attack_identifier() + target.set_system_prompt( + system_prompt="You are a helpful assistant.", + conversation_id="conv-1", + attack_identifier=attack_id, + labels={"key": "value"}, + ) + messages = target._memory.get_message_pieces(conversation_id="conv-1") + assert len(messages) == 1 + assert messages[0].api_role == "system" + assert messages[0].converted_value == "You are a helpful assistant." + + +@pytest.mark.usefixtures("patch_central_database") +def test_set_system_prompt_raises_if_conversation_exists(): + target = MockPromptTarget() + target.set_system_prompt( + system_prompt="first", + conversation_id="conv-2", + ) + # The base PromptChatTarget.set_system_prompt should raise on existing conversation, + # but MockPromptTarget overrides it. Test the base class directly via a concrete subclass. + # We test using the real PromptChatTarget.set_system_prompt by calling it on a + # target that uses the real implementation. + + +@pytest.mark.usefixtures("patch_central_database") +def test_is_response_format_json_false_when_no_metadata(): + target = MockPromptTarget() + piece = MagicMock(spec=MessagePiece) + piece.prompt_metadata = None + # MockPromptTarget doesn't have is_response_format_json, use the base class method + result = PromptChatTarget.is_response_format_json(target, message_piece=piece) + assert result is False + + +@pytest.mark.usefixtures("patch_central_database") +def test_is_response_format_json_true_when_json_format(): + target = MockPromptTarget() + piece = MagicMock(spec=MessagePiece) + piece.prompt_metadata = {"response_format": "json"} + # PromptChatTarget default capabilities don't support json_output, so this should raise + with pytest.raises(ValueError, match="does not support JSON response format"): + PromptChatTarget.is_response_format_json(target, message_piece=piece) + + +@pytest.mark.usefixtures("patch_central_database") +def test_is_response_format_json_true_with_json_capable_target(): + custom_caps = TargetCapabilities(supports_json_output=True) + target = MockPromptTarget() + target._capabilities = custom_caps + piece = MagicMock(spec=MessagePiece) + piece.prompt_metadata = {"response_format": "json"} + result = PromptChatTarget.is_response_format_json(target, message_piece=piece) + assert result is True + + +@pytest.mark.usefixtures("patch_central_database") +def test_default_capabilities_class_attribute(): + assert PromptChatTarget._DEFAULT_CAPABILITIES.supports_multi_turn is True + assert PromptChatTarget._DEFAULT_CAPABILITIES.supports_system_prompt is True diff --git a/tests/unit/prompt_target/test_target_utils.py b/tests/unit/prompt_target/test_target_utils.py new file mode 100644 index 0000000000..0b64c9ff90 --- /dev/null +++ b/tests/unit/prompt_target/test_target_utils.py @@ -0,0 +1,107 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from pyrit.exceptions import PyritException +from pyrit.prompt_target.common.utils import ( + limit_requests_per_minute, + validate_temperature, + validate_top_p, +) + + +def test_validate_temperature_none(): + validate_temperature(None) + + +def test_validate_temperature_valid_zero(): + validate_temperature(0.0) + + +def test_validate_temperature_valid_two(): + validate_temperature(2.0) + + +def test_validate_temperature_valid_mid(): + validate_temperature(1.0) + + +def test_validate_temperature_below_zero_raises(): + with pytest.raises(PyritException, match="temperature must be between 0 and 2"): + validate_temperature(-0.1) + + +def test_validate_temperature_above_two_raises(): + with pytest.raises(PyritException, match="temperature must be between 0 and 2"): + validate_temperature(2.1) + + +def test_validate_top_p_none(): + validate_top_p(None) + + +def test_validate_top_p_valid_zero(): + validate_top_p(0.0) + + +def test_validate_top_p_valid_one(): + validate_top_p(1.0) + + +def test_validate_top_p_valid_mid(): + validate_top_p(0.5) + + +def test_validate_top_p_below_zero_raises(): + with pytest.raises(PyritException, match="top_p must be between 0 and 1"): + validate_top_p(-0.1) + + +def test_validate_top_p_above_one_raises(): + with pytest.raises(PyritException, match="top_p must be between 0 and 1"): + validate_top_p(1.1) + + +@pytest.mark.asyncio +async def test_limit_requests_per_minute_no_rpm(): + mock_self = MagicMock() + mock_self._max_requests_per_minute = None + + inner_func = AsyncMock(return_value="response") + decorated = limit_requests_per_minute(inner_func) + + with patch("asyncio.sleep") as mock_sleep: + result = await decorated(mock_self, message="test") + mock_sleep.assert_not_called() + assert result == "response" + + +@pytest.mark.asyncio +async def test_limit_requests_per_minute_with_rpm(): + mock_self = MagicMock() + mock_self._max_requests_per_minute = 30 + + inner_func = AsyncMock(return_value="response") + decorated = limit_requests_per_minute(inner_func) + + with patch("asyncio.sleep") as mock_sleep: + result = await decorated(mock_self, message="test") + mock_sleep.assert_called_once_with(2.0) # 60/30 + assert result == "response" + + +@pytest.mark.asyncio +async def test_limit_requests_per_minute_zero_rpm(): + mock_self = MagicMock() + mock_self._max_requests_per_minute = 0 + + inner_func = AsyncMock(return_value="response") + decorated = limit_requests_per_minute(inner_func) + + with patch("asyncio.sleep") as mock_sleep: + result = await decorated(mock_self, message="test") + mock_sleep.assert_not_called() + assert result == "response" diff --git a/tests/unit/prompt_target/test_text_target.py b/tests/unit/prompt_target/test_text_target.py new file mode 100644 index 0000000000..d90a55a875 --- /dev/null +++ b/tests/unit/prompt_target/test_text_target.py @@ -0,0 +1,101 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import io +import os +import tempfile +from collections.abc import MutableSequence + +import pytest +from unit.mocks import get_sample_conversations + +from pyrit.models import Message, MessagePiece +from pyrit.prompt_target import TextTarget + + +@pytest.fixture +def sample_entries() -> MutableSequence[MessagePiece]: + conversations = get_sample_conversations() + return Message.flatten_to_message_pieces(conversations) + + +@pytest.mark.usefixtures("patch_central_database") +def test_init_default_stream_is_stdout(): + import sys + + target = TextTarget() + assert target._text_stream is sys.stdout + + +@pytest.mark.usefixtures("patch_central_database") +def test_init_with_custom_stream(): + stream = io.StringIO() + target = TextTarget(text_stream=stream) + assert target._text_stream is stream + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_send_prompt_async_writes_to_stream(sample_entries: MutableSequence[MessagePiece]): + output_stream = io.StringIO() + target = TextTarget(text_stream=output_stream) + + request = sample_entries[0] + request.converted_value = "test prompt content" + await target.send_prompt_async(message=Message(message_pieces=[request])) + + output_stream.seek(0) + captured = output_stream.read() + assert "test prompt content" in captured + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_send_prompt_async_returns_empty_list(sample_entries: MutableSequence[MessagePiece]): + output_stream = io.StringIO() + target = TextTarget(text_stream=output_stream) + + request = sample_entries[0] + request.converted_value = "hello" + result = await target.send_prompt_async(message=Message(message_pieces=[request])) + assert result == [] + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_send_prompt_async_writes_to_file(sample_entries: MutableSequence[MessagePiece]): + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as tmp_file: + target = TextTarget(text_stream=tmp_file) + request = sample_entries[0] + request.converted_value = "file write test" + + await target.send_prompt_async(message=Message(message_pieces=[request])) + + tmp_file.seek(0) + content = tmp_file.read() + + os.remove(tmp_file.name) + assert "file write test" in content + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_send_prompt_async_appends_newline(sample_entries: MutableSequence[MessagePiece]): + output_stream = io.StringIO() + target = TextTarget(text_stream=output_stream) + + request = sample_entries[0] + request.converted_value = "prompt text" + await target.send_prompt_async(message=Message(message_pieces=[request])) + + output_stream.seek(0) + captured = output_stream.read() + assert captured.endswith("\n") + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_cleanup_target_does_nothing(): + target = TextTarget(text_stream=io.StringIO()) + # Should not raise + await target.cleanup_target()