diff --git a/steps/src/toxicity_guardrail/toxicity_guardrail.py b/steps/src/toxicity_guardrail/toxicity_guardrail.py index def0616b..525866c0 100644 --- a/steps/src/toxicity_guardrail/toxicity_guardrail.py +++ b/steps/src/toxicity_guardrail/toxicity_guardrail.py @@ -14,21 +14,9 @@ # from typing import Any, Dict - +from transformers import pipeline class ToxicityGuardrailStep: - """ - A serving graph step that filters out toxic requests using a pre-trained - text classification model. - - If the toxicity score of the input text meets or exceeds the threshold, - the request is blocked with a ValueError. Safe requests are passed through - unchanged. - - The classifier label "toxic" maps directly to the toxicity score; any - other label (e.g. "non-toxic") inverts the model's confidence score. - """ - def __init__( self, context=None, @@ -37,13 +25,23 @@ def __init__( model_name: str = "unitary/toxic-bert", **kwargs, ): + """ + A serving graph step that filters out toxic requests using a pre-trained + text classification model. + + :param context: MLRun context object, injected automatically by the serving graph. + :param name: Name of this step in the serving graph. + :param threshold: Toxicity score threshold; requests whose toxicity score meets or + exceeds this value are blocked with a ValueError. Defaults to 0.5. + :param model_name: HuggingFace model identifier used for text classification. + Defaults to "unitary/toxic-bert". + :param kwargs: Additional keyword arguments forwarded to the serving graph step base. + """ self.threshold = threshold self.model_name = model_name self._classifier = None def post_init(self, mode="sync", **kwargs): - from transformers import pipeline - self._classifier = pipeline("text-classification", model=self.model_name) def do(self, event: Dict[str, Any]) -> Dict[str, Any]: