Source code for ray.train.v2.api.callback
from typing import Any, Dict, List, Optional
from ray.train import Checkpoint
from ray.train.v2._internal.execution.context import TrainRunContext
from ray.util.annotations import DeveloperAPI
@DeveloperAPI
class RayTrainCallback:
    """Base Ray Train callback interface."""
    pass
@DeveloperAPI
class UserCallback(RayTrainCallback):
    """Callback interface for custom user-defined callbacks to handling events
    during training.
    This callback is called on the Ray Train controller process, not on the
    worker processes.
    """
[docs]
    def after_report(
        self,
        run_context: TrainRunContext,
        metrics: List[Dict[str, Any]],
        checkpoint: Optional[Checkpoint],
    ):
        """Called after all workers have reported a metric + checkpoint
        via `ray.train.report`.
        Args:
            run_context: The `TrainRunContext` for the current training run.
            metrics: A list of metric dictionaries reported by workers,
                where metrics[i] is the metrics dict reported by worker i.
            checkpoint: A Checkpoint object that has been persisted to
                storage. This is None if no workers reported a checkpoint
                (e.g. `ray.train.report(metrics, checkpoint=None)`).
        """
        pass 
[docs]
    def after_exception(
        self, run_context: TrainRunContext, worker_exceptions: Dict[int, Exception]
    ):
        """Called after one or more workers have raised an exception.
        Args:
            run_context: The `TrainRunContext` for the current training run.
            worker_exceptions: A dict from worker world rank to the exception
                raised by that worker.
        """
        pass