import inspect
import logging
import types
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, Union
import ray
from ray.tune.execution.placement_groups import (
    PlacementGroupFactory,
    resource_dict_to_pg_factory,
)
from ray.tune.registry import _ParameterRegistry
from ray.util.annotations import PublicAPI
if TYPE_CHECKING:
    from ray.tune.trainable import Trainable
logger = logging.getLogger(__name__)
[docs]
@PublicAPI(stability="beta")
def with_parameters(trainable: Union[Type["Trainable"], Callable], **kwargs):
    """Wrapper for trainables to pass arbitrary large data objects.
    This wrapper function will store all passed parameters in the Ray
    object store and retrieve them when calling the function. It can thus
    be used to pass arbitrary data, even datasets, to Tune trainables.
    This can also be used as an alternative to ``functools.partial`` to pass
    default arguments to trainables.
    When used with the function API, the trainable function is called with
    the passed parameters as keyword arguments. When used with the class API,
    the ``Trainable.setup()`` method is called with the respective kwargs.
    If the data already exists in the object store (are instances of
    ObjectRef), using ``tune.with_parameters()`` is not necessary. You can
    instead pass the object refs to the training function via the ``config``
    or use Python partials.
    Args:
        trainable: Trainable to wrap.
        **kwargs: parameters to store in object store.
    Function API example:
    .. code-block:: python
        from ray import train, tune
        def train_fn(config, data=None):
            for sample in data:
                loss = update_model(sample)
                train.report(loss=loss)
        data = HugeDataset(download=True)
        tuner = Tuner(
            tune.with_parameters(train_fn, data=data),
            # ...
        )
        tuner.fit()
    Class API example:
    .. code-block:: python
        from ray import tune
        class MyTrainable(tune.Trainable):
            def setup(self, config, data=None):
                self.data = data
                self.iter = iter(self.data)
                self.next_sample = next(self.iter)
            def step(self):
                loss = update_model(self.next_sample)
                try:
                    self.next_sample = next(self.iter)
                except StopIteration:
                    return {"loss": loss, done: True}
                return {"loss": loss}
        data = HugeDataset(download=True)
        tuner = Tuner(
            tune.with_parameters(MyTrainable, data=data),
            # ...
        )
    """
    from ray.tune.trainable import Trainable
    if not callable(trainable) or (
        inspect.isclass(trainable) and not issubclass(trainable, Trainable)
    ):
        raise ValueError(
            f"`tune.with_parameters() only works with function trainables "
            f"or classes that inherit from `tune.Trainable()`. Got type: "
            f"{type(trainable)}."
        )
    parameter_registry = _ParameterRegistry()
    ray._private.worker._post_init_hooks.append(parameter_registry.flush)
    # Objects are moved into the object store
    prefix = f"{str(trainable)}_"
    for k, v in kwargs.items():
        parameter_registry.put(prefix + k, v)
    trainable_name = getattr(trainable, "__name__", "tune_with_parameters")
    keys = set(kwargs.keys())
    if inspect.isclass(trainable):
        # Class trainable
        class _Inner(trainable):
            def setup(self, config):
                setup_kwargs = {}
                for k in keys:
                    setup_kwargs[k] = parameter_registry.get(prefix + k)
                super(_Inner, self).setup(config, **setup_kwargs)
        trainable_with_params = _Inner
    else:
        # Function trainable
        def inner(config):
            fn_kwargs = {}
            for k in keys:
                fn_kwargs[k] = parameter_registry.get(prefix + k)
            return trainable(config, **fn_kwargs)
        trainable_with_params = inner
        if hasattr(trainable, "__mixins__"):
            trainable_with_params.__mixins__ = trainable.__mixins__
        # If the trainable has been wrapped with `tune.with_resources`, we should
        # keep the `_resources` attribute around
        if hasattr(trainable, "_resources"):
            trainable_with_params._resources = trainable._resources
    trainable_with_params.__name__ = trainable_name
    return trainable_with_params 
[docs]
@PublicAPI(stability="beta")
def with_resources(
    trainable: Union[Type["Trainable"], Callable],
    resources: Union[
        Dict[str, float],
        PlacementGroupFactory,
        Callable[[dict], PlacementGroupFactory],
    ],
):
    """Wrapper for trainables to specify resource requests.
    This wrapper allows specification of resource requirements for a specific
    trainable. It will override potential existing resource requests (use
    with caution!).
    The main use case is to request resources for function trainables when used
    with the Tuner() API.
    Class trainables should usually just implement the ``default_resource_request()``
    method.
    Args:
        trainable: Trainable to wrap.
        resources: Resource dict, placement group factory, or callable that takes
            in a config dict and returns a placement group factory.
    Example:
    .. code-block:: python
        from ray import tune
        from ray.tune.tuner import Tuner
        def train_fn(config):
            return len(ray.get_gpu_ids())  # Returns 2
        tuner = Tuner(
            tune.with_resources(train_fn, resources={"gpu": 2}),
            # ...
        )
        results = tuner.fit()
    """
    from ray.tune.trainable import Trainable
    if not callable(trainable) or (
        inspect.isclass(trainable) and not issubclass(trainable, Trainable)
    ):
        raise ValueError(
            f"`tune.with_resources() only works with function trainables "
            f"or classes that inherit from `tune.Trainable()`. Got type: "
            f"{type(trainable)}."
        )
    if isinstance(resources, PlacementGroupFactory):
        pgf = resources
    elif isinstance(resources, dict):
        pgf = resource_dict_to_pg_factory(resources)
    elif callable(resources):
        pgf = resources
    else:
        raise ValueError(
            f"Invalid resource type for `with_resources()`: {type(resources)}"
        )
    if not inspect.isclass(trainable):
        if isinstance(trainable, types.MethodType):
            # Methods cannot set arbitrary attributes, so we have to wrap them
            def _trainable(config):
                return trainable(config)
            _trainable._resources = pgf
            return _trainable
        # Just set an attribute. This will be resolved later in `wrap_function()`.
        try:
            trainable._resources = pgf
        except AttributeError as e:
            raise RuntimeError(
                "Could not use `tune.with_resources()` on the supplied trainable. "
                "Wrap your trainable in a regular function before passing it "
                "to Ray Tune."
            ) from e
    else:
        class ResourceTrainable(trainable):
            @classmethod
            def default_resource_request(
                cls, config: Dict[str, Any]
            ) -> Optional[PlacementGroupFactory]:
                if not isinstance(pgf, PlacementGroupFactory) and callable(pgf):
                    return pgf(config)
                return pgf
        ResourceTrainable.__name__ = trainable.__name__
        trainable = ResourceTrainable
    return trainable