Skip to content

support train kwargs and configs#267

Open
BernhardAhrens wants to merge 9 commits into
mainfrom
ba/train_signatures
Open

support train kwargs and configs#267
BernhardAhrens wants to merge 9 commits into
mainfrom
ba/train_signatures

Conversation

@BernhardAhrens
Copy link
Copy Markdown
Collaborator

No description provided.

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the train function by extracting its core logic into a private _train function and updating the public train method to support both explicit configuration objects and flat keyword arguments. Feedback points out a logic error where explicit configuration objects are overwritten if keyword arguments are provided, potentially leading to lost settings. It is also recommended to align the implementation with the documentation regarding the deprecation of flat keyword arguments by adding formal warnings.

Comment thread src/training/train.jl
Comment on lines +83 to +94
function train(
model, data;
train_cfg::TrainConfig = TrainConfig(),
data_cfg::DataConfig = DataConfig(),
kwargs...,
)
if !isempty(kwargs)
train_cfg, data_cfg = kwargs_to_configs((), kwargs)
end

return _train(model, data, train_cfg, data_cfg)
end
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation has a logic error when mixing explicit configuration objects (train_cfg, data_cfg) with flat keyword arguments. If kwargs is not empty, the provided train_cfg and data_cfg are completely overwritten by new instances created from kwargs and defaults, leading to silent loss of user configuration.

If mixing is not intended to be supported, it is safer to use nothing as default values to detect explicit usage and prevent accidental overwriting. If mixing is intended, kwargs_to_configs should be updated to merge with the existing configuration objects.

function train(
    model, data;
    train_cfg::Union{TrainConfig, Nothing} = nothing,
    data_cfg::Union{DataConfig, Nothing} = nothing,
    kwargs...,
)
    if !isempty(kwargs)
        if !isnothing(train_cfg) || !isnothing(data_cfg)
            throw(ArgumentError("Cannot mix explicit `train_cfg`/`data_cfg` with flat keyword arguments."))
        end
        train_cfg, data_cfg = kwargs_to_configs((), kwargs)
    else
        train_cfg = isnothing(train_cfg) ? TrainConfig() : train_cfg
        data_cfg = isnothing(data_cfg) ? DataConfig() : data_cfg
    end

    return _train(model, data, train_cfg, data_cfg)
end

Comment thread src/training/train.jl Outdated
# Keyword Arguments
- `train_cfg`: Training configuration. See [`TrainConfig`](@ref) for all options.
- `data_cfg`: Data preparation configuration. See [`DataConfig`](@ref) for all options.
- Any other kwargs (deprecated) are forwarded as fields to `TrainConfig` / `DataConfig`.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring labels flat keyword arguments as "deprecated" in the context of the new train method. However, they are being explicitly supported in the method signature. If the intention is to phase them out, consider adding a formal deprecation warning when they are used via the new API, or clarify if they are intended as a permanent convenience feature.

Comment thread src/training/train.jl Outdated

unknown = [k for k in keys(kwargs) if k ∉ train_keys && k ∉ data_keys]
if !isempty(unknown)
@warn "Unknown kwargs will be ignored: $(join(unknown, ", "))"
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe is better to throw error instead?

Comment thread src/io/checkpoints.jl
target_names = model.targets
save_epoch = stopper.best_epoch == 0 ? 0 : stopper.best_epoch
save_ps_st!(paths.best_model, model, cfg.cdev(ps), cfg.cdev(st), cfg.tracked_params, save_epoch)
save_ps_st(paths.best_model, model, cfg.cdev(ps), cfg.cdev(st), cfg.tracked_params, save_epoch)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should be !. At this stage, the initial file was created already. I think this will write fully the file, or?

Comment thread src/io/save.jl
export get_all_groups
export load_group
function save_ps_st(file_name, hm, ps, st, save_ps)
function save_ps_st(file_name, hm, ps, st, save_ps, epoch = 0)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, I see. Same function in both places, no need for the ! variant?

Comment thread src/training/train.jl
when non-empty it is forwarded as `tracked_params` on the resulting `TrainConfig`.
"""
function kwargs_to_configs(kwargs)
function kwargs_to_configs(save_ps, kwargs)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so, this makes the interface compatible with the old syntax ? good.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

nn weights regularisation Custom L2 and L1 regularization - only on NN parameters

2 participants