diff --git a/tests/e2e/test_checkpoint_e2e.py b/tests/e2e/test_checkpoint_e2e.py new file mode 100644 index 0000000..c41432a --- /dev/null +++ b/tests/e2e/test_checkpoint_e2e.py @@ -0,0 +1,395 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""End-to-end tests for save_checkpoint and load_checkpoint. + +Run with: + pytest tests/e2e/test_checkpoint_e2e.py -v +""" + +import json +import os + +import pytest +import ray +import torch +from omegaconf import OmegaConf +from tensordict import TensorDict + +import transfer_queue as tq + +os.environ["RAY_DEDUP_LOGS"] = "0" + +_TQ_CONFIG = OmegaConf.create( + { + "controller": {"polling_mode": True}, + "backend": { + "storage_backend": "SimpleStorage", + "SimpleStorage": { + "total_storage_size": 200, + "num_data_storage_units": 2, + }, + }, + } +) + + +@pytest.fixture(scope="module") +def ray_init(): + if not ray.is_initialized(): + ray.init(namespace="TestCheckpointE2E") + yield + if ray.is_initialized(): + ray.shutdown() + + +@pytest.fixture(scope="module") +def tq_system(ray_init): + tq.init(_TQ_CONFIG) + yield + tq.close() + + +@pytest.fixture +def controller(tq_system): + return ray.get_actor("TransferQueueController", namespace="transfer_queue") + + +@pytest.fixture(autouse=True) +def cleanup_partitions(controller): + yield + try: + for pid in ray.get(controller.list_partitions.remote()): + ray.get(controller.clear_partition.remote(pid)) + except Exception: + pass + + +@pytest.fixture +def checkpoint_dir(tmp_path): + return tmp_path / "checkpoint" + + +# --------------------------------------------------------------------------- +# helpers +# --------------------------------------------------------------------------- + + +def _put_batch(keys, partition_id, input_ids, attention_mask, tags=None): + fields = TensorDict( + {"input_ids": input_ids, "attention_mask": attention_mask}, + batch_size=len(keys), + ) + if tags is None: + tags = [{} for _ in keys] + tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=fields, tags=tags) + + +def _get_batch(keys, partition_id): + return tq.kv_batch_get(keys=keys, partition_id=partition_id) + + +def assert_tensor_equal(tensor_a, tensor_b, msg=""): + """Assert two tensors are equal, handling nested vs dense comparisons.""" + if (isinstance(tensor_a, torch.Tensor) and tensor_a.is_nested) or ( + isinstance(tensor_b, torch.Tensor) and tensor_b.is_nested + ): + seq_a = list(tensor_a) + seq_b = list(tensor_b) + assert len(seq_a) == len(seq_b), f"{msg} Length mismatch: {len(seq_a)} vs {len(seq_b)}" + for t1, t2 in zip(seq_a, seq_b, strict=True): + assert torch.equal(t1, t2), f"{msg} Tensors are not equal: {tensor_a} vs {tensor_b}" + else: + assert torch.equal(tensor_a, tensor_b), f"{msg} Tensors are not equal: {tensor_a} vs {tensor_b}" + + +# --------------------------------------------------------------------------- +# basic save / load roundtrip +# --------------------------------------------------------------------------- + + +class TestCheckpointRoundtrip: + def test_save_creates_expected_files(self, tq_system, checkpoint_dir): + keys = ["k0", "k1"] + partition_id = "p0" + _put_batch(keys, partition_id, torch.tensor([[1, 2], [3, 4]]), torch.ones(2, 2)) + + tq.save_checkpoint(checkpoint_dir) + + assert (checkpoint_dir / "metadata.json").exists() + assert (checkpoint_dir / "controller_state.pkl").exists() + + with open(checkpoint_dir / "metadata.json") as f: + info = json.load(f) + + # two storage units configured + assert len(info["storage_units"]) == 2 + su_dir = checkpoint_dir / "storage_units" + for entry in info["storage_units"]: + assert (su_dir / f"su_{entry['position']}_{entry['storage_unit_id']}.pkl").exists() + + def test_metadata_json_content(self, tq_system, checkpoint_dir): + keys = ["m0"] + _put_batch(keys, "p_meta", torch.tensor([[10, 20]]), torch.ones(1, 2)) + + tq.save_checkpoint(checkpoint_dir, metadata={"iteration": 42, "loss": 0.5}) + + with open(checkpoint_dir / "metadata.json") as f: + meta = json.load(f) + + assert meta["user_metadata"]["iteration"] == 42 + assert meta["user_metadata"]["loss"] == pytest.approx(0.5) + assert "storage_units" in meta + + def test_load_restores_controller_partitions(self, tq_system, checkpoint_dir, controller): + keys = ["a0", "a1", "a2"] + partition_id = "p_ctrl" + input_ids = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + tags = [{"idx": i} for i in range(3)] + _put_batch(keys, partition_id, input_ids, torch.ones(3, 3), tags) + + tq.save_checkpoint(checkpoint_dir) + + # wipe controller state + ray.get(controller.clear_partition.remote(partition_id)) + assert ray.get(controller.list_partitions.remote()) == [] + + tq.load_checkpoint(checkpoint_dir) + + # partition must be back + partitions = ray.get(controller.list_partitions.remote()) + assert partition_id in partitions + + # key-to-global-index mapping must be intact + snapshot = ray.get(controller.get_partition_snapshot.remote(partition_id)) + for key in keys: + assert key in snapshot.keys_mapping + + # tags must be intact + for i, key in enumerate(keys): + gidx = snapshot.keys_mapping[key] + assert snapshot.custom_meta[gidx]["idx"] == i + + def test_load_restores_storage_data(self, tq_system, checkpoint_dir, controller): + keys = ["s0", "s1"] + partition_id = "p_storage" + input_ids = torch.tensor([[10, 20], [30, 40]]) + attention_mask = torch.ones(2, 2) + _put_batch(keys, partition_id, input_ids, attention_mask) + + tq.save_checkpoint(checkpoint_dir) + + # clear both controller and storage state so load has to restore from scratch + ray.get(controller.clear_partition.remote(partition_id)) + + tq.load_checkpoint(checkpoint_dir) + + retrieved = _get_batch(keys, partition_id) + assert_tensor_equal(retrieved["input_ids"], input_ids) + assert_tensor_equal(retrieved["attention_mask"], attention_mask) + + def test_load_restores_multiple_partitions(self, tq_system, checkpoint_dir, controller): + for i in range(3): + _put_batch( + [f"p{i}_k0", f"p{i}_k1"], + f"part_{i}", + torch.full((2, 4), i, dtype=torch.long), + torch.ones(2, 4), + ) + + tq.save_checkpoint(checkpoint_dir) + + for i in range(3): + ray.get(controller.clear_partition.remote(f"part_{i}")) + + tq.load_checkpoint(checkpoint_dir) + + for i in range(3): + retrieved = tq.kv_batch_get( + keys=[f"p{i}_k0", f"p{i}_k1"], + partition_id=f"part_{i}", + select_fields=["input_ids"], + ) + assert_tensor_equal(retrieved["input_ids"], torch.full((2, 4), i, dtype=torch.long)) + + +# --------------------------------------------------------------------------- +# include_storage=False +# --------------------------------------------------------------------------- + + +class TestCheckpointMetadataOnly: + def test_save_simplestorage_always_includes_storage(self, tq_system, checkpoint_dir): + """SimpleStorage is in-memory, so storage data is always saved regardless of include_storage.""" + _put_batch(["n0"], "p_nometa", torch.tensor([[1, 2]]), torch.ones(1, 2)) + + tq.save_checkpoint(checkpoint_dir, include_storage=False) + + with open(checkpoint_dir / "metadata.json") as f: + info = json.load(f) + + # SimpleStorage backend: storage_checkpoint_required=True, so data is always saved + assert len(info["storage_units"]) == 2 + assert (checkpoint_dir / "storage_units").exists() + + def test_load_after_include_storage_false_simplestorage(self, tq_system, checkpoint_dir, controller): + """Even with include_storage=False, SimpleStorage data is saved and can be restored.""" + keys = ["n0", "n1"] + partition_id = "p_nometa2" + input_ids = torch.tensor([[5, 6], [7, 8]]) + _put_batch(keys, partition_id, input_ids, torch.ones(2, 2)) + + tq.save_checkpoint(checkpoint_dir, include_storage=False) + + ray.get(controller.clear_partition.remote(partition_id)) + + tq.load_checkpoint(checkpoint_dir) + + partitions = ray.get(controller.list_partitions.remote()) + assert partition_id in partitions + + snapshot = ray.get(controller.get_partition_snapshot.remote(partition_id)) + for key in keys: + assert key in snapshot.keys_mapping + + retrieved = _get_batch(keys, partition_id) + assert_tensor_equal(retrieved["input_ids"], input_ids) + + +# --------------------------------------------------------------------------- +# error handling +# --------------------------------------------------------------------------- + + +class TestCheckpointErrors: + def test_save_raises_if_not_initialized(self, tmp_path): + # call save_checkpoint before tq.init() in a fresh module state + import transfer_queue.interface as iface + + original = iface._TQ_CONTROLLER + try: + iface._TQ_CONTROLLER = None + with pytest.raises(RuntimeError, match="not initialized"): + tq.save_checkpoint(tmp_path / "ck") + finally: + iface._TQ_CONTROLLER = original + + def test_load_raises_if_not_initialized(self, tmp_path): + import transfer_queue.interface as iface + + original = iface._TQ_CONTROLLER + try: + iface._TQ_CONTROLLER = None + with pytest.raises(RuntimeError, match="not initialized"): + tq.load_checkpoint(tmp_path / "ck") + finally: + iface._TQ_CONTROLLER = original + + def test_load_raises_if_dir_missing(self, tq_system, tmp_path): + with pytest.raises(FileNotFoundError): + tq.load_checkpoint(tmp_path / "nonexistent") + + def test_load_raises_if_metadata_missing(self, tq_system, tmp_path): + ck = tmp_path / "ck" + ck.mkdir() + with pytest.raises(FileNotFoundError, match="metadata.json"): + tq.load_checkpoint(ck) + + def test_load_raises_on_storage_unit_count_mismatch(self, tq_system, tmp_path, checkpoint_dir): + _put_batch(["e0"], "p_err", torch.tensor([[1, 2]]), torch.ones(1, 2)) + tq.save_checkpoint(checkpoint_dir) + + # tamper: add a fake extra entry so count differs + meta_path = checkpoint_dir / "metadata.json" + with open(meta_path) as f: + meta = json.load(f) + meta["storage_units"].append({"position": 99, "storage_unit_id": "fake", "file_size": 0}) + with open(meta_path, "w") as f: + json.dump(meta, f) + + with pytest.raises(ValueError, match="count mismatch"): + tq.load_checkpoint(checkpoint_dir) + + def test_no_partial_state_on_failed_save(self, tq_system, tmp_path): + """A failed save must not leave a partial directory.""" + _put_batch(["f0"], "p_fail", torch.tensor([[1, 2]]), torch.ones(1, 2)) + + ck = tmp_path / "ck" + + import unittest.mock as mock + + with mock.patch( + "transfer_queue.client.TransferQueueClient.dump_storage_checkpoint", + side_effect=RuntimeError("simulated dump failure"), + ): + with pytest.raises(RuntimeError, match="simulated dump failure"): + tq.save_checkpoint(ck) + + assert not ck.exists(), "Partial checkpoint directory should have been cleaned up" + assert not (tmp_path / "ck.tmp").exists(), "Temp directory should have been cleaned up" + + +# --------------------------------------------------------------------------- +# data variety +# --------------------------------------------------------------------------- + + +class TestCheckpointDataVariety: + def test_non_tensor_fields_roundtrip(self, tq_system, checkpoint_dir, controller): + """String fields should survive save/load.""" + from tensordict import NonTensorStack + + keys = ["t0", "t1"] + partition_id = "p_str" + fields = TensorDict( + { + "input_ids": torch.tensor([[1, 2], [3, 4]]), + "text": NonTensorStack("hello", "world"), + }, + batch_size=2, + ) + tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=fields, tags=[{}, {}]) + + tq.save_checkpoint(checkpoint_dir) + + ray.get(controller.clear_partition.remote(partition_id)) + + tq.load_checkpoint(checkpoint_dir) + + retrieved = tq.kv_batch_get(keys=keys, partition_id=partition_id, select_fields=["input_ids"]) + assert_tensor_equal(retrieved["input_ids"], torch.tensor([[1, 2], [3, 4]])) + + def test_nested_tensor_fields_roundtrip(self, tq_system, checkpoint_dir, controller): + """Variable-length (jagged) tensor fields should survive save/load.""" + keys = ["j0", "j1", "j2"] + partition_id = "p_jagged" + for i, key in enumerate(keys): + seq = torch.arange(i + 1, dtype=torch.float).unsqueeze(0) + tq.kv_put( + key=key, + partition_id=partition_id, + fields=TensorDict({"seq": seq}, batch_size=1), + tag=None, + ) + + tq.save_checkpoint(checkpoint_dir) + + ray.get(controller.clear_partition.remote(partition_id)) + + tq.load_checkpoint(checkpoint_dir) + + retrieved = tq.kv_batch_get(keys=keys, partition_id=partition_id, select_fields=["seq"]) + for i, component in enumerate(retrieved["seq"].unbind()): + assert_tensor_equal(component, torch.arange(i + 1, dtype=torch.float)) diff --git a/transfer_queue/__init__.py b/transfer_queue/__init__.py index 0675427..ad8e8ab 100644 --- a/transfer_queue/__init__.py +++ b/transfer_queue/__init__.py @@ -34,6 +34,8 @@ kv_clear, kv_list, kv_put, + load_checkpoint, + save_checkpoint, ) from .metadata import BatchMeta, KVBatchMeta from .sampler import BaseSampler @@ -62,6 +64,11 @@ "async_kv_clear", "KVBatchMeta", ] + + [ + # Checkpoint Interface + "save_checkpoint", + "load_checkpoint", + ] + [ # High-Level StreamingDataLoader Interface "StreamingDataset", diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 45e6303..9ba558b 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -1061,6 +1061,75 @@ def close(self) -> None: except Exception as e: logger.warning(f"Error closing storage manager: {e}") + # ==================== Checkpoint API ==================== + @with_controller_socket + async def async_dump_controller_checkpoint( + self, + path: str, + socket: zmq.asyncio.Socket | None = None, + ) -> None: + """Send CHECKPOINT_DUMP to controller and wait for response.""" + assert socket is not None + request_msg = ZMQMessage.create( + request_type=ZMQRequestType.CHECKPOINT_DUMP, # type: ignore[arg-type] + sender_id=self.client_id, + receiver_id=self._controller.id, + body={"path": path}, + ) + await socket.send_multipart(request_msg.serialize()) + response_serialized = await socket.recv_multipart(copy=False) + response_msg = ZMQMessage.deserialize(response_serialized) + if response_msg.request_type != ZMQRequestType.CHECKPOINT_DUMP_RESPONSE: + raise RuntimeError( + f"[{self.client_id}]: Unexpected response type {response_msg.request_type} " + f"from controller during checkpoint dump" + ) + if not response_msg.body.get("success"): + raise RuntimeError(f"[{self.client_id}]: Controller failed to dump checkpoint to {path}") + + @with_controller_socket + async def async_restore_controller_checkpoint( + self, + path: str, + socket: zmq.asyncio.Socket | None = None, + ) -> None: + """Send CHECKPOINT_RESTORE to controller and wait for response.""" + assert socket is not None + request_msg = ZMQMessage.create( + request_type=ZMQRequestType.CHECKPOINT_RESTORE, # type: ignore[arg-type] + sender_id=self.client_id, + receiver_id=self._controller.id, + body={"path": path}, + ) + await socket.send_multipart(request_msg.serialize()) + response_serialized = await socket.recv_multipart(copy=False) + response_msg = ZMQMessage.deserialize(response_serialized) + if response_msg.request_type != ZMQRequestType.CHECKPOINT_RESTORE_RESPONSE: + raise RuntimeError( + f"[{self.client_id}]: Unexpected response type {response_msg.request_type} " + f"from controller during checkpoint restore" + ) + if not response_msg.body.get("success"): + raise RuntimeError(f"[{self.client_id}]: Controller failed to restore checkpoint from {path}") + + async def async_dump_storage_checkpoint(self, output_dir: str) -> list[dict]: + """Dump all storage units to files in output_dir via StorageManager.""" + if not hasattr(self, "storage_manager") or self.storage_manager is None: + raise RuntimeError( + f"[{self.client_id}]: Storage manager not initialized. " + "Call initialize_storage_manager() before checkpoint operations." + ) + return await self.storage_manager.dump_checkpoint(output_dir) + + async def async_restore_storage_checkpoint(self, checkpoint_dir: str, su_info_list: list[dict]) -> None: + """Restore all storage units from files in checkpoint_dir via StorageManager.""" + if not hasattr(self, "storage_manager") or self.storage_manager is None: + raise RuntimeError( + f"[{self.client_id}]: Storage manager not initialized. " + "Call initialize_storage_manager() before checkpoint operations." + ) + await self.storage_manager.restore_checkpoint(checkpoint_dir, su_info_list) + class TransferQueueClient(AsyncTransferQueueClient): """Synchronous client wrapper for TransferQueue. @@ -1129,6 +1198,10 @@ def wrapper(*args, **kwargs): self._kv_retrieve_meta = _make_sync(self.async_kv_retrieve_meta) self._kv_retrieve_keys = _make_sync(self.async_kv_retrieve_keys) self._kv_list = _make_sync(self.async_kv_list) + self._dump_controller_checkpoint = _make_sync(self.async_dump_controller_checkpoint) + self._restore_controller_checkpoint = _make_sync(self.async_restore_controller_checkpoint) + self._dump_storage_checkpoint = _make_sync(self.async_dump_storage_checkpoint) + self._restore_storage_checkpoint = _make_sync(self.async_restore_storage_checkpoint) # ==================== Basic API ==================== def get_meta( @@ -1561,6 +1634,23 @@ def kv_list( return self._kv_list(partition_id=partition_id) + # ==================== Checkpoint API ==================== + def dump_controller_checkpoint(self, path: str) -> None: + """Synchronously dump controller state to a file via ZMQ RPC.""" + return self._dump_controller_checkpoint(path) + + def restore_controller_checkpoint(self, path: str) -> None: + """Synchronously restore controller state from a file via ZMQ RPC.""" + return self._restore_controller_checkpoint(path) + + def dump_storage_checkpoint(self, output_dir: str) -> list[dict]: + """Synchronously dump all storage units to files in output_dir.""" + return self._dump_storage_checkpoint(output_dir) + + def restore_storage_checkpoint(self, checkpoint_dir: str, su_info_list: list[dict]) -> None: + """Synchronously restore all storage units from files in checkpoint_dir.""" + return self._restore_storage_checkpoint(checkpoint_dir, su_info_list) + def close(self) -> None: """Close the client and cleanup resources including event loop and thread.""" diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 1182a44..34304a9 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -15,6 +15,7 @@ import copy import os +import pickle import time from collections import defaultdict from dataclasses import dataclass, field @@ -2031,6 +2032,26 @@ def _process_request(self): body={"partition_info": partition_info, "message": message}, ) + elif request_msg.request_type == ZMQRequestType.CHECKPOINT_DUMP: + path = request_msg.body["path"] + success = self.dump_to_file(path) + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.CHECKPOINT_DUMP_RESPONSE, + sender_id=self.controller_id, + receiver_id=request_msg.sender_id, + body={"success": success}, + ) + + elif request_msg.request_type == ZMQRequestType.CHECKPOINT_RESTORE: + path = request_msg.body["path"] + success = self.restore_from_file(path) + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.CHECKPOINT_RESTORE_RESPONSE, + sender_id=self.controller_id, + receiver_id=request_msg.sender_id, + body={"success": success}, + ) + self.request_handle_socket.send_multipart([identity, *response_msg.serialize()]) def get_zmq_server_info(self) -> ZMQServerInfo: @@ -2045,6 +2066,74 @@ def get_config(self) -> DictConfig: """Retrieve the global config of TransferQueue.""" return self.tq_config + def dump_to_file(self, path: str) -> bool: + """Serialize controller state directly to a file. + + Writes in-process to avoid transmitting the payload back over the + Ray object store — only a bool ACK is returned to the caller. + + Args: + path: Absolute path for the output .pkl file. + + Returns: + True on success, False on failure. + """ + try: + state = { + "controller_id": self.controller_id, + "partitions": {pid: p.to_snapshot() for pid, p in self.partitions.items()}, + "index_manager": { + "partition_to_indexes": dict(copy.deepcopy(self.index_manager.partition_to_indexes)), + "reusable_indexes": list(self.index_manager.reusable_indexes), + "global_index_counter": self.index_manager.global_index_counter, + "allocated_indexes": set(self.index_manager.allocated_indexes), + }, + "sampler": self.sampler.get_state() if hasattr(self.sampler, "get_state") else None, + # tq_config and connected_storage_managers are intentionally excluded: + # tq_config may differ after reinit (e.g. different ports/addresses); + # connected_storage_managers holds UUIDs that are regenerated on every reinit. + } + with open(path, "wb") as f: + pickle.dump(state, f, protocol=pickle.HIGHEST_PROTOCOL) + logger.info(f"[{self.controller_id}]: dumped to {path}") + return True + except Exception as e: + logger.error(f"[{self.controller_id}]: dump_to_file failed: {e}") + return False + + def restore_from_file(self, path: str) -> bool: + """Restore controller state directly from a file. + + Args: + path: Absolute path to a .pkl file previously written by dump_to_file(). + + Returns: + True on success, False on failure. + """ + try: + with open(path, "rb") as f: + state = pickle.load(f) + + self.controller_id = state["controller_id"] + self.partitions = state["partitions"] + + im = state["index_manager"] + self.index_manager.partition_to_indexes = defaultdict(set, im["partition_to_indexes"]) + self.index_manager.reusable_indexes = im["reusable_indexes"] + self.index_manager.global_index_counter = im["global_index_counter"] + self.index_manager.allocated_indexes = im["allocated_indexes"] + + if state["sampler"] is not None and hasattr(self.sampler, "restore_state"): + self.sampler.restore_state(state["sampler"]) + # tq_config and connected_storage_managers are not restored: they belong to + # the new init context, not the checkpoint. + + logger.info(f"[{self.controller_id}]: restored from {path}") + return True + except Exception as e: + logger.error(f"[{self.controller_id}]: restore_from_file failed: {e}") + return False + def register_sampler( self, sampler: BaseSampler | type[BaseSampler] = SequentialSampler, diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 82ceaca..5fcd3a1 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -13,10 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os +import shutil import subprocess import time from importlib import resources +from pathlib import Path from typing import Any, Callable import ray @@ -1047,3 +1050,157 @@ def get_client(): """Get a TransferQueueClient for using low-level API""" assert _TQ_CLIENT is not None, "Please initialize the TransferQueue first by calling `tq.init()`!" return _TQ_CLIENT + + +# ==================== Checkpoint API ==================== + +_METADATA_FILE = "metadata.json" +_CONTROLLER_FILE = "controller_state.pkl" +_STORAGE_UNITS_DIR = "storage_units" + + +def save_checkpoint( + checkpoint_dir: str | Path, + *, + include_storage: bool = True, + metadata: dict[str, Any] | None = None, +) -> None: + """Save a full checkpoint of the TransferQueue system state. + + .. note:: + **Multi-node limitation**: checkpoint_dir must reside on a shared network + filesystem (e.g. NFS, GPFS, Lustre) accessible from all nodes running + storage units. Single-node deployments have no such requirement. + + Args: + checkpoint_dir: Directory to save the checkpoint (created if not exists). + include_storage: Controls whether storage contents are saved for KV backends + (MooncakeStore, Yuanrong, etc.) that persist data externally. + Ignored for SimpleStorage — its in-memory data is always saved + because it would be lost on restart. + metadata: User-defined key-value pairs written into metadata.json. + Example: {"timestamp": 1718123456.789012, "step": 1000} + + Raises: + RuntimeError: TransferQueue is not initialized. + OSError: Failed to write checkpoint files. + """ + if _TQ_CONTROLLER is None: + raise RuntimeError("TransferQueue is not initialized. Call tq.init() first.") + + checkpoint_dir = Path(checkpoint_dir) + tmp_dir = checkpoint_dir.parent / (checkpoint_dir.name + ".tmp") + + if tmp_dir.exists(): + shutil.rmtree(tmp_dir) + tmp_dir.mkdir(parents=True) + + try: + client = _maybe_create_tq_client() + + # Controller dumps itself to file via ZMQ + controller_path = tmp_dir / _CONTROLLER_FILE + client.dump_controller_checkpoint(str(controller_path)) + controller_size = controller_path.stat().st_size + logger.info(f"Controller state saved ({controller_size} bytes)") + + su_info_list: list[dict[str, Any]] = [] + if hasattr(client, "storage_manager") and client.storage_manager is not None: + sm = client.storage_manager + if sm.storage_checkpoint_required or include_storage: + su_dir = tmp_dir / _STORAGE_UNITS_DIR + su_dir.mkdir() + try: + su_info_list = client.dump_storage_checkpoint(str(su_dir)) + for info in su_info_list: + logger.info(f"Storage unit {info['storage_unit_id']} (pos={info['position']}) ") + except NotImplementedError: + logger.warning("Storage backend does not support checkpoint; storage data will not be saved.") + + # Write metadata.json + meta_content = { + "storage_units": su_info_list, + "user_metadata": metadata or {}, + } + with open(tmp_dir / _METADATA_FILE, "w") as f: + json.dump(meta_content, f, indent=2) + + # Atomic rename into final location + if checkpoint_dir.exists(): + shutil.rmtree(checkpoint_dir) + tmp_dir.rename(checkpoint_dir) + + logger.info(f"Checkpoint saved to {checkpoint_dir}") + + except Exception: + if tmp_dir.exists(): + shutil.rmtree(tmp_dir) + raise + + +def load_checkpoint( + checkpoint_dir: str | Path, +) -> None: + """Restore TransferQueue system state from a checkpoint. + + The ordered storage unit list of the current system must exactly match the + checkpoint (same count, same positions). + + .. note:: + **Multi-node limitation**: checkpoint_dir must reside on a shared network + filesystem (e.g. NFS, GPFS, Lustre) accessible from all nodes running + storage units. Single-node deployments have no such requirement. + + Args: + checkpoint_dir: Path to the checkpoint directory. + + Raises: + FileNotFoundError: Checkpoint directory or required files do not exist. + ValueError: Storage unit list does not match or checkpoint is corrupted. + RuntimeError: TransferQueue is not initialized, or restore fails. + """ + if _TQ_CONTROLLER is None: + raise RuntimeError("TransferQueue is not initialized. Call tq.init() first.") + + checkpoint_dir = Path(checkpoint_dir) + if not checkpoint_dir.exists(): + raise FileNotFoundError(f"Checkpoint directory not found: {checkpoint_dir}") + + metadata_path = checkpoint_dir / _METADATA_FILE + if not metadata_path.exists(): + raise FileNotFoundError(f"{_METADATA_FILE} not found in {checkpoint_dir}") + + with open(metadata_path) as f: + meta = json.load(f) + + saved_su_list = meta.get("storage_units", []) + + client = _maybe_create_tq_client() + + # Validate storage unit count before touching any state + if saved_su_list: + if not hasattr(client, "storage_manager") or client.storage_manager is None: + raise ValueError("Checkpoint contains storage unit data but current system has no storage manager.") + sm = client.storage_manager + if not sm.storage_checkpoint_required: + raise ValueError( + "Checkpoint contains storage unit data but current storage backend does not support checkpoint." + ) + if hasattr(sm, "get_zmq_server_info") and len(sm.get_zmq_server_info()) != len(saved_su_list): + raise ValueError( + f"Storage unit count mismatch: checkpoint has {len(saved_su_list)}, " + f"current system has {len(sm.get_zmq_server_info())}." + ) + + # Restore controller via ZMQ RPC + controller_path = checkpoint_dir / _CONTROLLER_FILE + if not controller_path.exists(): + raise FileNotFoundError(f"{_CONTROLLER_FILE} not found in {checkpoint_dir}") + + client.restore_controller_checkpoint(str(controller_path)) + + # Restore storage units via StorageManager abstraction + if saved_su_list: + client.restore_storage_checkpoint(str(checkpoint_dir / _STORAGE_UNITS_DIR), saved_su_list) + + logger.info(f"Checkpoint loaded from {checkpoint_dir}") diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index f4d545d..a746459 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -352,6 +352,42 @@ async def clear_data(self, metadata: BatchMeta) -> None: """ raise NotImplementedError("Subclasses must implement clear_data") + @property + def storage_checkpoint_required(self) -> bool: + """Whether storage contents must be checkpointed for correct restore. + + Returns True for in-memory backends (e.g. SimpleStorage) where data + is lost on restart and must be serialized. Returns False for persistent + KV backends (e.g. MooncakeStore, Yuanrong) where data survives restarts + and only controller metadata needs to be saved. + + Subclasses should override this to reflect their actual persistence model. + """ + return False + + async def dump_checkpoint(self, output_dir: str) -> list[dict]: + """Dump all storage units to files in output_dir. + + Returns: + List of dicts, each with keys: position, storage_unit_id. + + Raises: + NotImplementedError: If this storage backend does not support checkpoint. + """ + raise NotImplementedError(f"{self.__class__.__name__} does not support checkpoint") + + async def restore_checkpoint(self, checkpoint_dir: str, su_info_list: list[dict]) -> None: + """Restore all storage units from files in checkpoint_dir. + + Args: + checkpoint_dir: Path to the checkpoint directory containing storage unit files. + su_info_list: Ordered list of storage unit info dicts from metadata.json. + + Raises: + NotImplementedError: If this storage backend does not support checkpoint. + """ + raise NotImplementedError(f"{self.__class__.__name__} does not support checkpoint") + def close(self) -> None: """Close all ZMQ sockets/contexts and stop the notify loop.""" diff --git a/transfer_queue/storage/managers/simple_storage_manager.py b/transfer_queue/storage/managers/simple_storage_manager.py index 8080352..874da7e 100644 --- a/transfer_queue/storage/managers/simple_storage_manager.py +++ b/transfer_queue/storage/managers/simple_storage_manager.py @@ -19,6 +19,7 @@ from collections import defaultdict from collections.abc import Mapping from operator import itemgetter +from pathlib import Path from typing import Any, Callable, NamedTuple import torch @@ -85,6 +86,11 @@ def __init__(self, controller_info: ZMQServerInfo, config: DictConfig): self.storage_unit_infos = self._register_servers(server_infos) + @property + def storage_checkpoint_required(self) -> bool: + """Whether storage contents must be checkpointed for correct restore.""" + return True + def _register_servers(self, server_infos: "ZMQServerInfo | dict[Any, ZMQServerInfo]"): """Register and validate server information. @@ -524,6 +530,98 @@ def get_zmq_server_info(self) -> dict[str, ZMQServerInfo]: """ return self.storage_unit_infos + @with_storage_unit_socket + async def _dump_single_storage_unit( + self, + path: str, + target_storage_unit: str, + socket: zmq.Socket = None, + ): + try: + request_msg = ZMQMessage.create( + request_type=ZMQRequestType.CHECKPOINT_DUMP, # type: ignore[arg-type] + sender_id=self.storage_manager_id, + receiver_id=target_storage_unit, + body={"path": path}, + ) + await socket.send_multipart(request_msg.serialize(), copy=False) + messages = await socket.recv_multipart(copy=False) + response_msg = ZMQMessage.deserialize(messages) + if response_msg.request_type != ZMQRequestType.CHECKPOINT_DUMP_RESPONSE: + raise RuntimeError( + f"Unexpected response from storage unit {target_storage_unit}: {response_msg.request_type}" + ) + except Exception as e: + logger.error(f"[{self.storage_manager_id}]: Error dumping storage unit {target_storage_unit}: {str(e)}") + raise + + @with_storage_unit_socket + async def _restore_single_storage_unit( + self, + path: str, + target_storage_unit: str, + socket: zmq.Socket = None, + ): + try: + request_msg = ZMQMessage.create( + request_type=ZMQRequestType.CHECKPOINT_RESTORE, # type: ignore[arg-type] + sender_id=self.storage_manager_id, + receiver_id=target_storage_unit, + body={"path": path}, + ) + await socket.send_multipart(request_msg.serialize(), copy=False) + messages = await socket.recv_multipart(copy=False) + response_msg = ZMQMessage.deserialize(messages) + if response_msg.request_type != ZMQRequestType.CHECKPOINT_RESTORE_RESPONSE: + raise RuntimeError( + f"Unexpected response from storage unit {target_storage_unit}: {response_msg.request_type}" + ) + except Exception as e: + logger.error( + f"[{self.storage_manager_id}]: Error restoring for storage unit {target_storage_unit}: {str(e)}" + ) + raise + + async def dump_checkpoint(self, output_dir: str) -> list[dict]: + """Dump all storage units to files in output_dir in parallel. + + Returns: + Ordered list of dicts with keys: position, storage_unit_id. + """ + su_ids = list(self.storage_unit_infos.keys()) + paths = [str(Path(output_dir) / f"su_{pos}_{su_id}.pkl") for pos, su_id in enumerate(su_ids)] + tasks = [ + self._dump_single_storage_unit(path, target_storage_unit=su_id) + for su_id, path in zip(su_ids, paths, strict=False) + ] + + await asyncio.gather(*tasks) + + return [ + { + "position": pos, + "storage_unit_id": su_id, + } + for pos, (su_id, path) in enumerate(zip(su_ids, paths, strict=False)) + ] + + async def restore_checkpoint(self, checkpoint_dir: str, su_info_list: list[dict]) -> None: + """Restore all storage units from files in checkpoint_dir. + + Matches by position: entry at position i is restored to the storage unit + currently at position i in storage_unit_infos. + """ + su_ids = list(self.storage_unit_infos.keys()) + entries = sorted(su_info_list, key=lambda e: e["position"]) + tasks = [ + self._restore_single_storage_unit( + str(Path(checkpoint_dir) / f"su_{entry['position']}_{entry['storage_unit_id']}.pkl"), + target_storage_unit=su_ids[entry["position"]], + ) + for entry in entries + ] + await asyncio.gather(*tasks) + def close(self) -> None: """Close all ZMQ sockets and context to prevent resource leaks.""" super().close() diff --git a/transfer_queue/storage/simple_storage.py b/transfer_queue/storage/simple_storage.py index e70648e..1c91e41 100644 --- a/transfer_queue/storage/simple_storage.py +++ b/transfer_queue/storage/simple_storage.py @@ -14,6 +14,7 @@ # limitations under the License. import os +import pickle import time import weakref from threading import Event, Thread @@ -308,6 +309,10 @@ def _worker_routine(self) -> None: response_msg = self._handle_clear(request_msg) elif operation == ZMQRequestType.GET_METRICS: # type: ignore[arg-type] response_msg = self._handle_get_metrics() + elif operation == ZMQRequestType.CHECKPOINT_DUMP: # type: ignore[arg-type] + response_msg = self._handle_checkpoint_dump(request_msg) + elif operation == ZMQRequestType.CHECKPOINT_RESTORE: # type: ignore[arg-type] + response_msg = self._handle_checkpoint_restore(request_msg) else: response_msg = ZMQMessage.create( request_type=ZMQRequestType.PUT_GET_OPERATION_ERROR, # type: ignore[arg-type] @@ -535,11 +540,96 @@ def _handle_get_metrics(self) -> ZMQMessage: metrics["op_stats"] = op_stats return ZMQMessage.create( - request_type=ZMQRequestType.METRICS_RESPONSE, + request_type=ZMQRequestType.METRICS_RESPONSE, # type: ignore[arg-type] sender_id=self.storage_unit_id, body=metrics, ) + def _handle_checkpoint_dump(self, data_parts) -> ZMQMessage: + """ + Serialize storage unit data directly to a file. + + Writes data in-process to avoid transmitting the payload back over the + Ray object store — only a bool ACK is returned to the caller. + + Args: + path: data_parts: ZMQMessage from client, including + absolute path for the output .pkl file. + The caller must ensure this path is reachable from the node + running this actor (shared filesystem required for multi-node setups). + + Returns: + Checkpoint dump success response ZMQMessage. + """ + path = data_parts.body["path"] + try: + state = { + "storage_unit_id": self.storage_unit_id, + "storage_unit_size": self.storage_unit_size, + "field_data": self.storage_data.field_data, + "active_keys": self.storage_data._active_keys, + } + with open(path, "wb") as f: + pickle.dump(state, f, protocol=pickle.HIGHEST_PROTOCOL) + return ZMQMessage.create( + request_type=ZMQRequestType.CHECKPOINT_DUMP_RESPONSE, # type: ignore[arg-type] + sender_id=self.storage_unit_id, + body={"message": f"[{self.storage_unit_id}] dumped to {path} successfully"}, + ) + except Exception as e: + logger.error(f"[{self.storage_unit_id}]: dump_to_file failed: {e}") + return ZMQMessage.create( + request_type=ZMQRequestType.SAVE_LOAD_CKPT_ERROR, # type: ignore[arg-type] + sender_id=self.storage_unit_id, + body={"message": str(e)}, + ) + + def _handle_checkpoint_restore(self, data_parts) -> ZMQMessage: + """ + Restore storage unit data directly from a file. + + Args: + path: data_parts: ZMQMessage from client, including + absolute path for the output .pkl file. + The caller must ensure this path is reachable from the node + running this actor (shared filesystem required for multi-node setups). + + Returns: + Checkpoint restore success response ZMQMessage. + """ + path = data_parts.body["path"] + try: + with open(path, "rb") as f: + data = pickle.load(f) + + if data["storage_unit_size"] != self.storage_unit_size: + logger.warning( + f"[{self.storage_unit_id}]: storage_unit_size mismatch — " + f"checkpoint={data['storage_unit_size']}, current={self.storage_unit_size}" + ) + + self.storage_data.field_data.clear() + self.storage_data._active_keys.clear() + self.storage_data.field_data = data["field_data"] + self.storage_data._active_keys = data["active_keys"] + + return ZMQMessage.create( + request_type=ZMQRequestType.CHECKPOINT_RESTORE_RESPONSE, # type: ignore[arg-type] + sender_id=self.storage_unit_id, + body={ + "message": f"[{self.storage_unit_id}] restored from {path} — " + f"{len(data['active_keys'])} keys, {len(data['field_data'])} fields" + }, + ) + + except Exception as e: + logger.error(f"[{self.storage_unit_id}]: restore_from_file failed: {e}") + return ZMQMessage.create( + request_type=ZMQRequestType.SAVE_LOAD_CKPT_ERROR, # type: ignore[arg-type] + sender_id=self.storage_unit_id, + body={"message": str(e)}, + ) + @staticmethod def _cumulative_bucket_counts(hist) -> list[float]: """Build cumulative counts from a prometheus_client Histogram's non-cumulative buckets.""" diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py index 4fe32f0..a58617b 100644 --- a/transfer_queue/utils/zmq_utils.py +++ b/transfer_queue/utils/zmq_utils.py @@ -101,6 +101,13 @@ class ZMQRequestType(ExplicitEnum): GET_METRICS = "GET_METRICS" METRICS_RESPONSE = "METRICS_RESPONSE" + # CHECKPOINT + CHECKPOINT_DUMP = "CHECKPOINT_DUMP" + CHECKPOINT_DUMP_RESPONSE = "CHECKPOINT_DUMP_RESPONSE" + CHECKPOINT_RESTORE = "CHECKPOINT_RESTORE" + CHECKPOINT_RESTORE_RESPONSE = "CHECKPOINT_RESTORE_RESPONSE" + SAVE_LOAD_CKPT_ERROR = "SAVE_LOAD_CKPT_ERROR" + class ZMQServerInfo: """