Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 155 additions & 0 deletions tests/unit/prompt_target/test_batch_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from unittest.mock import AsyncMock, MagicMock

import pytest

from pyrit.prompt_target.batch_helper import (
_get_chunks,
_validate_rate_limit_parameters,
batch_task_async,
)


def test_get_chunks_single_list():
items = [1, 2, 3, 4, 5]
chunks = list(_get_chunks(items, batch_size=2))
assert chunks == [[[1, 2]], [[3, 4]], [[5]]]


def test_get_chunks_multiple_lists():
a = [1, 2, 3, 4]
b = ["a", "b", "c", "d"]
chunks = list(_get_chunks(a, b, batch_size=2))
assert chunks == [[[1, 2], ["a", "b"]], [[3, 4], ["c", "d"]]]


def test_get_chunks_no_args_raises():
with pytest.raises(ValueError, match="No arguments provided"):
list(_get_chunks(batch_size=2))


def test_get_chunks_mismatched_lengths_raises():
with pytest.raises(ValueError, match="same length"):
list(_get_chunks([1, 2], [1], batch_size=2))


def test_get_chunks_batch_size_larger_than_list():
items = [1, 2]
chunks = list(_get_chunks(items, batch_size=10))
assert chunks == [[[1, 2]]]


def test_validate_rate_limit_no_target():
# Should not raise when no target is provided
_validate_rate_limit_parameters(prompt_target=None, batch_size=5)


def test_validate_rate_limit_no_rpm():
target = MagicMock()
target._max_requests_per_minute = None
# Should not raise when target has no RPM limit
_validate_rate_limit_parameters(prompt_target=target, batch_size=5)


def test_validate_rate_limit_rpm_with_batch_1():
target = MagicMock()
target._max_requests_per_minute = 10
# Should not raise when batch_size is 1 (compatible with RPM limiting)
_validate_rate_limit_parameters(prompt_target=target, batch_size=1)


def test_validate_rate_limit_rpm_with_batch_gt_1_raises():
target = MagicMock()
target._max_requests_per_minute = 10
with pytest.raises(ValueError, match="Batch size must be configured to 1"):
_validate_rate_limit_parameters(prompt_target=target, batch_size=5)


@pytest.mark.asyncio
async def test_batch_task_async_empty_items_raises():
with pytest.raises(ValueError, match="No items to batch"):
await batch_task_async(
batch_size=2,
items_to_batch=[],
task_func=AsyncMock(),
task_arguments=["arg"],
)


@pytest.mark.asyncio
async def test_batch_task_async_empty_inner_list_raises():
with pytest.raises(ValueError, match="No items to batch"):
await batch_task_async(
batch_size=2,
items_to_batch=[[]],
task_func=AsyncMock(),
task_arguments=["arg"],
)


@pytest.mark.asyncio
async def test_batch_task_async_mismatched_args_raises():
with pytest.raises(ValueError, match="Number of lists of items to batch must match"):
await batch_task_async(
batch_size=2,
items_to_batch=[[1, 2]],
task_func=AsyncMock(),
task_arguments=["arg1", "arg2"],
)


@pytest.mark.asyncio
async def test_batch_task_async_calls_func():
mock_func = AsyncMock(return_value="result")
results = await batch_task_async(
batch_size=2,
items_to_batch=[[1, 2, 3]],
task_func=mock_func,
task_arguments=["item"],
)
assert len(results) == 3
assert mock_func.call_count == 3


@pytest.mark.asyncio
async def test_batch_task_async_multiple_item_lists():
mock_func = AsyncMock(return_value="ok")
results = await batch_task_async(
batch_size=2,
items_to_batch=[[1, 2], ["a", "b"]],
task_func=mock_func,
task_arguments=["num", "letter"],
)
assert len(results) == 2
assert mock_func.call_count == 2


@pytest.mark.asyncio
async def test_batch_task_async_passes_kwargs():
mock_func = AsyncMock(return_value="done")
await batch_task_async(
batch_size=1,
items_to_batch=[[10]],
task_func=mock_func,
task_arguments=["x"],
extra_param="extra_value",
)
call_kwargs = mock_func.call_args[1]
assert call_kwargs["x"] == 10
assert call_kwargs["extra_param"] == "extra_value"


@pytest.mark.asyncio
async def test_batch_task_async_validates_rate_limit():
target = MagicMock()
target._max_requests_per_minute = 10
with pytest.raises(ValueError, match="Batch size must be configured to 1"):
await batch_task_async(
prompt_target=target,
batch_size=2,
items_to_batch=[[1, 2]],
task_func=AsyncMock(),
task_arguments=["item"],
)
94 changes: 94 additions & 0 deletions tests/unit/prompt_target/test_prompt_chat_target.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from unittest.mock import MagicMock

import pytest
from unit.mocks import MockPromptTarget, get_mock_attack_identifier

from pyrit.models import MessagePiece
from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget
from pyrit.prompt_target.common.target_capabilities import TargetCapabilities


@pytest.mark.usefixtures("patch_central_database")
def test_init_default_capabilities():
target = MockPromptTarget()
caps = target.capabilities
assert caps.supports_multi_turn is True
assert caps.supports_multi_message_pieces is True
assert caps.supports_system_prompt is True


@pytest.mark.usefixtures("patch_central_database")
def test_init_custom_capabilities():
custom = TargetCapabilities(supports_multi_turn=True)
target = MockPromptTarget()
target._capabilities = custom
assert target.capabilities.supports_multi_turn is True


