From d96808ba4994b434a51fbe16e893d3be35ee48dd Mon Sep 17 00:00:00 2001 From: Cedric Lim Date: Wed, 4 Mar 2026 20:00:42 -0800 Subject: [PATCH 1/3] Hot-fix for OptimizerMixin bug with or statement --- src/quantem/core/ml/optimizer_mixin.py | 618 +++++++++++++++++++++---- 1 file changed, 531 insertions(+), 87 deletions(-) diff --git a/src/quantem/core/ml/optimizer_mixin.py b/src/quantem/core/ml/optimizer_mixin.py index e2d2e89a..94c12518 100644 --- a/src/quantem/core/ml/optimizer_mixin.py +++ b/src/quantem/core/ml/optimizer_mixin.py @@ -1,5 +1,6 @@ from abc import abstractmethod -from typing import TYPE_CHECKING, Generator, Iterator, Sequence +from dataclasses import dataclass +from typing import TYPE_CHECKING, Generator, Iterator, Literal, Sequence from quantem.core import config @@ -10,6 +11,474 @@ import torch +class OptimizerParams: + """ + Container for optimizer parameter dataclasses. + + Each nested class configures a specific PyTorch optimizer and can be passed + directly to ``OptimizerMixin.set_optimizer``, or constructed from a dict via + ``OptimizerParams.parse_dict``. + + Supported optimizers + -------------------- + Adam + ``torch.optim.Adam`` — adaptive moment estimation. + AdamW + ``torch.optim.AdamW`` — Adam with decoupled weight decay. + SGD + ``torch.optim.SGD`` — stochastic gradient descent with optional momentum and Nesterov. + NoneOptimizer + Sentinel that disables / removes the optimizer. + + Examples + -------- + >>> obj.set_optimizer(OptimizerParams.Adam(lr=1e-4)) + >>> obj.set_optimizer({"name": "adam", "lr": 1e-4}) # equivalent dict form + """ + + @dataclass + class Adam: + """ + Adam optimizer (``torch.optim.Adam``). + + Parameters + ---------- + lr : float + Learning rate. Default: 1e-3. + betas : tuple[float, float] + Coefficients for computing running averages of the gradient and its square. + Default: (0.9, 0.999). + eps : float + Term added to the denominator for numerical stability. Default: 1e-8. + weight_decay : float + L2 regularisation penalty. Default: 0. + """ + + lr: float = 1e-3 + betas: tuple[float, float] = (0.9, 0.999) + eps: float = 1e-8 + weight_decay: float = 0 + _name: str = "adam" + + def params(self) -> dict: + return { + "lr": self.lr, + "betas": self.betas, + "eps": self.eps, + "weight_decay": self.weight_decay, + } + + @dataclass + class AdamW: + """ + AdamW optimizer (``torch.optim.AdamW``). + + Identical to Adam but applies weight decay directly to the parameters + rather than folding it into the gradient update (decoupled weight decay). + + Parameters + ---------- + lr : float + Learning rate. Default: 1e-3. + betas : tuple[float, float] + Coefficients for computing running averages of the gradient and its square. + Default: (0.9, 0.999). + eps : float + Term added to the denominator for numerical stability. Default: 1e-8. + weight_decay : float + Decoupled L2 regularisation penalty. Default: 0. + """ + + lr: float = 1e-3 + betas: tuple[float, float] = (0.9, 0.999) + eps: float = 1e-8 + weight_decay: float = 0 + _name: str = "adamw" + + def params(self) -> dict: + return { + "lr": self.lr, + "betas": self.betas, + "eps": self.eps, + "weight_decay": self.weight_decay, + } + + @dataclass + class SGD: + """ + SGD optimizer (``torch.optim.SGD``). + + Parameters + ---------- + lr : float + Learning rate. Default: 1e-3. + momentum : float + Momentum factor. Default: 0. + dampening : float + Dampening for momentum. Default: 0. + weight_decay : float + L2 regularisation penalty. Default: 0. + nesterov : bool + Enables Nesterov momentum. Default: False. + """ + + lr: float = 1e-3 + momentum: float = 0 + dampening: float = 0 + weight_decay: float = 0 + nesterov: bool = False + _name: str = "sgd" + + def params(self) -> dict: + return { + "lr": self.lr, + "momentum": self.momentum, + "dampening": self.dampening, + "weight_decay": self.weight_decay, + "nesterov": self.nesterov, + } + + @dataclass + class NoneOptimizer: + """ + Sentinel optimizer that disables optimization. + + Passing this to ``set_optimizer`` will call ``remove_optimizer``, + clearing both the optimizer and scheduler. + """ + + _name: str = "none" + + def params(self) -> dict: + return {} + + @classmethod + def parse_dict(cls, d: dict): + """ + Parse dictionary to a optimizer params object. + Accepts either ``"name"`` or ``"type"`` as the optimizer key. + """ + d = dict(d) # avoid mutating caller's dict + name = d.pop("name", None) or d.pop("type", "none") + if isinstance(name, type): + name = name.__name__.lower() + elif isinstance(name, str): + name = name.lower() + else: + raise ValueError(f"Unknown optimizer type: {name}") + if name == "adam": + return OptimizerParams.Adam(**d) + elif name == "adamw": + return OptimizerParams.AdamW(**d) + elif name == "sgd": + return OptimizerParams.SGD(**d) + elif name == "none": + return OptimizerParams.NoneOptimizer() + else: + raise ValueError(f"Unknown optimizer type: {name.lower()}") + + +OptimizerType = ( + OptimizerParams.Adam + | OptimizerParams.AdamW + | OptimizerParams.SGD + | OptimizerParams.NoneOptimizer +) + + +class SchedulerParams: + """ + Container for learning-rate scheduler parameter dataclasses. + + Each nested class configures a specific PyTorch LR scheduler and can be passed + directly to ``OptimizerMixin.set_scheduler``, or constructed from a dict via + ``SchedulerParams.parse_dict``. + + Supported schedulers + -------------------- + Plateau + ``torch.optim.lr_scheduler.ReduceLROnPlateau`` — reduce LR when a metric stops improving. + Exponential + ``torch.optim.lr_scheduler.ExponentialLR`` — multiply LR by ``gamma`` each step. + Cyclic + ``torch.optim.lr_scheduler.CyclicLR`` — cycle LR between ``base_lr`` and ``max_lr``. + Linear + ``torch.optim.lr_scheduler.LinearLR`` — linearly interpolate LR over ``total_iters`` steps. + CosineAnnealing + ``torch.optim.lr_scheduler.CosineAnnealingLR`` — cosine-annealing LR schedule. + NoneScheduler + Sentinel that disables / removes the scheduler. + + Examples + -------- + >>> obj.set_scheduler(SchedulerParams.Plateau(factor=0.5, patience=10, cooldown=20)) + >>> obj.set_scheduler({"name": "plateau", "factor": 0.5}) # equivalent dict form + """ + + @dataclass + class Plateau: + """ + ReduceLROnPlateau scheduler (``torch.optim.lr_scheduler.ReduceLROnPlateau``). + + Reduces the learning rate when a monitored metric stops improving. + + Parameters + ---------- + mode : {'min', 'max'} + Whether the monitored metric should be minimised or maximised. Default: 'min'. + min_lr_factor : float + Sets ``min_lr = min_lr_factor * base_lr`` when ``min_lr`` is not provided. + Default: 1/20. + min_lr : float or None + Absolute lower bound on the learning rate. Overrides ``min_lr_factor`` when set. + Default: None. + factor : float + Factor by which the LR is reduced: ``new_lr = lr * factor``. Default: 0.5. + patience : int + Number of epochs with no improvement before reducing LR. Default: 10. + threshold : float + Minimum change in the monitored metric to qualify as an improvement. Default: 1e-5. + cooldown : int + Number of epochs to wait after a LR reduction before resuming normal operation. + Default: 50. + """ + + mode: Literal["min", "max"] = "min" + min_lr_factor: float = 1 / 20 + min_lr: float | None = None + factor: float = 0.5 + patience: int = 10 + threshold: float = 1e-5 + cooldown: int = 50 + _name: str = "plateau" + + def params(self, base_LR: float, num_iter: int | None = None) -> dict: + if self.min_lr is None: + self.min_lr = self.min_lr_factor * base_LR + return { + "mode": self.mode, + "factor": self.factor, + "patience": self.patience, + "threshold": self.threshold, + "min_lr": self.min_lr, + "cooldown": self.cooldown, + } + + @dataclass + class Exponential: + """ + ExponentialLR scheduler (``torch.optim.lr_scheduler.ExponentialLR``). + + Multiplies the learning rate by ``gamma`` after each step. + + Parameters + ---------- + gamma : float + Multiplicative decay factor applied each step. Default: 0.9. + factor : float or None + Reserved for future use. Default: 0.5. + num_iter : int or None + Reserved for future use. Default: None. + """ + + gamma: float = 0.9 + factor: float | None = None + num_iter: int | None = None + _name: str = "exponential" + + def params(self, base_LR: float, num_iter: int | None = None) -> dict: + effective_num_iter = self.num_iter if self.num_iter is not None else num_iter + if effective_num_iter is None: + raise ValueError("num_iter must be set if num_iter is not provided") + + self.num_iter = effective_num_iter + + if self.factor is not None: + self.gamma = self.factor ** (1.0 / effective_num_iter) + + return { + "gamma": self.gamma, + } + + @dataclass + class Cyclic: + """ + CyclicLR scheduler (``torch.optim.lr_scheduler.CyclicLR``). + + Cycles the learning rate between a lower bound (``base_lr``) and an upper + bound (``max_lr``). Bounds can be set directly or derived from the optimizer's + base LR via the factor parameters. + + Parameters + ---------- + base_lr_factor : float + Sets ``base_lr = base_lr_factor * optimizer_lr`` when ``base_lr`` is not provided. + Default: 1/4. + max_lr_factor : float + Sets ``max_lr = max_lr_factor * optimizer_lr`` when ``max_lr`` is not provided. + Default: 4. + base_lr : float or None + Absolute lower bound of the LR cycle. Overrides ``base_lr_factor`` when set. + Default: None. + max_lr : float or None + Absolute upper bound of the LR cycle. Overrides ``max_lr_factor`` when set. + Default: None. + step_size_up : int + Number of steps in the increasing half of each cycle. Default: 100. + step_size_down : int + Number of steps in the decreasing half of each cycle. Default: 100. + mode : {'triangular2', 'triangular', 'exp_range'} + Cycling policy. Default: 'triangular2'. + cycle_momentum : bool + If True, cycles momentum inversely to the learning rate. Default: False. + """ + + base_lr_factor: float = 1 / 4 + max_lr_factor: float = 4 + base_lr: float | None = None + max_lr: float | None = None + step_size_up: int = 100 + step_size_down: int = 100 + mode: Literal["triangular2", "triangular", "exp_range"] = "triangular2" + cycle_momentum: bool = False + _name: str = "cyclic" + + def params(self, base_LR: float, num_iter: int | None = None) -> dict: + if self.base_lr is None: + self.base_lr = self.base_lr_factor * base_LR + if self.max_lr is None: + self.max_lr = self.max_lr_factor * base_LR + return { + "base_lr": self.base_lr, + "max_lr": self.max_lr, + "step_size_up": self.step_size_up, + "step_size_down": self.step_size_down, + "mode": self.mode, + "cycle_momentum": self.cycle_momentum, + } + + @dataclass + class Linear: + """ + LinearLR scheduler (``torch.optim.lr_scheduler.LinearLR``). + + Linearly interpolates the learning rate from ``start_factor * base_lr`` to + ``end_factor * base_lr`` over ``total_iters`` steps. + + Parameters + ---------- + total_iters : int + Number of steps over which to interpolate the LR. Required. + start_factor : float + Multiplicative factor applied to the LR at the first step. Default: 0.1. + end_factor : float + Multiplicative factor applied to the LR at the final step. Default: 1.0. + """ + + total_iters: int | None = None + start_factor: float = 0.1 + end_factor: float = 1.0 + _name: str = "linear" + + def params(self, base_LR: float, num_iter: int | None = None) -> dict: + if num_iter is None and self.total_iters is None: + raise ValueError( + "total_iters must be set if num_iter is not provided" + ) # Should never be reached + if self.total_iters is None: + self.total_iters = num_iter + return { + "start_factor": self.start_factor, + "end_factor": self.end_factor, + "total_iters": self.total_iters, + } + + @dataclass + class CosineAnnealing: + """ + CosineAnnealingLR scheduler (``torch.optim.lr_scheduler.CosineAnnealingLR``). + + Anneals the learning rate following a cosine curve from the base LR down to + ``eta_min`` over ``T_max`` steps, then restarts. + + Parameters + ---------- + T_max : int + Maximum number of iterations (half-period of the cosine cycle). Required. + eta_min : float + Minimum learning rate at the trough of the cosine curve. Default: 1e-7. + """ + + eta_min: float = 1e-7 + T_max: int | None = None + _name: str = "cosine_annealing" + + def params(self, base_LR: float, num_iter: int | None = None) -> dict: + if num_iter is None and self.T_max is None: + raise ValueError( + "T_max must be set if num_iter is not provided" + ) # Should never be reached + if self.T_max is None: + self.T_max = num_iter + return { + "T_max": self.T_max, + "eta_min": self.eta_min, + } + + @dataclass + class NoneScheduler: + """ + Sentinel scheduler that disables LR scheduling. + + Passing this to ``set_scheduler`` clears the active scheduler without + affecting the optimizer. + """ + + _name: str = "none" + + def params(self, base_LR: float, num_iter: int | None = None) -> dict: + return {} + + @classmethod + def parse_dict(cls, d: dict): + """ + Parse dictionary to a scheduler params object. + Accepts either ``"name"`` or ``"type"`` as the scheduler key. + """ + d = dict(d) # avoid mutating caller's dict + name = d.pop("name", None) or d.pop("type", "none") + if isinstance(name, type): + name = name.__name__.lower() + elif isinstance(name, str): + name = name.lower() + else: + raise ValueError(f"Unknown scheduler type: {name}") + if name == "plateau": + return SchedulerParams.Plateau(**d) + elif name == "exponential": + return SchedulerParams.Exponential(**d) + elif name == "cyclic": + return SchedulerParams.Cyclic(**d) + elif name == "linear": + return SchedulerParams.Linear(**d) + elif name == "cosine_annealing": + return SchedulerParams.CosineAnnealing(**d) + elif name == "none": + return SchedulerParams.NoneScheduler() + else: + raise ValueError(f"Unknown scheduler type: {name}") + + +SchedulerType = ( + SchedulerParams.Plateau + | SchedulerParams.Exponential + | SchedulerParams.Cyclic + | SchedulerParams.Linear + | SchedulerParams.CosineAnnealing + | SchedulerParams.NoneScheduler +) + + class OptimizerMixin: """ Mixin class for handling optimizer and scheduler management. @@ -22,8 +491,8 @@ def __init__(self): """Initialize the optimizer mixin.""" self._optimizer = None self._scheduler = None - self._optimizer_params = {} - self._scheduler_params = {} + self._optimizer_params: OptimizerType = OptimizerParams.NoneOptimizer() + self._scheduler_params: SchedulerType = SchedulerParams.NoneScheduler() # Don't call super().__init__() in mixin classes to avoid MRO issues @property @@ -37,38 +506,32 @@ def scheduler(self) -> "torch.optim.lr_scheduler.LRScheduler | None": return self._scheduler @property - def optimizer_params(self) -> dict: + def optimizer_params(self) -> OptimizerType: """Get the optimizer parameters.""" return self._optimizer_params @optimizer_params.setter - def optimizer_params(self, params: dict): + def optimizer_params(self, params: OptimizerType | dict): """Set the optimizer parameters.""" - self._optimizer_params = params.copy() if params else {} + if isinstance(params, dict): + params = OptimizerParams.parse_dict(d=params) + if not isinstance(params, OptimizerType): + raise TypeError(f"optimizer parameters must be a OptimizerType, got {type(params)}") + self._optimizer_params = params @property - def scheduler_params(self) -> dict: + def scheduler_params(self) -> SchedulerType: """Get the scheduler parameters.""" return self._scheduler_params @scheduler_params.setter - def scheduler_params(self, params: dict): + def scheduler_params(self, params: SchedulerType | dict): """Set the scheduler parameters.""" - if params: - if params["type"].lower() not in [ - "cyclic", - "plateau", - "exp", - "gamma", - "linear", - "none", - ]: - raise ValueError( - f"Unknown scheduler type: {params['type']}, expected one of ['cyclic', 'plateau', 'exp', 'gamma', 'none']" - ) - self._scheduler_params = params.copy() - else: - self._scheduler_params = {} + if isinstance(params, dict): + params = SchedulerParams.parse_dict(d=params) + if not isinstance(params, SchedulerType): + raise TypeError(f"scheduler parameters must be a SchedulerType, got {type(params)}") + self._scheduler_params = params @abstractmethod def get_optimization_parameters( @@ -81,7 +544,7 @@ def get_optimization_parameters( """ raise NotImplementedError("Subclasses must implement get_optimization_parameters") - def set_optimizer(self, opt_params: dict | None = None) -> None: + def set_optimizer(self, opt_params: OptimizerType | dict | None = None) -> None: """ Set the optimizer for this model. Currently supports single LR for all parameters, TODO allow for per parameter LRs by @@ -94,10 +557,7 @@ def set_optimizer(self, opt_params: dict | None = None) -> None: self._optimizer = None return - opt_params = self._optimizer_params.copy() - opt_type = opt_params.pop("type", self.DEFAULT_OPTIMIZER_TYPE) - - if opt_type == "none": + if isinstance(self._optimizer_params, OptimizerParams.NoneOptimizer): self.remove_optimizer() return @@ -111,22 +571,18 @@ def set_optimizer(self, opt_params: dict | None = None) -> None: for p in params: p.requires_grad_(True) - if isinstance(opt_type, type): - self._optimizer = opt_type(params, **opt_params) - elif isinstance(opt_type, str): - if opt_type.lower() == "adam": - self._optimizer = torch.optim.Adam(params, **opt_params) - elif opt_type.lower() == "adamw": - self._optimizer = torch.optim.AdamW(params, **opt_params) - elif opt_type.lower() == "sgd": - self._optimizer = torch.optim.SGD(params, **opt_params) - else: - raise NotImplementedError(f"Unknown optimizer type: {opt_type}") - else: - raise TypeError(f"optimizer type must be string or type, got {type(opt_type)}") + match self._optimizer_params: + case OptimizerParams.Adam(): + self._optimizer = torch.optim.Adam(params, **self._optimizer_params.params()) + case OptimizerParams.AdamW(): + self._optimizer = torch.optim.AdamW(params, **self._optimizer_params.params()) + case OptimizerParams.SGD(): + self._optimizer = torch.optim.SGD(params, **self._optimizer_params.params()) + case _: + raise NotImplementedError(f"Unknown optimizer type: {self._optimizer_params}") def set_scheduler( - self, scheduler_params: dict | None = None, num_iter: int | None = None + self, scheduler_params: SchedulerType | dict | None = None, num_iter: int | None = None ) -> None: """Set the scheduler for this model.""" if scheduler_params is not None: @@ -136,51 +592,39 @@ def set_scheduler( self._scheduler = None return - params = self._scheduler_params - sched_type = params.get("type", "none").lower() optimizer = self._optimizer base_LR = optimizer.param_groups[0]["lr"] - - if sched_type == "none": - self._scheduler = None - elif sched_type == "cyclic": - self._scheduler = torch.optim.lr_scheduler.CyclicLR( - optimizer, - base_lr=params.get("base_lr", base_LR / 4), - max_lr=params.get("max_lr", base_LR * 4), - step_size_up=params.get("step_size_up", 100), - step_size_down=params.get("step_size_down", params.get("step_size_up", 100)), - mode=params.get("mode", "triangular2"), - cycle_momentum=params.get("momentum", False), - ) - elif sched_type.startswith(("plat", "reducelronplat")): - self._scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( - optimizer, - mode="min", - factor=params.get("factor", 0.5), - patience=params.get("patience", 10), - threshold=params.get("threshold", 1e-3), - min_lr=params.get("min_lr", base_LR / 20), - cooldown=params.get("cooldown", 50), - ) - elif sched_type in ["exp", "gamma", "exponential"]: - if "gamma" in params: - gamma = params["gamma"] - elif num_iter is not None: - fac = params.get("factor", 0.01) - gamma = fac ** (1.0 / num_iter) - else: - gamma = 0.9 - self._scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma) - elif sched_type == "linear": - self._scheduler = torch.optim.lr_scheduler.LinearLR( - optimizer, - start_factor=params.get("start_factor", 0.1), - end_factor=params.get("end_factor", 1.0), - total_iters=params.get("total_iters", num_iter), - ) - else: - raise ValueError(f"Unknown scheduler type: {sched_type}") + params = self._scheduler_params.params(base_LR, num_iter=num_iter) + match self.scheduler_params: + case SchedulerParams.NoneScheduler(): + self._scheduler = None + case SchedulerParams.Cyclic(): + self._scheduler = torch.optim.lr_scheduler.CyclicLR( + optimizer, + **params, + ) + case SchedulerParams.Plateau(): + self._scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + **params, + ) + case SchedulerParams.Exponential(): + self._scheduler = torch.optim.lr_scheduler.ExponentialLR( + optimizer, + **params, + ) + case SchedulerParams.Linear(): + self._scheduler = torch.optim.lr_scheduler.LinearLR( + optimizer, + **params, + ) + case SchedulerParams.CosineAnnealing(): + self._scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + **params, + ) + case _: + raise ValueError(f"Unknown scheduler type: {self.scheduler_params}") def step_optimizer(self) -> None: """Step the optimizer if it exists.""" @@ -214,9 +658,9 @@ def get_current_lr(self) -> float: def remove_optimizer(self) -> None: """Remove the optimizer and scheduler.""" self._optimizer = None - self._optimizer_params = {} + self._optimizer_params = OptimizerParams.NoneOptimizer() self._scheduler = None - self._scheduler_params = {} + self._scheduler_params = SchedulerParams.NoneScheduler() def reset_optimizer(self) -> None: """Reset the optimizer and scheduler.""" From 895cb1b50c9f3feb8b21ff8c64fd765ed0af7430 Mon Sep 17 00:00:00 2001 From: Cedric Lim Date: Wed, 4 Mar 2026 20:03:34 -0800 Subject: [PATCH 2/3] Hotfix from tomo_refactor --- src/quantem/core/ml/optimizer_mixin.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/quantem/core/ml/optimizer_mixin.py b/src/quantem/core/ml/optimizer_mixin.py index e9dc9599..94c12518 100644 --- a/src/quantem/core/ml/optimizer_mixin.py +++ b/src/quantem/core/ml/optimizer_mixin.py @@ -159,7 +159,7 @@ def parse_dict(cls, d: dict): Accepts either ``"name"`` or ``"type"`` as the optimizer key. """ d = dict(d) # avoid mutating caller's dict - name = d.pop("name", "none") or d.pop("type", "none") + name = d.pop("name", None) or d.pop("type", "none") if isinstance(name, type): name = name.__name__.lower() elif isinstance(name, str): @@ -172,6 +172,8 @@ def parse_dict(cls, d: dict): return OptimizerParams.AdamW(**d) elif name == "sgd": return OptimizerParams.SGD(**d) + elif name == "none": + return OptimizerParams.NoneOptimizer() else: raise ValueError(f"Unknown optimizer type: {name.lower()}") @@ -444,7 +446,7 @@ def parse_dict(cls, d: dict): Accepts either ``"name"`` or ``"type"`` as the scheduler key. """ d = dict(d) # avoid mutating caller's dict - name = d.pop("name", "none") or d.pop("type", "none") + name = d.pop("name", None) or d.pop("type", "none") if isinstance(name, type): name = name.__name__.lower() elif isinstance(name, str): From 8f5d3cb66669c13b79f0c81e556065ba1e3bceda Mon Sep 17 00:00:00 2001 From: Cedric Lim Date: Wed, 4 Mar 2026 20:13:00 -0800 Subject: [PATCH 3/3] Hotfix w/ better name or type checking --- src/quantem/core/ml/optimizer_mixin.py | 12 +++++- tests/ml/test_optimizermixin.py | 59 +++++++++++++++++++++++++- 2 files changed, 67 insertions(+), 4 deletions(-) diff --git a/src/quantem/core/ml/optimizer_mixin.py b/src/quantem/core/ml/optimizer_mixin.py index 94c12518..d30229da 100644 --- a/src/quantem/core/ml/optimizer_mixin.py +++ b/src/quantem/core/ml/optimizer_mixin.py @@ -159,7 +159,11 @@ def parse_dict(cls, d: dict): Accepts either ``"name"`` or ``"type"`` as the optimizer key. """ d = dict(d) # avoid mutating caller's dict - name = d.pop("name", None) or d.pop("type", "none") + name = d.pop("name", None) + type_ = d.pop("type", None) + name = name or type_ + if name is None: + raise ValueError("Must provide either 'name' or 'type' key") if isinstance(name, type): name = name.__name__.lower() elif isinstance(name, str): @@ -446,7 +450,11 @@ def parse_dict(cls, d: dict): Accepts either ``"name"`` or ``"type"`` as the scheduler key. """ d = dict(d) # avoid mutating caller's dict - name = d.pop("name", None) or d.pop("type", "none") + name = d.pop("name", None) + type_ = d.pop("type", None) + name = name or type_ + if name is None: + raise ValueError("Must provide either 'name' or 'type' key") if isinstance(name, type): name = name.__name__.lower() elif isinstance(name, str): diff --git a/tests/ml/test_optimizermixin.py b/tests/ml/test_optimizermixin.py index 6fd6c22a..3682d452 100644 --- a/tests/ml/test_optimizermixin.py +++ b/tests/ml/test_optimizermixin.py @@ -114,6 +114,61 @@ def test_parse_invalid_name_type_raises(self): OptimizerParams.parse_dict({"name": 42}) +# ─── parse_dict "name" vs "type" key handling ─────────────────────────────── + + +class TestOptimizerParseDictKeyHandling: + def test_parse_with_type_key(self): + result = OptimizerParams.parse_dict({"type": "adam", "lr": 0.01}) + assert isinstance(result, OptimizerParams.Adam) + assert result.lr == 0.01 + + def test_name_takes_precedence_over_type(self): + result = OptimizerParams.parse_dict({"name": "adam", "type": "sgd"}) + assert isinstance(result, OptimizerParams.Adam) + + def test_neither_name_nor_type_raises(self): + with pytest.raises(ValueError, match="Must provide either"): + OptimizerParams.parse_dict({"lr": 0.01}) + + def test_type_key_not_leaked_into_constructor(self): + """'type' should be popped from d so it doesn't become an unexpected kwarg.""" + result = OptimizerParams.parse_dict({"type": "sgd", "momentum": 0.9}) + assert isinstance(result, OptimizerParams.SGD) + assert result.momentum == 0.9 + + def test_both_keys_popped_when_name_used(self): + """Even when 'name' is used, 'type' should be popped so it doesn't leak.""" + result = OptimizerParams.parse_dict({"name": "adam", "type": "ignored", "lr": 0.05}) + assert isinstance(result, OptimizerParams.Adam) + assert result.lr == 0.05 + + +class TestSchedulerParseDictKeyHandling: + def test_parse_with_type_key(self): + result = SchedulerParams.parse_dict({"type": "plateau", "patience": 20}) + assert isinstance(result, SchedulerParams.Plateau) + assert result.patience == 20 + + def test_name_takes_precedence_over_type(self): + result = SchedulerParams.parse_dict({"name": "plateau", "type": "linear"}) + assert isinstance(result, SchedulerParams.Plateau) + + def test_neither_name_nor_type_raises(self): + with pytest.raises(ValueError, match="Must provide either"): + SchedulerParams.parse_dict({"patience": 20}) + + def test_type_key_not_leaked_into_constructor(self): + result = SchedulerParams.parse_dict({"type": "cyclic", "step_size_up": 50}) + assert isinstance(result, SchedulerParams.Cyclic) + assert result.step_size_up == 50 + + def test_both_keys_popped_when_name_used(self): + result = SchedulerParams.parse_dict({"name": "plateau", "type": "ignored", "patience": 5}) + assert isinstance(result, SchedulerParams.Plateau) + assert result.patience == 5 + + # ─── SchedulerParams defaults ─────────────────────────────────────────────── @@ -291,5 +346,5 @@ def test_parse_invalid_name_type_raises(self): SchedulerParams.parse_dict({"name": 3.14}) def test_parse_default_name_is_none(self): - result = SchedulerParams.parse_dict({}) - assert isinstance(result, SchedulerParams.NoneScheduler) + with pytest.raises(ValueError, match="Must provide either"): + SchedulerParams.parse_dict({})