From 707f2bddbcb6c3f3a881d6bf2cc448c65e46947e Mon Sep 17 00:00:00 2001 From: Abhishek Agrawal Date: Thu, 30 Apr 2026 02:42:06 -0700 Subject: [PATCH] Add a script for converting Safetensors to Orbax native layout. PiperOrigin-RevId: 908050113 --- .../orbax/checkpoint/_src/path/gcs_utils.py | 46 ++ .../v1/_src/layout/safetensors_layout.py | 191 +++++++- .../v1/converter_lib_safetensors.py | 429 ++++++++++++++++++ .../checkpoint/experimental/v1/run_script.sh | 48 ++ 4 files changed, 697 insertions(+), 17 deletions(-) create mode 100644 checkpoint/orbax/checkpoint/experimental/v1/converter_lib_safetensors.py create mode 100755 checkpoint/orbax/checkpoint/experimental/v1/run_script.sh diff --git a/checkpoint/orbax/checkpoint/_src/path/gcs_utils.py b/checkpoint/orbax/checkpoint/_src/path/gcs_utils.py index f023dffa1..ef312dd80 100644 --- a/checkpoint/orbax/checkpoint/_src/path/gcs_utils.py +++ b/checkpoint/orbax/checkpoint/_src/path/gcs_utils.py @@ -28,6 +28,52 @@ def is_gcs_path(path: epath.Path) -> bool: return path.as_posix().startswith(_GCS_PATH_PREFIX) +def to_gs_path(path: epath.PathLike) -> str: + """Converts a GCS path to a gs:// path string. + + GCS paths can start with any of the prefixes in _GCS_PATH_PREFIX. This + function converts them to gs:// format. + + Args: + path: A GCS path which can be a string or epath.Path. + + Returns: + A GCS path string starting with gs://. + + Raises: + ValueError: If path is not a GCS path. + """ + path_str = str(path) + if path_str.startswith('gs://'): + return path_str + else: + raise ValueError(f'Path is not a GCS path: {path}') + + +def to_gcsfuse_path(path: epath.PathLike) -> str: + """Converts a GCS path to a gcsfuse path string. + + GCSfuse paths start with /gcs/ and are accessible via File API when gcsfuse + is enabled. + + Args: + path: A GCS path which can be a string or epath.Path. + + Returns: + A gcsfuse path string starting with /gcs/. + + Raises: + ValueError: If path is not a GCS path. + """ + path_str = str(path) + if path_str.startswith('gs://'): + return path_str.replace('gs://', '/gcs/', 1) + elif path_str.startswith('/gcs/'): + return path_str + else: + raise ValueError(f'Path is not a GCS path: {path}') + + def parse_gcs_path(path: epath.PathLike) -> tuple[str, str]: parsed = parse.urlparse(str(path)) assert parsed.scheme == 'gs', f'Unsupported scheme for GCS: {parsed.scheme}' diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout.py index cbaa5ce02..0c779b9b0 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout.py @@ -15,8 +15,12 @@ """Defines `SafetensorsLayout`, a class to handle Safetensors checkpoint formats.""" import asyncio +from concurrent import futures import dataclasses import json +import mmap +import os +import subprocess import time from typing import Any, Awaitable, cast @@ -26,6 +30,7 @@ import numpy as np from orbax.checkpoint._src.multihost import multihost from orbax.checkpoint._src.path import async_path +from orbax.checkpoint._src.path import gcs_utils from orbax.checkpoint._src.tree import utils as tree_utils from orbax.checkpoint.experimental.v1._src.context import context as context_lib from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout @@ -306,6 +311,55 @@ def _reshard_transient_array( )(global_transient_array) +def _get_tensor_bounds( + header: dict[str, Any], + tensor_names: list[str], +) -> tuple[list[str], int, int]: + """Calculates tensor bounds and returns tensors to load, min_start and max_end offsets.""" + min_start = float("inf") + max_end = 0 + + tensors_to_load = [] + for t_name in tensor_names: + if t_name == "__metadata__": + continue + if t_name not in header: + continue + tensors_to_load.append(t_name) + start, end = header[t_name]["data_offsets"] + if start < min_start: + min_start = start + if end > max_end: + max_end = end + if not tensors_to_load: + return [], 0, 0 + return tensors_to_load, int(min_start), int(max_end) + + +def _process_data_bytes( + data_bytes: bytes, + header: dict[str, Any], + tensor_names: list[str], + min_start_offset: int, +) -> dict[str, np.ndarray]: + """Extracts tensors from data bytes.""" + tensors = {} + data_mv = memoryview(data_bytes) + for name in tensor_names: + if name == "__metadata__": + continue + shape, dtype = _get_array_properties(header[name]) + start_offset, end_offset = header[name]["data_offsets"] + tensor_bytes = data_mv[ + start_offset - min_start_offset : end_offset - min_start_offset + ] + np_array = np.frombuffer(tensor_bytes, dtype=dtype).reshape(shape) + if not np.isfinite(np_array).all(): + raise ValueError(f"Non-finite values found in tensor {name}.") + tensors[name] = np_array + return tensors + + @dataclasses.dataclass class _LoadContext: host_id: int @@ -438,24 +492,127 @@ async def _read_bundle( bundle_start_offset = 0 return bundle_bytes, bundle_start_offset - async def load_single_host(self) -> dict[str, np.ndarray]: + def _read_single_chunk( + self, + gcs_path_str: str, + chunk_data: tuple[int, int] + ) -> bytes: + """Reads a single chunk of data from a GCS file.""" + chunk_size, offset = chunk_data + with open(gcs_path_str, "rb") as f: + bytes_read = 0 + chunk_pieces = [] + while bytes_read < chunk_size: + piece = os.pread( + f.fileno(), chunk_size - bytes_read, offset + bytes_read + ) + if not piece: + raise EOFError( + f"Unexpected end of file at offset {offset + bytes_read} " + f"in file {gcs_path_str}. Expected {chunk_size} bytes, " + f"got {bytes_read}." + ) + chunk_pieces.append(piece) + bytes_read += len(piece) + return b"".join(chunk_pieces) + + async def load_single_host_gcs( + self, + *, + data_start_offset: int, + min_start: int, + max_end: int, + ) -> bytes: + """Downloads tensors from Google Cloud Storage using high-bandwidth parallel reads. + + This method uses `os.pread` with a thread pool to achieve high-bandwidth + parallel downloads from GCS via gcsfuse. It first calculates the bounding + box of the required tensor data and then reads chunks within that range. + + Args: + data_start_offset: The offset where the tensor data begins. + min_start: The minimum start offset of the tensors to load. + max_end: The maximum end offset of the tensors to load. + + Returns: + A bytes object containing the loaded tensor data. + + Raises: + EOFError: If the file is truncated or reading fails unexpectedly. + ValueError: If non-finite values are found in a loaded tensor. + """ + + gcs_path_str = gcs_utils.to_gcsfuse_path(self.path) + if not os.path.exists(gcs_path_str): + _, blob_name = gcs_utils.parse_gcs_path(gcs_utils.to_gs_path(self.path)) + blob_name = blob_name.rstrip("/") + safe_temp_name = blob_name.replace("/", "_") + ram_disk_path = f"/dev/shm/{safe_temp_name}_temp.bin" + subprocess.run( + [ + "gcloud", + "storage", + "cp", + gcs_utils.to_gs_path(self.path), + ram_disk_path, + ], + check=True, + ) + with open(ram_disk_path, "rb") as f: + with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm: + data_bytes = mm[ + data_start_offset + min_start : data_start_offset + max_end + ] + os.remove(ram_disk_path) + return data_bytes + chunk_size = 1024 * 1024 * 1024 + max_workers = 16 + offset = data_start_offset + min_start + length = max_end - min_start + chunks = [] + bytes_read = 0 + while bytes_read < length: + current_chunk_size = min(chunk_size, length - bytes_read) + current_offset = offset + bytes_read + chunks.append((current_chunk_size, current_offset)) + bytes_read += current_chunk_size + + # 2. Execute the parallel reads + with futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + read_chunks = executor.map( + lambda chunk: self._read_single_chunk(gcs_path_str, chunk), chunks + ) + + data_bytes = b"".join(read_chunks) + return data_bytes + + async def load_single_host( + self, abstract_pytree: dict[str, Any] | None + ) -> dict[str, np.ndarray]: """Loads tensors from a safetensors file into host NumPy arrays.""" header, data_start_offset = await self.read_header() - tensors = {} - async with async_path.open_file(self.path, mode="rb") as f: - await f.seek(data_start_offset) - data_bytes = await f.read() - for name, info in header.items(): - if name == "__metadata__": - continue - shape, dtype = _get_array_properties(info) - start_offset, end_offset = info["data_offsets"] - tensor_bytes = data_bytes[start_offset:end_offset] - np_array = np.frombuffer(tensor_bytes, dtype=dtype).reshape(shape) - if not np.isfinite(np_array).all(): - raise ValueError(f"Non-finite values found in tensor {name}.") - tensors[name] = np_array - return tensors + if abstract_pytree is None: + tensor_names = list(header.keys()) + else: + tensor_names = list(abstract_pytree.keys()) + tensors_to_load, min_start, max_end = _get_tensor_bounds( + header, tensor_names + ) + if not tensors_to_load: + return {} + if gcs_utils.is_gcs_path(self.path): + data_bytes = await self.load_single_host_gcs( + data_start_offset=data_start_offset, + min_start=min_start, + max_end=max_end, + ) + else: + async with async_path.open_file(self.path, mode="rb") as f: + await f.seek(data_start_offset + min_start) + data_bytes = await f.read(max_end - min_start) + return _process_data_bytes( + data_bytes, header, tensors_to_load, min_start + ) async def load_multi_host( self, abstract_pytree: dict[str, Any] @@ -585,7 +742,7 @@ async def _load_single_host(self, abstract_pytree: dict[str, Any]) -> Any: start = time.time() load_ops = [] for loader in await self._get_loaders(): - load_ops.append(loader.load_single_host()) + load_ops.append(loader.load_single_host(abstract_pytree)) restored_pytree = {} for file_tensors in await asyncio.gather(*load_ops): diff --git a/checkpoint/orbax/checkpoint/experimental/v1/converter_lib_safetensors.py b/checkpoint/orbax/checkpoint/experimental/v1/converter_lib_safetensors.py new file mode 100644 index 000000000..d0837c186 --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/v1/converter_lib_safetensors.py @@ -0,0 +1,429 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Converter for Safetensors checkpoint format to native Orbax format.""" + +import asyncio +import collections +from concurrent import futures +import json +import time +from typing import Any, Dict, Sequence, Tuple, cast + +from absl import app +from absl import flags +from absl import logging +from etils import epath +import jax +import numpy as np +from orbax.checkpoint.experimental import v1 as ocp +from orbax.checkpoint.experimental.v1._src.layout import safetensors_layout +from tensorflow.io import gfile + + +Path = ocp.path.Path +ThreadPoolExecutor = futures.ThreadPoolExecutor + + +_INPUT_DIR = flags.DEFINE_string( + 'input_dir', None, 'Directory containing Safetensors files.' +) +_OUTPUT_DIR = flags.DEFINE_string( + 'output_dir', None, 'Directory to save the converted Orbax checkpoint.' +) +_MAX_BATCH_SIZE_MB = flags.DEFINE_integer( + 'max_batch_size_mb', 10240, 'Maximum size of a single batch in MB.' +) +_WORKERS = flags.DEFINE_integer( + 'workers', + 1, + 'Number of worker threads to use for concurrent processing of batches.', +) +_THREADS = flags.DEFINE_integer( + 'threads', + 1, + 'Maximum number of threads to use for concurrent processing of batches.', +) +_USE_FILE_LOGGER_ONLY = flags.DEFINE_boolean( + 'use_file_logger_only', + True, + 'If True, only logs from this file are printed to avoid excessive logging ' + 'from other modules. If False, enables global INFO logging.', +) +_SAVING_ENABLED = flags.DEFINE_boolean( + 'saving_enabled', + False, + 'If True, saves the converted Orbax checkpoint to the output directory.', +) + + +def _log_info(msg: str, *args): + """Designated method for internal logging with volume control.""" + if _USE_FILE_LOGGER_ONLY.value: + print('INFO: ' + (msg % args if args else msg)) + else: + logging.info(msg, *args) + + +def _log_error(msg: str, *args): + """Designated method for internal error logging with volume control.""" + if _USE_FILE_LOGGER_ONLY.value: + print('ERROR: ' + str(msg % args if args else msg)) + else: + logging.error(msg, *args) + + +async def create_loading_plan( + paths: Sequence[Path], + abstract_pytree: dict[str, Any] | None = None, +) -> Tuple[list[dict[str, Any]], list[float], dict[str, Any]]: + """Creates a plan for loading tensors in batches based on memory limits.""" + + async def _fetch_header(p): + h, _ = await safetensors_layout._SingleFileLoader(p).read_header() # pylint: disable=protected-access + return p, h + + if abstract_pytree is None: + abstract_pytree = {} + total_tensors = 0 + for file_path in paths: + _, header = await _fetch_header(file_path) + for tensor in header: + if tensor != '__metadata__': + total_tensors += 1 + shape, dtype = safetensors_layout._get_array_properties(header[tensor]) # pylint: disable=protected-access + abstract_pytree[tensor] = jax.ShapeDtypeStruct( + shape=shape, + dtype=dtype, + ) + + all_headers = await asyncio.gather(*[_fetch_header(p) for p in paths]) + batches = [] + batches_size = [] + current_batch = {} + current_batch_size = 0 + max_bytes = _MAX_BATCH_SIZE_MB.value * 1024 * 1024 + total_tensors = 0 + for _, header in all_headers: + for tensor_name, leaf_meta in header.items(): + if tensor_name in abstract_pytree: + total_tensors += 1 + if 'shape' in leaf_meta and 'dtype' in leaf_meta: + shape, dtype = safetensors_layout._get_array_properties(leaf_meta) # pylint: disable=protected-access + dtype_size = np.dtype(dtype).itemsize + size = np.prod(shape) * dtype_size + if current_batch_size + size > max_bytes: + logging.info( + 'Batch size: %.2f MB', current_batch_size / (1024 * 1024) + ) + total_tensors = 0 + batches.append(current_batch) + batches_size.append(current_batch_size / (1024 * 1024)) + current_batch = {} + current_batch_size = 0 + current_batch[tensor_name] = jax.ShapeDtypeStruct( + shape=shape, + dtype=dtype, + ) + current_batch_size += size + if current_batch: + logging.info('Batch size: %.2f MB', current_batch_size / (1024 * 1024)) + batches.append(current_batch) + batches_size.append(current_batch_size / (1024 * 1024)) + return batches, batches_size, abstract_pytree + + +async def get_tensor_to_path_indexing(paths: Sequence[Path]) -> dict[str, Path]: + """Returns a mapping from tensor name to safetensors file.""" + + file_to_path = {} + for file_ in paths: + file_to_path[str(Path(file_).name)] = file_ + + path_ = Path(str(paths[0].parent) + '/model.safetensors.index.json') + + tensor_to_path = {} + if not await asyncio.to_thread(path_.exists): + for path in paths: + header, _ = await safetensors_layout._SingleFileLoader(path).read_header() # pylint: disable=protected-access + for name in header: + if name == '__metadata__': + continue + if name in tensor_to_path: + raise ValueError(f'Duplicate tensor {name} found in multiple files.') + tensor_to_path[name] = path + return tensor_to_path + raw_data = await asyncio.to_thread(path_.read_bytes) + index_data = json.loads(raw_data) + + for name, path in index_data['weight_map'].items(): + if name in tensor_to_path: + raise ValueError(f'Duplicate tensor {name} found in multiple files.') + tensor_to_path[name] = file_to_path[str(path)] + return tensor_to_path + + +def analyze_model_structure(metadata_tree: Any) -> None: + """Logs detailed statistics about the model structure and expected size.""" + flat_metadata, _ = jax.tree_util.tree_flatten_with_path(metadata_tree) + + total_params = 0 + total_bytes = 0 + dtype_counts = {} + total_tensors = 0 + + for _, leaf in flat_metadata: + # Get path string (e.g., "model/layers/0/self_attn/q_proj") + if hasattr(leaf, 'shape') and hasattr(leaf, 'dtype'): + # Calculate size for this specific tensor + shape = leaf.shape + dtype = leaf.dtype + param_count = np.prod(shape) + byte_size = param_count * np.dtype(dtype).itemsize + + # Update totals + total_params += param_count + total_bytes += byte_size + total_tensors += 1 + + # Track dtype distribution + dtype_str = str(dtype) + dtype_counts[dtype_str] = dtype_counts.get(dtype_str, 0) + 1 + + total_gb = total_bytes / (1024**3) + total_mb = total_bytes / (1024**2) + + _log_info('Total Tensors count: %d', total_tensors) + _log_info('Total Parameters: %s', f'{total_params:,}') + _log_info( + 'Expected Raw Size (Uncompressed): %.2f MB (%.4f GB)', + total_mb, + total_gb, + ) + _log_info('Dtype Distribution: %s', dtype_counts) + + +async def _execute_batch_async( + file_abstract_trees: dict[str, Any], + batch_index: int, + num_batches: int, + batch_size: float, +) -> Tuple[dict[str, Any], int]: + """Loads and saves a single batch.""" + load_start_time = time.time() + _log_info( + '\033[1m[VM %d] Processing batch %d/%d...\033[0m', + jax.process_index(), + batch_index + 1, + num_batches, + ) + _log_info( + '[VM %d] [Batch %d] Loading into Host RAM from Safetensors, starting at' + ' %.2f seconds, total files: %d | Batch Size: %.2f MB', + jax.process_index(), + batch_index + 1, + load_start_time, + len(file_abstract_trees), + batch_size, + ) + loaded_tensors = {} + for path, abstract_pytree in file_abstract_trees.items(): + _log_info( + '[VM %d] [Batch %d] Loading path %s with %d tensors', + jax.process_index(), + batch_index + 1, + path, + len(abstract_pytree), + ) + current_load_start_time = time.time() + with ocp.Context( + checkpoint_layout=ocp.options.CheckpointLayout.SAFETENSORS + ): + loaded = ocp.load_pytree_async( + path=Path(path), + abstract_pytree=abstract_pytree, + checkpointable_name=None, + ) + current_loaded_tensors = loaded.result() + current_loaded_tensors = cast(Dict[str, Any], current_loaded_tensors) + loaded_tensors.update(current_loaded_tensors) + current_load_size = sum(t.nbytes for t in current_loaded_tensors.values()) + _log_info( + '[VM %d] [Batch %d] Loaded path %s with %d tensors | Load Time: %.2f' + ' seconds | Load Size: %.2f MB' + ' | Load Throughput: %.2f MB/s', + jax.process_index(), + batch_index + 1, + path, + len(current_loaded_tensors), + time.time() - current_load_start_time, + current_load_size / (1024 * 1024), + current_load_size + / (1024 * 1024) + / (time.time() - current_load_start_time), + ) + del current_loaded_tensors + load_end_time = time.time() + total_loaded_size = sum(t.nbytes for t in loaded_tensors.values()) + _log_info( + '[VM %d] [Batch %d] Loaded Batch | Load Time: %.2f seconds | Load Size:' + ' %.2f MB | Load Throughput: %.2f MB/s', + jax.process_index(), + batch_index + 1, + load_end_time - load_start_time, + total_loaded_size / (1024 * 1024), + total_loaded_size / (1024 * 1024) / (time.time() - load_start_time), + ) + return loaded_tensors, batch_index + + +def _execute_batch( + file_abstract_trees: dict[str, Any], + batch_index: int, + num_batches: int, + batch_size: float, +) -> Tuple[dict[str, Any], int]: + """Sync wrapper for _execute_batch_async to run in ThreadPoolExecutor.""" + return asyncio.run( + _execute_batch_async( + file_abstract_trees, + batch_index, + num_batches, + batch_size, + ) + ) + + +async def _load_safetensors(path: Path) -> Dict[str, Any]: + """Calls the correct safetensors loading function.""" + paths = list(path.glob('*.safetensors')) + tensor_to_path = await get_tensor_to_path_indexing(paths) + batches, batch_sizes, abstract_pytree = await create_loading_plan(paths) + batches_to_load = [] + for batch_abstract_pytree in batches: + file_abstract_trees = collections.defaultdict(dict) + for tensor_name in batch_abstract_pytree: + path = tensor_to_path[tensor_name] + file_abstract_trees[path][tensor_name] = abstract_pytree[tensor_name] + print(f'file_abstract_trees: {len(file_abstract_trees)} files') + batches_to_load.append(file_abstract_trees) + + total_load_size = 0 + all_loaded_tensors = {} + num_batches = len(batches_to_load) + with ThreadPoolExecutor(max_workers=_THREADS.value) as pool: + results = pool.map( + _execute_batch, + batches_to_load, + range(0, num_batches), + [num_batches] * num_batches, + batch_sizes, + ) + for r in results: + loaded_tensors, batch_index = r + batch_size = batch_sizes[batch_index] + all_loaded_tensors.update(loaded_tensors) + del loaded_tensors + total_load_size += batch_size + return all_loaded_tensors + + +async def run_cpu_batching(input_dir: str, output_dir: str): + """Orchestrates the metadata sizing, planning, and batch execution loop.""" + # Only VM 0 cleans up the directory before the distributed run starts + + input_path = epath.Path(input_dir) + + _log_info('=' * 60) + _log_info( + f'🔎 STARTING CONVERSION FROM SAFETENSORS TO ORBAX for {input_path}' + ) + _log_info('Output directory: %s', output_dir) + _log_info('=' * 60) + _log_info('Conversion will be done in following steps.') + _log_info('step 1: Reading Safetensors Metadata') + _log_info('step 2: Analyzing Model Structure') + _log_info('step 3: Executing Loading Loop') + _log_info('step 4: Finalizing Native Orbax Checkpoint') + + _log_info('---------- Step 1: Reading Safetensors Metadata -----------') + metadata_start_time = time.time() + with ocp.Context(checkpoint_layout=ocp.options.CheckpointLayout.SAFETENSORS): + metadata_ckpt = ocp.pytree_metadata(input_path) + metadata_end_time = time.time() + _log_info( + 'Safetensors Metadata Time: %.2f seconds', + metadata_end_time - metadata_start_time, + ) + + if 'pytree' in metadata_ckpt.metadata: + model_metadata = metadata_ckpt.metadata['pytree'] + else: + model_metadata = metadata_ckpt.metadata + + _log_info('------------ Step 2: Analyzing Model Structure --------------') + analyze_start_time = time.time() + analyze_model_structure(model_metadata) + analyze_end_time = time.time() + _log_info( + 'Model Structure Analysis Time: %.2f seconds', + analyze_end_time - analyze_start_time, + ) + + _log_info('------------- Step 3: Executing Loading --------------------') + load_start_time = time.time() + loaded_tensors = await _load_safetensors(input_path) + total_load_size = sum(t.nbytes for t in loaded_tensors.values()) + total_load_size = total_load_size / (1024 * 1024) + _log_info( + '\033[1mFinal load size: %.2f MB | Final load time: %.2f seconds |' + ' Final load throughput: %.2f MB/s \033[0m', + total_load_size, + time.time() - load_start_time, + (total_load_size / (time.time() - load_start_time)), + ) + if _SAVING_ENABLED.value: + _log_info( + '------------ Step 4: Saving Native Orbax Checkpoint --------------' + ) + if gfile.exists(output_dir): + _log_info('Removing existing checkpoint directory: %s', output_dir) + gfile.rmtree(output_dir) + save_start_time = time.time() + with ocp.Context(checkpoint_layout=ocp.options.CheckpointLayout.ORBAX): + ocp.save_checkpointables( + path=output_dir, + checkpointables={'pytree': loaded_tensors}, + ) + total_save_time = time.time() - save_start_time + _log_info( + '\033[1mSaved Native Orbax Checkpoint Time: %.2f seconds\033[0m', + total_save_time, + ) + + +def main(argv: Sequence[str]) -> None: + if len(argv) > 1: + raise app.UsageError('Too many command-line arguments.') + if not _INPUT_DIR.value or not _OUTPUT_DIR.value: + raise app.UsageError('--input_dir and --output_dir must be provided.') + + if not _USE_FILE_LOGGER_ONLY.value: + logging.set_stderrthreshold('INFO') + + asyncio.run(run_cpu_batching(_INPUT_DIR.value, _OUTPUT_DIR.value)) + + +if __name__ == '__main__': + app.run(main) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/run_script.sh b/checkpoint/orbax/checkpoint/experimental/v1/run_script.sh new file mode 100755 index 000000000..637a6526d --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/v1/run_script.sh @@ -0,0 +1,48 @@ + +#!/bin/bash + +# 1. Unmount and clean up (Always runs to ensure a clean state) +echo "Cleaning up /gcs..." +sudo fusermount -u /gcs 2>/dev/null || echo "/gcs not mounted" +sudo rm -rf /gcs + +# 2. Drop Caches +echo "Dropping system caches..." +sudo sh -c 'sync; echo 3 > /proc/sys/vm/drop_caches' + +# 3. GCS Mount +echo "Installing gcsfuse dependencies..." +sudo apt-get update -qq > /dev/null 2>&1 +sudo apt-get install -qq -y curl gnupg lsb-release > /dev/null 2>&1 + +echo "Installing gcsfuse..." +export GCSFUSE_REPO=gcsfuse-`lsb_release -c -s` +echo "deb https://packages.cloud.google.com/apt $GCSFUSE_REPO main" | sudo tee /etc/apt/sources.list.d/gcsfuse.list > /dev/null +curl -s https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add - > /dev/null 2>&1 +sudo apt-get update -qq > /dev/null 2>&1 +sudo apt-get install -qq -y gcsfuse > /dev/null 2>&1 + +echo "Recreating /gcs mount point..." +sudo mkdir -p /gcs +sudo chown $USER:$USER /gcs +sudo chmod 755 /gcs + +echo "Mounting GCS bucket..." +gcsfuse --implicit-dirs /gcs +ls -la /gcs/safetensor-kimi-central/ + +echo "Installing git..." +sudo apt-get update -qq > /dev/null 2>&1 +sudo apt-get install -qq -y git > /dev/null 2>&1 +echo "Git installed." + +echo "Resetting Orbax repository..." +cd ~ +rm -rf orbax +git clone https://github.com/google/orbax.git > /dev/null 2>&1 +cd orbax > /dev/null 2>&1 +git fetch origin pull/3147/head:pr-3147 > /dev/null 2>&1 +git checkout pr-3147 > /dev/null 2>&1 +free -h + +echo "Setup complete. we are at $(pwd)"