Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 32 additions & 32 deletions clarifai/runners/models/model_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from clarifai.utils.logging import logger

_METHOD_INFO_ATTR = '_cf_method_info'
_MODEL_METHODS_REGISTRY_ATTR = '_cf_model_methods_registry'

_RAISE_EXCEPTIONS = os.getenv("RAISE_EXCEPTIONS", "false").lower() in ("true", "1")

Expand Down Expand Up @@ -120,11 +121,12 @@ def predict_wrapper(
# first we look for a PostModelOutputs method that is implemented as protos and use that
# if it exists.
# if not we default to 'predict'.
method_infos = self._get_method_infos()
method_name = None
if len(request.inputs) > 0 and '_method_name' in request.inputs[0].data.metadata:
method_name = request.inputs[0].data.metadata['_method_name']
if method_name is None and FALLBACK_METHOD_PROTO in self._get_method_infos():
_info = self._get_method_infos(FALLBACK_METHOD_PROTO)
if method_name is None and FALLBACK_METHOD_PROTO in method_infos:
_info = method_infos[FALLBACK_METHOD_PROTO]
if _info.proto_method:
method_name = FALLBACK_METHOD_PROTO
if method_name is None:
Expand All @@ -133,10 +135,10 @@ def predict_wrapper(
method_name == '_GET_SIGNATURES'
): # special case to fetch signatures, TODO add endpoint for this
return self._handle_get_signatures_request()
if method_name not in self._get_method_infos():
if method_name not in method_infos:
raise ValueError(f"Method {method_name} not found in model class")
method = getattr(self, method_name)
method_info = self._get_method_infos(method_name)
method_info = method_infos[method_name]
signature = method_info.signature
proto_method = method_info.proto_method

Expand Down Expand Up @@ -171,7 +173,7 @@ def predict_wrapper(
)
return out_proto

python_param_types = method_info.python_param_types
cast_types = method_info.cast_types
for input in request.inputs:
# check if input is in old format
is_convert = DataConverter.is_old_format(input.data)
Expand All @@ -184,7 +186,7 @@ def predict_wrapper(

# convert inputs to python types
inputs = self._convert_input_protos_to_python(
request.inputs, signature.input_fields, python_param_types
request.inputs, signature.input_fields, cast_types
)
if len(inputs) == 1:
inputs = inputs[0]
Expand Down Expand Up @@ -223,13 +225,14 @@ def generate_wrapper(
) -> Iterator[service_pb2.MultiOutputResponse]:
try:
assert len(request.inputs) == 1, "Generate requires exactly one input"
method_infos = self._get_method_infos()
method_name = 'generate'
if len(request.inputs) > 0 and '_method_name' in request.inputs[0].data.metadata:
method_name = request.inputs[0].data.metadata['_method_name']
method = getattr(self, method_name)
method_info = self._get_method_infos(method_name)
method_info = method_infos[method_name]
signature = method_info.signature
python_param_types = method_info.python_param_types
cast_types = method_info.cast_types
for input in request.inputs:
# check if input is in old format
is_convert = DataConverter.is_old_format(input.data)
Expand All @@ -240,7 +243,7 @@ def generate_wrapper(
)
input.data.CopyFrom(new_data)
inputs = self._convert_input_protos_to_python(
request.inputs, signature.input_fields, python_param_types
request.inputs, signature.input_fields, cast_types
)
if len(inputs) == 1:
inputs = inputs[0]
Expand Down Expand Up @@ -320,13 +323,14 @@ def stream_wrapper(
request = next(request_iterator) # get first request to determine method
assert len(request.inputs) == 1, "Streaming requires exactly one input"

method_infos = self._get_method_infos()
method_name = 'stream'
if len(request.inputs) > 0 and '_method_name' in request.inputs[0].data.metadata:
method_name = request.inputs[0].data.metadata['_method_name']
method = getattr(self, method_name)
method_info = self._get_method_infos(method_name)
method_info = method_infos[method_name]
signature = method_info.signature
python_param_types = method_info.python_param_types
cast_types = method_info.cast_types

# find the streaming vars in the signature
stream_sig = get_stream_from_signature(signature.input_fields)
Expand All @@ -345,7 +349,7 @@ def stream_wrapper(
input.data.CopyFrom(new_data)
# convert all inputs for the first request, including the first stream value
inputs = self._convert_input_protos_to_python(
request.inputs, signature.input_fields, python_param_types
request.inputs, signature.input_fields, cast_types
)
kwargs = inputs[0]

Expand All @@ -358,7 +362,7 @@ def InputStream():
# subsequent streaming items contain only the streaming input
for request in request_iterator:
item = self._convert_input_protos_to_python(
request.inputs, [stream_sig], python_param_types
request.inputs, [stream_sig], cast_types
)
item = item[0][stream_argname]
yield item
Expand Down Expand Up @@ -392,26 +396,15 @@ def _convert_input_protos_to_python(
self,
inputs: List[resources_pb2.Input],
variables_signature: List[resources_pb2.ModelTypeField],
python_param_types,
cast_types: Dict[str, Any],
) -> List[Dict[str, Any]]:
result = []
for input in inputs:
kwargs = deserialize(input.data, variables_signature)
# dynamic cast to annotated types
for k, v in kwargs.items():
if k not in python_param_types:
continue

if hasattr(python_param_types[k], "__args__") and (
getattr(python_param_types[k], "__origin__", None)
in [abc.Iterator, abc.Generator, abc.Iterable]
):
# get the type of the items in the stream
stream_type = python_param_types[k].__args__[0]

kwargs[k] = data_types.cast(v, stream_type)
else:
kwargs[k] = data_types.cast(v, python_param_types[k])
if k in cast_types:
kwargs[k] = data_types.cast(v, cast_types[k])
result.append(kwargs)
return result

Expand Down Expand Up @@ -476,11 +469,9 @@ def _register_model_methods(cls):

@classmethod
def _get_method_infos(cls, func_name=None):
# FIXME: this is a re-use of the _METHOD_INFO_ATTR attribute to store the method info
# for all methods on the class. Should use a different attribute name to avoid confusion.
if not hasattr(cls, _METHOD_INFO_ATTR):
setattr(cls, _METHOD_INFO_ATTR, cls._register_model_methods())
method_infos = getattr(cls, _METHOD_INFO_ATTR)
if not hasattr(cls, _MODEL_METHODS_REGISTRY_ATTR):
setattr(cls, _MODEL_METHODS_REGISTRY_ATTR, cls._register_model_methods())
method_infos = getattr(cls, _MODEL_METHODS_REGISTRY_ATTR)
if func_name:
return method_infos[func_name]
return method_infos
Expand All @@ -506,3 +497,12 @@ def __init__(self, method, proto_method=False):
if p.annotation != inspect.Parameter.empty
}
self.python_param_types.pop('self', None)

self.cast_types = {}
for k, v in self.python_param_types.items():
if hasattr(v, "__args__") and (
getattr(v, "__origin__", None) in [abc.Iterator, abc.Generator, abc.Iterable]
):
self.cast_types[k] = v.__args__[0]
else:
self.cast_types[k] = v
120 changes: 64 additions & 56 deletions clarifai/runners/models/model_runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
import time
from typing import Iterator, Optional, Union
from typing import Iterable, Iterator, Optional, Union

from clarifai_grpc.grpc.api import service_pb2
from clarifai_grpc.grpc.api.status import status_code_pb2, status_pb2
Expand Down Expand Up @@ -45,6 +46,7 @@ def __init__(
pat,
token,
num_parallel_polls,
health_check_port=health_check_port,
**kwargs,
)
self.model = model
Expand Down Expand Up @@ -183,33 +185,31 @@ def runner_item_predict(
duration_ms = (time.time() - start_time) * 1000
logger.info(f"{endpoint} | {status_str} | {duration_ms:.2f}ms | req_id={req_id}")
return service_pb2.RunnerItemOutput(multi_output_response=resp)
successes = []

num_success = 0
num_total = len(resp.outputs)
for output in resp.outputs:
if not output.HasField('status') or not output.status.code:
raise Exception(
"Output must have a status code, please check the model implementation."
)
successes.append(output.status.code == status_code_pb2.SUCCESS)
if all(successes):
status = status_pb2.Status(
code=status_code_pb2.SUCCESS,
description="Success",
)
if output.status.code == status_code_pb2.SUCCESS:
num_success += 1

resp.status.Clear()
if num_success == num_total:
resp.status.code = status_code_pb2.SUCCESS
resp.status.description = "Success"
status_str = STATUS_OK
elif any(successes):
status = status_pb2.Status(
code=status_code_pb2.MIXED_STATUS,
description="Mixed Status",
)
elif num_success > 0:
resp.status.code = status_code_pb2.MIXED_STATUS
resp.status.description = "Mixed Status"
status_str = STATUS_MIXED
else:
status = status_pb2.Status(
code=status_code_pb2.FAILURE,
description="Failed",
)
resp.status.code = status_code_pb2.FAILURE
resp.status.description = "Failed"
status_str = STATUS_FAIL

resp.status.CopyFrom(status)
if logging:
duration_ms = (time.time() - start_time) * 1000
logger.info(f"{endpoint} | {status_str} | {duration_ms:.2f}ms | req_id={req_id}")
Expand Down Expand Up @@ -270,35 +270,35 @@ def runner_item_generate(
yield service_pb2.RunnerItemOutput(multi_output_response=resp)
continue
# Single output with non-SUCCESS status
resp.status.CopyFrom(
status_pb2.Status(code=status_code_pb2.FAILURE, description="Failed")
)
resp.status.Clear()
resp.status.code = status_code_pb2.FAILURE
resp.status.description = "Failed"
status_str = STATUS_FAIL
yield service_pb2.RunnerItemOutput(multi_output_response=resp)
continue

# Multi-output path (batch generate)
successes = []
# Multi-output path (batch generate) - use counters to avoid list allocation
num_success = 0
for output in outputs:
if not output.HasField('status') or not output.status.code:
raise Exception(
"Output must have a status code, please check the model implementation."
)
successes.append(output.status.code == status_code_pb2.SUCCESS)
if all(successes):
if output.status.code == status_code_pb2.SUCCESS:
num_success += 1

if num_success == num_outputs:
resp.status.CopyFrom(_success_status)
status_str = STATUS_OK
elif any(successes):
resp.status.CopyFrom(
status_pb2.Status(
code=status_code_pb2.MIXED_STATUS, description="Mixed Status"
)
)
elif num_success > 0:
resp.status.Clear()
resp.status.code = status_code_pb2.MIXED_STATUS
resp.status.description = "Mixed Status"
status_str = STATUS_MIXED
else:
resp.status.CopyFrom(
status_pb2.Status(code=status_code_pb2.FAILURE, description="Failed")
)
resp.status.Clear()
resp.status.code = status_code_pb2.FAILURE
resp.status.description = "Failed"
status_str = STATUS_FAIL

yield service_pb2.RunnerItemOutput(multi_output_response=resp)
Expand All @@ -307,7 +307,7 @@ def runner_item_generate(
logger.info(f"{endpoint} | {status_str} | {duration_ms:.2f}ms | req_id={req_id}")

def runner_item_stream(
self, runner_item_iterator: Iterator[service_pb2.RunnerItem]
self, runner_item_iterator: Iterable[service_pb2.RunnerItem]
) -> Iterator[service_pb2.RunnerItemOutput]:
# Call the generate() method the underlying model implements.
start_time = time.time()
Expand All @@ -317,14 +317,24 @@ def runner_item_stream(

# Get the first request to establish secrets context
first_request = None
runner_items = list(runner_item_iterator) # Convert to list to avoid consuming iterator
if runner_items:
first_request = runner_items[0].post_model_outputs_request
try:
runner_item_iterator = iter(runner_item_iterator)
first_runner_item = next(runner_item_iterator)
if not first_runner_item.HasField('post_model_outputs_request'):
raise Exception("Unexpected work item type: {}".format(first_runner_item))
first_request = first_runner_item.post_model_outputs_request
# Reconstruct the iterator using itertools.chain to avoid consuming the whole stream into memory
runner_items = itertools.chain((first_runner_item,), runner_item_iterator)
except StopIteration:
# No items in the stream: short-circuit and yield nothing.
duration_ms = (time.time() - start_time) * 1000
logger.info(f"{endpoint} | {status_str} | {duration_ms:.2f}ms | req_id={req_id}")
return

# Use req_secrets_context based on the first request (secrets should be consistent across stream)
with req_secrets_context(first_request):
for resp in self.model.stream_wrapper(
pmo_iterator(iter(runner_items), auth_helper=self._auth_helper)
pmo_iterator(runner_items, auth_helper=self._auth_helper)
):
# if we have any non-successful code already it's an error we can return.
if (
Expand All @@ -338,32 +348,30 @@ def runner_item_stream(
)
yield service_pb2.RunnerItemOutput(multi_output_response=resp)
continue
successes = []

num_success = 0
num_total = len(resp.outputs)
for output in resp.outputs:
if not output.HasField('status') or not output.status.code:
raise Exception(
"Output must have a status code, please check the model implementation."
)
successes.append(output.status.code == status_code_pb2.SUCCESS)
if all(successes):
status = status_pb2.Status(
code=status_code_pb2.SUCCESS,
description="Success",
)
if output.status.code == status_code_pb2.SUCCESS:
num_success += 1

resp.status.Clear()
if num_success == num_total:
resp.status.code = status_code_pb2.SUCCESS
resp.status.description = "Success"
status_str = STATUS_OK
elif any(successes):
status = status_pb2.Status(
code=status_code_pb2.MIXED_STATUS,
description="Mixed Status",
)
elif num_success > 0:
resp.status.code = status_code_pb2.MIXED_STATUS
resp.status.description = "Mixed Status"
status_str = STATUS_MIXED
else:
status = status_pb2.Status(
code=status_code_pb2.RUNNER_PROCESSING_FAILED,
description="Runner Processing Failed",
)
resp.status.code = status_code_pb2.RUNNER_PROCESSING_FAILED
resp.status.description = "Runner Processing Failed"
status_str = STATUS_FAIL
resp.status.CopyFrom(status)

yield service_pb2.RunnerItemOutput(multi_output_response=resp)

Expand Down
Loading
Loading