Source code for ray.train.backend
import logging
from contextlib import nullcontext
from typing import TypeVar
from ray.train._internal.utils import Singleton
from ray.train._internal.worker_group import WorkerGroup
from ray.util.annotations import DeveloperAPI
from ray.widgets import make_table_html_repr
EncodedData = TypeVar("EncodedData")
logger = logging.getLogger(__name__)
@DeveloperAPI
class BackendConfig:
    """Parent class for configurations of training backend."""
    @property
    def backend_cls(self):
        return Backend
    @property
    def train_func_context(self):
        return nullcontext
    def _repr_html_(self) -> str:
        return make_table_html_repr(obj=self, title=type(self).__name__)
[docs]
@DeveloperAPI
class Backend(metaclass=Singleton):
    """Singleton for distributed communication backend.
    Attributes:
        share_cuda_visible_devices: If True, each worker
            process will have CUDA_VISIBLE_DEVICES set as the visible device
            IDs of all workers on the same node for this training instance.
            If False, each worker will have CUDA_VISIBLE_DEVICES set to the
            device IDs allocated by Ray for that worker.
    """
    share_cuda_visible_devices: bool = False
[docs]
    def on_start(self, worker_group: WorkerGroup, backend_config: BackendConfig):
        """Logic for starting this backend."""
        pass 
[docs]
    def on_shutdown(self, worker_group: WorkerGroup, backend_config: BackendConfig):
        """Logic for shutting down the backend."""
        pass 
[docs]
    def on_training_start(
        self, worker_group: WorkerGroup, backend_config: BackendConfig
    ):
        """Logic ran right before training is started.
        Session API is available at this point."""
        pass