Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 30 additions & 19 deletions deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@

# DeepSpeed Team

import argparse
import sys
import types
import json
from typing import Optional, Union
from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
Expand All @@ -29,6 +30,8 @@
from .constants import TORCH_DISTRIBUTED_DEFAULT_PORT
from .runtime.engine import DeepSpeedEngine, DeepSpeedOptimizerCallable, DeepSpeedSchedulerCallable
from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER, MUON_OPTIMIZER
from .runtime.base_optimizer import DeepSpeedOptimizer
from .runtime.dataloader import DeepSpeedDataLoader
from .runtime.hybrid_engine import DeepSpeedHybridEngine
from .runtime.pipe.engine import PipelineEngine
from .inference.engine import InferenceEngine
Expand Down Expand Up @@ -68,7 +71,7 @@ def _parse_version(version_str):
dist = None


def set_optimizer_flags(config_class, model):
def set_optimizer_flags(config_class: DeepSpeedConfig, model: torch.nn.Module) -> None:
if config_class.optimizer_name == MUON_OPTIMIZER:
for name, p in model.named_parameters():
if p.ndim >= 2 and not any(keyword in name.lower() for keyword in ("embed", "lm_head")):
Expand All @@ -77,19 +80,21 @@ def set_optimizer_flags(config_class, model):
setattr(p, "use_muon", False)


def initialize(args=None,
model: torch.nn.Module = None,
optimizer: Optional[Union[Optimizer, DeepSpeedOptimizerCallable]] = None,
model_parameters: Optional[torch.nn.Module] = None,
training_data: Optional[torch.utils.data.Dataset] = None,
lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]] = None,
distributed_port: int = TORCH_DISTRIBUTED_DEFAULT_PORT,
mpu=None,
dist_init_required: Optional[bool] = None,
collate_fn=None,
config=None,
mesh_param=None,
config_params=None):
def initialize(
args: Any = None,
model: torch.nn.Module = None,
optimizer: Optional[Union[Optimizer, DeepSpeedOptimizerCallable]] = None,
model_parameters: Optional[torch.nn.Module] = None,
training_data: Optional[torch.utils.data.Dataset] = None,
lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]] = None,
distributed_port: int = TORCH_DISTRIBUTED_DEFAULT_PORT,
mpu: Any = None,
dist_init_required: Optional[bool] = None,
collate_fn: Optional[Callable] = None,
config: Optional[Union[str, Dict[str, Any]]] = None,
mesh_param: Any = None,
config_params: Optional[Union[str, Dict[str, Any]]] = None
) -> Tuple[DeepSpeedEngine, Optional[Union[Optimizer, DeepSpeedOptimizer]], Optional[DeepSpeedDataLoader], Any]:
"""Initialize the DeepSpeed Engine.

Arguments:
Expand Down Expand Up @@ -287,7 +292,7 @@ def _add_core_arguments(parser):
return parser


def add_config_arguments(parser):
def add_config_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
r"""Update the argument parser to enabling parsing of DeepSpeed command line arguments.
The set of DeepSpeed arguments include the following:
1) --deepspeed: boolean flag to enable DeepSpeed
Expand All @@ -303,14 +308,16 @@ def add_config_arguments(parser):
return parser


def default_inference_config():
def default_inference_config() -> Dict[str, Any]:
"""
Return a default DeepSpeed inference configuration dictionary.
"""
return DeepSpeedInferenceConfig().dict()


def init_inference(model, config=None, **kwargs):
def init_inference(model: torch.nn.Module,
config: Optional[Union[str, Dict[str, Any]]] = None,
**kwargs: Any) -> InferenceEngine:
"""Initialize the DeepSpeed InferenceEngine.

Description: all four cases are valid and supported in DS init_inference() API.
Expand Down Expand Up @@ -388,7 +395,11 @@ def init_inference(model, config=None, **kwargs):
return engine


def tp_model_init(model, tp_size, dtype, config=None, **kwargs):
def tp_model_init(model: torch.nn.Module,
tp_size: int,
dtype: torch.dtype,
config: Optional[Union[str, Dict[str, Any]]] = None,
**kwargs: Any) -> torch.nn.Module:
"""
Record tensor-parallel initialization arguments for training.

Expand Down
Loading