Skip to content

Multi-LoRA SFT support FSDP2 #155

Open
kevssim wants to merge 19 commits intomodelscope:mainfrom
kevssim:multilora_fsdp
Open

Multi-LoRA SFT support FSDP2 #155
kevssim wants to merge 19 commits intomodelscope:mainfrom
kevssim:multilora_fsdp

Conversation

@kevssim
Copy link
Copy Markdown
Collaborator

@kevssim kevssim commented Apr 14, 2026

PR type

  • Bug Fix
  • New Feature
  • Document Updates
  • More Models or Datasets Support

PR information

Multi-LoRA support FSDP2

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements FSDP2 support for MultiLoraTransformersModel by integrating it into the shared strategy and lazy-wrap lifecycle and introducing sharding-aware parameter access helpers. Review feedback identifies critical bugs in the distributed tensor handling: _write_param_tensor may incorrectly double-shard local data, set_state_dict risks shape mismatches when applying global state to local shards, and get_state_dict returns sharded tensors that could lead to corrupt checkpoints. Furthermore, the model's initialization should be refactored to properly use the parent class, and internal imports should be moved to the module level.

Comment thread src/twinkle/model/multi_lora.py
Comment thread src/twinkle/model/multi_lora.py Outdated
Comment thread src/twinkle/model/multi_lora.py
Comment thread src/twinkle/model/transformers/multi_lora_transformers.py
Comment thread src/twinkle/model/multi_lora.py Outdated
@kevssim kevssim changed the title Multi-LoRA SFT support FSDP2 [WIP] Multi-LoRA SFT support FSDP2 Apr 16, 2026
@xichengpro
Copy link
Copy Markdown
Contributor

I'd love to have this feature! Just curious — why was this PR changed to draft? Any other plans in the works?

@kevssim
Copy link
Copy Markdown
Collaborator Author

kevssim commented Apr 21, 2026

I'd love to have this feature! Just curious — why was this PR changed to draft? Any other plans in the works?

cause working in progress, when finished, will merge into branch main

@kevssim kevssim marked this pull request as ready for review May 9, 2026 06:22
@kevssim kevssim changed the title [WIP] Multi-LoRA SFT support FSDP2 Multi-LoRA SFT support FSDP2 May 9, 2026
@kevssim
Copy link
Copy Markdown
Collaborator Author

kevssim commented May 9, 2026

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements distributed training support for MultiLora, introducing helper methods for tensor sharding and updating model lifecycle methods to handle distributed contexts. The review identified critical issues where local shards are incorrectly processed as global tensors, potentially leading to corrupted weights during sharding and incomplete state dicts. Feedback emphasizes the need to gather tensors before saving or returning them to ensure compatibility with standard loaders and correct distributed behavior.

return
value = value.detach().to(dtype=parameter.dtype)
if hasattr(parameter, 'device_mesh') and hasattr(parameter, 'placements'):
value = distribute_tensor(value.to(parameter.device), parameter.device_mesh, parameter.placements)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The use of distribute_tensor here assumes that value is a global tensor that needs to be sharded according to the device_mesh and placements. However, in several call sites (like set_state_dict and _load_initial_weights), the value passed to _write_param_tensor is derived from _read_param_tensor, which returns a local shard. Calling distribute_tensor on a local shard will incorrectly attempt to shard the shard again, leading to incorrect parameter values in distributed training.

Comment on lines +635 to +641
target_tensor = self._read_param_tensor(parameter)
if target_tensor is None:
continue
target_tensor = target_tensor.clone()
src_tensor = state_dict[state_key].to(dtype=target_tensor.dtype, device=target_tensor.device)
self._copy_rank_tensor(name, target_tensor, src_tensor)
self._write_param_tensor(parameter, target_tensor)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a mismatch between local and global tensors here. target_tensor is a local shard (from _read_param_tensor), while src_tensor is a global tensor (from state_dict).

  1. _copy_rank_tensor (line 640) will fail with a shape mismatch if the dimension being copied is sharded by FSDP (e.g., num_embeddings for embeddings or out_features for linear layers).
  2. _write_param_tensor (line 641) will then attempt to shard this local shard again as discussed in the previous comment.

To fix this, you should either shard src_tensor to match the local shard's placements before copying, or perform the copy on global tensors (on CPU) and then use _write_param_tensor to shard the result.

Comment on lines +658 to 662
_param = self._slice_rank_tensor(name, self._read_param_tensor(parameter), _lora.tenant_config.r)
if _param is None:
continue
name = name.replace(f'.{_lora.adapter_name}.', '.')
state_dict[name] = _param
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In a distributed setting with FSDP2, _read_param_tensor returns a local shard. Consequently, get_state_dict returns a dictionary of sharded tensors. Since this method is decorated with @remote_function(collect='first') in the model class, only rank 0's local shards will be returned to the caller. This results in an incomplete and unusable state dict for the LoRA adapter. You should gather the shards into a global tensor before slicing and returning them.

Comment on lines +619 to 623
_param = self._slice_rank_tensor(name, self._read_param_tensor(parameter), _lora.tenant_config.r)
if _param is not None:
_param = _param.clone()
name = name.replace(f'.{_lora.adapter_name}.', '.')
return name, _param
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to get_state_dict, save_lora_converter is returning a local shard of the parameter. When saving the model, this will result in a checkpoint containing sharded LoRA weights, which is incompatible with standard PEFT loaders. LoRA weights are typically small enough to be gathered and saved as full tensors even in FSDP environments.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants