Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions tests/unit/identifiers/test_class_name_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import pytest

from pyrit.identifiers.class_name_utils import class_name_to_snake_case, snake_case_to_class_name

# --- class_name_to_snake_case ---


def test_class_name_to_snake_case_simple():
assert class_name_to_snake_case("MyClass") == "my_class"


def test_class_name_to_snake_case_single_word():
assert class_name_to_snake_case("Scorer") == "scorer"


def test_class_name_to_snake_case_multiple_words():
assert class_name_to_snake_case("SelfAskRefusalScorer") == "self_ask_refusal_scorer"


def test_class_name_to_snake_case_with_suffix_stripped():
assert class_name_to_snake_case("SelfAskRefusalScorer", suffix="Scorer") == "self_ask_refusal"


def test_class_name_to_snake_case_suffix_not_present():
assert class_name_to_snake_case("MyClass", suffix="Scorer") == "my_class"


def test_class_name_to_snake_case_with_acronym():
assert class_name_to_snake_case("XMLParser") == "xml_parser"


def test_class_name_to_snake_case_with_consecutive_uppercase():
assert class_name_to_snake_case("getHTTPResponse") == "get_http_response"


def test_class_name_to_snake_case_empty_string():
assert class_name_to_snake_case("") == ""


def test_class_name_to_snake_case_already_lowercase():
assert class_name_to_snake_case("already") == "already"


def test_class_name_to_snake_case_suffix_equals_class_name():
assert class_name_to_snake_case("Scorer", suffix="Scorer") == ""


def test_class_name_to_snake_case_with_numbers():
assert class_name_to_snake_case("Base64Converter") == "base64_converter"


# --- snake_case_to_class_name ---


def test_snake_case_to_class_name_simple():
assert snake_case_to_class_name("my_class") == "MyClass"


def test_snake_case_to_class_name_single_word():
assert snake_case_to_class_name("scorer") == "Scorer"


def test_snake_case_to_class_name_with_suffix():
assert snake_case_to_class_name("my_custom", suffix="Scenario") == "MyCustomScenario"


def test_snake_case_to_class_name_no_suffix():
assert snake_case_to_class_name("self_ask_refusal") == "SelfAskRefusal"


def test_snake_case_to_class_name_empty_string():
assert snake_case_to_class_name("") == ""


def test_snake_case_to_class_name_empty_string_with_suffix():
assert snake_case_to_class_name("", suffix="Scorer") == "Scorer"


def test_snake_case_to_class_name_single_char_parts():
assert snake_case_to_class_name("a_b_c") == "ABC"


# --- round-trip tests ---


@pytest.mark.parametrize(
"class_name",
["MyClass", "SelfAskRefusal", "Base"],
)
def test_round_trip_snake_to_class(class_name):
snake = class_name_to_snake_case(class_name)
result = snake_case_to_class_name(snake)
assert result == class_name
125 changes: 125 additions & 0 deletions tests/unit/identifiers/test_identifier_filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import pytest

from pyrit.identifiers.identifier_filters import IdentifierFilter, IdentifierType

# --- IdentifierType enum ---


def test_identifier_type_values():
assert IdentifierType.ATTACK.value == "attack"
assert IdentifierType.TARGET.value == "target"
assert IdentifierType.SCORER.value == "scorer"
assert IdentifierType.CONVERTER.value == "converter"


def test_identifier_type_member_count():
assert len(IdentifierType) == 4


# --- IdentifierFilter creation ---


def test_identifier_filter_defaults():
f = IdentifierFilter(identifier_type=IdentifierType.TARGET, property_path="$.name", value="openai")
assert f.identifier_type == IdentifierType.TARGET
assert f.property_path == "$.name"
assert f.value == "openai"
assert f.array_element_path is None
assert f.partial_match is False
assert f.case_sensitive is False


def test_identifier_filter_with_partial_match():
f = IdentifierFilter(
identifier_type=IdentifierType.SCORER,
property_path="$.class_name",
value="Refusal",
partial_match=True,
)
assert f.partial_match is True


def test_identifier_filter_with_case_sensitive():
f = IdentifierFilter(
identifier_type=IdentifierType.CONVERTER,
property_path="$.class_name",
value="Base64",
case_sensitive=True,
)
assert f.case_sensitive is True


def test_identifier_filter_with_array_element_path():
f = IdentifierFilter(
identifier_type=IdentifierType.ATTACK,
property_path="$.converters",
value="Base64Converter",
array_element_path="$.class_name",
)
assert f.array_element_path == "$.class_name"


# --- IdentifierFilter validation ---


def test_identifier_filter_raises_array_element_path_with_partial_match():
with pytest.raises(ValueError, match="Cannot use array_element_path with partial_match"):
IdentifierFilter(
identifier_type=IdentifierType.TARGET,
property_path="$.items",
value="test",
array_element_path="$.name",
partial_match=True,
)


def test_identifier_filter_raises_array_element_path_with_case_sensitive():
with pytest.raises(ValueError, match="Cannot use array_element_path with partial_match or case_sensitive"):
IdentifierFilter(
identifier_type=IdentifierType.TARGET,
property_path="$.items",
value="test",
array_element_path="$.name",
case_sensitive=True,
)


def test_identifier_filter_raises_partial_match_with_case_sensitive():
with pytest.raises(ValueError, match="case_sensitive is not reliably supported with partial_match"):
IdentifierFilter(
identifier_type=IdentifierType.TARGET,
property_path="$.name",
value="test",
partial_match=True,
case_sensitive=True,
)


# --- Frozen dataclass behavior ---


def test_identifier_filter_is_frozen():
f = IdentifierFilter(identifier_type=IdentifierType.TARGET, property_path="$.name", value="x")
with pytest.raises(AttributeError):
f.value = "y"


def test_identifier_filter_equality():
f1 = IdentifierFilter(identifier_type=IdentifierType.TARGET, property_path="$.name", value="x")
f2 = IdentifierFilter(identifier_type=IdentifierType.TARGET, property_path="$.name", value="x")
assert f1 == f2


def test_identifier_filter_inequality():
f1 = IdentifierFilter(identifier_type=IdentifierType.TARGET, property_path="$.name", value="x")
f2 = IdentifierFilter(identifier_type=IdentifierType.TARGET, property_path="$.name", value="y")
assert f1 != f2


def test_identifier_filter_hashable():
f = IdentifierFilter(identifier_type=IdentifierType.TARGET, property_path="$.name", value="x")
s = {f}
assert f in s
Loading