Source code for ray.tune.schedulers
import inspect
from ray._private.utils import get_function_args
from ray.tune.schedulers.async_hyperband import ASHAScheduler, AsyncHyperBandScheduler
from ray.tune.schedulers.hb_bohb import HyperBandForBOHB
from ray.tune.schedulers.hyperband import HyperBandScheduler
from ray.tune.schedulers.median_stopping_rule import MedianStoppingRule
from ray.tune.schedulers.pbt import (
    PopulationBasedTraining,
    PopulationBasedTrainingReplay,
)
from ray.tune.schedulers.resource_changing_scheduler import ResourceChangingScheduler
from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler
from ray.util import PublicAPI
def _pb2_importer():
    # PB2 introduces a GPy dependency which can be expensive, so we import
    # lazily.
    from ray.tune.schedulers.pb2 import PB2
    return PB2
# Values in this dictionary will be one two kinds:
#    class of the scheduler object to create
#    wrapper function to support a lazy import of the scheduler class
SCHEDULER_IMPORT = {
    "fifo": FIFOScheduler,
    "async_hyperband": AsyncHyperBandScheduler,
    "asynchyperband": AsyncHyperBandScheduler,
    "median_stopping_rule": MedianStoppingRule,
    "medianstopping": MedianStoppingRule,
    "hyperband": HyperBandScheduler,
    "hb_bohb": HyperBandForBOHB,
    "pbt": PopulationBasedTraining,
    "pbt_replay": PopulationBasedTrainingReplay,
    "pb2": _pb2_importer,
    "resource_changing": ResourceChangingScheduler,
}
[docs]
@PublicAPI(stability="beta")
def create_scheduler(
    scheduler,
    **kwargs,
):
    """Instantiate a scheduler based on the given string.
    This is useful for swapping between different schedulers.
    Args:
        scheduler: The scheduler to use.
        **kwargs: Scheduler parameters.
            These keyword arguments will be passed to the initialization
            function of the chosen scheduler.
    Returns:
        ray.tune.schedulers.trial_scheduler.TrialScheduler: The scheduler.
    Example:
        >>> from ray import tune
        >>> pbt_kwargs = {}
        >>> scheduler = tune.create_scheduler('pbt', **pbt_kwargs) # doctest: +SKIP
    """
    scheduler = scheduler.lower()
    if scheduler not in SCHEDULER_IMPORT:
        raise ValueError(
            f"The `scheduler` argument must be one of "
            f"{list(SCHEDULER_IMPORT)}. "
            f"Got: {scheduler}"
        )
    SchedulerClass = SCHEDULER_IMPORT[scheduler]
    if inspect.isfunction(SchedulerClass):
        # invoke the wrapper function to retrieve class
        SchedulerClass = SchedulerClass()
    scheduler_args = get_function_args(SchedulerClass)
    trimmed_kwargs = {k: v for k, v in kwargs.items() if k in scheduler_args}
    return SchedulerClass(**trimmed_kwargs) 
__all__ = [
    "TrialScheduler",
    "HyperBandScheduler",
    "AsyncHyperBandScheduler",
    "ASHAScheduler",
    "MedianStoppingRule",
    "FIFOScheduler",
    "PopulationBasedTraining",
    "PopulationBasedTrainingReplay",
    "HyperBandForBOHB",
    "ResourceChangingScheduler",
]