Skip to content

Conversation

@WHoutstanding
Copy link
Contributor

PR Category

other

Description

Add FP32_ONLY_FUNCS = {
torch.nn.functional.softmax,
torch.nn.functional.layer_norm,
torch.nn.functional.group_norm,
torch.nn.functional.batch_norm,
torch.nn.functional.embedding,
torch.exp,
torch.log,
torch.pow,
torch.sigmoid,
torch.tanh,
torch.conv_transpose2d,
} to fix dtype generalization pass

@paddle-bot
Copy link

paddle-bot bot commented Jan 27, 2026

Thanks for your contribution!

@paddle-bot paddle-bot bot added the contributor External developers label Jan 27, 2026
@WHoutstanding
Copy link
Contributor Author

之前的代码逻辑中并没有修改黑名单算子如torch.layer_norm 的dtype,因此torch.layer_norm 报错 expected scalar type Float but found Half 的原因不在于 layer_norm 这个节点本身有没有被改写,而在于传给它的输入参数可能被上游节点转成了f16或Bf16。

@Xreki
Copy link
Collaborator

Xreki commented Jan 28, 2026

这种方式虽然能跑起来,但是你仔细观察变换后的fx.Graph,会发现有些cast是多余的,比如layer_norm的权重,在placeholder之前插入一个fp32 -> fp16cast、在layer_norm之前再插入一个fp16 -> fp32cast。实际上,layer_norm的输入x可以是fp16类型,只是权重必须是fp32类型,torch算子内部实现时会自动将x转换成fp32参与计算,这样能少一些cast kernel。

@WHoutstanding
Copy link
Contributor Author

WHoutstanding commented Jan 28, 2026

#Weights that must remain float32 for numerical stability
FLOAT32_PRESERVED_WEIGHTS = {
"running_mean",
"running_var",
"num_batches_tracked",
"bn_parameters_weight",
"bn_parameters_bias",
"ln_parameters_weight",
"ln_parameters_bias",
}
以torch.layer_norm()为例:
graph_net/torch/sample_pass/dtype_generalizer.py定义了torch.layer_norm()的权重和偏置dtype保留为float32
所以,在创建新的get_attr节点时,torch.layer_norm()的权重和偏置是float32,所以我有点不太明白如果不对输入做dtype转换torch.layer_norm()会error的原因

@Xreki
Copy link
Collaborator

Xreki commented Jan 28, 2026

image

你看错误信息栈,layer_norm的第3、4个参数是to_4to_5,而不是原始的parameter。你打印下fx.Graph,追溯下to_4to_5,发现来源如下:
image

to_4to_5是因为在前面的fp32 -> fp16cast算子得到的,FLOAT32_PRESERVED_WEIGHTS没有起作用。

@WHoutstanding
Copy link
Contributor Author

  1. 新增create_new_kwargs函数修复Error:
    RuntimeError: invalid dtype for bias - should match query's dtype
  2. ayer_norm算子dtype转化修复:
    修改要保留的
    FLOAT32_PRESERVED_WEIGHTS = {
    "running_mean",
    "running_var",
    "num_batches_tracked",
    "batch_norm",
    "layer_norm",
    "group_norm",
    "norm",
    "weight",
    "bias",
    "eps",
    "pos_embed",
    "embedding",
    }
    并在create_placeholder函数中判断如果if self.should_preserve_weight(attr_name)为True,则保持dtype为fp32,测试后to_4和to_5成功,但是在to_384和to_385出现:
%to_384 : [num_users=1] = call_method[target=to](args = (%L_self_modules_mask_decoder_modules_transformer_modules_layers_modules_0_modules_layer_norm4_parameters_weight_, torch.float16), kwargs = {})
%L_self_modules_mask_decoder_modules_transformer_modules_layers_modules_0_modules_layer_norm4_parameters_bias_ : torch.Tensor [num_users=1] = placeholder[target=L_self_modules_mask_decoder_modules_transformer_modules_layers_modules_0_modules_layer_norm4_parameters_bias_]
%to_385 : [num_users=1] = call_method[target=to](args = (%L_self_modules_mask_decoder_modules_transformer_modules_layers_modules_0_modules_layer_norm4_parameters_bias_, torch.float16), kwargs = {})
%L_self_modules_mask_decoder_modules_transformer_modules_layers_modules_0_modules_layer_norm4_eps : torch.Tensor [num_users=1] = placeholder[target=L_self_modules_mask_decoder_modules_transformer_modules_layers_modules_0_modules_layer_norm4_eps]

should_preserve_weight(attr_name)的逻辑是通过判断node.target是否在FLOAT32_PRESERVED_WEIGHTS来保持为fp32,说明FLOAT32_PRESERVED_WEIGHTS无法完全处理。

Copy link
Collaborator

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@WHoutstanding WHoutstanding changed the title Add FP32_ONLY_FUNCS to fix dtype generalization pass Add new_kwargs function to fix dtype generalization pass Jan 28, 2026
@Xreki
Copy link
Collaborator

Xreki commented Jan 28, 2026

FLOAT32_PRESERVED_WEIGHTS当前实现是基于名字匹配,是很可能有问题的。这个实现方案需要改成从fx.Graph里面去解析真正的参数名,比如layer_norm是第3、4个输入是权重。

@Xreki Xreki merged commit 2233f1a into PaddlePaddle:develop Jan 28, 2026
3 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

contributor External developers

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants