diff --git a/src/quantem/core/ml/__init__.py b/src/quantem/core/ml/__init__.py index ce1f6de0..af2d5b64 100644 --- a/src/quantem/core/ml/__init__.py +++ b/src/quantem/core/ml/__init__.py @@ -1,4 +1,8 @@ from quantem.core.ml.cnn import CNN2d as CNN2d, CNN3d as CNN3d +from quantem.core.ml.optimizer_mixin import ( + OptimizerParams as OptimizerParams, + SchedulerParams as SchedulerParams, +) from quantem.core.ml.inr import HSiren as HSiren from quantem.core.ml.dense_nn import DenseNN as DenseNN from quantem.core.ml.cnn_dense import CNNDense as CNNDense diff --git a/src/quantem/core/ml/optimizer_mixin.py b/src/quantem/core/ml/optimizer_mixin.py index e2d2e89a..e9dc9599 100644 --- a/src/quantem/core/ml/optimizer_mixin.py +++ b/src/quantem/core/ml/optimizer_mixin.py @@ -1,5 +1,6 @@ from abc import abstractmethod -from typing import TYPE_CHECKING, Generator, Iterator, Sequence +from dataclasses import dataclass +from typing import TYPE_CHECKING, Generator, Iterator, Literal, Sequence from quantem.core import config @@ -10,6 +11,472 @@ import torch +class OptimizerParams: + """ + Container for optimizer parameter dataclasses. + + Each nested class configures a specific PyTorch optimizer and can be passed + directly to ``OptimizerMixin.set_optimizer``, or constructed from a dict via + ``OptimizerParams.parse_dict``. + + Supported optimizers + -------------------- + Adam + ``torch.optim.Adam`` — adaptive moment estimation. + AdamW + ``torch.optim.AdamW`` — Adam with decoupled weight decay. + SGD + ``torch.optim.SGD`` — stochastic gradient descent with optional momentum and Nesterov. + NoneOptimizer + Sentinel that disables / removes the optimizer. + + Examples + -------- + >>> obj.set_optimizer(OptimizerParams.Adam(lr=1e-4)) + >>> obj.set_optimizer({"name": "adam", "lr": 1e-4}) # equivalent dict form + """ + + @dataclass + class Adam: + """ + Adam optimizer (``torch.optim.Adam``). + + Parameters + ---------- + lr : float + Learning rate. Default: 1e-3. + betas : tuple[float, float] + Coefficients for computing running averages of the gradient and its square. + Default: (0.9, 0.999). + eps : float + Term added to the denominator for numerical stability. Default: 1e-8. + weight_decay : float + L2 regularisation penalty. Default: 0. + """ + + lr: float = 1e-3 + betas: tuple[float, float] = (0.9, 0.999) + eps: float = 1e-8 + weight_decay: float = 0 + _name: str = "adam" + + def params(self) -> dict: + return { + "lr": self.lr, + "betas": self.betas, + "eps": self.eps, + "weight_decay": self.weight_decay, + } + + @dataclass + class AdamW: + """ + AdamW optimizer (``torch.optim.AdamW``). + + Identical to Adam but applies weight decay directly to the parameters + rather than folding it into the gradient update (decoupled weight decay). + + Parameters + ---------- + lr : float + Learning rate. Default: 1e-3. + betas : tuple[float, float] + Coefficients for computing running averages of the gradient and its square. + Default: (0.9, 0.999). + eps : float + Term added to the denominator for numerical stability. Default: 1e-8. + weight_decay : float + Decoupled L2 regularisation penalty. Default: 0. + """ + + lr: float = 1e-3 + betas: tuple[float, float] = (0.9, 0.999) + eps: float = 1e-8 + weight_decay: float = 0 + _name: str = "adamw" + + def params(self) -> dict: + return { + "lr": self.lr, + "betas": self.betas, + "eps": self.eps, + "weight_decay": self.weight_decay, + } + + @dataclass + class SGD: + """ + SGD optimizer (``torch.optim.SGD``). + + Parameters + ---------- + lr : float + Learning rate. Default: 1e-3. + momentum : float + Momentum factor. Default: 0. + dampening : float + Dampening for momentum. Default: 0. + weight_decay : float + L2 regularisation penalty. Default: 0. + nesterov : bool + Enables Nesterov momentum. Default: False. + """ + + lr: float = 1e-3 + momentum: float = 0 + dampening: float = 0 + weight_decay: float = 0 + nesterov: bool = False + _name: str = "sgd" + + def params(self) -> dict: + return { + "lr": self.lr, + "momentum": self.momentum, + "dampening": self.dampening, + "weight_decay": self.weight_decay, + "nesterov": self.nesterov, + } + + @dataclass + class NoneOptimizer: + """ + Sentinel optimizer that disables optimization. + + Passing this to ``set_optimizer`` will call ``remove_optimizer``, + clearing both the optimizer and scheduler. + """ + + _name: str = "none" + + def params(self) -> dict: + return {} + + @classmethod + def parse_dict(cls, d: dict): + """ + Parse dictionary to a optimizer params object. + Accepts either ``"name"`` or ``"type"`` as the optimizer key. + """ + d = dict(d) # avoid mutating caller's dict + name = d.pop("name", "none") or d.pop("type", "none") + if isinstance(name, type): + name = name.__name__.lower() + elif isinstance(name, str): + name = name.lower() + else: + raise ValueError(f"Unknown optimizer type: {name}") + if name == "adam": + return OptimizerParams.Adam(**d) + elif name == "adamw": + return OptimizerParams.AdamW(**d) + elif name == "sgd": + return OptimizerParams.SGD(**d) + else: + raise ValueError(f"Unknown optimizer type: {name.lower()}") + + +OptimizerType = ( + OptimizerParams.Adam + | OptimizerParams.AdamW + | OptimizerParams.SGD + | OptimizerParams.NoneOptimizer +) + + +class SchedulerParams: + """ + Container for learning-rate scheduler parameter dataclasses. + + Each nested class configures a specific PyTorch LR scheduler and can be passed + directly to ``OptimizerMixin.set_scheduler``, or constructed from a dict via + ``SchedulerParams.parse_dict``. + + Supported schedulers + -------------------- + Plateau + ``torch.optim.lr_scheduler.ReduceLROnPlateau`` — reduce LR when a metric stops improving. + Exponential + ``torch.optim.lr_scheduler.ExponentialLR`` — multiply LR by ``gamma`` each step. + Cyclic + ``torch.optim.lr_scheduler.CyclicLR`` — cycle LR between ``base_lr`` and ``max_lr``. + Linear + ``torch.optim.lr_scheduler.LinearLR`` — linearly interpolate LR over ``total_iters`` steps. + CosineAnnealing + ``torch.optim.lr_scheduler.CosineAnnealingLR`` — cosine-annealing LR schedule. + NoneScheduler + Sentinel that disables / removes the scheduler. + + Examples + -------- + >>> obj.set_scheduler(SchedulerParams.Plateau(factor=0.5, patience=10, cooldown=20)) + >>> obj.set_scheduler({"name": "plateau", "factor": 0.5}) # equivalent dict form + """ + + @dataclass + class Plateau: + """ + ReduceLROnPlateau scheduler (``torch.optim.lr_scheduler.ReduceLROnPlateau``). + + Reduces the learning rate when a monitored metric stops improving. + + Parameters + ---------- + mode : {'min', 'max'} + Whether the monitored metric should be minimised or maximised. Default: 'min'. + min_lr_factor : float + Sets ``min_lr = min_lr_factor * base_lr`` when ``min_lr`` is not provided. + Default: 1/20. + min_lr : float or None + Absolute lower bound on the learning rate. Overrides ``min_lr_factor`` when set. + Default: None. + factor : float + Factor by which the LR is reduced: ``new_lr = lr * factor``. Default: 0.5. + patience : int + Number of epochs with no improvement before reducing LR. Default: 10. + threshold : float + Minimum change in the monitored metric to qualify as an improvement. Default: 1e-5. + cooldown : int + Number of epochs to wait after a LR reduction before resuming normal operation. + Default: 50. + """ + + mode: Literal["min", "max"] = "min" + min_lr_factor: float = 1 / 20 + min_lr: float | None = None + factor: float = 0.5 + patience: int = 10 + threshold: float = 1e-5 + cooldown: int = 50 + _name: str = "plateau" + + def params(self, base_LR: float, num_iter: int | None = None) -> dict: + if self.min_lr is None: + self.min_lr = self.min_lr_factor * base_LR + return { + "mode": self.mode, + "factor": self.factor, + "patience": self.patience, + "threshold": self.threshold, + "min_lr": self.min_lr, + "cooldown": self.cooldown, + } + + @dataclass + class Exponential: + """ + ExponentialLR scheduler (``torch.optim.lr_scheduler.ExponentialLR``). + + Multiplies the learning rate by ``gamma`` after each step. + + Parameters + ---------- + gamma : float + Multiplicative decay factor applied each step. Default: 0.9. + factor : float or None + Reserved for future use. Default: 0.5. + num_iter : int or None + Reserved for future use. Default: None. + """ + + gamma: float = 0.9 + factor: float | None = None + num_iter: int | None = None + _name: str = "exponential" + + def params(self, base_LR: float, num_iter: int | None = None) -> dict: + effective_num_iter = self.num_iter if self.num_iter is not None else num_iter + if effective_num_iter is None: + raise ValueError("num_iter must be set if num_iter is not provided") + + self.num_iter = effective_num_iter + + if self.factor is not None: + self.gamma = self.factor ** (1.0 / effective_num_iter) + + return { + "gamma": self.gamma, + } + + @dataclass + class Cyclic: + """ + CyclicLR scheduler (``torch.optim.lr_scheduler.CyclicLR``). + + Cycles the learning rate between a lower bound (``base_lr``) and an upper + bound (``max_lr``). Bounds can be set directly or derived from the optimizer's + base LR via the factor parameters. + + Parameters + ---------- + base_lr_factor : float + Sets ``base_lr = base_lr_factor * optimizer_lr`` when ``base_lr`` is not provided. + Default: 1/4. + max_lr_factor : float + Sets ``max_lr = max_lr_factor * optimizer_lr`` when ``max_lr`` is not provided. + Default: 4. + base_lr : float or None + Absolute lower bound of the LR cycle. Overrides ``base_lr_factor`` when set. + Default: None. + max_lr : float or None + Absolute upper bound of the LR cycle. Overrides ``max_lr_factor`` when set. + Default: None. + step_size_up : int + Number of steps in the increasing half of each cycle. Default: 100. + step_size_down : int + Number of steps in the decreasing half of each cycle. Default: 100. + mode : {'triangular2', 'triangular', 'exp_range'} + Cycling policy. Default: 'triangular2'. + cycle_momentum : bool + If True, cycles momentum inversely to the learning rate. Default: False. + """ + + base_lr_factor: float = 1 / 4 + max_lr_factor: float = 4 + base_lr: float | None = None + max_lr: float | None = None + step_size_up: int = 100 + step_size_down: int = 100 + mode: Literal["triangular2", "triangular", "exp_range"] = "triangular2" + cycle_momentum: bool = False + _name: str = "cyclic" + + def params(self, base_LR: float, num_iter: int | None = None) -> dict: + if self.base_lr is None: + self.base_lr = self.base_lr_factor * base_LR + if self.max_lr is None: + self.max_lr = self.max_lr_factor * base_LR + return { + "base_lr": self.base_lr, + "max_lr": self.max_lr, + "step_size_up": self.step_size_up, + "step_size_down": self.step_size_down, + "mode": self.mode, + "cycle_momentum": self.cycle_momentum, + } + + @dataclass + class Linear: + """ + LinearLR scheduler (``torch.optim.lr_scheduler.LinearLR``). + + Linearly interpolates the learning rate from ``start_factor * base_lr`` to + ``end_factor * base_lr`` over ``total_iters`` steps. + + Parameters + ---------- + total_iters : int + Number of steps over which to interpolate the LR. Required. + start_factor : float + Multiplicative factor applied to the LR at the first step. Default: 0.1. + end_factor : float + Multiplicative factor applied to the LR at the final step. Default: 1.0. + """ + + total_iters: int | None = None + start_factor: float = 0.1 + end_factor: float = 1.0 + _name: str = "linear" + + def params(self, base_LR: float, num_iter: int | None = None) -> dict: + if num_iter is None and self.total_iters is None: + raise ValueError( + "total_iters must be set if num_iter is not provided" + ) # Should never be reached + if self.total_iters is None: + self.total_iters = num_iter + return { + "start_factor": self.start_factor, + "end_factor": self.end_factor, + "total_iters": self.total_iters, + } + + @dataclass + class CosineAnnealing: + """ + CosineAnnealingLR scheduler (``torch.optim.lr_scheduler.CosineAnnealingLR``). + + Anneals the learning rate following a cosine curve from the base LR down to + ``eta_min`` over ``T_max`` steps, then restarts. + + Parameters + ---------- + T_max : int + Maximum number of iterations (half-period of the cosine cycle). Required. + eta_min : float + Minimum learning rate at the trough of the cosine curve. Default: 1e-7. + """ + + eta_min: float = 1e-7 + T_max: int | None = None + _name: str = "cosine_annealing" + + def params(self, base_LR: float, num_iter: int | None = None) -> dict: + if num_iter is None and self.T_max is None: + raise ValueError( + "T_max must be set if num_iter is not provided" + ) # Should never be reached + if self.T_max is None: + self.T_max = num_iter + return { + "T_max": self.T_max, + "eta_min": self.eta_min, + } + + @dataclass + class NoneScheduler: + """ + Sentinel scheduler that disables LR scheduling. + + Passing this to ``set_scheduler`` clears the active scheduler without + affecting the optimizer. + """ + + _name: str = "none" + + def params(self, base_LR: float, num_iter: int | None = None) -> dict: + return {} + + @classmethod + def parse_dict(cls, d: dict): + """ + Parse dictionary to a scheduler params object. + Accepts either ``"name"`` or ``"type"`` as the scheduler key. + """ + d = dict(d) # avoid mutating caller's dict + name = d.pop("name", "none") or d.pop("type", "none") + if isinstance(name, type): + name = name.__name__.lower() + elif isinstance(name, str): + name = name.lower() + else: + raise ValueError(f"Unknown scheduler type: {name}") + if name == "plateau": + return SchedulerParams.Plateau(**d) + elif name == "exponential": + return SchedulerParams.Exponential(**d) + elif name == "cyclic": + return SchedulerParams.Cyclic(**d) + elif name == "linear": + return SchedulerParams.Linear(**d) + elif name == "cosine_annealing": + return SchedulerParams.CosineAnnealing(**d) + elif name == "none": + return SchedulerParams.NoneScheduler() + else: + raise ValueError(f"Unknown scheduler type: {name}") + + +SchedulerType = ( + SchedulerParams.Plateau + | SchedulerParams.Exponential + | SchedulerParams.Cyclic + | SchedulerParams.Linear + | SchedulerParams.CosineAnnealing + | SchedulerParams.NoneScheduler +) + + class OptimizerMixin: """ Mixin class for handling optimizer and scheduler management. @@ -22,8 +489,8 @@ def __init__(self): """Initialize the optimizer mixin.""" self._optimizer = None self._scheduler = None - self._optimizer_params = {} - self._scheduler_params = {} + self._optimizer_params: OptimizerType = OptimizerParams.NoneOptimizer() + self._scheduler_params: SchedulerType = SchedulerParams.NoneScheduler() # Don't call super().__init__() in mixin classes to avoid MRO issues @property @@ -37,38 +504,32 @@ def scheduler(self) -> "torch.optim.lr_scheduler.LRScheduler | None": return self._scheduler @property - def optimizer_params(self) -> dict: + def optimizer_params(self) -> OptimizerType: """Get the optimizer parameters.""" return self._optimizer_params @optimizer_params.setter - def optimizer_params(self, params: dict): + def optimizer_params(self, params: OptimizerType | dict): """Set the optimizer parameters.""" - self._optimizer_params = params.copy() if params else {} + if isinstance(params, dict): + params = OptimizerParams.parse_dict(d=params) + if not isinstance(params, OptimizerType): + raise TypeError(f"optimizer parameters must be a OptimizerType, got {type(params)}") + self._optimizer_params = params @property - def scheduler_params(self) -> dict: + def scheduler_params(self) -> SchedulerType: """Get the scheduler parameters.""" return self._scheduler_params @scheduler_params.setter - def scheduler_params(self, params: dict): + def scheduler_params(self, params: SchedulerType | dict): """Set the scheduler parameters.""" - if params: - if params["type"].lower() not in [ - "cyclic", - "plateau", - "exp", - "gamma", - "linear", - "none", - ]: - raise ValueError( - f"Unknown scheduler type: {params['type']}, expected one of ['cyclic', 'plateau', 'exp', 'gamma', 'none']" - ) - self._scheduler_params = params.copy() - else: - self._scheduler_params = {} + if isinstance(params, dict): + params = SchedulerParams.parse_dict(d=params) + if not isinstance(params, SchedulerType): + raise TypeError(f"scheduler parameters must be a SchedulerType, got {type(params)}") + self._scheduler_params = params @abstractmethod def get_optimization_parameters( @@ -81,7 +542,7 @@ def get_optimization_parameters( """ raise NotImplementedError("Subclasses must implement get_optimization_parameters") - def set_optimizer(self, opt_params: dict | None = None) -> None: + def set_optimizer(self, opt_params: OptimizerType | dict | None = None) -> None: """ Set the optimizer for this model. Currently supports single LR for all parameters, TODO allow for per parameter LRs by @@ -94,10 +555,7 @@ def set_optimizer(self, opt_params: dict | None = None) -> None: self._optimizer = None return - opt_params = self._optimizer_params.copy() - opt_type = opt_params.pop("type", self.DEFAULT_OPTIMIZER_TYPE) - - if opt_type == "none": + if isinstance(self._optimizer_params, OptimizerParams.NoneOptimizer): self.remove_optimizer() return @@ -111,22 +569,18 @@ def set_optimizer(self, opt_params: dict | None = None) -> None: for p in params: p.requires_grad_(True) - if isinstance(opt_type, type): - self._optimizer = opt_type(params, **opt_params) - elif isinstance(opt_type, str): - if opt_type.lower() == "adam": - self._optimizer = torch.optim.Adam(params, **opt_params) - elif opt_type.lower() == "adamw": - self._optimizer = torch.optim.AdamW(params, **opt_params) - elif opt_type.lower() == "sgd": - self._optimizer = torch.optim.SGD(params, **opt_params) - else: - raise NotImplementedError(f"Unknown optimizer type: {opt_type}") - else: - raise TypeError(f"optimizer type must be string or type, got {type(opt_type)}") + match self._optimizer_params: + case OptimizerParams.Adam(): + self._optimizer = torch.optim.Adam(params, **self._optimizer_params.params()) + case OptimizerParams.AdamW(): + self._optimizer = torch.optim.AdamW(params, **self._optimizer_params.params()) + case OptimizerParams.SGD(): + self._optimizer = torch.optim.SGD(params, **self._optimizer_params.params()) + case _: + raise NotImplementedError(f"Unknown optimizer type: {self._optimizer_params}") def set_scheduler( - self, scheduler_params: dict | None = None, num_iter: int | None = None + self, scheduler_params: SchedulerType | dict | None = None, num_iter: int | None = None ) -> None: """Set the scheduler for this model.""" if scheduler_params is not None: @@ -136,51 +590,39 @@ def set_scheduler( self._scheduler = None return - params = self._scheduler_params - sched_type = params.get("type", "none").lower() optimizer = self._optimizer base_LR = optimizer.param_groups[0]["lr"] - - if sched_type == "none": - self._scheduler = None - elif sched_type == "cyclic": - self._scheduler = torch.optim.lr_scheduler.CyclicLR( - optimizer, - base_lr=params.get("base_lr", base_LR / 4), - max_lr=params.get("max_lr", base_LR * 4), - step_size_up=params.get("step_size_up", 100), - step_size_down=params.get("step_size_down", params.get("step_size_up", 100)), - mode=params.get("mode", "triangular2"), - cycle_momentum=params.get("momentum", False), - ) - elif sched_type.startswith(("plat", "reducelronplat")): - self._scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( - optimizer, - mode="min", - factor=params.get("factor", 0.5), - patience=params.get("patience", 10), - threshold=params.get("threshold", 1e-3), - min_lr=params.get("min_lr", base_LR / 20), - cooldown=params.get("cooldown", 50), - ) - elif sched_type in ["exp", "gamma", "exponential"]: - if "gamma" in params: - gamma = params["gamma"] - elif num_iter is not None: - fac = params.get("factor", 0.01) - gamma = fac ** (1.0 / num_iter) - else: - gamma = 0.9 - self._scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma) - elif sched_type == "linear": - self._scheduler = torch.optim.lr_scheduler.LinearLR( - optimizer, - start_factor=params.get("start_factor", 0.1), - end_factor=params.get("end_factor", 1.0), - total_iters=params.get("total_iters", num_iter), - ) - else: - raise ValueError(f"Unknown scheduler type: {sched_type}") + params = self._scheduler_params.params(base_LR, num_iter=num_iter) + match self.scheduler_params: + case SchedulerParams.NoneScheduler(): + self._scheduler = None + case SchedulerParams.Cyclic(): + self._scheduler = torch.optim.lr_scheduler.CyclicLR( + optimizer, + **params, + ) + case SchedulerParams.Plateau(): + self._scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + **params, + ) + case SchedulerParams.Exponential(): + self._scheduler = torch.optim.lr_scheduler.ExponentialLR( + optimizer, + **params, + ) + case SchedulerParams.Linear(): + self._scheduler = torch.optim.lr_scheduler.LinearLR( + optimizer, + **params, + ) + case SchedulerParams.CosineAnnealing(): + self._scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + **params, + ) + case _: + raise ValueError(f"Unknown scheduler type: {self.scheduler_params}") def step_optimizer(self) -> None: """Step the optimizer if it exists.""" @@ -214,9 +656,9 @@ def get_current_lr(self) -> float: def remove_optimizer(self) -> None: """Remove the optimizer and scheduler.""" self._optimizer = None - self._optimizer_params = {} + self._optimizer_params = OptimizerParams.NoneOptimizer() self._scheduler = None - self._scheduler_params = {} + self._scheduler_params = SchedulerParams.NoneScheduler() def reset_optimizer(self) -> None: """Reset the optimizer and scheduler.""" diff --git a/tests/ml/test_optimizermixin.py b/tests/ml/test_optimizermixin.py new file mode 100644 index 00000000..6fd6c22a --- /dev/null +++ b/tests/ml/test_optimizermixin.py @@ -0,0 +1,295 @@ +"""Tests for OptimizerParams and SchedulerParams dataclasses.""" + +import pytest + +# Now import the module under test — adjust the path if needed +from quantem.core.ml.optimizer_mixin import ( + OptimizerParams, + SchedulerParams, +) + +# ─── OptimizerParams defaults ─────────────────────────────────────────────── + + +class TestAdamDefaults: + def test_defaults(self): + adam = OptimizerParams.Adam() + assert adam.lr == 1e-3 + assert adam.betas == (0.9, 0.999) + assert adam.eps == 1e-8 + assert adam.weight_decay == 0 + assert adam._name == "adam" + + def test_params_dict(self): + adam = OptimizerParams.Adam(lr=0.01, weight_decay=1e-4) + p = adam.params() + assert p == { + "lr": 0.01, + "betas": (0.9, 0.999), + "eps": 1e-8, + "weight_decay": 1e-4, + } + + def test_custom_betas(self): + adam = OptimizerParams.Adam(betas=(0.8, 0.99)) + assert adam.params()["betas"] == (0.8, 0.99) + + +class TestAdamWDefaults: + def test_defaults(self): + adamw = OptimizerParams.AdamW() + assert adamw.lr == 1e-3 + assert adamw._name == "adamw" + + def test_params_dict(self): + adamw = OptimizerParams.AdamW(lr=5e-4, eps=1e-7) + p = adamw.params() + assert p["lr"] == 5e-4 + assert p["eps"] == 1e-7 + + +class TestSGDDefaults: + def test_defaults(self): + sgd = OptimizerParams.SGD() + assert sgd.lr == 1e-3 + assert sgd.momentum == 0 + assert sgd.dampening == 0 + assert sgd.nesterov is False + assert sgd._name == "sgd" + + def test_params_dict(self): + sgd = OptimizerParams.SGD(lr=0.1, momentum=0.9, nesterov=True) + p = sgd.params() + assert p == { + "lr": 0.1, + "momentum": 0.9, + "dampening": 0, + "weight_decay": 0, + "nesterov": True, + } + + +class TestNoneOptimizer: + def test_defaults(self): + none_opt = OptimizerParams.NoneOptimizer() + assert none_opt._name == "none" + assert none_opt.params() == {} + + +# ─── OptimizerParams.parse_dict ───────────────────────────────────────────── + + +class TestOptimizerParseDict: + def test_parse_adam(self): + result = OptimizerParams.parse_dict({"name": "adam", "lr": 0.01}) + assert isinstance(result, OptimizerParams.Adam) + assert result.lr == 0.01 + + def test_parse_adamw(self): + result = OptimizerParams.parse_dict({"name": "adamw", "weight_decay": 0.1}) + assert isinstance(result, OptimizerParams.AdamW) + assert result.weight_decay == 0.1 + + def test_parse_sgd(self): + result = OptimizerParams.parse_dict({"name": "sgd", "momentum": 0.9}) + assert isinstance(result, OptimizerParams.SGD) + assert result.momentum == 0.9 + + def test_parse_case_insensitive(self): + result = OptimizerParams.parse_dict({"name": "Adam"}) + assert isinstance(result, OptimizerParams.Adam) + + def test_parse_unknown_raises(self): + with pytest.raises(ValueError, match="Unknown optimizer type"): + OptimizerParams.parse_dict({"name": "rmsprop"}) + + def test_parse_does_not_mutate_input(self): + d = {"name": "adam", "lr": 0.01} + original = dict(d) + OptimizerParams.parse_dict(d) + assert d == original + + def test_parse_invalid_name_type_raises(self): + with pytest.raises(ValueError, match="Unknown optimizer type"): + OptimizerParams.parse_dict({"name": 42}) + + +# ─── SchedulerParams defaults ─────────────────────────────────────────────── + + +class TestPlateauDefaults: + def test_defaults(self): + p = SchedulerParams.Plateau() + assert p.mode == "min" + assert p.factor == 0.5 + assert p.patience == 10 + assert p.cooldown == 50 + assert p.min_lr is None + assert p._name == "plateau" + + def test_params_computes_min_lr(self): + p = SchedulerParams.Plateau() + result = p.params(base_LR=0.01) + assert result["min_lr"] == pytest.approx(0.01 / 20) + + def test_params_explicit_min_lr(self): + p = SchedulerParams.Plateau(min_lr=1e-6) + result = p.params(base_LR=0.01) + assert result["min_lr"] == 1e-6 + + +class TestExponentialDefaults: + def test_defaults(self): + e = SchedulerParams.Exponential() + assert e.gamma == 0.9 + assert e._name == "exponential" + + def test_params_with_num_iter(self): + e = SchedulerParams.Exponential(factor=None) + result = e.params(base_LR=0.01, num_iter=100) + assert result == {"gamma": 0.9} + + def test_params_factor_overrides_gamma(self): + e = SchedulerParams.Exponential(factor=0.01) + result = e.params(base_LR=0.01, num_iter=100) + expected_gamma = 0.01 ** (1.0 / 100) + assert result["gamma"] == pytest.approx(expected_gamma) + + def test_params_no_num_iter_raises(self): + e = SchedulerParams.Exponential() + with pytest.raises(ValueError, match="num_iter must be set"): + e.params(base_LR=0.01, num_iter=None) + + def test_params_uses_own_num_iter(self): + e = SchedulerParams.Exponential(num_iter=50, factor=None) + result = e.params(base_LR=0.01, num_iter=None) + assert result == {"gamma": 0.9} + + +class TestCyclicDefaults: + def test_defaults(self): + c = SchedulerParams.Cyclic() + assert c.mode == "triangular2" + assert c.cycle_momentum is False + assert c._name == "cyclic" + + def test_params_computes_lr_bounds(self): + c = SchedulerParams.Cyclic() + result = c.params(base_LR=0.01) + assert result["base_lr"] == pytest.approx(0.01 / 4) + assert result["max_lr"] == pytest.approx(0.01 * 4) + + def test_params_explicit_lr_bounds(self): + c = SchedulerParams.Cyclic(base_lr=0.001, max_lr=0.1) + result = c.params(base_LR=999.0) # should be ignored + assert result["base_lr"] == 0.001 + assert result["max_lr"] == 0.1 + + +class TestLinearDefaults: + def test_defaults(self): + test = SchedulerParams.Linear() + assert test.start_factor == 0.1 + assert test.end_factor == 1.0 + assert test._name == "linear" + + def test_params_uses_num_iter(self): + test = SchedulerParams.Linear() + result = test.params(base_LR=0.01, num_iter=200) + assert result["total_iters"] == 200 + + def test_params_explicit_total_iters(self): + test = SchedulerParams.Linear(total_iters=50) + result = test.params(base_LR=0.01, num_iter=200) + assert result["total_iters"] == 50 + + def test_params_no_iters_raises(self): + test = SchedulerParams.Linear() + with pytest.raises(ValueError, match="total_iters must be set"): + test.params(base_LR=0.01, num_iter=None) + + +class TestCosineAnnealingDefaults: + def test_defaults(self): + ca = SchedulerParams.CosineAnnealing() + assert ca.eta_min == 1e-7 + assert ca.T_max is None + assert ca._name == "cosine_annealing" + + def test_params_uses_num_iter(self): + ca = SchedulerParams.CosineAnnealing() + result = ca.params(base_LR=0.01, num_iter=300) + assert result["T_max"] == 300 + + def test_params_explicit_T_max(self): + ca = SchedulerParams.CosineAnnealing(T_max=150) + result = ca.params(base_LR=0.01, num_iter=300) + assert result["T_max"] == 150 + + def test_params_no_T_max_raises(self): + ca = SchedulerParams.CosineAnnealing() + with pytest.raises(ValueError, match="T_max must be set"): + ca.params(base_LR=0.01, num_iter=None) + + +class TestNoneScheduler: + def test_defaults(self): + ns = SchedulerParams.NoneScheduler() + assert ns._name == "none" + assert ns.params(base_LR=0.01) == {} + + +# ─── SchedulerParams.parse_dict ───────────────────────────────────────────── + + +class TestSchedulerParseDict: + def test_parse_plateau(self): + result = SchedulerParams.parse_dict({"name": "plateau", "patience": 20}) + assert isinstance(result, SchedulerParams.Plateau) + assert result.patience == 20 + + def test_parse_exponential(self): + result = SchedulerParams.parse_dict({"name": "exponential", "gamma": 0.95}) + assert isinstance(result, SchedulerParams.Exponential) + assert result.gamma == 0.95 + + def test_parse_cyclic(self): + result = SchedulerParams.parse_dict({"name": "cyclic", "step_size_up": 50}) + assert isinstance(result, SchedulerParams.Cyclic) + assert result.step_size_up == 50 + + def test_parse_linear(self): + result = SchedulerParams.parse_dict({"name": "linear", "start_factor": 0.5}) + assert isinstance(result, SchedulerParams.Linear) + assert result.start_factor == 0.5 + + def test_parse_cosine_annealing(self): + result = SchedulerParams.parse_dict({"name": "cosine_annealing", "T_max": 100}) + assert isinstance(result, SchedulerParams.CosineAnnealing) + assert result.T_max == 100 + + def test_parse_none(self): + result = SchedulerParams.parse_dict({"name": "none"}) + assert isinstance(result, SchedulerParams.NoneScheduler) + + def test_parse_case_insensitive(self): + result = SchedulerParams.parse_dict({"name": "Plateau"}) + assert isinstance(result, SchedulerParams.Plateau) + + def test_parse_unknown_raises(self): + with pytest.raises(ValueError, match="Unknown scheduler type"): + SchedulerParams.parse_dict({"name": "warmup"}) + + def test_parse_does_not_mutate_input(self): + d = {"name": "plateau", "patience": 5} + original = dict(d) + SchedulerParams.parse_dict(d) + assert d == original + + def test_parse_invalid_name_type_raises(self): + with pytest.raises(ValueError, match="Unknown scheduler type"): + SchedulerParams.parse_dict({"name": 3.14}) + + def test_parse_default_name_is_none(self): + result = SchedulerParams.parse_dict({}) + assert isinstance(result, SchedulerParams.NoneScheduler)