From 5631b56ee071cc1064a2e1bd2773d1760dbbf988 Mon Sep 17 00:00:00 2001 From: yxstev Date: Fri, 12 Jun 2026 15:24:34 +0800 Subject: [PATCH 1/6] provide save/load checkpoint interfaces Signed-off-by: yxstev --- tests/e2e/test_checkpoint_e2e.py | 403 +++++++++++++++++++++++ transfer_queue/__init__.py | 7 + transfer_queue/controller.py | 68 ++++ transfer_queue/interface.py | 190 +++++++++++ transfer_queue/storage/simple_storage.py | 63 ++++ 5 files changed, 731 insertions(+) create mode 100644 tests/e2e/test_checkpoint_e2e.py diff --git a/tests/e2e/test_checkpoint_e2e.py b/tests/e2e/test_checkpoint_e2e.py new file mode 100644 index 00000000..fdb6112b --- /dev/null +++ b/tests/e2e/test_checkpoint_e2e.py @@ -0,0 +1,403 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""End-to-end tests for save_checkpoint and load_checkpoint. + +Run with: + pytest tests/e2e/test_checkpoint_e2e.py -v +""" + +import json +import os +from pathlib import Path + +import pytest +import ray +import torch +from omegaconf import OmegaConf +from tensordict import TensorDict + +import transfer_queue as tq + +os.environ["RAY_DEDUP_LOGS"] = "0" + +_TQ_CONFIG = OmegaConf.create( + { + "controller": {"polling_mode": True}, + "backend": { + "storage_backend": "SimpleStorage", + "SimpleStorage": { + "total_storage_size": 200, + "num_data_storage_units": 2, + }, + }, + } +) + + +@pytest.fixture(scope="module") +def ray_init(): + if not ray.is_initialized(): + ray.init(namespace="TestCheckpointE2E") + yield + if ray.is_initialized(): + ray.shutdown() + + +@pytest.fixture(scope="module") +def tq_system(ray_init): + tq.init(_TQ_CONFIG) + yield + tq.close() + + +@pytest.fixture +def controller(tq_system): + return ray.get_actor("TransferQueueController", namespace="transfer_queue") + + +@pytest.fixture(autouse=True) +def cleanup_partitions(controller): + yield + try: + for pid in ray.get(controller.list_partitions.remote()): + ray.get(controller.clear_partition.remote(pid)) + except Exception: + pass + + +@pytest.fixture +def checkpoint_dir(tmp_path): + return tmp_path / "checkpoint" + + +# --------------------------------------------------------------------------- +# helpers +# --------------------------------------------------------------------------- + + +def _put_batch(keys, partition_id, input_ids, attention_mask, tags=None): + fields = TensorDict( + {"input_ids": input_ids, "attention_mask": attention_mask}, + batch_size=len(keys), + ) + if tags is None: + tags = [{} for _ in keys] + tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=fields, tags=tags) + + +def _get_batch(keys, partition_id): + return tq.kv_batch_get(keys=keys, partition_id=partition_id) + + +def assert_tensor_equal(tensor_a, tensor_b, msg=""): + """Assert two tensors are equal, handling nested vs dense comparisons.""" + if (isinstance(tensor_a, torch.Tensor) and tensor_a.is_nested) or ( + isinstance(tensor_b, torch.Tensor) and tensor_b.is_nested + ): + seq_a = list(tensor_a) + seq_b = list(tensor_b) + assert len(seq_a) == len(seq_b), f"{msg} Length mismatch: {len(seq_a)} vs {len(seq_b)}" + for t1, t2 in zip(seq_a, seq_b, strict=True): + assert torch.equal(t1, t2), f"{msg} Tensors are not equal: {tensor_a} vs {tensor_b}" + else: + assert torch.equal(tensor_a, tensor_b), f"{msg} Tensors are not equal: {tensor_a} vs {tensor_b}" + + +# --------------------------------------------------------------------------- +# basic save / load roundtrip +# --------------------------------------------------------------------------- + + +class TestCheckpointRoundtrip: + def test_save_creates_expected_files(self, tq_system, checkpoint_dir): + keys = ["k0", "k1"] + partition_id = "p0" + _put_batch(keys, partition_id, torch.tensor([[1, 2], [3, 4]]), torch.ones(2, 2)) + + info = tq.save_checkpoint(checkpoint_dir) + + assert Path(info["checkpoint_dir"]) == checkpoint_dir + assert (checkpoint_dir / "metadata.json").exists() + assert (checkpoint_dir / "controller_state.pkl").exists() + assert info["controller_state_size"] > 0 + assert info["total_size"] > 0 + + # two storage units configured + assert len(info["storage_units"]) == 2 + su_dir = checkpoint_dir / "storage_units" + for entry in info["storage_units"]: + assert (su_dir / f"su_{entry['position']}_{entry['storage_unit_id']}.pkl").exists() + + def test_metadata_json_content(self, tq_system, checkpoint_dir): + keys = ["m0"] + _put_batch(keys, "p_meta", torch.tensor([[10, 20]]), torch.ones(1, 2)) + + tq.save_checkpoint(checkpoint_dir, metadata={"iteration": 42, "loss": 0.5}) + + with open(checkpoint_dir / "metadata.json") as f: + meta = json.load(f) + + assert meta["user_metadata"]["iteration"] == 42 + assert meta["user_metadata"]["loss"] == pytest.approx(0.5) + assert "timestamp" in meta + assert "storage_units" in meta + + def test_load_restores_controller_partitions(self, tq_system, checkpoint_dir, controller): + keys = ["a0", "a1", "a2"] + partition_id = "p_ctrl" + input_ids = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + tags = [{"idx": i} for i in range(3)] + _put_batch(keys, partition_id, input_ids, torch.ones(3, 3), tags) + + tq.save_checkpoint(checkpoint_dir) + + # wipe controller state + ray.get(controller.clear_partition.remote(partition_id)) + assert ray.get(controller.list_partitions.remote()) == [] + + ok = tq.load_checkpoint(checkpoint_dir) + assert ok is True + + # partition must be back + partitions = ray.get(controller.list_partitions.remote()) + assert partition_id in partitions + + # key-to-global-index mapping must be intact + snapshot = ray.get(controller.get_partition_snapshot.remote(partition_id)) + for key in keys: + assert key in snapshot.keys_mapping + + # tags must be intact + for i, key in enumerate(keys): + gidx = snapshot.keys_mapping[key] + assert snapshot.custom_meta[gidx]["idx"] == i + + def test_load_restores_storage_data(self, tq_system, checkpoint_dir, controller): + keys = ["s0", "s1"] + partition_id = "p_storage" + input_ids = torch.tensor([[10, 20], [30, 40]]) + attention_mask = torch.ones(2, 2) + _put_batch(keys, partition_id, input_ids, attention_mask) + + tq.save_checkpoint(checkpoint_dir) + + # clear both controller and storage state so load has to restore from scratch + ray.get(controller.clear_partition.remote(partition_id)) + + ok = tq.load_checkpoint(checkpoint_dir) + assert ok is True + + retrieved = _get_batch(keys, partition_id) + assert_tensor_equal(retrieved["input_ids"], input_ids) + assert_tensor_equal(retrieved["attention_mask"], attention_mask) + + def test_load_restores_multiple_partitions(self, tq_system, checkpoint_dir, controller): + for i in range(3): + _put_batch( + [f"p{i}_k0", f"p{i}_k1"], + f"part_{i}", + torch.full((2, 4), i, dtype=torch.long), + torch.ones(2, 4), + ) + + tq.save_checkpoint(checkpoint_dir) + + for i in range(3): + ray.get(controller.clear_partition.remote(f"part_{i}")) + + ok = tq.load_checkpoint(checkpoint_dir) + assert ok is True + + for i in range(3): + retrieved = tq.kv_batch_get( + keys=[f"p{i}_k0", f"p{i}_k1"], + partition_id=f"part_{i}", + select_fields=["input_ids"], + ) + assert_tensor_equal(retrieved["input_ids"], torch.full((2, 4), i, dtype=torch.long)) + + +# --------------------------------------------------------------------------- +# include_storage=False +# --------------------------------------------------------------------------- + + +class TestCheckpointMetadataOnly: + def test_save_metadata_only_no_storage_files(self, tq_system, checkpoint_dir): + _put_batch(["n0"], "p_nometa", torch.tensor([[1, 2]]), torch.ones(1, 2)) + + info = tq.save_checkpoint(checkpoint_dir, include_storage=False) + + assert info["storage_units"] == [] + assert not (checkpoint_dir / "storage_units").exists() + + def test_load_after_metadata_only_save(self, tq_system, checkpoint_dir, controller): + keys = ["n0", "n1"] + partition_id = "p_nometa2" + input_ids = torch.tensor([[5, 6], [7, 8]]) + _put_batch(keys, partition_id, input_ids, torch.ones(2, 2)) + + # save without storage + tq.save_checkpoint(checkpoint_dir, include_storage=False) + + ray.get(controller.clear_partition.remote(partition_id)) + + ok = tq.load_checkpoint(checkpoint_dir) + assert ok is True + + # controller state (partition metadata) must be restored + partitions = ray.get(controller.list_partitions.remote()) + assert partition_id in partitions + + snapshot = ray.get(controller.get_partition_snapshot.remote(partition_id)) + for key in keys: + assert key in snapshot.keys_mapping + _get_batch(keys, partition_id) + + +# --------------------------------------------------------------------------- +# error handling +# --------------------------------------------------------------------------- + + +class TestCheckpointErrors: + def test_save_raises_if_not_initialized(self, tmp_path): + # call save_checkpoint before tq.init() in a fresh module state + import transfer_queue.interface as iface + + original = iface._TQ_CONTROLLER + try: + iface._TQ_CONTROLLER = None + with pytest.raises(RuntimeError, match="not initialized"): + tq.save_checkpoint(tmp_path / "ck") + finally: + iface._TQ_CONTROLLER = original + + def test_load_raises_if_not_initialized(self, tmp_path): + import transfer_queue.interface as iface + + original = iface._TQ_CONTROLLER + try: + iface._TQ_CONTROLLER = None + with pytest.raises(RuntimeError, match="not initialized"): + tq.load_checkpoint(tmp_path / "ck") + finally: + iface._TQ_CONTROLLER = original + + def test_load_raises_if_dir_missing(self, tq_system, tmp_path): + with pytest.raises(FileNotFoundError): + tq.load_checkpoint(tmp_path / "nonexistent") + + def test_load_raises_if_metadata_missing(self, tq_system, tmp_path): + ck = tmp_path / "ck" + ck.mkdir() + with pytest.raises(FileNotFoundError, match="metadata.json"): + tq.load_checkpoint(ck) + + def test_load_raises_on_storage_unit_count_mismatch(self, tq_system, tmp_path, checkpoint_dir): + _put_batch(["e0"], "p_err", torch.tensor([[1, 2]]), torch.ones(1, 2)) + tq.save_checkpoint(checkpoint_dir) + + # tamper: add a fake extra entry so count differs + meta_path = checkpoint_dir / "metadata.json" + with open(meta_path) as f: + meta = json.load(f) + meta["storage_units"].append({"position": 99, "storage_unit_id": "fake", "file_size": 0}) + with open(meta_path, "w") as f: + json.dump(meta, f) + + with pytest.raises(ValueError, match="count mismatch"): + tq.load_checkpoint(checkpoint_dir) + + def test_no_partial_state_on_failed_save(self, tq_system, tmp_path): + """A failed save must not leave a partial directory.""" + _put_batch(["f0"], "p_fail", torch.tensor([[1, 2]]), torch.ones(1, 2)) + + ck = tmp_path / "ck" + # force failure by making the parent read-only on a subpath + # We simulate by patching ray.get to raise mid-save + original_ray_get = ray.get + + call_count = [0] + + def failing_ray_get(futures, *args, **kwargs): + call_count[0] += 1 + if call_count[0] == 2: # fail on storage unit dump + raise RuntimeError("simulated dump failure") + return original_ray_get(futures, *args, **kwargs) + + import unittest.mock as mock + + with mock.patch("transfer_queue.interface.ray.get", side_effect=failing_ray_get): + with pytest.raises(RuntimeError, match="simulated dump failure"): + tq.save_checkpoint(ck) + + assert not ck.exists(), "Partial checkpoint directory should have been cleaned up" + assert not (tmp_path / "ck.tmp").exists(), "Temp directory should have been cleaned up" + + +# --------------------------------------------------------------------------- +# data variety +# --------------------------------------------------------------------------- + + +class TestCheckpointDataVariety: + def test_non_tensor_fields_roundtrip(self, tq_system, checkpoint_dir, controller): + """String fields should survive save/load.""" + from tensordict import NonTensorStack + + keys = ["t0", "t1"] + partition_id = "p_str" + fields = TensorDict( + { + "input_ids": torch.tensor([[1, 2], [3, 4]]), + "text": NonTensorStack("hello", "world"), + }, + batch_size=2, + ) + tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=fields, tags=[{}, {}]) + + tq.save_checkpoint(checkpoint_dir) + + ray.get(controller.clear_partition.remote(partition_id)) + + tq.load_checkpoint(checkpoint_dir) + + retrieved = tq.kv_batch_get(keys=keys, partition_id=partition_id, select_fields=["input_ids"]) + assert_tensor_equal(retrieved["input_ids"], torch.tensor([[1, 2], [3, 4]])) + + def test_nested_tensor_fields_roundtrip(self, tq_system, checkpoint_dir, controller): + """Variable-length (jagged) tensor fields should survive save/load.""" + keys = ["j0", "j1", "j2"] + partition_id = "p_jagged" + for i, key in enumerate(keys): + seq = torch.arange(i + 1, dtype=torch.float).unsqueeze(0) + tq.kv_put( + key=key, + partition_id=partition_id, + fields=TensorDict({"seq": seq}, batch_size=1), + tag=None, + ) + + tq.save_checkpoint(checkpoint_dir) + + ray.get(controller.clear_partition.remote(partition_id)) + + tq.load_checkpoint(checkpoint_dir) + + retrieved = tq.kv_batch_get(keys=keys, partition_id=partition_id, select_fields=["seq"]) + for i, component in enumerate(retrieved["seq"].unbind()): + assert_tensor_equal(component, torch.arange(i + 1, dtype=torch.float)) diff --git a/transfer_queue/__init__.py b/transfer_queue/__init__.py index 06754278..ad8e8ab9 100644 --- a/transfer_queue/__init__.py +++ b/transfer_queue/__init__.py @@ -34,6 +34,8 @@ kv_clear, kv_list, kv_put, + load_checkpoint, + save_checkpoint, ) from .metadata import BatchMeta, KVBatchMeta from .sampler import BaseSampler @@ -62,6 +64,11 @@ "async_kv_clear", "KVBatchMeta", ] + + [ + # Checkpoint Interface + "save_checkpoint", + "load_checkpoint", + ] + [ # High-Level StreamingDataLoader Interface "StreamingDataset", diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 1182a44c..e2753643 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -15,6 +15,7 @@ import copy import os +import pickle import time from collections import defaultdict from dataclasses import dataclass, field @@ -2045,6 +2046,73 @@ def get_config(self) -> DictConfig: """Retrieve the global config of TransferQueue.""" return self.tq_config + def dump_to_file(self, path: str) -> bool: + """Serialize controller state directly to a file. + + Writes in-process to avoid transmitting the payload back over the + Ray object store — only a bool ACK is returned to the caller. + + Args: + path: Absolute path for the output .pkl file. + + Returns: + True on success, False on failure. + """ + try: + state = { + "controller_id": self.controller_id, + "partitions": {pid: p.to_snapshot() for pid, p in self.partitions.items()}, + "index_manager": { + "partition_to_indexes": dict(copy.deepcopy(self.index_manager.partition_to_indexes)), + "reusable_indexes": list(self.index_manager.reusable_indexes), + "global_index_counter": self.index_manager.global_index_counter, + "allocated_indexes": set(self.index_manager.allocated_indexes), + }, + "sampler": self.sampler.get_state() if hasattr(self.sampler, "get_state") else None, + "tq_config": self.tq_config, + "connected_storage_managers": set(self._connected_storage_managers), + } + with open(path, "wb") as f: + pickle.dump(state, f, protocol=pickle.HIGHEST_PROTOCOL) + logger.info(f"[{self.controller_id}]: dumped to {path}") + return True + except Exception as e: + logger.error(f"[{self.controller_id}]: dump_to_file failed: {e}") + return False + + def restore_from_file(self, path: str) -> bool: + """Restore controller state directly from a file. + + Args: + path: Absolute path to a .pkl file previously written by dump_to_file(). + + Returns: + True on success, False on failure. + """ + try: + with open(path, "rb") as f: + state = pickle.load(f) + + self.controller_id = state["controller_id"] + self.partitions = state["partitions"] + + im = state["index_manager"] + self.index_manager.partition_to_indexes = defaultdict(set, im["partition_to_indexes"]) + self.index_manager.reusable_indexes = im["reusable_indexes"] + self.index_manager.global_index_counter = im["global_index_counter"] + self.index_manager.allocated_indexes = im["allocated_indexes"] + + if state["sampler"] is not None and hasattr(self.sampler, "restore_state"): + self.sampler.restore_state(state["sampler"]) + self.tq_config = state["tq_config"] + self._connected_storage_managers = state["connected_storage_managers"] + + logger.info(f"[{self.controller_id}]: restored from {path}") + return True + except Exception as e: + logger.error(f"[{self.controller_id}]: restore_from_file failed: {e}") + return False + def register_sampler( self, sampler: BaseSampler | type[BaseSampler] = SequentialSampler, diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 82ceacaf..996624c6 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -13,10 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os +import shutil import subprocess import time from importlib import resources +from pathlib import Path from typing import Any, Callable import ray @@ -1047,3 +1050,190 @@ def get_client(): """Get a TransferQueueClient for using low-level API""" assert _TQ_CLIENT is not None, "Please initialize the TransferQueue first by calling `tq.init()`!" return _TQ_CLIENT + + +# ==================== Checkpoint API ==================== + +_METADATA_FILE = "metadata.json" +_CONTROLLER_FILE = "controller_state.pkl" +_STORAGE_UNITS_DIR = "storage_units" + + +def save_checkpoint( + checkpoint_dir: str | Path, + *, + include_storage: bool = True, + metadata: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Save a full checkpoint of the TransferQueue system state. + + Args: + checkpoint_dir: Directory to save the checkpoint (created if not exists). + include_storage: Whether to include storage unit data. + If False, only controller metadata is saved. + metadata: User-defined key-value pairs written into metadata.json. + + Returns: + A dict containing checkpoint_dir, timestamp, version, + controller_state_size, storage_units, and total_size. + + Raises: + RuntimeError: TransferQueue is not initialized. + OSError: Failed to write checkpoint files. + """ + if _TQ_CONTROLLER is None: + raise RuntimeError("TransferQueue is not initialized. Call tq.init() first.") + + checkpoint_dir = Path(checkpoint_dir) + tmp_dir = checkpoint_dir.parent / (checkpoint_dir.name + ".tmp") + + if tmp_dir.exists(): + shutil.rmtree(tmp_dir) + tmp_dir.mkdir(parents=True) + + try: + # Step 1: controller dumps itself to file + controller_path = tmp_dir / _CONTROLLER_FILE + success = ray.get(_TQ_CONTROLLER.dump_to_file.remote(str(controller_path))) + if not success: + raise RuntimeError("Controller failed to dump state to file") + controller_size = controller_path.stat().st_size + logger.info(f"Controller state saved ({controller_size} bytes)") + + # Step 2: storage units dump themselves to files in parallel + su_info_list: list[dict[str, Any]] = [] + if include_storage and _TQ_STORAGE and "SimpleStorage" in _TQ_STORAGE: + su_dir = tmp_dir / _STORAGE_UNITS_DIR + su_dir.mkdir() + + su_handles: dict[str, Any] = _TQ_STORAGE["SimpleStorage"] + + futures = { + su_id: ( + pos, + su_dir / f"su_{pos}_{su_id}.pkl", + su_handles[su_id].dump_to_file.remote(str(su_dir / f"su_{pos}_{su_id}.pkl")), + ) + for pos, su_id in enumerate(su_handles) + } + + all_success = ray.get([f for _, _, f in futures.values()]) + for (su_id, (pos, path, _)), success in zip(futures.items(), all_success, strict=False): + if not success: + raise RuntimeError(f"Storage unit {su_id} failed to dump to {path}") + su_info_list.append({"position": pos, "storage_unit_id": su_id, "file_size": path.stat().st_size}) + logger.info(f"Storage unit {su_id} (pos={pos}) saved ({su_info_list[-1]['file_size']} bytes)") + + # Step 3: write metadata.json + timestamp = time.time() + total_size = controller_size + sum(s["file_size"] for s in su_info_list) + meta_content = { + "timestamp": timestamp, + "controller_state_size": controller_size, + "storage_units": su_info_list, + "total_size": total_size, + "user_metadata": metadata or {}, + } + with open(tmp_dir / _METADATA_FILE, "w") as f: + json.dump(meta_content, f, indent=2) + + # Step 4: atomic rename into final location + if checkpoint_dir.exists(): + shutil.rmtree(checkpoint_dir) + tmp_dir.rename(checkpoint_dir) + + result = { + "checkpoint_dir": str(checkpoint_dir), + "timestamp": timestamp, + "controller_state_size": controller_size, + "storage_units": su_info_list, + "total_size": total_size, + } + logger.info(f"Checkpoint saved to {checkpoint_dir} (total {total_size} bytes)") + return result + + except Exception: + if tmp_dir.exists(): + shutil.rmtree(tmp_dir) + raise + + +def load_checkpoint( + checkpoint_dir: str | Path, +) -> bool: + """Restore TransferQueue system state from a checkpoint. + + The ordered storage unit list of the current system must exactly match the + checkpoint (same count, same IDs, same positions). This is required because + data routing is position-based (global_idx % num_units). + + Args: + checkpoint_dir: Path to the checkpoint directory. + + Returns: + True on success, False on failure. + + Raises: + FileNotFoundError: Checkpoint directory or required files do not exist. + ValueError: Storage unit list does not match or checkpoint is corrupted. + RuntimeError: TransferQueue is not initialized. + """ + if _TQ_CONTROLLER is None: + raise RuntimeError("TransferQueue is not initialized. Call tq.init() first.") + + checkpoint_dir = Path(checkpoint_dir) + if not checkpoint_dir.exists(): + raise FileNotFoundError(f"Checkpoint directory not found: {checkpoint_dir}") + + metadata_path = checkpoint_dir / _METADATA_FILE + if not metadata_path.exists(): + raise FileNotFoundError(f"{_METADATA_FILE} not found in {checkpoint_dir}") + + with open(metadata_path) as f: + meta = json.load(f) + + # Validate storage unit count before touching any state + saved_su_list = meta.get("storage_units", []) + if saved_su_list: + if not (_TQ_STORAGE and "SimpleStorage" in _TQ_STORAGE): + raise ValueError("Checkpoint contains storage unit data but current system has no SimpleStorage backend.") + + current_su_handles = list(_TQ_STORAGE["SimpleStorage"].values()) + if len(current_su_handles) != len(saved_su_list): + raise ValueError( + f"Storage unit count mismatch: checkpoint has {len(saved_su_list)}, " + f"current system has {len(current_su_handles)}." + ) + + # Restore controller + controller_path = checkpoint_dir / _CONTROLLER_FILE + if not controller_path.exists(): + raise FileNotFoundError(f"{_CONTROLLER_FILE} not found in {checkpoint_dir}") + + success = ray.get(_TQ_CONTROLLER.restore_from_file.remote(str(controller_path))) + if not success: + logger.error("Controller restore_from_file returned False") + return False + + # Restore storage units in parallel, matched by position + if saved_su_list: + current_su_handles = list(_TQ_STORAGE["SimpleStorage"].values()) + su_dir = checkpoint_dir / _STORAGE_UNITS_DIR + + entries_by_pos = sorted(saved_su_list, key=lambda e: e["position"]) + futures = [] + for entry in entries_by_pos: + pos = entry["position"] + path = su_dir / f"su_{pos}_{entry['storage_unit_id']}.pkl" + if not path.exists(): + raise FileNotFoundError(f"Storage unit file not found: {path}") + futures.append(current_su_handles[pos].restore_from_file.remote(str(path))) + + results = ray.get(futures) + if not all(results): + failed_positions = [i for i, r in enumerate(results) if not r] + logger.error(f"Storage units at positions {failed_positions} failed to restore") + return False + + logger.info(f"Checkpoint loaded from {checkpoint_dir}") + return True diff --git a/transfer_queue/storage/simple_storage.py b/transfer_queue/storage/simple_storage.py index e70648ea..d615c7ba 100644 --- a/transfer_queue/storage/simple_storage.py +++ b/transfer_queue/storage/simple_storage.py @@ -14,6 +14,7 @@ # limitations under the License. import os +import pickle import time import weakref from threading import Event, Thread @@ -630,3 +631,65 @@ def get_zmq_server_info(self) -> ZMQServerInfo: ZMQServerInfo containing connection details for this storage unit. """ return self.zmq_server_info + + def dump_to_file(self, path: str) -> bool: + """Serialize storage unit data directly to a file. + + Writes data in-process to avoid transmitting the payload back over the + Ray object store — only a bool ACK is returned to the caller. + + Args: + path: Absolute path for the output .pkl file. The caller must ensure + this path is reachable from the node running this actor + (shared filesystem required for multi-node setups). + + Returns: + True on success, False on failure. + """ + try: + state = { + "storage_unit_id": self.storage_unit_id, + "storage_unit_size": self.storage_unit_size, + "field_data": self.storage_data.field_data, + "active_keys": self.storage_data._active_keys, + } + with open(path, "wb") as f: + pickle.dump(state, f, protocol=pickle.HIGHEST_PROTOCOL) + logger.info(f"[{self.storage_unit_id}]: dumped to {path}") + return True + except Exception as e: + logger.error(f"[{self.storage_unit_id}]: dump_to_file failed: {e}") + return False + + def restore_from_file(self, path: str) -> bool: + """Restore storage unit data directly from a file. + + Args: + path: Absolute path to a .pkl file previously written by dump_to_file(). + + Returns: + True on success, False on failure. + """ + try: + with open(path, "rb") as f: + data = pickle.load(f) + + if data["storage_unit_size"] != self.storage_unit_size: + logger.warning( + f"[{self.storage_unit_id}]: storage_unit_size mismatch — " + f"checkpoint={data['storage_unit_size']}, current={self.storage_unit_size}" + ) + + self.storage_data.field_data.clear() + self.storage_data._active_keys.clear() + self.storage_data.field_data = data["field_data"] + self.storage_data._active_keys = data["active_keys"] + + logger.info( + f"[{self.storage_unit_id}]: restored from {path} — " + f"{len(data['active_keys'])} keys, {len(data['field_data'])} fields" + ) + return True + except Exception as e: + logger.error(f"[{self.storage_unit_id}]: restore_from_file failed: {e}") + return False From b479a68409c0934b3db962793467d21bcce4d3ad Mon Sep 17 00:00:00 2001 From: yxstev Date: Mon, 15 Jun 2026 15:36:20 +0800 Subject: [PATCH 2/6] resolve comments Signed-off-by: yxstev --- tests/e2e/test_checkpoint_e2e.py | 27 +++++++++---------- transfer_queue/interface.py | 45 ++++++++------------------------ 2 files changed, 23 insertions(+), 49 deletions(-) diff --git a/tests/e2e/test_checkpoint_e2e.py b/tests/e2e/test_checkpoint_e2e.py index fdb6112b..a7f94dfe 100644 --- a/tests/e2e/test_checkpoint_e2e.py +++ b/tests/e2e/test_checkpoint_e2e.py @@ -21,7 +21,6 @@ import json import os -from pathlib import Path import pytest import ray @@ -127,13 +126,13 @@ def test_save_creates_expected_files(self, tq_system, checkpoint_dir): partition_id = "p0" _put_batch(keys, partition_id, torch.tensor([[1, 2], [3, 4]]), torch.ones(2, 2)) - info = tq.save_checkpoint(checkpoint_dir) + tq.save_checkpoint(checkpoint_dir) - assert Path(info["checkpoint_dir"]) == checkpoint_dir assert (checkpoint_dir / "metadata.json").exists() assert (checkpoint_dir / "controller_state.pkl").exists() - assert info["controller_state_size"] > 0 - assert info["total_size"] > 0 + + with open(checkpoint_dir / "metadata.json") as f: + info = json.load(f) # two storage units configured assert len(info["storage_units"]) == 2 @@ -152,7 +151,6 @@ def test_metadata_json_content(self, tq_system, checkpoint_dir): assert meta["user_metadata"]["iteration"] == 42 assert meta["user_metadata"]["loss"] == pytest.approx(0.5) - assert "timestamp" in meta assert "storage_units" in meta def test_load_restores_controller_partitions(self, tq_system, checkpoint_dir, controller): @@ -168,8 +166,7 @@ def test_load_restores_controller_partitions(self, tq_system, checkpoint_dir, co ray.get(controller.clear_partition.remote(partition_id)) assert ray.get(controller.list_partitions.remote()) == [] - ok = tq.load_checkpoint(checkpoint_dir) - assert ok is True + tq.load_checkpoint(checkpoint_dir) # partition must be back partitions = ray.get(controller.list_partitions.remote()) @@ -197,8 +194,7 @@ def test_load_restores_storage_data(self, tq_system, checkpoint_dir, controller) # clear both controller and storage state so load has to restore from scratch ray.get(controller.clear_partition.remote(partition_id)) - ok = tq.load_checkpoint(checkpoint_dir) - assert ok is True + tq.load_checkpoint(checkpoint_dir) retrieved = _get_batch(keys, partition_id) assert_tensor_equal(retrieved["input_ids"], input_ids) @@ -218,8 +214,7 @@ def test_load_restores_multiple_partitions(self, tq_system, checkpoint_dir, cont for i in range(3): ray.get(controller.clear_partition.remote(f"part_{i}")) - ok = tq.load_checkpoint(checkpoint_dir) - assert ok is True + tq.load_checkpoint(checkpoint_dir) for i in range(3): retrieved = tq.kv_batch_get( @@ -239,7 +234,10 @@ class TestCheckpointMetadataOnly: def test_save_metadata_only_no_storage_files(self, tq_system, checkpoint_dir): _put_batch(["n0"], "p_nometa", torch.tensor([[1, 2]]), torch.ones(1, 2)) - info = tq.save_checkpoint(checkpoint_dir, include_storage=False) + tq.save_checkpoint(checkpoint_dir, include_storage=False) + + with open(checkpoint_dir / "metadata.json") as f: + info = json.load(f) assert info["storage_units"] == [] assert not (checkpoint_dir / "storage_units").exists() @@ -255,8 +253,7 @@ def test_load_after_metadata_only_save(self, tq_system, checkpoint_dir, controll ray.get(controller.clear_partition.remote(partition_id)) - ok = tq.load_checkpoint(checkpoint_dir) - assert ok is True + tq.load_checkpoint(checkpoint_dir) # controller state (partition metadata) must be restored partitions = ray.get(controller.list_partitions.remote()) diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 996624c6..7f19a7ae 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -1064,18 +1064,15 @@ def save_checkpoint( *, include_storage: bool = True, metadata: dict[str, Any] | None = None, -) -> dict[str, Any]: +) -> None: """Save a full checkpoint of the TransferQueue system state. Args: checkpoint_dir: Directory to save the checkpoint (created if not exists). include_storage: Whether to include storage unit data. - If False, only controller metadata is saved. + If False, only controller state is saved. metadata: User-defined key-value pairs written into metadata.json. - - Returns: - A dict containing checkpoint_dir, timestamp, version, - controller_state_size, storage_units, and total_size. + Example: {"timestamp": 1718123456.789012, "step": 1000} Raises: RuntimeError: TransferQueue is not initialized. @@ -1125,13 +1122,8 @@ def save_checkpoint( logger.info(f"Storage unit {su_id} (pos={pos}) saved ({su_info_list[-1]['file_size']} bytes)") # Step 3: write metadata.json - timestamp = time.time() - total_size = controller_size + sum(s["file_size"] for s in su_info_list) meta_content = { - "timestamp": timestamp, - "controller_state_size": controller_size, "storage_units": su_info_list, - "total_size": total_size, "user_metadata": metadata or {}, } with open(tmp_dir / _METADATA_FILE, "w") as f: @@ -1142,15 +1134,7 @@ def save_checkpoint( shutil.rmtree(checkpoint_dir) tmp_dir.rename(checkpoint_dir) - result = { - "checkpoint_dir": str(checkpoint_dir), - "timestamp": timestamp, - "controller_state_size": controller_size, - "storage_units": su_info_list, - "total_size": total_size, - } - logger.info(f"Checkpoint saved to {checkpoint_dir} (total {total_size} bytes)") - return result + logger.info(f"Checkpoint saved to {checkpoint_dir}") except Exception: if tmp_dir.exists(): @@ -1160,23 +1144,20 @@ def save_checkpoint( def load_checkpoint( checkpoint_dir: str | Path, -) -> bool: +) -> None: """Restore TransferQueue system state from a checkpoint. The ordered storage unit list of the current system must exactly match the - checkpoint (same count, same IDs, same positions). This is required because - data routing is position-based (global_idx % num_units). + checkpoint (same count, same positions). This is required because data + routing is position-based (global_idx % num_units). Args: checkpoint_dir: Path to the checkpoint directory. - Returns: - True on success, False on failure. - Raises: FileNotFoundError: Checkpoint directory or required files do not exist. ValueError: Storage unit list does not match or checkpoint is corrupted. - RuntimeError: TransferQueue is not initialized. + RuntimeError: TransferQueue is not initialized, or restore fails. """ if _TQ_CONTROLLER is None: raise RuntimeError("TransferQueue is not initialized. Call tq.init() first.") @@ -1210,10 +1191,8 @@ def load_checkpoint( if not controller_path.exists(): raise FileNotFoundError(f"{_CONTROLLER_FILE} not found in {checkpoint_dir}") - success = ray.get(_TQ_CONTROLLER.restore_from_file.remote(str(controller_path))) - if not success: - logger.error("Controller restore_from_file returned False") - return False + if not ray.get(_TQ_CONTROLLER.restore_from_file.remote(str(controller_path))): + raise RuntimeError("Controller failed to restore from checkpoint.") # Restore storage units in parallel, matched by position if saved_su_list: @@ -1232,8 +1211,6 @@ def load_checkpoint( results = ray.get(futures) if not all(results): failed_positions = [i for i, r in enumerate(results) if not r] - logger.error(f"Storage units at positions {failed_positions} failed to restore") - return False + raise RuntimeError(f"Storage units at positions {failed_positions} failed to restore.") logger.info(f"Checkpoint loaded from {checkpoint_dir}") - return True From 62dd616f4695275f467376a361da9d0cd9147a55 Mon Sep 17 00:00:00 2001 From: yxstev Date: Mon, 15 Jun 2026 17:00:59 +0800 Subject: [PATCH 3/6] refactor checkpoint to use ZMQ RPC instead of ray.remote Controller and storage unit checkpoint operations are now triggered via ZMQ messages (CHECKPOINT_DUMP/RESTORE) rather than direct ray.get() calls. StorageManager base class exposes dump_checkpoint/restore_checkpoint, with AsyncSimpleStorageManager implementing them via the storage unit ZMQ channel. interface.py no longer accesses _TQ_STORAGE or _TQ_CONTROLLER ray handles for checkpoint; all calls go through TransferQueueClient. Signed-off-by: yxstev --- tests/e2e/test_checkpoint_e2e.py | 16 +- transfer_queue/client.py | 90 ++++++++++ transfer_queue/controller.py | 20 +++ transfer_queue/interface.py | 87 ++++------ transfer_queue/storage/managers/base.py | 23 +++ .../managers/simple_storage_manager.py | 105 ++++++++++++ transfer_queue/storage/simple_storage.py | 155 +++++++++++------- transfer_queue/utils/zmq_utils.py | 6 + 8 files changed, 376 insertions(+), 126 deletions(-) diff --git a/tests/e2e/test_checkpoint_e2e.py b/tests/e2e/test_checkpoint_e2e.py index a7f94dfe..ac2f1729 100644 --- a/tests/e2e/test_checkpoint_e2e.py +++ b/tests/e2e/test_checkpoint_e2e.py @@ -324,21 +324,13 @@ def test_no_partial_state_on_failed_save(self, tq_system, tmp_path): _put_batch(["f0"], "p_fail", torch.tensor([[1, 2]]), torch.ones(1, 2)) ck = tmp_path / "ck" - # force failure by making the parent read-only on a subpath - # We simulate by patching ray.get to raise mid-save - original_ray_get = ray.get - - call_count = [0] - - def failing_ray_get(futures, *args, **kwargs): - call_count[0] += 1 - if call_count[0] == 2: # fail on storage unit dump - raise RuntimeError("simulated dump failure") - return original_ray_get(futures, *args, **kwargs) import unittest.mock as mock - with mock.patch("transfer_queue.interface.ray.get", side_effect=failing_ray_get): + with mock.patch( + "transfer_queue.client.TransferQueueClient.dump_storage_checkpoint", + side_effect=RuntimeError("simulated dump failure"), + ): with pytest.raises(RuntimeError, match="simulated dump failure"): tq.save_checkpoint(ck) diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 45e6303a..9ba558b3 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -1061,6 +1061,75 @@ def close(self) -> None: except Exception as e: logger.warning(f"Error closing storage manager: {e}") + # ==================== Checkpoint API ==================== + @with_controller_socket + async def async_dump_controller_checkpoint( + self, + path: str, + socket: zmq.asyncio.Socket | None = None, + ) -> None: + """Send CHECKPOINT_DUMP to controller and wait for response.""" + assert socket is not None + request_msg = ZMQMessage.create( + request_type=ZMQRequestType.CHECKPOINT_DUMP, # type: ignore[arg-type] + sender_id=self.client_id, + receiver_id=self._controller.id, + body={"path": path}, + ) + await socket.send_multipart(request_msg.serialize()) + response_serialized = await socket.recv_multipart(copy=False) + response_msg = ZMQMessage.deserialize(response_serialized) + if response_msg.request_type != ZMQRequestType.CHECKPOINT_DUMP_RESPONSE: + raise RuntimeError( + f"[{self.client_id}]: Unexpected response type {response_msg.request_type} " + f"from controller during checkpoint dump" + ) + if not response_msg.body.get("success"): + raise RuntimeError(f"[{self.client_id}]: Controller failed to dump checkpoint to {path}") + + @with_controller_socket + async def async_restore_controller_checkpoint( + self, + path: str, + socket: zmq.asyncio.Socket | None = None, + ) -> None: + """Send CHECKPOINT_RESTORE to controller and wait for response.""" + assert socket is not None + request_msg = ZMQMessage.create( + request_type=ZMQRequestType.CHECKPOINT_RESTORE, # type: ignore[arg-type] + sender_id=self.client_id, + receiver_id=self._controller.id, + body={"path": path}, + ) + await socket.send_multipart(request_msg.serialize()) + response_serialized = await socket.recv_multipart(copy=False) + response_msg = ZMQMessage.deserialize(response_serialized) + if response_msg.request_type != ZMQRequestType.CHECKPOINT_RESTORE_RESPONSE: + raise RuntimeError( + f"[{self.client_id}]: Unexpected response type {response_msg.request_type} " + f"from controller during checkpoint restore" + ) + if not response_msg.body.get("success"): + raise RuntimeError(f"[{self.client_id}]: Controller failed to restore checkpoint from {path}") + + async def async_dump_storage_checkpoint(self, output_dir: str) -> list[dict]: + """Dump all storage units to files in output_dir via StorageManager.""" + if not hasattr(self, "storage_manager") or self.storage_manager is None: + raise RuntimeError( + f"[{self.client_id}]: Storage manager not initialized. " + "Call initialize_storage_manager() before checkpoint operations." + ) + return await self.storage_manager.dump_checkpoint(output_dir) + + async def async_restore_storage_checkpoint(self, checkpoint_dir: str, su_info_list: list[dict]) -> None: + """Restore all storage units from files in checkpoint_dir via StorageManager.""" + if not hasattr(self, "storage_manager") or self.storage_manager is None: + raise RuntimeError( + f"[{self.client_id}]: Storage manager not initialized. " + "Call initialize_storage_manager() before checkpoint operations." + ) + await self.storage_manager.restore_checkpoint(checkpoint_dir, su_info_list) + class TransferQueueClient(AsyncTransferQueueClient): """Synchronous client wrapper for TransferQueue. @@ -1129,6 +1198,10 @@ def wrapper(*args, **kwargs): self._kv_retrieve_meta = _make_sync(self.async_kv_retrieve_meta) self._kv_retrieve_keys = _make_sync(self.async_kv_retrieve_keys) self._kv_list = _make_sync(self.async_kv_list) + self._dump_controller_checkpoint = _make_sync(self.async_dump_controller_checkpoint) + self._restore_controller_checkpoint = _make_sync(self.async_restore_controller_checkpoint) + self._dump_storage_checkpoint = _make_sync(self.async_dump_storage_checkpoint) + self._restore_storage_checkpoint = _make_sync(self.async_restore_storage_checkpoint) # ==================== Basic API ==================== def get_meta( @@ -1561,6 +1634,23 @@ def kv_list( return self._kv_list(partition_id=partition_id) + # ==================== Checkpoint API ==================== + def dump_controller_checkpoint(self, path: str) -> None: + """Synchronously dump controller state to a file via ZMQ RPC.""" + return self._dump_controller_checkpoint(path) + + def restore_controller_checkpoint(self, path: str) -> None: + """Synchronously restore controller state from a file via ZMQ RPC.""" + return self._restore_controller_checkpoint(path) + + def dump_storage_checkpoint(self, output_dir: str) -> list[dict]: + """Synchronously dump all storage units to files in output_dir.""" + return self._dump_storage_checkpoint(output_dir) + + def restore_storage_checkpoint(self, checkpoint_dir: str, su_info_list: list[dict]) -> None: + """Synchronously restore all storage units from files in checkpoint_dir.""" + return self._restore_storage_checkpoint(checkpoint_dir, su_info_list) + def close(self) -> None: """Close the client and cleanup resources including event loop and thread.""" diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index e2753643..d25d94fb 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -2032,6 +2032,26 @@ def _process_request(self): body={"partition_info": partition_info, "message": message}, ) + elif request_msg.request_type == ZMQRequestType.CHECKPOINT_DUMP: + path = request_msg.body["path"] + success = self.dump_to_file(path) + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.CHECKPOINT_DUMP_RESPONSE, + sender_id=self.controller_id, + receiver_id=request_msg.sender_id, + body={"success": success}, + ) + + elif request_msg.request_type == ZMQRequestType.CHECKPOINT_RESTORE: + path = request_msg.body["path"] + success = self.restore_from_file(path) + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.CHECKPOINT_RESTORE_RESPONSE, + sender_id=self.controller_id, + receiver_id=request_msg.sender_id, + body={"success": success}, + ) + self.request_handle_socket.send_multipart([identity, *response_msg.serialize()]) def get_zmq_server_info(self) -> ZMQServerInfo: diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 7f19a7ae..54de5e85 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -1089,39 +1089,27 @@ def save_checkpoint( tmp_dir.mkdir(parents=True) try: - # Step 1: controller dumps itself to file + client = _maybe_create_tq_client() + + # Controller dumps itself to file via ZMQ controller_path = tmp_dir / _CONTROLLER_FILE - success = ray.get(_TQ_CONTROLLER.dump_to_file.remote(str(controller_path))) - if not success: - raise RuntimeError("Controller failed to dump state to file") + client.dump_controller_checkpoint(str(controller_path)) controller_size = controller_path.stat().st_size logger.info(f"Controller state saved ({controller_size} bytes)") - # Step 2: storage units dump themselves to files in parallel + # Storage units dump themselves to files via StorageManager su_info_list: list[dict[str, Any]] = [] - if include_storage and _TQ_STORAGE and "SimpleStorage" in _TQ_STORAGE: + if include_storage and hasattr(client, "storage_manager") and client.storage_manager is not None: su_dir = tmp_dir / _STORAGE_UNITS_DIR su_dir.mkdir() - - su_handles: dict[str, Any] = _TQ_STORAGE["SimpleStorage"] - - futures = { - su_id: ( - pos, - su_dir / f"su_{pos}_{su_id}.pkl", - su_handles[su_id].dump_to_file.remote(str(su_dir / f"su_{pos}_{su_id}.pkl")), - ) - for pos, su_id in enumerate(su_handles) - } - - all_success = ray.get([f for _, _, f in futures.values()]) - for (su_id, (pos, path, _)), success in zip(futures.items(), all_success, strict=False): - if not success: - raise RuntimeError(f"Storage unit {su_id} failed to dump to {path}") - su_info_list.append({"position": pos, "storage_unit_id": su_id, "file_size": path.stat().st_size}) - logger.info(f"Storage unit {su_id} (pos={pos}) saved ({su_info_list[-1]['file_size']} bytes)") - - # Step 3: write metadata.json + try: + su_info_list = client.dump_storage_checkpoint(str(su_dir)) + for info in su_info_list: + logger.info(f"Storage unit {info['storage_unit_id']} (pos={info['position']}) ") + except NotImplementedError: + logger.warning("Storage backend does not support checkpoint; storage data will not be saved.") + + # Write metadata.json meta_content = { "storage_units": su_info_list, "user_metadata": metadata or {}, @@ -1129,7 +1117,7 @@ def save_checkpoint( with open(tmp_dir / _METADATA_FILE, "w") as f: json.dump(meta_content, f, indent=2) - # Step 4: atomic rename into final location + # Atomic rename into final location if checkpoint_dir.exists(): shutil.rmtree(checkpoint_dir) tmp_dir.rename(checkpoint_dir) @@ -1173,44 +1161,41 @@ def load_checkpoint( with open(metadata_path) as f: meta = json.load(f) - # Validate storage unit count before touching any state saved_su_list = meta.get("storage_units", []) - if saved_su_list: - if not (_TQ_STORAGE and "SimpleStorage" in _TQ_STORAGE): - raise ValueError("Checkpoint contains storage unit data but current system has no SimpleStorage backend.") - current_su_handles = list(_TQ_STORAGE["SimpleStorage"].values()) - if len(current_su_handles) != len(saved_su_list): + client = _maybe_create_tq_client() + + # Validate storage unit count before touching any state + if saved_su_list: + if not hasattr(client, "storage_manager") or client.storage_manager is None: + raise ValueError("Checkpoint contains storage unit data but current system has no storage manager.") + try: + current_su_count = len(client.storage_manager.get_zmq_server_info()) + if current_su_count != len(saved_su_list): + raise ValueError( + f"Storage unit count mismatch: checkpoint has {len(saved_su_list)}, " + f"current system has {current_su_count}." + ) + except NotImplementedError as e: raise ValueError( - f"Storage unit count mismatch: checkpoint has {len(saved_su_list)}, " - f"current system has {len(current_su_handles)}." - ) + "Checkpoint contains storage unit data but current storage backend does not support checkpoint." + ) from e - # Restore controller + # Restore controller via ZMQ RPC controller_path = checkpoint_dir / _CONTROLLER_FILE if not controller_path.exists(): raise FileNotFoundError(f"{_CONTROLLER_FILE} not found in {checkpoint_dir}") - if not ray.get(_TQ_CONTROLLER.restore_from_file.remote(str(controller_path))): - raise RuntimeError("Controller failed to restore from checkpoint.") + client.restore_controller_checkpoint(str(controller_path)) - # Restore storage units in parallel, matched by position + # Restore storage units via StorageManager abstraction if saved_su_list: - current_su_handles = list(_TQ_STORAGE["SimpleStorage"].values()) su_dir = checkpoint_dir / _STORAGE_UNITS_DIR - - entries_by_pos = sorted(saved_su_list, key=lambda e: e["position"]) - futures = [] - for entry in entries_by_pos: + for entry in saved_su_list: pos = entry["position"] path = su_dir / f"su_{pos}_{entry['storage_unit_id']}.pkl" if not path.exists(): raise FileNotFoundError(f"Storage unit file not found: {path}") - futures.append(current_su_handles[pos].restore_from_file.remote(str(path))) - - results = ray.get(futures) - if not all(results): - failed_positions = [i for i, r in enumerate(results) if not r] - raise RuntimeError(f"Storage units at positions {failed_positions} failed to restore.") + client.restore_storage_checkpoint(str(su_dir), saved_su_list) logger.info(f"Checkpoint loaded from {checkpoint_dir}") diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index f4d545da..6299efff 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -352,6 +352,29 @@ async def clear_data(self, metadata: BatchMeta) -> None: """ raise NotImplementedError("Subclasses must implement clear_data") + async def dump_checkpoint(self, output_dir: str) -> list[dict]: + """Dump all storage units to files in output_dir. + + Returns: + List of dicts, each with keys: position, storage_unit_id. + + Raises: + NotImplementedError: If this storage backend does not support checkpoint. + """ + raise NotImplementedError(f"{self.__class__.__name__} does not support checkpoint") + + async def restore_checkpoint(self, checkpoint_dir: str, su_info_list: list[dict]) -> None: + """Restore all storage units from files in checkpoint_dir. + + Args: + checkpoint_dir: Path to the checkpoint directory containing storage unit files. + su_info_list: Ordered list of storage unit info dicts from metadata.json. + + Raises: + NotImplementedError: If this storage backend does not support checkpoint. + """ + raise NotImplementedError(f"{self.__class__.__name__} does not support checkpoint") + def close(self) -> None: """Close all ZMQ sockets/contexts and stop the notify loop.""" diff --git a/transfer_queue/storage/managers/simple_storage_manager.py b/transfer_queue/storage/managers/simple_storage_manager.py index 80803522..83fc442e 100644 --- a/transfer_queue/storage/managers/simple_storage_manager.py +++ b/transfer_queue/storage/managers/simple_storage_manager.py @@ -524,6 +524,111 @@ def get_zmq_server_info(self) -> dict[str, ZMQServerInfo]: """ return self.storage_unit_infos + @with_storage_unit_socket + async def _dump_single_storage_unit( + self, + path: str, + target_storage_unit: str, + socket: zmq.Socket = None, + ): + try: + request_msg = ZMQMessage.create( + request_type=ZMQRequestType.CHECKPOINT_DUMP, # type: ignore[arg-type] + sender_id=self.storage_manager_id, + receiver_id=target_storage_unit, + body={"path": path}, + ) + await socket.send_multipart(request_msg.serialize(), copy=False) + messages = await socket.recv_multipart(copy=False) + response_msg = ZMQMessage.deserialize(messages) + if response_msg.request_type != ZMQRequestType.CHECKPOINT_DUMP_RESPONSE: + raise RuntimeError( + f"Unexpected response from storage unit {target_storage_unit}: {response_msg.request_type}" + ) + except Exception as e: + logger.error(f"[{self.storage_manager_id}]: Error dumping storage unit {target_storage_unit}: {str(e)}") + raise + + @with_storage_unit_socket + async def _restore_single_storage_unit( + self, + path: str, + target_storage_unit: str, + socket: zmq.Socket = None, + ): + try: + request_msg = ZMQMessage.create( + request_type=ZMQRequestType.CHECKPOINT_RESTORE, # type: ignore[arg-type] + sender_id=self.storage_manager_id, + receiver_id=target_storage_unit, + body={"path": path}, + ) + await socket.send_multipart(request_msg.serialize(), copy=False) + messages = await socket.recv_multipart(copy=False) + response_msg = ZMQMessage.deserialize(messages) + if response_msg.request_type != ZMQRequestType.CHECKPOINT_RESTORE_RESPONSE: + raise RuntimeError( + f"Unexpected response from storage unit {target_storage_unit}: {response_msg.request_type}" + ) + except Exception as e: + logger.error( + f"[{self.storage_manager_id}]: Error restoring for storage unit {target_storage_unit}: {str(e)}" + ) + raise + + async def dump_checkpoint(self, output_dir: str) -> list[dict]: + """Dump all storage units to files in output_dir in parallel. + + Returns: + Ordered list of dicts with keys: position, storage_unit_id. + """ + from pathlib import Path + + su_ids = list(self.storage_unit_infos.keys()) + tasks = [] + paths = [] + for pos, su_id in enumerate(su_ids): + path = str(Path(output_dir) / f"su_{pos}_{su_id}.pkl") + paths.append(path) + tasks.append(self._dump_single_storage_unit(path, target_storage_unit=su_id)) + + await asyncio.gather(*tasks) + + su_info_list = [] + for pos, (su_id, path) in enumerate(zip(su_ids, paths, strict=False)): + su_info_list.append( + { + "position": pos, + "storage_unit_id": su_id, + } + ) + return su_info_list + + async def restore_checkpoint(self, checkpoint_dir: str, su_info_list: list[dict]) -> None: + """Restore all storage units from files in checkpoint_dir. + + Matches by position: entry at position i is restored to the storage unit + currently at position i in storage_unit_infos. + """ + from pathlib import Path + + su_ids = list(self.storage_unit_infos.keys()) + entries = sorted(su_info_list, key=lambda e: e["position"]) + + if len(entries) != len(su_ids): + raise ValueError( + f"Storage unit count mismatch: checkpoint has {len(entries)}, current system has {len(su_ids)}." + ) + + tasks = [] + for entry in entries: + pos = entry["position"] + su_id = su_ids[pos] + path = str(Path(checkpoint_dir) / f"su_{pos}_{entry['storage_unit_id']}.pkl") + tasks.append(self._restore_single_storage_unit(path, target_storage_unit=su_id)) + + await asyncio.gather(*tasks) + def close(self) -> None: """Close all ZMQ sockets and context to prevent resource leaks.""" super().close() diff --git a/transfer_queue/storage/simple_storage.py b/transfer_queue/storage/simple_storage.py index d615c7ba..aa417cf2 100644 --- a/transfer_queue/storage/simple_storage.py +++ b/transfer_queue/storage/simple_storage.py @@ -309,6 +309,10 @@ def _worker_routine(self) -> None: response_msg = self._handle_clear(request_msg) elif operation == ZMQRequestType.GET_METRICS: # type: ignore[arg-type] response_msg = self._handle_get_metrics() + elif operation == ZMQRequestType.CHECKPOINT_DUMP: # type: ignore[arg-type] + response_msg = self._handle_checkpoint_dump(request_msg) + elif operation == ZMQRequestType.CHECKPOINT_RESTORE: # type: ignore[arg-type] + response_msg = self._handle_checkpoint_restore(request_msg) else: response_msg = ZMQMessage.create( request_type=ZMQRequestType.PUT_GET_OPERATION_ERROR, # type: ignore[arg-type] @@ -536,11 +540,98 @@ def _handle_get_metrics(self) -> ZMQMessage: metrics["op_stats"] = op_stats return ZMQMessage.create( - request_type=ZMQRequestType.METRICS_RESPONSE, + request_type=ZMQRequestType.METRICS_RESPONSE, # type: ignore[arg-type] sender_id=self.storage_unit_id, body=metrics, ) + def _handle_checkpoint_dump(self, data_parts) -> ZMQMessage: + """ + Serialize storage unit data directly to a file. + + Writes data in-process to avoid transmitting the payload back over the + Ray object store — only a bool ACK is returned to the caller. + + Args: + path: data_parts: ZMQMessage from client, including + absolute path for the output .pkl file. + The caller must ensure this path is reachable from the node + running this actor (shared filesystem required for multi-node setups). + + Returns: + Checkpoint dump success response ZMQMessage. + """ + path = data_parts.body["path"] + try: + state = { + "storage_unit_id": self.storage_unit_id, + "storage_unit_size": self.storage_unit_size, + "field_data": self.storage_data.field_data, + "active_keys": self.storage_data._active_keys, + } + with open(path, "wb") as f: + pickle.dump(state, f, protocol=pickle.HIGHEST_PROTOCOL) + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.CHECKPOINT_DUMP_RESPONSE, # type: ignore[arg-type] + sender_id=self.storage_unit_id, + body={"message": f"[{self.storage_unit_id}] dumped to {path} successfully"}, + ) + except Exception: + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.CHECKPOINT_DUMP_RESPONSE, # type: ignore[arg-type] + sender_id=self.storage_unit_id, + body={"message": f"[{self.storage_unit_id}]: dumped to {path} failed"}, + ) + + return response_msg + + def _handle_checkpoint_restore(self, data_parts) -> ZMQMessage: + """ + Restore storage unit data directly from a file. + + Args: + path: data_parts: ZMQMessage from client, including + absolute path for the output .pkl file. + The caller must ensure this path is reachable from the node + running this actor (shared filesystem required for multi-node setups). + + Returns: + Checkpoint restore success response ZMQMessage. + """ + path = data_parts.body["path"] + try: + with open(path, "rb") as f: + data = pickle.load(f) + + if data["storage_unit_size"] != self.storage_unit_size: + logger.warning( + f"[{self.storage_unit_id}]: storage_unit_size mismatch — " + f"checkpoint={data['storage_unit_size']}, current={self.storage_unit_size}" + ) + + self.storage_data.field_data.clear() + self.storage_data._active_keys.clear() + self.storage_data.field_data = data["field_data"] + self.storage_data._active_keys = data["active_keys"] + + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.CHECKPOINT_RESTORE_RESPONSE, # type: ignore[arg-type] + sender_id=self.storage_unit_id, + body={ + "message": f"[{self.storage_unit_id}] restored from {path} — " + f"{len(data['active_keys'])} keys, {len(data['field_data'])} fields" + }, + ) + + except Exception as e: + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.CHECKPOINT_RESTORE_RESPONSE, # type: ignore[arg-type] + sender_id=self.storage_unit_id, + body={"message": f"[{self.storage_unit_id}] restored from {path} failed: {e}"}, + ) + + return response_msg + @staticmethod def _cumulative_bucket_counts(hist) -> list[float]: """Build cumulative counts from a prometheus_client Histogram's non-cumulative buckets.""" @@ -631,65 +722,3 @@ def get_zmq_server_info(self) -> ZMQServerInfo: ZMQServerInfo containing connection details for this storage unit. """ return self.zmq_server_info - - def dump_to_file(self, path: str) -> bool: - """Serialize storage unit data directly to a file. - - Writes data in-process to avoid transmitting the payload back over the - Ray object store — only a bool ACK is returned to the caller. - - Args: - path: Absolute path for the output .pkl file. The caller must ensure - this path is reachable from the node running this actor - (shared filesystem required for multi-node setups). - - Returns: - True on success, False on failure. - """ - try: - state = { - "storage_unit_id": self.storage_unit_id, - "storage_unit_size": self.storage_unit_size, - "field_data": self.storage_data.field_data, - "active_keys": self.storage_data._active_keys, - } - with open(path, "wb") as f: - pickle.dump(state, f, protocol=pickle.HIGHEST_PROTOCOL) - logger.info(f"[{self.storage_unit_id}]: dumped to {path}") - return True - except Exception as e: - logger.error(f"[{self.storage_unit_id}]: dump_to_file failed: {e}") - return False - - def restore_from_file(self, path: str) -> bool: - """Restore storage unit data directly from a file. - - Args: - path: Absolute path to a .pkl file previously written by dump_to_file(). - - Returns: - True on success, False on failure. - """ - try: - with open(path, "rb") as f: - data = pickle.load(f) - - if data["storage_unit_size"] != self.storage_unit_size: - logger.warning( - f"[{self.storage_unit_id}]: storage_unit_size mismatch — " - f"checkpoint={data['storage_unit_size']}, current={self.storage_unit_size}" - ) - - self.storage_data.field_data.clear() - self.storage_data._active_keys.clear() - self.storage_data.field_data = data["field_data"] - self.storage_data._active_keys = data["active_keys"] - - logger.info( - f"[{self.storage_unit_id}]: restored from {path} — " - f"{len(data['active_keys'])} keys, {len(data['field_data'])} fields" - ) - return True - except Exception as e: - logger.error(f"[{self.storage_unit_id}]: restore_from_file failed: {e}") - return False diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py index 4fe32f0a..3a260a14 100644 --- a/transfer_queue/utils/zmq_utils.py +++ b/transfer_queue/utils/zmq_utils.py @@ -101,6 +101,12 @@ class ZMQRequestType(ExplicitEnum): GET_METRICS = "GET_METRICS" METRICS_RESPONSE = "METRICS_RESPONSE" + # CHECKPOINT + CHECKPOINT_DUMP = "CHECKPOINT_DUMP" + CHECKPOINT_DUMP_RESPONSE = "CHECKPOINT_DUMP_RESPONSE" + CHECKPOINT_RESTORE = "CHECKPOINT_RESTORE" + CHECKPOINT_RESTORE_RESPONSE = "CHECKPOINT_RESTORE_RESPONSE" + class ZMQServerInfo: """ From 3498b61072cde9adb7b53b41a6d46294c35c5a4b Mon Sep 17 00:00:00 2001 From: yxstev Date: Mon, 15 Jun 2026 17:24:53 +0800 Subject: [PATCH 4/6] resolve comments Signed-off-by: yxstev --- tests/e2e/test_checkpoint_e2e.py | 17 ++++--- transfer_queue/controller.py | 9 ++-- transfer_queue/interface.py | 49 +++++++++++++------ transfer_queue/storage/managers/base.py | 13 +++++ .../managers/simple_storage_manager.py | 4 ++ 5 files changed, 67 insertions(+), 25 deletions(-) diff --git a/tests/e2e/test_checkpoint_e2e.py b/tests/e2e/test_checkpoint_e2e.py index ac2f1729..c41432a9 100644 --- a/tests/e2e/test_checkpoint_e2e.py +++ b/tests/e2e/test_checkpoint_e2e.py @@ -231,7 +231,8 @@ def test_load_restores_multiple_partitions(self, tq_system, checkpoint_dir, cont class TestCheckpointMetadataOnly: - def test_save_metadata_only_no_storage_files(self, tq_system, checkpoint_dir): + def test_save_simplestorage_always_includes_storage(self, tq_system, checkpoint_dir): + """SimpleStorage is in-memory, so storage data is always saved regardless of include_storage.""" _put_batch(["n0"], "p_nometa", torch.tensor([[1, 2]]), torch.ones(1, 2)) tq.save_checkpoint(checkpoint_dir, include_storage=False) @@ -239,30 +240,32 @@ def test_save_metadata_only_no_storage_files(self, tq_system, checkpoint_dir): with open(checkpoint_dir / "metadata.json") as f: info = json.load(f) - assert info["storage_units"] == [] - assert not (checkpoint_dir / "storage_units").exists() + # SimpleStorage backend: storage_checkpoint_required=True, so data is always saved + assert len(info["storage_units"]) == 2 + assert (checkpoint_dir / "storage_units").exists() - def test_load_after_metadata_only_save(self, tq_system, checkpoint_dir, controller): + def test_load_after_include_storage_false_simplestorage(self, tq_system, checkpoint_dir, controller): + """Even with include_storage=False, SimpleStorage data is saved and can be restored.""" keys = ["n0", "n1"] partition_id = "p_nometa2" input_ids = torch.tensor([[5, 6], [7, 8]]) _put_batch(keys, partition_id, input_ids, torch.ones(2, 2)) - # save without storage tq.save_checkpoint(checkpoint_dir, include_storage=False) ray.get(controller.clear_partition.remote(partition_id)) tq.load_checkpoint(checkpoint_dir) - # controller state (partition metadata) must be restored partitions = ray.get(controller.list_partitions.remote()) assert partition_id in partitions snapshot = ray.get(controller.get_partition_snapshot.remote(partition_id)) for key in keys: assert key in snapshot.keys_mapping - _get_batch(keys, partition_id) + + retrieved = _get_batch(keys, partition_id) + assert_tensor_equal(retrieved["input_ids"], input_ids) # --------------------------------------------------------------------------- diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index d25d94fb..34304a9b 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -2089,8 +2089,9 @@ def dump_to_file(self, path: str) -> bool: "allocated_indexes": set(self.index_manager.allocated_indexes), }, "sampler": self.sampler.get_state() if hasattr(self.sampler, "get_state") else None, - "tq_config": self.tq_config, - "connected_storage_managers": set(self._connected_storage_managers), + # tq_config and connected_storage_managers are intentionally excluded: + # tq_config may differ after reinit (e.g. different ports/addresses); + # connected_storage_managers holds UUIDs that are regenerated on every reinit. } with open(path, "wb") as f: pickle.dump(state, f, protocol=pickle.HIGHEST_PROTOCOL) @@ -2124,8 +2125,8 @@ def restore_from_file(self, path: str) -> bool: if state["sampler"] is not None and hasattr(self.sampler, "restore_state"): self.sampler.restore_state(state["sampler"]) - self.tq_config = state["tq_config"] - self._connected_storage_managers = state["connected_storage_managers"] + # tq_config and connected_storage_managers are not restored: they belong to + # the new init context, not the checkpoint. logger.info(f"[{self.controller_id}]: restored from {path}") return True diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 54de5e85..b108c15f 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -1067,10 +1067,17 @@ def save_checkpoint( ) -> None: """Save a full checkpoint of the TransferQueue system state. + .. note:: + **Multi-node limitation**: checkpoint_dir must reside on a shared network + filesystem (e.g. NFS, GPFS, Lustre) accessible from all nodes running + storage units. Single-node deployments have no such requirement. + Args: checkpoint_dir: Directory to save the checkpoint (created if not exists). - include_storage: Whether to include storage unit data. - If False, only controller state is saved. + include_storage: Controls whether storage contents are saved for KV backends + (MooncakeStore, Yuanrong, etc.) that persist data externally. + Ignored for SimpleStorage — its in-memory data is always saved + because it would be lost on restart. metadata: User-defined key-value pairs written into metadata.json. Example: {"timestamp": 1718123456.789012, "step": 1000} @@ -1097,17 +1104,27 @@ def save_checkpoint( controller_size = controller_path.stat().st_size logger.info(f"Controller state saved ({controller_size} bytes)") - # Storage units dump themselves to files via StorageManager + # Storage units dump themselves to files via StorageManager. + # For in-memory backends (e.g. SimpleStorage), storage_checkpoint_required is True + # and we always save regardless of include_storage. + # For persistent KV backends (e.g. MooncakeStore, Yuanrong), include_storage controls + # whether to save storage contents (they survive restarts so it's optional). su_info_list: list[dict[str, Any]] = [] - if include_storage and hasattr(client, "storage_manager") and client.storage_manager is not None: - su_dir = tmp_dir / _STORAGE_UNITS_DIR - su_dir.mkdir() - try: - su_info_list = client.dump_storage_checkpoint(str(su_dir)) - for info in su_info_list: - logger.info(f"Storage unit {info['storage_unit_id']} (pos={info['position']}) ") - except NotImplementedError: - logger.warning("Storage backend does not support checkpoint; storage data will not be saved.") + if hasattr(client, "storage_manager") and client.storage_manager is not None: + sm = client.storage_manager + should_save = sm.storage_checkpoint_required or include_storage + if should_save: + su_dir = tmp_dir / _STORAGE_UNITS_DIR + su_dir.mkdir() + try: + su_info_list = client.dump_storage_checkpoint(str(su_dir)) + for info in su_info_list: + logger.info( + f"Storage unit {info['storage_unit_id']} (pos={info['position']}) " + f"saved ({info['file_size']} bytes)" + ) + except NotImplementedError: + logger.warning("Storage backend does not support checkpoint; storage data will not be saved.") # Write metadata.json meta_content = { @@ -1136,8 +1153,12 @@ def load_checkpoint( """Restore TransferQueue system state from a checkpoint. The ordered storage unit list of the current system must exactly match the - checkpoint (same count, same positions). This is required because data - routing is position-based (global_idx % num_units). + checkpoint (same count, same positions). + + .. note:: + **Multi-node limitation**: checkpoint_dir must reside on a shared network + filesystem (e.g. NFS, GPFS, Lustre) accessible from all nodes running + storage units. Single-node deployments have no such requirement. Args: checkpoint_dir: Path to the checkpoint directory. diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index 6299efff..a7464590 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -352,6 +352,19 @@ async def clear_data(self, metadata: BatchMeta) -> None: """ raise NotImplementedError("Subclasses must implement clear_data") + @property + def storage_checkpoint_required(self) -> bool: + """Whether storage contents must be checkpointed for correct restore. + + Returns True for in-memory backends (e.g. SimpleStorage) where data + is lost on restart and must be serialized. Returns False for persistent + KV backends (e.g. MooncakeStore, Yuanrong) where data survives restarts + and only controller metadata needs to be saved. + + Subclasses should override this to reflect their actual persistence model. + """ + return False + async def dump_checkpoint(self, output_dir: str) -> list[dict]: """Dump all storage units to files in output_dir. diff --git a/transfer_queue/storage/managers/simple_storage_manager.py b/transfer_queue/storage/managers/simple_storage_manager.py index 83fc442e..14609a91 100644 --- a/transfer_queue/storage/managers/simple_storage_manager.py +++ b/transfer_queue/storage/managers/simple_storage_manager.py @@ -85,6 +85,10 @@ def __init__(self, controller_info: ZMQServerInfo, config: DictConfig): self.storage_unit_infos = self._register_servers(server_infos) + @property + def storage_checkpoint_required(self) -> bool: + return True + def _register_servers(self, server_infos: "ZMQServerInfo | dict[Any, ZMQServerInfo]"): """Register and validate server information. From 4fb941c58d756c6cf89ff4fa832978bf45127162 Mon Sep 17 00:00:00 2001 From: yxstev Date: Mon, 15 Jun 2026 18:00:37 +0800 Subject: [PATCH 5/6] resolve comments Signed-off-by: yxstev --- transfer_queue/interface.py | 21 ++------ .../managers/simple_storage_manager.py | 50 +++++++------------ transfer_queue/storage/simple_storage.py | 24 ++++----- transfer_queue/utils/zmq_utils.py | 1 + 4 files changed, 34 insertions(+), 62 deletions(-) diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index b108c15f..d6999315 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -1104,25 +1104,16 @@ def save_checkpoint( controller_size = controller_path.stat().st_size logger.info(f"Controller state saved ({controller_size} bytes)") - # Storage units dump themselves to files via StorageManager. - # For in-memory backends (e.g. SimpleStorage), storage_checkpoint_required is True - # and we always save regardless of include_storage. - # For persistent KV backends (e.g. MooncakeStore, Yuanrong), include_storage controls - # whether to save storage contents (they survive restarts so it's optional). su_info_list: list[dict[str, Any]] = [] if hasattr(client, "storage_manager") and client.storage_manager is not None: sm = client.storage_manager - should_save = sm.storage_checkpoint_required or include_storage - if should_save: + if sm.storage_checkpoint_required or include_storage: su_dir = tmp_dir / _STORAGE_UNITS_DIR su_dir.mkdir() try: su_info_list = client.dump_storage_checkpoint(str(su_dir)) for info in su_info_list: - logger.info( - f"Storage unit {info['storage_unit_id']} (pos={info['position']}) " - f"saved ({info['file_size']} bytes)" - ) + logger.info(f"Storage unit {info['storage_unit_id']} (pos={info['position']}) ") except NotImplementedError: logger.warning("Storage backend does not support checkpoint; storage data will not be saved.") @@ -1211,12 +1202,6 @@ def load_checkpoint( # Restore storage units via StorageManager abstraction if saved_su_list: - su_dir = checkpoint_dir / _STORAGE_UNITS_DIR - for entry in saved_su_list: - pos = entry["position"] - path = su_dir / f"su_{pos}_{entry['storage_unit_id']}.pkl" - if not path.exists(): - raise FileNotFoundError(f"Storage unit file not found: {path}") - client.restore_storage_checkpoint(str(su_dir), saved_su_list) + client.restore_storage_checkpoint(str(checkpoint_dir / _STORAGE_UNITS_DIR), saved_su_list) logger.info(f"Checkpoint loaded from {checkpoint_dir}") diff --git a/transfer_queue/storage/managers/simple_storage_manager.py b/transfer_queue/storage/managers/simple_storage_manager.py index 14609a91..a49afd0b 100644 --- a/transfer_queue/storage/managers/simple_storage_manager.py +++ b/transfer_queue/storage/managers/simple_storage_manager.py @@ -19,6 +19,7 @@ from collections import defaultdict from collections.abc import Mapping from operator import itemgetter +from pathlib import Path from typing import Any, Callable, NamedTuple import torch @@ -586,27 +587,22 @@ async def dump_checkpoint(self, output_dir: str) -> list[dict]: Returns: Ordered list of dicts with keys: position, storage_unit_id. """ - from pathlib import Path - su_ids = list(self.storage_unit_infos.keys()) - tasks = [] - paths = [] - for pos, su_id in enumerate(su_ids): - path = str(Path(output_dir) / f"su_{pos}_{su_id}.pkl") - paths.append(path) - tasks.append(self._dump_single_storage_unit(path, target_storage_unit=su_id)) + paths = [str(Path(output_dir) / f"su_{pos}_{su_id}.pkl") for pos, su_id in enumerate(su_ids)] + tasks = [ + self._dump_single_storage_unit(path, target_storage_unit=su_id) + for su_id, path in zip(su_ids, paths, strict=False) + ] await asyncio.gather(*tasks) - su_info_list = [] - for pos, (su_id, path) in enumerate(zip(su_ids, paths, strict=False)): - su_info_list.append( - { - "position": pos, - "storage_unit_id": su_id, - } - ) - return su_info_list + return [ + { + "position": pos, + "storage_unit_id": su_id, + } + for pos, (su_id, path) in enumerate(zip(su_ids, paths, strict=False)) + ] async def restore_checkpoint(self, checkpoint_dir: str, su_info_list: list[dict]) -> None: """Restore all storage units from files in checkpoint_dir. @@ -614,23 +610,15 @@ async def restore_checkpoint(self, checkpoint_dir: str, su_info_list: list[dict] Matches by position: entry at position i is restored to the storage unit currently at position i in storage_unit_infos. """ - from pathlib import Path - su_ids = list(self.storage_unit_infos.keys()) entries = sorted(su_info_list, key=lambda e: e["position"]) - - if len(entries) != len(su_ids): - raise ValueError( - f"Storage unit count mismatch: checkpoint has {len(entries)}, current system has {len(su_ids)}." + tasks = [ + self._restore_single_storage_unit( + str(Path(checkpoint_dir) / f"su_{entry['position']}_{entry['storage_unit_id']}.pkl"), + target_storage_unit=su_ids[entry["position"]], ) - - tasks = [] - for entry in entries: - pos = entry["position"] - su_id = su_ids[pos] - path = str(Path(checkpoint_dir) / f"su_{pos}_{entry['storage_unit_id']}.pkl") - tasks.append(self._restore_single_storage_unit(path, target_storage_unit=su_id)) - + for entry in entries + ] await asyncio.gather(*tasks) def close(self) -> None: diff --git a/transfer_queue/storage/simple_storage.py b/transfer_queue/storage/simple_storage.py index aa417cf2..1c91e419 100644 --- a/transfer_queue/storage/simple_storage.py +++ b/transfer_queue/storage/simple_storage.py @@ -571,20 +571,19 @@ def _handle_checkpoint_dump(self, data_parts) -> ZMQMessage: } with open(path, "wb") as f: pickle.dump(state, f, protocol=pickle.HIGHEST_PROTOCOL) - response_msg = ZMQMessage.create( + return ZMQMessage.create( request_type=ZMQRequestType.CHECKPOINT_DUMP_RESPONSE, # type: ignore[arg-type] sender_id=self.storage_unit_id, body={"message": f"[{self.storage_unit_id}] dumped to {path} successfully"}, ) - except Exception: - response_msg = ZMQMessage.create( - request_type=ZMQRequestType.CHECKPOINT_DUMP_RESPONSE, # type: ignore[arg-type] + except Exception as e: + logger.error(f"[{self.storage_unit_id}]: dump_to_file failed: {e}") + return ZMQMessage.create( + request_type=ZMQRequestType.SAVE_LOAD_CKPT_ERROR, # type: ignore[arg-type] sender_id=self.storage_unit_id, - body={"message": f"[{self.storage_unit_id}]: dumped to {path} failed"}, + body={"message": str(e)}, ) - return response_msg - def _handle_checkpoint_restore(self, data_parts) -> ZMQMessage: """ Restore storage unit data directly from a file. @@ -614,7 +613,7 @@ def _handle_checkpoint_restore(self, data_parts) -> ZMQMessage: self.storage_data.field_data = data["field_data"] self.storage_data._active_keys = data["active_keys"] - response_msg = ZMQMessage.create( + return ZMQMessage.create( request_type=ZMQRequestType.CHECKPOINT_RESTORE_RESPONSE, # type: ignore[arg-type] sender_id=self.storage_unit_id, body={ @@ -624,14 +623,13 @@ def _handle_checkpoint_restore(self, data_parts) -> ZMQMessage: ) except Exception as e: - response_msg = ZMQMessage.create( - request_type=ZMQRequestType.CHECKPOINT_RESTORE_RESPONSE, # type: ignore[arg-type] + logger.error(f"[{self.storage_unit_id}]: restore_from_file failed: {e}") + return ZMQMessage.create( + request_type=ZMQRequestType.SAVE_LOAD_CKPT_ERROR, # type: ignore[arg-type] sender_id=self.storage_unit_id, - body={"message": f"[{self.storage_unit_id}] restored from {path} failed: {e}"}, + body={"message": str(e)}, ) - return response_msg - @staticmethod def _cumulative_bucket_counts(hist) -> list[float]: """Build cumulative counts from a prometheus_client Histogram's non-cumulative buckets.""" diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py index 3a260a14..a58617b2 100644 --- a/transfer_queue/utils/zmq_utils.py +++ b/transfer_queue/utils/zmq_utils.py @@ -106,6 +106,7 @@ class ZMQRequestType(ExplicitEnum): CHECKPOINT_DUMP_RESPONSE = "CHECKPOINT_DUMP_RESPONSE" CHECKPOINT_RESTORE = "CHECKPOINT_RESTORE" CHECKPOINT_RESTORE_RESPONSE = "CHECKPOINT_RESTORE_RESPONSE" + SAVE_LOAD_CKPT_ERROR = "SAVE_LOAD_CKPT_ERROR" class ZMQServerInfo: From 41171e5d5527046b941038136d1af1c7e10ea2e4 Mon Sep 17 00:00:00 2001 From: yxstev Date: Mon, 15 Jun 2026 18:49:43 +0800 Subject: [PATCH 6/6] resolve checks Signed-off-by: yxstev --- transfer_queue/interface.py | 17 ++++++++--------- .../storage/managers/simple_storage_manager.py | 1 + 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index d6999315..5fcd3a15 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -1181,17 +1181,16 @@ def load_checkpoint( if saved_su_list: if not hasattr(client, "storage_manager") or client.storage_manager is None: raise ValueError("Checkpoint contains storage unit data but current system has no storage manager.") - try: - current_su_count = len(client.storage_manager.get_zmq_server_info()) - if current_su_count != len(saved_su_list): - raise ValueError( - f"Storage unit count mismatch: checkpoint has {len(saved_su_list)}, " - f"current system has {current_su_count}." - ) - except NotImplementedError as e: + sm = client.storage_manager + if not sm.storage_checkpoint_required: raise ValueError( "Checkpoint contains storage unit data but current storage backend does not support checkpoint." - ) from e + ) + if hasattr(sm, "get_zmq_server_info") and len(sm.get_zmq_server_info()) != len(saved_su_list): + raise ValueError( + f"Storage unit count mismatch: checkpoint has {len(saved_su_list)}, " + f"current system has {len(sm.get_zmq_server_info())}." + ) # Restore controller via ZMQ RPC controller_path = checkpoint_dir / _CONTROLLER_FILE diff --git a/transfer_queue/storage/managers/simple_storage_manager.py b/transfer_queue/storage/managers/simple_storage_manager.py index a49afd0b..874da7ee 100644 --- a/transfer_queue/storage/managers/simple_storage_manager.py +++ b/transfer_queue/storage/managers/simple_storage_manager.py @@ -88,6 +88,7 @@ def __init__(self, controller_info: ZMQServerInfo, config: DictConfig): @property def storage_checkpoint_required(self) -> bool: + """Whether storage contents must be checkpointed for correct restore.""" return True def _register_servers(self, server_infos: "ZMQServerInfo | dict[Any, ZMQServerInfo]"):