Source code for boost_loss.resuming

from __future__ import annotations

from typing import Sequence

import numpy as np
from numpy.typing import NDArray

from .base import LossBase


[docs]class ResumingLoss(LossBase): def __init__( self, losses: Sequence[LossBase], *, weights: Sequence[float] | None = None, interval: int = 1, random_state: int | None = None, ) -> None: self.losses = losses if weights is None: self.weights = np.ones_like(losses) else: self.weights = np.array(weights) self.interval = interval self.random_state = random_state if self.random_state is None: if weights is not None: raise ValueError("weights must be None when random_state is None") else: self.random = np.random.RandomState(self.random_state) self._count = 0 self._idx = 0
[docs] def grad_hess(self, y_true: NDArray, y_pred: NDArray) -> tuple[NDArray, NDArray]: if self._count % self.interval == 0: if self.random_state is None: self._idx = self.random.choice(len(self.losses), p=self.weights) else: self._idx = (self._count // self.interval) % len(self.losses) self._count += 1 return self.losses[self._idx].grad_hess(y_true=y_true, y_pred=y_pred)
[docs] def loss(self, y_true: NDArray, y_pred: NDArray) -> NDArray | float: return self.losses[self._idx].loss(y_true=y_true, y_pred=y_pred)