diff --git a/clarifai/runners/models/model_class.py b/clarifai/runners/models/model_class.py index 82c0c524..55712148 100644 --- a/clarifai/runners/models/model_class.py +++ b/clarifai/runners/models/model_class.py @@ -25,6 +25,7 @@ from clarifai.utils.logging import logger _METHOD_INFO_ATTR = '_cf_method_info' +_MODEL_METHODS_REGISTRY_ATTR = '_cf_model_methods_registry' _RAISE_EXCEPTIONS = os.getenv("RAISE_EXCEPTIONS", "false").lower() in ("true", "1") @@ -120,11 +121,12 @@ def predict_wrapper( # first we look for a PostModelOutputs method that is implemented as protos and use that # if it exists. # if not we default to 'predict'. + method_infos = self._get_method_infos() method_name = None if len(request.inputs) > 0 and '_method_name' in request.inputs[0].data.metadata: method_name = request.inputs[0].data.metadata['_method_name'] - if method_name is None and FALLBACK_METHOD_PROTO in self._get_method_infos(): - _info = self._get_method_infos(FALLBACK_METHOD_PROTO) + if method_name is None and FALLBACK_METHOD_PROTO in method_infos: + _info = method_infos[FALLBACK_METHOD_PROTO] if _info.proto_method: method_name = FALLBACK_METHOD_PROTO if method_name is None: @@ -133,10 +135,10 @@ def predict_wrapper( method_name == '_GET_SIGNATURES' ): # special case to fetch signatures, TODO add endpoint for this return self._handle_get_signatures_request() - if method_name not in self._get_method_infos(): + if method_name not in method_infos: raise ValueError(f"Method {method_name} not found in model class") method = getattr(self, method_name) - method_info = self._get_method_infos(method_name) + method_info = method_infos[method_name] signature = method_info.signature proto_method = method_info.proto_method @@ -171,7 +173,7 @@ def predict_wrapper( ) return out_proto - python_param_types = method_info.python_param_types + cast_types = method_info.cast_types for input in request.inputs: # check if input is in old format is_convert = DataConverter.is_old_format(input.data) @@ -184,7 +186,7 @@ def predict_wrapper( # convert inputs to python types inputs = self._convert_input_protos_to_python( - request.inputs, signature.input_fields, python_param_types + request.inputs, signature.input_fields, cast_types ) if len(inputs) == 1: inputs = inputs[0] @@ -223,13 +225,14 @@ def generate_wrapper( ) -> Iterator[service_pb2.MultiOutputResponse]: try: assert len(request.inputs) == 1, "Generate requires exactly one input" + method_infos = self._get_method_infos() method_name = 'generate' if len(request.inputs) > 0 and '_method_name' in request.inputs[0].data.metadata: method_name = request.inputs[0].data.metadata['_method_name'] method = getattr(self, method_name) - method_info = self._get_method_infos(method_name) + method_info = method_infos[method_name] signature = method_info.signature - python_param_types = method_info.python_param_types + cast_types = method_info.cast_types for input in request.inputs: # check if input is in old format is_convert = DataConverter.is_old_format(input.data) @@ -240,7 +243,7 @@ def generate_wrapper( ) input.data.CopyFrom(new_data) inputs = self._convert_input_protos_to_python( - request.inputs, signature.input_fields, python_param_types + request.inputs, signature.input_fields, cast_types ) if len(inputs) == 1: inputs = inputs[0] @@ -320,13 +323,14 @@ def stream_wrapper( request = next(request_iterator) # get first request to determine method assert len(request.inputs) == 1, "Streaming requires exactly one input" + method_infos = self._get_method_infos() method_name = 'stream' if len(request.inputs) > 0 and '_method_name' in request.inputs[0].data.metadata: method_name = request.inputs[0].data.metadata['_method_name'] method = getattr(self, method_name) - method_info = self._get_method_infos(method_name) + method_info = method_infos[method_name] signature = method_info.signature - python_param_types = method_info.python_param_types + cast_types = method_info.cast_types # find the streaming vars in the signature stream_sig = get_stream_from_signature(signature.input_fields) @@ -345,7 +349,7 @@ def stream_wrapper( input.data.CopyFrom(new_data) # convert all inputs for the first request, including the first stream value inputs = self._convert_input_protos_to_python( - request.inputs, signature.input_fields, python_param_types + request.inputs, signature.input_fields, cast_types ) kwargs = inputs[0] @@ -358,7 +362,7 @@ def InputStream(): # subsequent streaming items contain only the streaming input for request in request_iterator: item = self._convert_input_protos_to_python( - request.inputs, [stream_sig], python_param_types + request.inputs, [stream_sig], cast_types ) item = item[0][stream_argname] yield item @@ -392,26 +396,15 @@ def _convert_input_protos_to_python( self, inputs: List[resources_pb2.Input], variables_signature: List[resources_pb2.ModelTypeField], - python_param_types, + cast_types: Dict[str, Any], ) -> List[Dict[str, Any]]: result = [] for input in inputs: kwargs = deserialize(input.data, variables_signature) # dynamic cast to annotated types for k, v in kwargs.items(): - if k not in python_param_types: - continue - - if hasattr(python_param_types[k], "__args__") and ( - getattr(python_param_types[k], "__origin__", None) - in [abc.Iterator, abc.Generator, abc.Iterable] - ): - # get the type of the items in the stream - stream_type = python_param_types[k].__args__[0] - - kwargs[k] = data_types.cast(v, stream_type) - else: - kwargs[k] = data_types.cast(v, python_param_types[k]) + if k in cast_types: + kwargs[k] = data_types.cast(v, cast_types[k]) result.append(kwargs) return result @@ -476,11 +469,9 @@ def _register_model_methods(cls): @classmethod def _get_method_infos(cls, func_name=None): - # FIXME: this is a re-use of the _METHOD_INFO_ATTR attribute to store the method info - # for all methods on the class. Should use a different attribute name to avoid confusion. - if not hasattr(cls, _METHOD_INFO_ATTR): - setattr(cls, _METHOD_INFO_ATTR, cls._register_model_methods()) - method_infos = getattr(cls, _METHOD_INFO_ATTR) + if not hasattr(cls, _MODEL_METHODS_REGISTRY_ATTR): + setattr(cls, _MODEL_METHODS_REGISTRY_ATTR, cls._register_model_methods()) + method_infos = getattr(cls, _MODEL_METHODS_REGISTRY_ATTR) if func_name: return method_infos[func_name] return method_infos @@ -506,3 +497,12 @@ def __init__(self, method, proto_method=False): if p.annotation != inspect.Parameter.empty } self.python_param_types.pop('self', None) + + self.cast_types = {} + for k, v in self.python_param_types.items(): + if hasattr(v, "__args__") and ( + getattr(v, "__origin__", None) in [abc.Iterator, abc.Generator, abc.Iterable] + ): + self.cast_types[k] = v.__args__[0] + else: + self.cast_types[k] = v diff --git a/clarifai/runners/models/model_runner.py b/clarifai/runners/models/model_runner.py index e49841e3..34353b42 100644 --- a/clarifai/runners/models/model_runner.py +++ b/clarifai/runners/models/model_runner.py @@ -1,5 +1,6 @@ +import itertools import time -from typing import Iterator, Optional, Union +from typing import Iterable, Iterator, Optional, Union from clarifai_grpc.grpc.api import service_pb2 from clarifai_grpc.grpc.api.status import status_code_pb2, status_pb2 @@ -45,6 +46,7 @@ def __init__( pat, token, num_parallel_polls, + health_check_port=health_check_port, **kwargs, ) self.model = model @@ -183,33 +185,31 @@ def runner_item_predict( duration_ms = (time.time() - start_time) * 1000 logger.info(f"{endpoint} | {status_str} | {duration_ms:.2f}ms | req_id={req_id}") return service_pb2.RunnerItemOutput(multi_output_response=resp) - successes = [] + + num_success = 0 + num_total = len(resp.outputs) for output in resp.outputs: if not output.HasField('status') or not output.status.code: raise Exception( "Output must have a status code, please check the model implementation." ) - successes.append(output.status.code == status_code_pb2.SUCCESS) - if all(successes): - status = status_pb2.Status( - code=status_code_pb2.SUCCESS, - description="Success", - ) + if output.status.code == status_code_pb2.SUCCESS: + num_success += 1 + + resp.status.Clear() + if num_success == num_total: + resp.status.code = status_code_pb2.SUCCESS + resp.status.description = "Success" status_str = STATUS_OK - elif any(successes): - status = status_pb2.Status( - code=status_code_pb2.MIXED_STATUS, - description="Mixed Status", - ) + elif num_success > 0: + resp.status.code = status_code_pb2.MIXED_STATUS + resp.status.description = "Mixed Status" status_str = STATUS_MIXED else: - status = status_pb2.Status( - code=status_code_pb2.FAILURE, - description="Failed", - ) + resp.status.code = status_code_pb2.FAILURE + resp.status.description = "Failed" status_str = STATUS_FAIL - resp.status.CopyFrom(status) if logging: duration_ms = (time.time() - start_time) * 1000 logger.info(f"{endpoint} | {status_str} | {duration_ms:.2f}ms | req_id={req_id}") @@ -270,35 +270,35 @@ def runner_item_generate( yield service_pb2.RunnerItemOutput(multi_output_response=resp) continue # Single output with non-SUCCESS status - resp.status.CopyFrom( - status_pb2.Status(code=status_code_pb2.FAILURE, description="Failed") - ) + resp.status.Clear() + resp.status.code = status_code_pb2.FAILURE + resp.status.description = "Failed" status_str = STATUS_FAIL yield service_pb2.RunnerItemOutput(multi_output_response=resp) continue - # Multi-output path (batch generate) - successes = [] + # Multi-output path (batch generate) - use counters to avoid list allocation + num_success = 0 for output in outputs: if not output.HasField('status') or not output.status.code: raise Exception( "Output must have a status code, please check the model implementation." ) - successes.append(output.status.code == status_code_pb2.SUCCESS) - if all(successes): + if output.status.code == status_code_pb2.SUCCESS: + num_success += 1 + + if num_success == num_outputs: resp.status.CopyFrom(_success_status) status_str = STATUS_OK - elif any(successes): - resp.status.CopyFrom( - status_pb2.Status( - code=status_code_pb2.MIXED_STATUS, description="Mixed Status" - ) - ) + elif num_success > 0: + resp.status.Clear() + resp.status.code = status_code_pb2.MIXED_STATUS + resp.status.description = "Mixed Status" status_str = STATUS_MIXED else: - resp.status.CopyFrom( - status_pb2.Status(code=status_code_pb2.FAILURE, description="Failed") - ) + resp.status.Clear() + resp.status.code = status_code_pb2.FAILURE + resp.status.description = "Failed" status_str = STATUS_FAIL yield service_pb2.RunnerItemOutput(multi_output_response=resp) @@ -307,7 +307,7 @@ def runner_item_generate( logger.info(f"{endpoint} | {status_str} | {duration_ms:.2f}ms | req_id={req_id}") def runner_item_stream( - self, runner_item_iterator: Iterator[service_pb2.RunnerItem] + self, runner_item_iterator: Iterable[service_pb2.RunnerItem] ) -> Iterator[service_pb2.RunnerItemOutput]: # Call the generate() method the underlying model implements. start_time = time.time() @@ -317,14 +317,24 @@ def runner_item_stream( # Get the first request to establish secrets context first_request = None - runner_items = list(runner_item_iterator) # Convert to list to avoid consuming iterator - if runner_items: - first_request = runner_items[0].post_model_outputs_request + try: + runner_item_iterator = iter(runner_item_iterator) + first_runner_item = next(runner_item_iterator) + if not first_runner_item.HasField('post_model_outputs_request'): + raise Exception("Unexpected work item type: {}".format(first_runner_item)) + first_request = first_runner_item.post_model_outputs_request + # Reconstruct the iterator using itertools.chain to avoid consuming the whole stream into memory + runner_items = itertools.chain((first_runner_item,), runner_item_iterator) + except StopIteration: + # No items in the stream: short-circuit and yield nothing. + duration_ms = (time.time() - start_time) * 1000 + logger.info(f"{endpoint} | {status_str} | {duration_ms:.2f}ms | req_id={req_id}") + return # Use req_secrets_context based on the first request (secrets should be consistent across stream) with req_secrets_context(first_request): for resp in self.model.stream_wrapper( - pmo_iterator(iter(runner_items), auth_helper=self._auth_helper) + pmo_iterator(runner_items, auth_helper=self._auth_helper) ): # if we have any non-successful code already it's an error we can return. if ( @@ -338,32 +348,30 @@ def runner_item_stream( ) yield service_pb2.RunnerItemOutput(multi_output_response=resp) continue - successes = [] + + num_success = 0 + num_total = len(resp.outputs) for output in resp.outputs: if not output.HasField('status') or not output.status.code: raise Exception( "Output must have a status code, please check the model implementation." ) - successes.append(output.status.code == status_code_pb2.SUCCESS) - if all(successes): - status = status_pb2.Status( - code=status_code_pb2.SUCCESS, - description="Success", - ) + if output.status.code == status_code_pb2.SUCCESS: + num_success += 1 + + resp.status.Clear() + if num_success == num_total: + resp.status.code = status_code_pb2.SUCCESS + resp.status.description = "Success" status_str = STATUS_OK - elif any(successes): - status = status_pb2.Status( - code=status_code_pb2.MIXED_STATUS, - description="Mixed Status", - ) + elif num_success > 0: + resp.status.code = status_code_pb2.MIXED_STATUS + resp.status.description = "Mixed Status" status_str = STATUS_MIXED else: - status = status_pb2.Status( - code=status_code_pb2.RUNNER_PROCESSING_FAILED, - description="Runner Processing Failed", - ) + resp.status.code = status_code_pb2.RUNNER_PROCESSING_FAILED + resp.status.description = "Runner Processing Failed" status_str = STATUS_FAIL - resp.status.CopyFrom(status) yield service_pb2.RunnerItemOutput(multi_output_response=resp) diff --git a/tests/runners/test_model_runner_unit.py b/tests/runners/test_model_runner_unit.py new file mode 100644 index 00000000..9cb8b4fe --- /dev/null +++ b/tests/runners/test_model_runner_unit.py @@ -0,0 +1,196 @@ +"""Unit tests for ModelRunner status aggregation and iterator handling.""" + +import types +from unittest.mock import MagicMock + +from clarifai_grpc.grpc.api import service_pb2 +from clarifai_grpc.grpc.api.status import status_code_pb2 + +from clarifai.runners.models.model_runner import ModelRunner + + +def _make_runner(mock_model): + """Create a minimal ModelRunner-like object bound to mock_model, bypassing __init__.""" + + class FakeRunner: + pass + + runner = FakeRunner() + runner.model = mock_model + runner._auth_helper = None + runner.runner_item_predict = types.MethodType(ModelRunner.runner_item_predict, runner) + runner.runner_item_generate = types.MethodType(ModelRunner.runner_item_generate, runner) + runner.runner_item_stream = types.MethodType(ModelRunner.runner_item_stream, runner) + return runner + + +def _make_runner_item(): + """Return a RunnerItem with an empty PostModelOutputsRequest.""" + return service_pb2.RunnerItem(post_model_outputs_request=service_pb2.PostModelOutputsRequest()) + + +def _make_resp(*output_codes): + """Build a MultiOutputResponse whose outputs have the given status codes.""" + resp = service_pb2.MultiOutputResponse() + for code in output_codes: + output = resp.outputs.add() + output.status.code = code + return resp + + +class TestRunnerItemPredictStatus: + """Tests for runner_item_predict status aggregation.""" + + def test_empty_outputs_treated_as_success(self): + """Empty resp.outputs should map to SUCCESS (consistent with all([]) == True).""" + mock_model = MagicMock() + mock_model.predict_wrapper.return_value = service_pb2.MultiOutputResponse() + runner = _make_runner(mock_model) + + result = runner.runner_item_predict(_make_runner_item()) + assert result.multi_output_response.status.code == status_code_pb2.SUCCESS + + def test_all_success_outputs(self): + mock_model = MagicMock() + mock_model.predict_wrapper.return_value = _make_resp( + status_code_pb2.SUCCESS, status_code_pb2.SUCCESS + ) + runner = _make_runner(mock_model) + + result = runner.runner_item_predict(_make_runner_item()) + assert result.multi_output_response.status.code == status_code_pb2.SUCCESS + assert result.multi_output_response.status.description == "Success" + + def test_mixed_outputs(self): + mock_model = MagicMock() + mock_model.predict_wrapper.return_value = _make_resp( + status_code_pb2.SUCCESS, status_code_pb2.FAILURE + ) + runner = _make_runner(mock_model) + + result = runner.runner_item_predict(_make_runner_item()) + assert result.multi_output_response.status.code == status_code_pb2.MIXED_STATUS + + def test_all_failed_outputs(self): + mock_model = MagicMock() + mock_model.predict_wrapper.return_value = _make_resp(status_code_pb2.FAILURE) + runner = _make_runner(mock_model) + + result = runner.runner_item_predict(_make_runner_item()) + assert result.multi_output_response.status.code == status_code_pb2.FAILURE + + def test_stale_status_fields_are_cleared(self): + """resp.status.Clear() must wipe stale fields (details, internal_details, etc.).""" + resp = _make_resp(status_code_pb2.SUCCESS) + resp.status.code = status_code_pb2.SUCCESS + resp.status.description = "old" + resp.status.details = "stale details" + resp.status.internal_details = "stale internal" + + mock_model = MagicMock() + mock_model.predict_wrapper.return_value = resp + runner = _make_runner(mock_model) + + result = runner.runner_item_predict(_make_runner_item()) + out = result.multi_output_response.status + assert out.code == status_code_pb2.SUCCESS + assert out.description == "Success" + assert out.details == "" + assert out.internal_details == "" + + +class TestRunnerItemGenerateStatus: + """Tests for runner_item_generate status aggregation.""" + + def test_empty_outputs_treated_as_success(self): + mock_model = MagicMock() + mock_model.generate_wrapper.return_value = iter([service_pb2.MultiOutputResponse()]) + runner = _make_runner(mock_model) + + results = list(runner.runner_item_generate(_make_runner_item())) + assert len(results) == 1 + assert results[0].multi_output_response.status.code == status_code_pb2.SUCCESS + + def test_stale_status_fields_are_cleared(self): + resp = _make_resp(status_code_pb2.SUCCESS) + resp.status.details = "stale" + resp.status.internal_details = "stale internal" + + mock_model = MagicMock() + mock_model.generate_wrapper.return_value = iter([resp]) + runner = _make_runner(mock_model) + + results = list(runner.runner_item_generate(_make_runner_item())) + out = results[0].multi_output_response.status + assert out.details == "" + assert out.internal_details == "" + + +class TestRunnerItemStreamStatus: + """Tests for runner_item_stream status aggregation and iterator handling.""" + + def test_empty_outputs_treated_as_success(self): + """Empty resp.outputs in stream should map to SUCCESS (aligned with predict/generate).""" + mock_model = MagicMock() + mock_model.stream_wrapper.return_value = iter([service_pb2.MultiOutputResponse()]) + runner = _make_runner(mock_model) + + results = list(runner.runner_item_stream(iter([_make_runner_item()]))) + assert len(results) == 1 + assert results[0].multi_output_response.status.code == status_code_pb2.SUCCESS + + def test_accepts_iterable_not_just_iterator(self): + """runner_item_stream must accept any iterable (e.g. list), not only an iterator.""" + mock_model = MagicMock() + resp = _make_resp(status_code_pb2.SUCCESS) + mock_model.stream_wrapper.return_value = iter([resp]) + runner = _make_runner(mock_model) + + # Pass a plain list (iterable) instead of an iterator + results = list(runner.runner_item_stream([_make_runner_item()])) + assert len(results) == 1 + assert results[0].multi_output_response.status.code == status_code_pb2.SUCCESS + + def test_empty_stream_yields_nothing(self): + """An empty input stream should produce no output items.""" + mock_model = MagicMock() + mock_model.stream_wrapper.return_value = iter([]) + runner = _make_runner(mock_model) + + results = list(runner.runner_item_stream(iter([]))) + assert results == [] + + def test_stale_status_fields_are_cleared(self): + resp = _make_resp(status_code_pb2.SUCCESS) + resp.status.details = "stale" + resp.status.internal_details = "stale internal" + + mock_model = MagicMock() + mock_model.stream_wrapper.return_value = iter([resp]) + runner = _make_runner(mock_model) + + results = list(runner.runner_item_stream([_make_runner_item()])) + out = results[0].multi_output_response.status + assert out.details == "" + assert out.internal_details == "" + + def test_mixed_outputs(self): + mock_model = MagicMock() + resp = _make_resp(status_code_pb2.SUCCESS, status_code_pb2.FAILURE) + mock_model.stream_wrapper.return_value = iter([resp]) + runner = _make_runner(mock_model) + + results = list(runner.runner_item_stream([_make_runner_item()])) + assert results[0].multi_output_response.status.code == status_code_pb2.MIXED_STATUS + + def test_all_failed_outputs(self): + mock_model = MagicMock() + resp = _make_resp(status_code_pb2.FAILURE) + mock_model.stream_wrapper.return_value = iter([resp]) + runner = _make_runner(mock_model) + + results = list(runner.runner_item_stream([_make_runner_item()])) + assert ( + results[0].multi_output_response.status.code + == status_code_pb2.RUNNER_PROCESSING_FAILED + )