Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 86 additions & 37 deletions src/twinkle/model/multi_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dataclasses import dataclass, field
from peft import LoraConfig, PeftModel, get_peft_model
from peft.tuners.lora import Embedding, Linear, LoraLayer
from torch.distributed.tensor import distribute_tensor
from types import MethodType
from typing import Any, Callable, Dict, List, Optional, Union

Expand Down Expand Up @@ -42,6 +43,69 @@ def _get_available_lora(self) -> Optional[LoraTenant]:
return _lora
return None

def _read_param_tensor(self, parameter):
return torch_util.to_local_tensor(parameter)

@staticmethod
def _is_distributed_param(parameter):
return hasattr(parameter, 'device_mesh') and hasattr(parameter, 'placements')

def _write_param_tensor(self, parameter, value):
if value is None:
return
value = value.detach().to(dtype=parameter.dtype)
if self._is_distributed_param(parameter):
if self._is_distributed_param(value):
parameter.data.copy_(value.to(parameter.device))
return

local_param = parameter.to_local() if hasattr(parameter, 'to_local') else None
parameter_shape = tuple(parameter.shape)
value_shape = tuple(value.shape)
if local_param is not None:
local_shape = tuple(local_param.shape)
if value_shape == local_shape and value_shape != parameter_shape:
local_param.copy_(value.to(local_param.device))
return
if value_shape != parameter_shape:
raise ValueError(
f'Cannot write tensor with shape {value_shape} to distributed parameter with global shape '
f'{parameter_shape} and local shape {local_shape}')
value = distribute_tensor(value.to(parameter.device), parameter.device_mesh, parameter.placements)
Comment thread
kevssim marked this conversation as resolved.
else:
value = value.to(parameter.device)
parameter.data.copy_(value)
Comment thread
kevssim marked this conversation as resolved.

@staticmethod
def _slice_rank_tensor(name: str, tensor, rank: int):
if tensor is None:
return None
if 'embedding_A' in name:
return tensor[:, :rank]
if 'embedding_B' in name:
return tensor[:rank, :]
if '_A' in name:
return tensor[:rank, :]
if '_B' in name:
return tensor[:, :rank]
return tensor

@staticmethod
def _copy_rank_tensor(name: str, target, value):
if target is None or value is None:
return None
if 'embedding_A' in name:
target[:, :value.shape[1]].copy_(value)
elif 'embedding_B' in name:
target[:value.shape[0], :].copy_(value)
elif '_A' in name:
target[:value.shape[0], :].copy_(value)
elif '_B' in name:
target[:, :value.shape[1]].copy_(value)
else:
target.copy_(value)
return target

def _count_available_loras(self):
return len([_lora for _lora in self.loras if _lora.tenant_adapter_name is None])

Expand Down Expand Up @@ -472,7 +536,7 @@ def save_initial_weights(self):
def _store_weights(_module):
for name, parameter in _module.named_parameters():
if pattern.search(name):
lora_tenant.lora_A_weights[name] = parameter.data.clone().to('cpu')
lora_tenant.lora_A_weights[name] = self._read_param_tensor(parameter).clone().to('cpu')

if isinstance(self.module, list):
for _module in self.module:
Expand Down Expand Up @@ -572,17 +636,9 @@ def save_lora_converter(self, name, parameter, adapter_name):
# patching makes the bridge skip non-target modules entirely), so we
# only check the adapter-name / weight pattern here.
if re.search(rf'\.lora_\w+\.({adapter_name}|weight)', name):
_param = torch_util.to_local_tensor(parameter)
if _param is None:
pass
elif 'embedding_A' in name:
_param = _param[:, :_lora.tenant_config.r].clone()
elif 'embedding_B' in name:
_param = _param[:_lora.tenant_config.r, :].clone()
elif '_A' in name:
_param = _param[:_lora.tenant_config.r, :].clone()
elif '_B' in name:
_param = _param[:, :_lora.tenant_config.r].clone()
_param = self._slice_rank_tensor(name, self._read_param_tensor(parameter), _lora.tenant_config.r)
if _param is not None:
_param = _param.clone()
name = name.replace(f'.{_lora.adapter_name}.', '.')
return name, _param
Comment thread
kevssim marked this conversation as resolved.
else:
Expand All @@ -595,20 +651,14 @@ def set_state_dict(self, tenant_adapter_name, state_dict):
def _load_weights(_module):
for name, parameter in _module.named_parameters():
if pattern.search(name) and self.match_target_modules(name, _lora.tenant_config.target_modules):
name = name.replace(f'.{_lora.adapter_name}.', '.')
src_tensor = state_dict[name]
if 'embedding_A' in name:
r_saved = src_tensor.shape[1]
parameter.data[:, :r_saved].copy_(src_tensor)
elif 'embedding_B' in name:
r_saved = src_tensor.shape[0]
parameter.data[:r_saved, :].copy_(src_tensor)
elif '_A' in name:
r_saved = src_tensor.shape[0]
parameter.data[:r_saved, :].copy_(src_tensor)
elif '_B' in name:
r_saved = src_tensor.shape[1]
parameter.data[:, :r_saved].copy_(src_tensor)
state_key = name.replace(f'.{_lora.adapter_name}.', '.')
target_tensor = self._read_param_tensor(parameter)
if target_tensor is None:
continue
target_tensor = target_tensor.clone()
src_tensor = state_dict[state_key].to(dtype=target_tensor.dtype, device=target_tensor.device)
self._copy_rank_tensor(name, target_tensor, src_tensor)
self._write_param_tensor(parameter, target_tensor)
Comment thread
kevssim marked this conversation as resolved.

if isinstance(self.module, list):
for _module in self.module:
Expand All @@ -625,15 +675,9 @@ def _get_weights(_module):
state_dict = {}
for name, parameter in _module.named_parameters():
if pattern.search(name) and self.match_target_modules(name, _lora.tenant_config.target_modules):
_param = torch_util.to_local_tensor(parameter)
if 'embedding_A' in name:
_param = _param[:, :_lora.tenant_config.r]
elif 'embedding_B' in name:
_param = _param[:_lora.tenant_config.r, :]
elif '_A' in name:
_param = _param[:_lora.tenant_config.r, :]
elif '_B' in name:
_param = _param[:, :_lora.tenant_config.r]
_param = self._slice_rank_tensor(name, self._read_param_tensor(parameter), _lora.tenant_config.r)
Comment thread
kevssim marked this conversation as resolved.
if _param is None:
continue
name = name.replace(f'.{_lora.adapter_name}.', '.')
state_dict[name] = _param
Comment thread
kevssim marked this conversation as resolved.
return state_dict
Expand All @@ -653,9 +697,14 @@ def _load_initial_weights(self, origin_adapter_name):
def _load_initial_weights(_module):
for name, parameter in _module.named_parameters():
if pattern_A.search(name):
parameter.data.copy_(_lora.lora_A_weights[name])
local_param = self._read_param_tensor(parameter)
if local_param is not None:
value = _lora.lora_A_weights[name].to(dtype=parameter.dtype, device=local_param.device)
self._write_param_tensor(parameter, value)
if pattern_B.search(name):
parameter.data.copy_(torch.zeros_like(parameter.data).to(parameter.data.dtype))
local_param = self._read_param_tensor(parameter)
if local_param is not None:
self._write_param_tensor(parameter, torch.zeros_like(local_param))

if isinstance(self.module, list):
for _module in self.module:
Expand Down
42 changes: 27 additions & 15 deletions src/twinkle/model/transformers/multi_lora_transformers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import os
import torch.distributed as dist
import transformers
from peft import LoraConfig, PeftConfig, PeftModel, load_peft_weights
from torch.optim import Optimizer
Expand All @@ -15,7 +16,6 @@
from twinkle.metric import Metric
from twinkle.processor import InputProcessor
from ..multi_lora import MultiLora
from .strategy import AccelerateStrategy
from .transformers import OptimizerGroup, TransformersModel