@pytest.mark.usefixtures("patch_central_database")
def test_set_system_prompt_adds_to_memory():
target = MockPromptTarget()
attack_id = get_mock_attack_identifier()
target.set_system_prompt(
system_prompt="You are a helpful assistant.",
conversation_id="conv-1",
attack_identifier=attack_id,
labels={"key": "value"},
)
messages = target._memory.get_message_pieces(conversation_id="conv-1")
assert len(messages) == 1
assert messages[0].api_role == "system"
assert messages[0].converted_value == "You are a helpful assistant."


@pytest.mark.usefixtures("patch_central_database")
def test_set_system_prompt_raises_if_conversation_exists():
target = MockPromptTarget()
target.set_system_prompt(
system_prompt="first",
conversation_id="conv-2",
)
# The base PromptChatTarget.set_system_prompt should raise on existing conversation,
# but MockPromptTarget overrides it. Test the base class directly via a concrete subclass.
# We test using the real PromptChatTarget.set_system_prompt by calling it on a
# target that uses the real implementation.


@pytest.mark.usefixtures("patch_central_database")
def test_is_response_format_json_false_when_no_metadata():
target = MockPromptTarget()
piece = MagicMock(spec=MessagePiece)
piece.prompt_metadata = None
# MockPromptTarget doesn't have is_response_format_json, use the base class method
result = PromptChatTarget.is_response_format_json(target, message_piece=piece)
assert result is False


@pytest.mark.usefixtures("patch_central_database")
def test_is_response_format_json_true_when_json_format():
target = MockPromptTarget()
piece = MagicMock(spec=MessagePiece)
piece.prompt_metadata = {"response_format": "json"}
# PromptChatTarget default capabilities don't support json_output, so this should raise
with pytest.raises(ValueError, match="does not support JSON response format"):
PromptChatTarget.is_response_format_json(target, message_piece=piece)


@pytest.mark.usefixtures("patch_central_database")
def test_is_response_format_json_true_with_json_capable_target():
custom_caps = TargetCapabilities(supports_json_output=True)
target = MockPromptTarget()
target._capabilities = custom_caps
piece = MagicMock(spec=MessagePiece)
piece.prompt_metadata = {"response_format": "json"}
result = PromptChatTarget.is_response_format_json(target, message_piece=piece)
assert result is True


@pytest.mark.usefixtures("patch_central_database")
def test_default_capabilities_class_attribute():
assert PromptChatTarget._DEFAULT_CAPABILITIES.supports_multi_turn is True
assert PromptChatTarget._DEFAULT_CAPABILITIES.supports_system_prompt is True
107 changes: 107 additions & 0 deletions tests/unit/prompt_target/test_target_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from unittest.mock import AsyncMock, MagicMock, patch

import pytest

from pyrit.exceptions import PyritException
from pyrit.prompt_target.common.utils import (
limit_requests_per_minute,
validate_temperature,
validate_top_p,
)


def test_validate_temperature_none():
validate_temperature(None)


def test_validate_temperature_valid_zero():
validate_temperature(0.0)


def test_validate_temperature_valid_two():
validate_temperature(2.0)


def test_validate_temperature_valid_mid():
validate_temperature(1.0)


def test_validate_temperature_below_zero_raises():
with pytest.raises(PyritException, match="temperature must be between 0 and 2"):
validate_temperature(-0.1)


def test_validate_temperature_above_two_raises():
with pytest.raises(PyritException, match="temperature must be between 0 and 2"):
validate_temperature(2.1)


def test_validate_top_p_none():
validate_top_p(None)


def test_validate_top_p_valid_zero():
validate_top_p(0.0)


def test_validate_top_p_valid_one():
validate_top_p(1.0)


def test_validate_top_p_valid_mid():
validate_top_p(0.5)


def test_validate_top_p_below_zero_raises():
with pytest.raises(PyritException, match="top_p must be between 0 and 1"):
validate_top_p(-0.1)


def test_validate_top_p_above_one_raises():
with pytest.raises(PyritException, match="top_p must be between 0 and 1"):
validate_top_p(1.1)


@pytest.mark.asyncio
async def test_limit_requests_per_minute_no_rpm():
mock_self = MagicMock()
mock_self._max_requests_per_minute = None

inner_func = AsyncMock(return_value="response")
decorated = limit_requests_per_minute(inner_func)

with patch("asyncio.sleep") as mock_sleep:
result = await decorated(mock_self, message="test")
mock_sleep.assert_not_called()
assert result == "response"


@pytest.mark.asyncio
async def test_limit_requests_per_minute_with_rpm():
mock_self = MagicMock()
mock_self._max_requests_per_minute = 30

inner_func = AsyncMock(return_value="response")
decorated = limit_requests_per_minute(inner_func)

with patch("asyncio.sleep") as mock_sleep:
result = await decorated(mock_self, message="test")
mock_sleep.assert_called_once_with(2.0) # 60/30
assert result == "response"


@pytest.mark.asyncio
async def test_limit_requests_per_minute_zero_rpm():
mock_self = MagicMock()
mock_self._max_requests_per_minute = 0

inner_func = AsyncMock(return_value="response")
decorated = limit_requests_per_minute(inner_func)

with patch("asyncio.sleep") as mock_sleep:
result = await decorated(mock_self, message="test")
mock_sleep.assert_not_called()
assert result == "response"
Loading
Loading