-
Notifications
You must be signed in to change notification settings - Fork 45
Add new_kwargs function to fix dtype generalization pass #614
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add new_kwargs function to fix dtype generalization pass #614
Conversation
|
Thanks for your contribution! |
|
之前的代码逻辑中并没有修改黑名单算子如torch.layer_norm 的dtype,因此torch.layer_norm 报错 expected scalar type Float but found Half 的原因不在于 layer_norm 这个节点本身有没有被改写,而在于传给它的输入参数可能被上游节点转成了f16或Bf16。 |
|
这种方式虽然能跑起来,但是你仔细观察变换后的 |
|
#Weights that must remain float32 for numerical stability |
should_preserve_weight(attr_name)的逻辑是通过判断node.target是否在FLOAT32_PRESERVED_WEIGHTS来保持为fp32,说明FLOAT32_PRESERVED_WEIGHTS无法完全处理。 |
Xreki
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
|


PR Category
other
Description
Add FP32_ONLY_FUNCS = {
torch.nn.functional.softmax,
torch.nn.functional.layer_norm,
torch.nn.functional.group_norm,
torch.nn.functional.batch_norm,
torch.nn.functional.embedding,
torch.exp,
torch.log,
torch.pow,
torch.sigmoid,
torch.tanh,
torch.conv_transpose2d,
} to fix dtype generalization pass