Expand All @@ -29,17 +29,28 @@ def __init__(
config: Optional[PretrainedConfig] = None,
device_mesh: Optional[DeviceMesh] = None,
mixed_precision: Literal['no', 'fp8', 'fp16', 'bf16'] = 'bf16',
strategy: Literal['accelerate', 'native_fsdp'] = 'accelerate',
ddp_config: Dict[str, Any] = None,
fsdp_config: Dict[str, Any] = None,
grad_scaler_config: Dict[str, Any] = None,
memory_efficient_init: bool = False,
Comment thread
kevssim marked this conversation as resolved.
max_loras: int = 5,
max_r: int = 32,
max_length: int = 8192,
target_modules: Union[List[str], str] = 'all-linear',
**kwargs):
assert device_mesh.fsdp_world_size <= 0, f'MultiLora does not support FSDP, current is: {str(device_mesh)}'
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
self._try_init_process_group()
super(PreTrainedModel, self).__init__()
Comment thread
kevssim marked this conversation as resolved.
model_id = HubOperation.download_model(model_id)
self.device_mesh = device_mesh
self.mixed_precision = mixed_precision
self._fsdp_config = dict(fsdp_config or {})
self._ddp_config = ddp_config or {}
self._memory_efficient_init = memory_efficient_init
self._decide_strategy(strategy)
self.grad_scaler_config = grad_scaler_config
if model_id is not None:
model_id = HubOperation.download_model(model_id)
self.model_id = model_id
if config is None:
from transformers import AutoConfig
Expand All @@ -52,24 +63,20 @@ def __init__(
model_cls = AutoModelForCausalLM
if isinstance(model_cls, str):
model_cls = getattr(transformers, model_cls)
self.model = model_cls.from_pretrained(model_id, config=self.hf_config, **kwargs)
self.model_id = model_id
if model_id is None:
self.model = model_cls.from_config(self.hf_config, **kwargs)
else:
with self.strategy.pretrained_load_context():
self.model = model_cls.from_pretrained(model_id, config=self.hf_config, **kwargs)
self.tokenizer_id = kwargs.get('tokenizer_id', self.model_id)
self.device_mesh = device_mesh
self.mixed_precision = mixed_precision
self.grad_scaler_config = grad_scaler_config
self._default_tokenizer = None
self._model_wrapped = False
self.sp_strategy = None
# Initialize expert parallel attributes (required by set_optimizer in TransformersModel)
self._expert_parallel_config = None
self._enable_expert_parallel = False
self._expert_parallel_applied = False
self.optimizer_group: Dict[str, OptimizerGroup] = {}
self.multi_adapter = MultiLora(max_loras=max_loras, max_r=max_r, max_length=max_length)
self.model.gradient_checkpointing_enable()
self.model = self.multi_adapter.patch(self.model, target_modules=target_modules)
self.strategy = AccelerateStrategy(mixed_precision=mixed_precision, device_mesh=None)
self.model = self.strategy.wrap_model(self.model)
self.multi_adapter.save_initial_weights()
# Active group for compatibility with single adapter
self.active_group = None
Expand Down Expand Up @@ -100,7 +107,7 @@ def unregister_mm_forward_hook(self, optimizer_group: OptimizerGroup):
pass

def _lazy_wrap_model(self):
pass
return super()._lazy_wrap_model()

@remote_function(dispatch='slice_dp', collect=collect_tensor_dict)
def forward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], **kwargs):
Expand Down Expand Up @@ -232,7 +239,10 @@ def get_state_dict(self, **kwargs):
def save(self, name, output_dir: Optional[str] = None, interval=1, **kwargs):
self._check_adapter_valid(kwargs.get('adapter_name'))
with self.multi_adapter.save_context(kwargs.get('adapter_name')):
return super().save(name, output_dir, interval, **kwargs)
checkpoint_dir = super().save(name, output_dir, interval, **kwargs)
if dist.is_initialized():
dist.barrier()
return checkpoint_dir

@remote_function()
def load(self, name: str, output_dir: Optional[str] = None, **kwargs):
Expand All @@ -256,6 +266,8 @@ def load(self, name: str, output_dir: Optional[str] = None, **kwargs):

if load_optimizer:
self._restore_training_state(checkpoint_dir, adapter_name=adapter_name)
if dist.is_initialized():
dist.barrier()

@remote_function()
def set_grad_scaler(self, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions src/twinkle/model/transformers/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1106,7 +1106,7 @@ def _restore_training_state(self, checkpoint_dir, *, adapter_name=''):

return trainer_state

@remote_function()
@remote_function(dispatch='all', collect='first', sync=True)
def resume_from_checkpoint(self, checkpoint_dir, *, resume_only_model=False, **kwargs):
adapter_name = kwargs.get('adapter_name', '')

Expand Down Expand Up @@ -1286,7 +1286,7 @@ def _get_trainable_parameters_example(self, adapter_name, model):
trainable_param_names = '\n'.join(trainable_param_names)
return trainable_param_names

@remote_function(execute='first', lazy_collect=False)
@remote_function(dispatch='all', collect='first', lazy_collect=False)
def get_train_configs(self, **kwargs) -> str:
expr = ''
adapter_name = kwargs.pop('adapter_name', self._get_default_group())
Expand Down
Loading