diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index 7057cc1..f4d545d 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -16,6 +16,7 @@ import asyncio import itertools import os +import threading import time import weakref from abc import ABC, abstractmethod @@ -45,6 +46,15 @@ TQ_STORAGE_HANDSHAKE_MAX_RETRIES = int(os.environ.get("TQ_STORAGE_HANDSHAKE_MAX_RETRIES", 3)) TQ_DATA_UPDATE_RESPONSE_TIMEOUT = int(os.environ.get("TQ_DATA_UPDATE_RESPONSE_TIMEOUT", 30)) + +def _run_notify_loop(notify_loop: asyncio.AbstractEventLoop) -> None: + asyncio.set_event_loop(notify_loop) + try: + notify_loop.run_forever() + finally: + notify_loop.close() + + LIMIT_THREADS_PER_MANAGER_IN_DRIVER = 8 LIMIT_THREADS_PER_MANAGER_IN_RAY_ACTOR = 4 @@ -61,9 +71,19 @@ def __init__(self, controller_info: ZMQServerInfo, config: DictConfig): # Handshake socket is sync (used only during initialization) self.controller_handshake_socket: zmq.Socket | None = None - self.zmq_context: zmq.asyncio.Context | None = None + self.zmq_context = zmq.asyncio.Context() self._connect_to_controller() + # Dedicated asyncio loop for ZMQ notify traffic, isolated from the caller's loop + self._notify_loop = asyncio.new_event_loop() + self._notify_thread = threading.Thread( + target=_run_notify_loop, + args=(self._notify_loop,), + daemon=True, + name=f"{self.storage_manager_id}-notify_data_status_loop", + ) + self._notify_thread.start() + def _connect_to_controller(self) -> None: """Initialize ZMQ sockets between storage unit and controller for handshake.""" if not isinstance(self.controller_info, ZMQServerInfo): @@ -90,9 +110,6 @@ def _connect_to_controller(self) -> None: self.controller_handshake_socket = None sync_zmq_context.term() - # create async context for data status update - self.zmq_context = zmq.asyncio.Context() - except Exception as e: logger.error(f"Failed to connect to controller: {e}") raise @@ -205,54 +222,51 @@ async def notify_data_update( logger.warning(f"No controller connected for storage manager {self.storage_manager_id}") return - # create dynamic socket - identity = f"{self.storage_manager_id}-data_update-{uuid4().hex[:8]}".encode() - sock = create_zmq_socket(self.zmq_context, zmq.DEALER, self.controller_info.ip, identity) + normalized_field_schema = {} + for field_name, field in field_schema.items(): + field_copy = field.copy() + per_sample_shapes = field_copy.get("per_sample_shapes", None) + if isinstance(per_sample_shapes, list | tuple): + if len(per_sample_shapes) != len(global_indexes): + raise ValueError( + f"per_sample_shapes length ({len(per_sample_shapes)}) does not match " + f"number of global_indexes ({len(global_indexes)}) for field '{field_name}'. " + ) + field_copy["per_sample_shapes"] = { + global_indexes[i]: per_sample_shapes[i] for i in range(len(global_indexes)) + } + normalized_field_schema[field_name] = field_copy - try: - sock.connect(self.controller_info.to_addr("request_handle_socket")) - - normalized_field_schema = {} - for field_name, field in field_schema.items(): - # Work on a shallow copy to avoid mutating caller-provided schema - field_copy = field.copy() - per_sample_shapes = field_copy.get("per_sample_shapes", None) - if isinstance(per_sample_shapes, list | tuple): - if len(per_sample_shapes) != len(global_indexes): - raise ValueError( - f"per_sample_shapes length ({len(per_sample_shapes)}) does not match " - f"number of global_indexes ({len(global_indexes)}) for field '{field_name}'; " - f"skipping per_sample_shapes normalization." - ) - else: - field_copy["per_sample_shapes"] = { - global_indexes[i]: per_sample_shapes[i] for i in range(len(global_indexes)) - } - - normalized_field_schema[field_name] = field_copy - - # convert per_sample_shapes into dict - for field in field_schema.values(): - per_sample_shapes = field.get("per_sample_shapes", None) - if per_sample_shapes: - per_sample_shapes = {global_indexes[i]: per_sample_shapes[i] for i in range(len(global_indexes))} - field["per_sample_shapes"] = per_sample_shapes - - request_msg = ZMQMessage.create( - request_type=ZMQRequestType.NOTIFY_DATA_UPDATE, # type: ignore[arg-type] - sender_id=self.storage_manager_id, - body={ - "partition_id": partition_id, - "global_indexes": global_indexes, - "field_schema": normalized_field_schema, - "custom_backend_meta": custom_backend_meta, - }, - ).serialize() + request_msg = ZMQMessage.create( + request_type=ZMQRequestType.NOTIFY_DATA_UPDATE, # type: ignore[arg-type] + sender_id=self.storage_manager_id, + body={ + "partition_id": partition_id, + "global_indexes": global_indexes, + "field_schema": normalized_field_schema, + "custom_backend_meta": custom_backend_meta, + }, + ).serialize() + + thread_future = asyncio.run_coroutine_threadsafe( + self._notify_and_wait(request_msg), + self._notify_loop, + ) + await asyncio.wrap_future(thread_future) + + async def _notify_and_wait(self, request_msg: list) -> None: + """Send a data status notification to the controller and block until ACK is received.""" + identity = f"{self.storage_manager_id}-notify-{uuid4().hex[:8]}".encode() + sock = create_zmq_socket( + ctx=self.zmq_context, socket_type=zmq.DEALER, ip=self.controller_info.ip, identity=identity + ) + sock.setsockopt(zmq.LINGER, 0) + sock.connect(self.controller_info.to_addr("request_handle_socket")) + try: await sock.send_multipart(request_msg) logger.debug( - f"[{self.storage_manager_id}]: Send data status update request " - f"from storage manager id #{self.storage_manager_id} " + f"[{self.storage_manager_id}]: Sent data status update request " f"to controller id #{self.controller_info.id} successfully." ) @@ -262,7 +276,10 @@ async def notify_data_update( while not response_received and timeout > 0: try: poll_interval = min(TQ_STORAGE_POLLER_TIMEOUT, timeout) - messages = await asyncio.wait_for(sock.recv_multipart(copy=False), timeout=poll_interval) + messages = await asyncio.wait_for( + sock.recv_multipart(copy=False), + timeout=poll_interval, + ) response_msg = ZMQMessage.deserialize(messages) if response_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE_ACK: # type: ignore[arg-type] @@ -271,6 +288,7 @@ async def notify_data_update( f"[{self.storage_manager_id}]: Get data status update ACK response " f"from controller id #{response_msg.sender_id} successfully." ) + break except asyncio.TimeoutError: timeout -= poll_interval except Exception as e: @@ -278,23 +296,14 @@ async def notify_data_update( break if not response_received: - logger.error(f"[{self.storage_manager_id}]: Did not receive data status update ACK.") - - except Exception as e: - logger.error(f"[{self.storage_manager_id}]: Error during notify_data_update: {e}") - try: - error_msg = ZMQMessage.create( - request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ERROR, # type: ignore[arg-type] - sender_id=self.storage_manager_id, - body={"message": f"Failed to notify: {str(e)}"}, - ).serialize() - await sock.send_multipart(error_msg) - except Exception: - pass + logger.error( + f"[{self.storage_manager_id}]: Timeout waiting for data status update ACK " + f"from controller after {TQ_DATA_UPDATE_RESPONSE_TIMEOUT}s." + ) finally: try: if not sock.closed: - sock.close(linger=-1) + sock.close(linger=0) except Exception: pass @@ -344,8 +353,8 @@ async def clear_data(self, metadata: BatchMeta) -> None: raise NotImplementedError("Subclasses must implement clear_data") def close(self) -> None: - """Close all ZMQ sockets and context to prevent resource leaks.""" - # Close handshake socket if it exists + """Close all ZMQ sockets/contexts and stop the notify loop.""" + if self.controller_handshake_socket: try: if not self.controller_handshake_socket.closed: @@ -353,8 +362,17 @@ def close(self) -> None: except Exception as e: logger.error(f"[{self.storage_manager_id}]: Error closing controller_handshake_socket: {str(e)}") - if self.zmq_context: - self.zmq_context.term() + if hasattr(self, "_notify_loop") and self._notify_loop.is_running(): + self._notify_loop.call_soon_threadsafe(self._notify_loop.stop) + + if hasattr(self, "_notify_thread") and self._notify_thread is not None: + self._notify_thread.join(timeout=5.0) + if self._notify_thread.is_alive(): + logger.warning(f"[{self.storage_manager_id}]: Notify ZMQ thread did not stop within 5 second timeout.") + else: + logger.debug(f"[{self.storage_manager_id}]: Notify ZMQ thread shut down.") + + self.zmq_context.term() def __del__(self): """Destructor to ensure resources are cleaned up."""