Source code for ray.tune.search.search_algorithm
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from ray.util.annotations import DeveloperAPI
if TYPE_CHECKING:
from ray.tune.experiment import Experiment
@DeveloperAPI
class SearchAlgorithm:
"""Interface of an event handler API for hyperparameter search.
Unlike TrialSchedulers, SearchAlgorithms will not have the ability
to modify the execution (i.e., stop and pause trials).
Trials added manually (i.e., via the Client API) will also notify
this class upon new events, so custom search algorithms should
maintain a list of trials ID generated from this class.
See also: `ray.tune.search.BasicVariantGenerator`.
"""
_finished = False
_metric = None
@property
def metric(self):
return self._metric
def set_search_properties(
self, metric: Optional[str], mode: Optional[str], config: Dict, **spec
) -> bool:
"""Pass search properties to search algorithm.
This method acts as an alternative to instantiating search algorithms
with their own specific search spaces. Instead they can accept a
Tune config through this method.
The search algorithm will usually pass this method to their
``Searcher`` instance.
Args:
metric: Metric to optimize
mode: One of ["min", "max"]. Direction to optimize.
config: Tune config dict.
**spec: Any kwargs for forward compatiblity.
Info like Experiment.PUBLIC_KEYS is provided through here.
"""
if self._metric and metric:
return False
if metric:
self._metric = metric
return True
@property
def total_samples(self):
"""Get number of total trials to be generated"""
return 0
def add_configurations(
self, experiments: Union["Experiment", List["Experiment"], Dict[str, Dict]]
):
"""Tracks given experiment specifications.
Arguments:
experiments: Experiments to run.
"""
raise NotImplementedError
def next_trial(self):
"""Returns single Trial object to be queued into the TrialRunner.
Returns:
trial: Returns a Trial object.
"""
raise NotImplementedError
def on_trial_result(self, trial_id: str, result: Dict):
"""Called on each intermediate result returned by a trial.
This will only be called when the trial is in the RUNNING state.
Arguments:
trial_id: Identifier for the trial.
result: Result dictionary.
"""
pass
def on_trial_complete(
self, trial_id: str, result: Optional[Dict] = None, error: bool = False
):
"""Notification for the completion of trial.
Arguments:
trial_id: Identifier for the trial.
result: Defaults to None. A dict will
be provided with this notification when the trial is in
the RUNNING state AND either completes naturally or
by manual termination.
error: Defaults to False. True if the trial is in
the RUNNING state and errors.
"""
pass
def is_finished(self) -> bool:
"""Returns True if no trials left to be queued into TrialRunner.
Can return True before all trials have finished executing.
"""
return self._finished
def set_finished(self):
"""Marks the search algorithm as finished."""
self._finished = True
def has_checkpoint(self, dirpath: str) -> bool:
"""Should return False if restoring is not implemented."""
return False
def save_to_dir(self, dirpath: str, **kwargs):
"""Saves a search algorithm."""
pass
def restore_from_dir(self, dirpath: str):
"""Restores a search algorithm along with its wrapped state."""
pass