diff --git a/clarifai/runners/models/vllm_openai_class.py b/clarifai/runners/models/vllm_openai_class.py new file mode 100644 index 00000000..b440e7c1 --- /dev/null +++ b/clarifai/runners/models/vllm_openai_class.py @@ -0,0 +1,158 @@ +import threading +from typing import Iterator + +import httpx +from clarifai_protocol import get_item_id, register_item_abort_callback + +from clarifai.runners.models.openai_class import OpenAIModelClass + + +class VLLMCancellationHandler: + # Important: closing the httpx response kills the TCP connection; + # vLLM detects is_disconnected(), triggers engine.abort() and frees KV cache. + def __init__(self): + self._cancel_events = {} + self._responses = {} + self._early_aborts = set() + self._lock = threading.Lock() + register_item_abort_callback(self._handle_abort) + + def _handle_abort(self, item_id: str) -> None: + with self._lock: + event = self._cancel_events.get(item_id) + response = self._responses.get(item_id) + if event: + event.set() + if response: + try: + response.close() + except Exception: + pass + else: + self._early_aborts.add(item_id) + + def register_request(self, item_id: str, response=None) -> threading.Event: + cancel_event = threading.Event() + with self._lock: + self._cancel_events[item_id] = cancel_event + if response is not None: + self._responses[item_id] = response + if item_id in self._early_aborts: + cancel_event.set() + self._early_aborts.discard(item_id) + if response is not None: + try: + response.close() + except Exception: + pass + return cancel_event + + def unregister_request(self, item_id: str) -> None: + with self._lock: + self._cancel_events.pop(item_id, None) + self._responses.pop(item_id, None) + self._early_aborts.discard(item_id) + + +class VLLMOpenAIModelClass(OpenAIModelClass): + """vLLM-backed OpenAI model with /health probes and cancellation support. + + Subclasses must set client, model, server and cancellation_handler in load_model(), for example: + + def load_model(self): + self.server = vllm_openai_server(checkpoints, **server_args) + self.client = OpenAI(base_url=f"http://{self.server.host}:{self.server.port}/v1", api_key="x") + self.model = self.client.models.list().data[0].id + self.cancellation_handler = VLLMCancellationHandler() + + For cancellation in generate() or custom streaming methods, follow this pattern: + + def generate(self, prompt, ...) -> Iterator[str]: + item_id = None + cancel_event = None + try: + item_id = get_item_id() + except Exception: + pass + try: + response = self.client.chat.completions.create(..., stream=True) + if item_id: + cancel_event = self.cancellation_handler.register_request(item_id, response=response.response) + for chunk in response: + if cancel_event and cancel_event.is_set(): + return + yield ... + except httpx.ReadError: + pass + finally: + if item_id: + self.cancellation_handler.unregister_request(item_id) + """ + + server = None + cancellation_handler = None + + def handle_liveness_probe(self) -> bool: + if self.server is None: + return super().handle_liveness_probe() + # /health is a non-blocking fast endpoint dedicated for health check + try: + resp = httpx.get(f"http://{self.server.host}:{self.server.port}/health", timeout=5.0) + return resp.status_code == 200 + except Exception: + return False + + def handle_readiness_probe(self) -> bool: + if self.server is None: + return super().handle_readiness_probe() + # /health is a non-blocking fast endpoint dedicated for health check + try: + resp = httpx.get(f"http://{self.server.host}:{self.server.port}/health", timeout=10.0) + return resp.status_code == 200 + except Exception: + return False + + @OpenAIModelClass.method + def openai_stream_transport(self, msg: str) -> Iterator[str]: + from pydantic_core import from_json + + item_id = None + try: + item_id = get_item_id() + except Exception: + pass + cancel_event = None + try: + request_data = from_json(msg) + request_data = self._update_old_fields(request_data) + endpoint = request_data.pop("openai_endpoint", self.DEFAULT_ENDPOINT) + if endpoint not in [self.ENDPOINT_CHAT_COMPLETIONS, self.ENDPOINT_RESPONSES]: + raise ValueError( + f"Only {self.ENDPOINT_CHAT_COMPLETIONS} and {self.ENDPOINT_RESPONSES} endpoints are supported for streaming." + ) + + if endpoint == self.ENDPOINT_RESPONSES: + # /responses endpoint — direct call (no retry), same Stream[T] interface + response_args = {**request_data} + response_args.update({"model": self.model}) + response = self.client.responses.create(**response_args) + else: + # /chat/completions endpoint + completion_args = self._create_completion_args(request_data) + response = self.client.chat.completions.create(**completion_args) + + if item_id and self.cancellation_handler: + cancel_event = self.cancellation_handler.register_request( + item_id, response=response.response + ) + + for chunk in response: + if cancel_event and cancel_event.is_set(): + return + self._set_usage(chunk) + yield chunk.model_dump_json() + except httpx.ReadError: + pass + finally: + if item_id and self.cancellation_handler: + self.cancellation_handler.unregister_request(item_id) diff --git a/tests/runners/test_vllm_openai_class.py b/tests/runners/test_vllm_openai_class.py new file mode 100644 index 00000000..58fdac45 --- /dev/null +++ b/tests/runners/test_vllm_openai_class.py @@ -0,0 +1,230 @@ +"""Unit tests for VLLMOpenAIModelClass and VLLMCancellationHandler.""" + +import json +import threading +from unittest.mock import MagicMock, patch + +import pytest + +from clarifai.runners.models.dummy_openai_model import MockOpenAIClient +from clarifai.runners.models.vllm_openai_class import VLLMCancellationHandler, VLLMOpenAIModelClass + + +# --------------------------------------------------------------------------- +# Minimal concrete subclass — no real vLLM server needed +# --------------------------------------------------------------------------- +class DummyVLLMModel(VLLMOpenAIModelClass): + client = MockOpenAIClient() + model = "dummy-model" + + +# --------------------------------------------------------------------------- +# VLLMCancellationHandler +# --------------------------------------------------------------------------- +class TestVLLMCancellationHandler: + def _make_handler(self): + with patch("clarifai.runners.models.vllm_openai_class.register_item_abort_callback"): + return VLLMCancellationHandler() + + def test_register_request_returns_unset_event(self): + handler = self._make_handler() + event = handler.register_request("item-1") + assert isinstance(event, threading.Event) + assert not event.is_set() + + def test_handle_abort_sets_event_for_registered_item(self): + handler = self._make_handler() + event = handler.register_request("item-1") + handler._handle_abort("item-1") + assert event.is_set() + + def test_handle_abort_closes_response(self): + handler = self._make_handler() + mock_response = MagicMock() + handler.register_request("item-1", response=mock_response) + handler._handle_abort("item-1") + mock_response.close.assert_called_once() + + def test_early_abort_sets_event_on_late_register(self): + """Abort arrives before register_request — event is immediately set on registration.""" + handler = self._make_handler() + handler._handle_abort("item-early") + event = handler.register_request("item-early") + assert event.is_set() + + def test_handle_abort_unknown_item_recorded_as_early_abort(self): + handler = self._make_handler() + handler._handle_abort("unknown-item") + assert "unknown-item" in handler._early_aborts + + def test_unregister_removes_all_state(self): + handler = self._make_handler() + mock_response = MagicMock() + handler.register_request("item-1", response=mock_response) + handler.unregister_request("item-1") + assert "item-1" not in handler._cancel_events + assert "item-1" not in handler._responses + assert "item-1" not in handler._early_aborts + + +# --------------------------------------------------------------------------- +# VLLMOpenAIModelClass — health probes +# --------------------------------------------------------------------------- +class TestVLLMOpenAIModelClassProbes: + def test_liveness_probe_no_server_delegates_to_super(self): + model = DummyVLLMModel() + # server is None → falls back to OpenAIModelClass.handle_liveness_probe() which returns True + assert model.handle_liveness_probe() is True + + def test_readiness_probe_no_server_delegates_to_super(self): + model = DummyVLLMModel() + assert model.handle_readiness_probe() is True + + def test_liveness_probe_returns_true_on_http_200(self): + model = DummyVLLMModel() + model.server = MagicMock(host="localhost", port=8000) + mock_resp = MagicMock(status_code=200) + with patch("clarifai.runners.models.vllm_openai_class.httpx.get", return_value=mock_resp): + assert model.handle_liveness_probe() is True + + def test_liveness_probe_returns_false_on_non_200(self): + model = DummyVLLMModel() + model.server = MagicMock(host="localhost", port=8000) + mock_resp = MagicMock(status_code=503) + with patch("clarifai.runners.models.vllm_openai_class.httpx.get", return_value=mock_resp): + assert model.handle_liveness_probe() is False + + def test_liveness_probe_returns_false_on_exception(self): + model = DummyVLLMModel() + model.server = MagicMock(host="localhost", port=8000) + with patch( + "clarifai.runners.models.vllm_openai_class.httpx.get", side_effect=Exception("timeout") + ): + assert model.handle_liveness_probe() is False + + def test_readiness_probe_returns_true_on_http_200(self): + model = DummyVLLMModel() + model.server = MagicMock(host="localhost", port=8000) + mock_resp = MagicMock(status_code=200) + with patch("clarifai.runners.models.vllm_openai_class.httpx.get", return_value=mock_resp): + assert model.handle_readiness_probe() is True + + def test_readiness_probe_returns_false_on_exception(self): + model = DummyVLLMModel() + model.server = MagicMock(host="localhost", port=8000) + with patch( + "clarifai.runners.models.vllm_openai_class.httpx.get", + side_effect=Exception("conn refused"), + ): + assert model.handle_readiness_probe() is False + + +# --------------------------------------------------------------------------- +# VLLMOpenAIModelClass — openai_stream_transport with cancellation +# --------------------------------------------------------------------------- +def _make_mock_stream(*chunk_texts): + """Return a mock streaming response whose chunks have the expected interface. + + _set_usage asserts that a chunk doesn't have both .usage and .response.usage set, + so we explicitly set both to None on each chunk. + """ + chunks = [] + for text in chunk_texts: + chunk = MagicMock() + chunk.usage = None + chunk.response = None + chunk.model_dump_json.return_value = json.dumps( + {"choices": [{"delta": {"content": text}}], "usage": None} + ) + chunks.append(chunk) + mock_stream = MagicMock() + mock_stream.__iter__ = MagicMock(return_value=iter(chunks)) + mock_stream.response = MagicMock() + return mock_stream + + +class TestVLLMStreamTransportCancellation: + def _model_with_mock_client_and_handler(self, cancel_event): + model = DummyVLLMModel() + mock_handler = MagicMock() + mock_handler.register_request.return_value = cancel_event + model.cancellation_handler = mock_handler + mock_stream = _make_mock_stream("Hello", " world") + model.client = MagicMock() + model.client.chat.completions.create.return_value = mock_stream + return model, mock_handler + + def test_cancel_before_iteration_yields_no_chunks(self): + cancel_event = threading.Event() + cancel_event.set() # already cancelled + model, mock_handler = self._model_with_mock_client_and_handler(cancel_event) + + request = json.dumps( + { + "model": "dummy-model", + "messages": [{"role": "user", "content": "Hello"}], + "stream": True, + } + ) + with patch( + "clarifai.runners.models.vllm_openai_class.get_item_id", return_value="item-abc" + ): + chunks = list(model.openai_stream_transport(request)) + + assert chunks == [] + mock_handler.unregister_request.assert_called_once_with("item-abc") + + def test_no_cancel_yields_all_chunks(self): + cancel_event = threading.Event() # never set + model, mock_handler = self._model_with_mock_client_and_handler(cancel_event) + + request = json.dumps( + { + "model": "dummy-model", + "messages": [{"role": "user", "content": "Hello"}], + "stream": True, + } + ) + with patch( + "clarifai.runners.models.vllm_openai_class.get_item_id", return_value="item-xyz" + ): + chunks = list(model.openai_stream_transport(request)) + + assert len(chunks) == 2 + mock_handler.unregister_request.assert_called_once_with("item-xyz") + + def test_unregister_called_even_when_get_item_id_fails(self): + """If get_item_id raises, no cancellation handler is used but stream still works.""" + model = DummyVLLMModel() + mock_stream = _make_mock_stream("chunk1") + model.client = MagicMock() + model.client.chat.completions.create.return_value = mock_stream + + request = json.dumps( + { + "model": "dummy-model", + "messages": [{"role": "user", "content": "Hello"}], + "stream": True, + } + ) + with patch( + "clarifai.runners.models.vllm_openai_class.get_item_id", + side_effect=Exception("no context"), + ): + chunks = list(model.openai_stream_transport(request)) + + assert len(chunks) == 1 + + def test_invalid_endpoint_raises_value_error(self): + model = DummyVLLMModel() + request = json.dumps( + { + "model": "dummy-model", + "messages": [{"role": "user", "content": "Hello"}], + "stream": True, + "openai_endpoint": "/unsupported", + } + ) + with patch("clarifai.runners.models.vllm_openai_class.get_item_id", side_effect=Exception): + with pytest.raises(ValueError, match="Only"): + list(model.openai_stream_transport(request))