From d33fe7ff58b4272d9481389644ceb40f6ce7df4f Mon Sep 17 00:00:00 2001 From: Arun Sharma Date: Fri, 19 Jun 2026 18:34:09 -0500 Subject: [PATCH 1/2] Add type hints to top-level public API functions Annotate the public API entry points in deepspeed/__init__.py (initialize, init_inference, default_inference_config, add_config_arguments, tp_model_init, set_optimizer_flags) with precise parameter and return types, addressing #8074. Scoped first increment for the top-level deepspeed.* public API surface; no runtime behavior changes. Signed-off-by: Arun Sharma --- deepspeed/__init__.py | 48 ++++++++++++++++++++++++++----------------- 1 file changed, 29 insertions(+), 19 deletions(-) diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index 0d53a172e64e..7730aa05afaf 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -3,13 +3,15 @@ # DeepSpeed Team +import argparse import sys import types import json -from typing import Optional, Union +from typing import Any, Callable, Dict, Optional, Tuple, Union import torch from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler +from torch.utils.data import DataLoader from packaging import version as pkg_version # Skip Triton import for AMD due to pytorch-triton-rocm module breaking device API in DeepSpeed @@ -68,7 +70,7 @@ def _parse_version(version_str): dist = None -def set_optimizer_flags(config_class, model): +def set_optimizer_flags(config_class: DeepSpeedConfig, model: torch.nn.Module) -> None: if config_class.optimizer_name == MUON_OPTIMIZER: for name, p in model.named_parameters(): if p.ndim >= 2 and not any(keyword in name.lower() for keyword in ("embed", "lm_head")): @@ -77,19 +79,21 @@ def set_optimizer_flags(config_class, model): setattr(p, "use_muon", False) -def initialize(args=None, - model: torch.nn.Module = None, - optimizer: Optional[Union[Optimizer, DeepSpeedOptimizerCallable]] = None, - model_parameters: Optional[torch.nn.Module] = None, - training_data: Optional[torch.utils.data.Dataset] = None, - lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]] = None, - distributed_port: int = TORCH_DISTRIBUTED_DEFAULT_PORT, - mpu=None, - dist_init_required: Optional[bool] = None, - collate_fn=None, - config=None, - mesh_param=None, - config_params=None): +def initialize( + args: Any = None, + model: torch.nn.Module = None, + optimizer: Optional[Union[Optimizer, DeepSpeedOptimizerCallable]] = None, + model_parameters: Optional[torch.nn.Module] = None, + training_data: Optional[torch.utils.data.Dataset] = None, + lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]] = None, + distributed_port: int = TORCH_DISTRIBUTED_DEFAULT_PORT, + mpu: Any = None, + dist_init_required: Optional[bool] = None, + collate_fn: Optional[Callable] = None, + config: Optional[Union[str, Dict[str, Any]]] = None, + mesh_param: Any = None, + config_params: Optional[Union[str, Dict[str, Any]]] = None +) -> Tuple[DeepSpeedEngine, Optional[Optimizer], Optional[DataLoader], Optional[_LRScheduler]]: """Initialize the DeepSpeed Engine. Arguments: @@ -287,7 +291,7 @@ def _add_core_arguments(parser): return parser -def add_config_arguments(parser): +def add_config_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: r"""Update the argument parser to enabling parsing of DeepSpeed command line arguments. The set of DeepSpeed arguments include the following: 1) --deepspeed: boolean flag to enable DeepSpeed @@ -303,14 +307,16 @@ def add_config_arguments(parser): return parser -def default_inference_config(): +def default_inference_config() -> Dict[str, Any]: """ Return a default DeepSpeed inference configuration dictionary. """ return DeepSpeedInferenceConfig().dict() -def init_inference(model, config=None, **kwargs): +def init_inference(model: torch.nn.Module, + config: Optional[Union[str, Dict[str, Any]]] = None, + **kwargs: Any) -> InferenceEngine: """Initialize the DeepSpeed InferenceEngine. Description: all four cases are valid and supported in DS init_inference() API. @@ -388,7 +394,11 @@ def init_inference(model, config=None, **kwargs): return engine -def tp_model_init(model, tp_size, dtype, config=None, **kwargs): +def tp_model_init(model: torch.nn.Module, + tp_size: int, + dtype: torch.dtype, + config: Optional[Union[str, Dict[str, Any]]] = None, + **kwargs: Any) -> torch.nn.Module: """ Record tensor-parallel initialization arguments for training. From 68e25c1dccf207d7e13989cb633c53085a3f23c2 Mon Sep 17 00:00:00 2001 From: Arun Sharma Date: Fri, 19 Jun 2026 19:15:27 -0500 Subject: [PATCH 2/2] Address review: widen initialize() return annotation to DeepSpeed wrapper types The Codex reviewer correctly noted the return tuple over-claimed PyTorch-only types. deepspeed.initialize() can return DeepSpeed wrappers, not the torch base classes: ZeRO/fp16/bf16 paths set engine.optimizer to a DeepSpeedOptimizer (which subclasses object, not torch.optim.Optimizer); engine.training_dataloader is a DeepSpeedDataLoader (not a torch DataLoader); and JSON-configured schedulers are DeepSpeed objects (not _LRScheduler). Widen the return to Tuple[DeepSpeedEngine, Optional[Union[Optimizer, DeepSpeedOptimizer]], Optional[DeepSpeedDataLoader], Any] so typed callers do not see false errors for valid DeepSpeed configurations. Signed-off-by: Arun Sharma --- deepspeed/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index 7730aa05afaf..51ca2b7cb3d7 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -11,7 +11,6 @@ import torch from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler -from torch.utils.data import DataLoader from packaging import version as pkg_version # Skip Triton import for AMD due to pytorch-triton-rocm module breaking device API in DeepSpeed @@ -31,6 +30,8 @@ from .constants import TORCH_DISTRIBUTED_DEFAULT_PORT from .runtime.engine import DeepSpeedEngine, DeepSpeedOptimizerCallable, DeepSpeedSchedulerCallable from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER, MUON_OPTIMIZER +from .runtime.base_optimizer import DeepSpeedOptimizer +from .runtime.dataloader import DeepSpeedDataLoader from .runtime.hybrid_engine import DeepSpeedHybridEngine from .runtime.pipe.engine import PipelineEngine from .inference.engine import InferenceEngine @@ -93,7 +94,7 @@ def initialize( config: Optional[Union[str, Dict[str, Any]]] = None, mesh_param: Any = None, config_params: Optional[Union[str, Dict[str, Any]]] = None -) -> Tuple[DeepSpeedEngine, Optional[Optimizer], Optional[DataLoader], Optional[_LRScheduler]]: +) -> Tuple[DeepSpeedEngine, Optional[Union[Optimizer, DeepSpeedOptimizer]], Optional[DeepSpeedDataLoader], Any]: """Initialize the DeepSpeed Engine. Arguments: