diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index 0d53a172e64e..51ca2b7cb3d7 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -3,10 +3,11 @@ # DeepSpeed Team +import argparse import sys import types import json -from typing import Optional, Union +from typing import Any, Callable, Dict, Optional, Tuple, Union import torch from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler @@ -29,6 +30,8 @@ from .constants import TORCH_DISTRIBUTED_DEFAULT_PORT from .runtime.engine import DeepSpeedEngine, DeepSpeedOptimizerCallable, DeepSpeedSchedulerCallable from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER, MUON_OPTIMIZER +from .runtime.base_optimizer import DeepSpeedOptimizer +from .runtime.dataloader import DeepSpeedDataLoader from .runtime.hybrid_engine import DeepSpeedHybridEngine from .runtime.pipe.engine import PipelineEngine from .inference.engine import InferenceEngine @@ -68,7 +71,7 @@ def _parse_version(version_str): dist = None -def set_optimizer_flags(config_class, model): +def set_optimizer_flags(config_class: DeepSpeedConfig, model: torch.nn.Module) -> None: if config_class.optimizer_name == MUON_OPTIMIZER: for name, p in model.named_parameters(): if p.ndim >= 2 and not any(keyword in name.lower() for keyword in ("embed", "lm_head")): @@ -77,19 +80,21 @@ def set_optimizer_flags(config_class, model): setattr(p, "use_muon", False) -def initialize(args=None, - model: torch.nn.Module = None, - optimizer: Optional[Union[Optimizer, DeepSpeedOptimizerCallable]] = None, - model_parameters: Optional[torch.nn.Module] = None, - training_data: Optional[torch.utils.data.Dataset] = None, - lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]] = None, - distributed_port: int = TORCH_DISTRIBUTED_DEFAULT_PORT, - mpu=None, - dist_init_required: Optional[bool] = None, - collate_fn=None, - config=None, - mesh_param=None, - config_params=None): +def initialize( + args: Any = None, + model: torch.nn.Module = None, + optimizer: Optional[Union[Optimizer, DeepSpeedOptimizerCallable]] = None, + model_parameters: Optional[torch.nn.Module] = None, + training_data: Optional[torch.utils.data.Dataset] = None, + lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]] = None, + distributed_port: int = TORCH_DISTRIBUTED_DEFAULT_PORT, + mpu: Any = None, + dist_init_required: Optional[bool] = None, + collate_fn: Optional[Callable] = None, + config: Optional[Union[str, Dict[str, Any]]] = None, + mesh_param: Any = None, + config_params: Optional[Union[str, Dict[str, Any]]] = None +) -> Tuple[DeepSpeedEngine, Optional[Union[Optimizer, DeepSpeedOptimizer]], Optional[DeepSpeedDataLoader], Any]: """Initialize the DeepSpeed Engine. Arguments: @@ -287,7 +292,7 @@ def _add_core_arguments(parser): return parser -def add_config_arguments(parser): +def add_config_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: r"""Update the argument parser to enabling parsing of DeepSpeed command line arguments. The set of DeepSpeed arguments include the following: 1) --deepspeed: boolean flag to enable DeepSpeed @@ -303,14 +308,16 @@ def add_config_arguments(parser): return parser -def default_inference_config(): +def default_inference_config() -> Dict[str, Any]: """ Return a default DeepSpeed inference configuration dictionary. """ return DeepSpeedInferenceConfig().dict() -def init_inference(model, config=None, **kwargs): +def init_inference(model: torch.nn.Module, + config: Optional[Union[str, Dict[str, Any]]] = None, + **kwargs: Any) -> InferenceEngine: """Initialize the DeepSpeed InferenceEngine. Description: all four cases are valid and supported in DS init_inference() API. @@ -388,7 +395,11 @@ def init_inference(model, config=None, **kwargs): return engine -def tp_model_init(model, tp_size, dtype, config=None, **kwargs): +def tp_model_init(model: torch.nn.Module, + tp_size: int, + dtype: torch.dtype, + config: Optional[Union[str, Dict[str, Any]]] = None, + **kwargs: Any) -> torch.nn.Module: """ Record tensor-parallel initialization arguments for training.