Source code for ray.train

# Try import ray[train] core requirements (defined in setup.py)
# isort: off
try:
    import fsspec  # noqa: F401
    import pandas  # noqa: F401
    import pyarrow  # noqa: F401
    import requests  # noqa: F401
except ImportError as exc:
    raise ImportError(
        "Can't import ray.train as some dependencies are missing. "
        'Run `pip install "ray[train]"` to fix.'
    ) from exc
# isort: on


from ray._private.usage import usage_lib
from ray.air.config import CheckpointConfig, FailureConfig, RunConfig, ScalingConfig
from ray.air.result import Result

# Import this first so it can be used in other modules
from ray.train._checkpoint import Checkpoint
from ray.train._internal.data_config import DataConfig
from ray.train._internal.session import get_checkpoint, get_dataset_shard, report
from ray.train._internal.syncer import SyncConfig
from ray.train.backend import BackendConfig
from ray.train.constants import TRAIN_DATASET_KEY
from ray.train.context import get_context
from ray.train.trainer import TrainingIterator

from ray.train.v2._internal.constants import is_v2_enabled

if is_v2_enabled():
    from ray.train.v2.api.callback import UserCallback  # noqa: F811
    from ray.train.v2.api.config import (  # noqa: F811
        FailureConfig,
        RunConfig,
        ScalingConfig,
    )
    from ray.train.v2.api.result import Result  # noqa: F811
    from ray.train.v2.api.train_fn_utils import (  # noqa: F811
        get_checkpoint,
        get_context,
        get_dataset_shard,
        report,
    )


usage_lib.record_library_usage("train")


__all__ = [
    "get_checkpoint",
    "get_context",
    "get_dataset_shard",
    "report",
    "BackendConfig",
    "Checkpoint",
    "CheckpointConfig",
    "DataConfig",
    "FailureConfig",
    "Result",
    "RunConfig",
    "ScalingConfig",
    "SyncConfig",
    "TrainingIterator",
    "TRAIN_DATASET_KEY",
]

get_checkpoint.__module__ = "ray.train"
get_context.__module__ = "ray.train"
get_dataset_shard.__module__ = "ray.train"
report.__module__ = "ray.train"
BackendConfig.__module__ = "ray.train"
Checkpoint.__module__ = "ray.train"
CheckpointConfig.__module__ = "ray.train"
DataConfig.__module__ = "ray.train"
FailureConfig.__module__ = "ray.train"
Result.__module__ = "ray.train"
RunConfig.__module__ = "ray.train"
ScalingConfig.__module__ = "ray.train"
SyncConfig.__module__ = "ray.train"
TrainingIterator.__module__ = "ray.train"

if is_v2_enabled():
    __all__.append("UserCallback")
    UserCallback.__module__ = "ray.train"


# DO NOT ADD ANYTHING AFTER THIS LINE.