Source code for ray.tune.stopper.trial_plateau
from collections import defaultdict, deque
from typing import Dict, Optional
import numpy as np
from ray.tune.stopper.stopper import Stopper
from ray.util.annotations import PublicAPI
[docs]
@PublicAPI
class TrialPlateauStopper(Stopper):
    """Early stop single trials when they reached a plateau.
    When the standard deviation of the `metric` result of a trial is
    below a threshold `std`, the trial plateaued and will be stopped
    early.
    Args:
        metric: Metric to check for convergence.
        std: Maximum metric standard deviation to decide if a
            trial plateaued. Defaults to 0.01.
        num_results: Number of results to consider for stdev
            calculation.
        grace_period: Minimum number of timesteps before a trial
            can be early stopped
        metric_threshold (Optional[float]):
            Minimum or maximum value the result has to exceed before it can
            be stopped early.
        mode: If a `metric_threshold` argument has been
            passed, this must be one of [min, max]. Specifies if we optimize
            for a large metric (max) or a small metric (min). If max, the
            `metric_threshold` has to be exceeded, if min the value has to
            be lower than `metric_threshold` in order to early stop.
    """
    def __init__(
        self,
        metric: str,
        std: float = 0.01,
        num_results: int = 4,
        grace_period: int = 4,
        metric_threshold: Optional[float] = None,
        mode: Optional[str] = None,
    ):
        self._metric = metric
        self._mode = mode
        self._std = std
        self._num_results = num_results
        self._grace_period = grace_period
        self._metric_threshold = metric_threshold
        if self._metric_threshold:
            if mode not in ["min", "max"]:
                raise ValueError(
                    f"When specifying a `metric_threshold`, the `mode` "
                    f"argument has to be one of [min, max]. "
                    f"Got: {mode}"
                )
        self._iter = defaultdict(lambda: 0)
        self._trial_results = defaultdict(lambda: deque(maxlen=self._num_results))
    def __call__(self, trial_id: str, result: Dict):
        metric_result = result.get(self._metric)
        self._trial_results[trial_id].append(metric_result)
        self._iter[trial_id] += 1
        # If still in grace period, do not stop yet
        if self._iter[trial_id] < self._grace_period:
            return False
        # If not enough results yet, do not stop yet
        if len(self._trial_results[trial_id]) < self._num_results:
            return False
        # If metric threshold value not reached, do not stop yet
        if self._metric_threshold is not None:
            if self._mode == "min" and metric_result > self._metric_threshold:
                return False
            elif self._mode == "max" and metric_result < self._metric_threshold:
                return False
        # Calculate stdev of last `num_results` results
        try:
            current_std = np.std(self._trial_results[trial_id])
        except Exception:
            current_std = float("inf")
        # If stdev is lower than threshold, stop early.
        return current_std < self._std
    def stop_all(self):
        return False