Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 85 additions & 67 deletions transfer_queue/storage/managers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import asyncio
import itertools
import os
import threading
import time
import weakref
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -45,6 +46,15 @@
TQ_STORAGE_HANDSHAKE_MAX_RETRIES = int(os.environ.get("TQ_STORAGE_HANDSHAKE_MAX_RETRIES", 3))
TQ_DATA_UPDATE_RESPONSE_TIMEOUT = int(os.environ.get("TQ_DATA_UPDATE_RESPONSE_TIMEOUT", 30))


def _run_notify_loop(notify_loop: asyncio.AbstractEventLoop) -> None:
asyncio.set_event_loop(notify_loop)
try:
notify_loop.run_forever()
finally:
notify_loop.close()


LIMIT_THREADS_PER_MANAGER_IN_DRIVER = 8
LIMIT_THREADS_PER_MANAGER_IN_RAY_ACTOR = 4

Expand All @@ -61,9 +71,19 @@ def __init__(self, controller_info: ZMQServerInfo, config: DictConfig):
# Handshake socket is sync (used only during initialization)
self.controller_handshake_socket: zmq.Socket | None = None

self.zmq_context: zmq.asyncio.Context | None = None
self.zmq_context = zmq.asyncio.Context()
self._connect_to_controller()

# Dedicated asyncio loop for ZMQ notify traffic, isolated from the caller's loop
self._notify_loop = asyncio.new_event_loop()
self._notify_thread = threading.Thread(
target=_run_notify_loop,
args=(self._notify_loop,),
daemon=True,
name=f"{self.storage_manager_id}-notify_data_status_loop",
)
self._notify_thread.start()

def _connect_to_controller(self) -> None:
"""Initialize ZMQ sockets between storage unit and controller for handshake."""
if not isinstance(self.controller_info, ZMQServerInfo):
Expand All @@ -90,9 +110,6 @@ def _connect_to_controller(self) -> None:
self.controller_handshake_socket = None
sync_zmq_context.term()

# create async context for data status update
self.zmq_context = zmq.asyncio.Context()

except Exception as e:
logger.error(f"Failed to connect to controller: {e}")
raise
Expand Down Expand Up @@ -205,54 +222,51 @@ async def notify_data_update(
logger.warning(f"No controller connected for storage manager {self.storage_manager_id}")
return

# create dynamic socket
identity = f"{self.storage_manager_id}-data_update-{uuid4().hex[:8]}".encode()
sock = create_zmq_socket(self.zmq_context, zmq.DEALER, self.controller_info.ip, identity)
normalized_field_schema = {}
for field_name, field in field_schema.items():
field_copy = field.copy()
per_sample_shapes = field_copy.get("per_sample_shapes", None)
if isinstance(per_sample_shapes, list | tuple):
if len(per_sample_shapes) != len(global_indexes):
raise ValueError(
f"per_sample_shapes length ({len(per_sample_shapes)}) does not match "
f"number of global_indexes ({len(global_indexes)}) for field '{field_name}'. "
)
Comment thread
0oshowero0 marked this conversation as resolved.
field_copy["per_sample_shapes"] = {
global_indexes[i]: per_sample_shapes[i] for i in range(len(global_indexes))
}
normalized_field_schema[field_name] = field_copy

try:
sock.connect(self.controller_info.to_addr("request_handle_socket"))

normalized_field_schema = {}
for field_name, field in field_schema.items():
# Work on a shallow copy to avoid mutating caller-provided schema
field_copy = field.copy()
per_sample_shapes = field_copy.get("per_sample_shapes", None)
if isinstance(per_sample_shapes, list | tuple):
if len(per_sample_shapes) != len(global_indexes):
raise ValueError(
f"per_sample_shapes length ({len(per_sample_shapes)}) does not match "
f"number of global_indexes ({len(global_indexes)}) for field '{field_name}'; "
f"skipping per_sample_shapes normalization."
)
else:
field_copy["per_sample_shapes"] = {
global_indexes[i]: per_sample_shapes[i] for i in range(len(global_indexes))
}

normalized_field_schema[field_name] = field_copy

# convert per_sample_shapes into dict
for field in field_schema.values():
per_sample_shapes = field.get("per_sample_shapes", None)
if per_sample_shapes:
per_sample_shapes = {global_indexes[i]: per_sample_shapes[i] for i in range(len(global_indexes))}
field["per_sample_shapes"] = per_sample_shapes

request_msg = ZMQMessage.create(
request_type=ZMQRequestType.NOTIFY_DATA_UPDATE, # type: ignore[arg-type]
sender_id=self.storage_manager_id,
body={
"partition_id": partition_id,
"global_indexes": global_indexes,
"field_schema": normalized_field_schema,
"custom_backend_meta": custom_backend_meta,
},
).serialize()
request_msg = ZMQMessage.create(
request_type=ZMQRequestType.NOTIFY_DATA_UPDATE, # type: ignore[arg-type]
sender_id=self.storage_manager_id,
body={
"partition_id": partition_id,
"global_indexes": global_indexes,
"field_schema": normalized_field_schema,
"custom_backend_meta": custom_backend_meta,
},
).serialize()

thread_future = asyncio.run_coroutine_threadsafe(
self._notify_and_wait(request_msg),
self._notify_loop,
)
await asyncio.wrap_future(thread_future)
Comment on lines +251 to +255

async def _notify_and_wait(self, request_msg: list) -> None:
"""Send a data status notification to the controller and block until ACK is received."""
identity = f"{self.storage_manager_id}-notify-{uuid4().hex[:8]}".encode()
sock = create_zmq_socket(
ctx=self.zmq_context, socket_type=zmq.DEALER, ip=self.controller_info.ip, identity=identity
)
sock.setsockopt(zmq.LINGER, 0)
sock.connect(self.controller_info.to_addr("request_handle_socket"))
Comment thread
0oshowero0 marked this conversation as resolved.

try:
await sock.send_multipart(request_msg)
logger.debug(
f"[{self.storage_manager_id}]: Send data status update request "
f"from storage manager id #{self.storage_manager_id} "
f"[{self.storage_manager_id}]: Sent data status update request "
f"to controller id #{self.controller_info.id} successfully."
)

Expand All @@ -262,7 +276,10 @@ async def notify_data_update(
while not response_received and timeout > 0:
try:
poll_interval = min(TQ_STORAGE_POLLER_TIMEOUT, timeout)
messages = await asyncio.wait_for(sock.recv_multipart(copy=False), timeout=poll_interval)
messages = await asyncio.wait_for(
sock.recv_multipart(copy=False),
timeout=poll_interval,
)
response_msg = ZMQMessage.deserialize(messages)

if response_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE_ACK: # type: ignore[arg-type]
Expand All @@ -271,30 +288,22 @@ async def notify_data_update(
f"[{self.storage_manager_id}]: Get data status update ACK response "
f"from controller id #{response_msg.sender_id} successfully."
)
break
except asyncio.TimeoutError:
timeout -= poll_interval
except Exception as e:
Comment thread
0oshowero0 marked this conversation as resolved.
logger.warning(f"[{self.storage_manager_id}]: Error receiving response: {e}")
break

if not response_received:
logger.error(f"[{self.storage_manager_id}]: Did not receive data status update ACK.")

except Exception as e:
logger.error(f"[{self.storage_manager_id}]: Error during notify_data_update: {e}")
try:
error_msg = ZMQMessage.create(
request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ERROR, # type: ignore[arg-type]
sender_id=self.storage_manager_id,
body={"message": f"Failed to notify: {str(e)}"},
).serialize()
await sock.send_multipart(error_msg)
except Exception:
pass
logger.error(
f"[{self.storage_manager_id}]: Timeout waiting for data status update ACK "
f"from controller after {TQ_DATA_UPDATE_RESPONSE_TIMEOUT}s."
)
Comment thread
0oshowero0 marked this conversation as resolved.
finally:
try:
if not sock.closed:
sock.close(linger=-1)
sock.close(linger=0)
except Exception:
pass

Expand Down Expand Up @@ -344,17 +353,26 @@ async def clear_data(self, metadata: BatchMeta) -> None:
raise NotImplementedError("Subclasses must implement clear_data")

def close(self) -> None:
"""Close all ZMQ sockets and context to prevent resource leaks."""
# Close handshake socket if it exists
"""Close all ZMQ sockets/contexts and stop the notify loop."""

if self.controller_handshake_socket:
try:
if not self.controller_handshake_socket.closed:
self.controller_handshake_socket.close(linger=0)
except Exception as e:
logger.error(f"[{self.storage_manager_id}]: Error closing controller_handshake_socket: {str(e)}")

if self.zmq_context:
self.zmq_context.term()
if hasattr(self, "_notify_loop") and self._notify_loop.is_running():
self._notify_loop.call_soon_threadsafe(self._notify_loop.stop)

if hasattr(self, "_notify_thread") and self._notify_thread is not None:
self._notify_thread.join(timeout=5.0)
if self._notify_thread.is_alive():
logger.warning(f"[{self.storage_manager_id}]: Notify ZMQ thread did not stop within 5 second timeout.")
else:
logger.debug(f"[{self.storage_manager_id}]: Notify ZMQ thread shut down.")

self.zmq_context.term()

def __del__(self):
"""Destructor to ensure resources are cleaned up."""
Expand Down
Loading