import itertools
import logging
import os
import pathlib
import re
from typing import (
    TYPE_CHECKING,
    Callable,
    Iterator,
    List,
    Optional,
    Tuple,
    TypeVar,
    Union,
)
import numpy as np
from ray.data._internal.progress_bar import ProgressBar
from ray.data._internal.remote_fn import cached_remote_fn
from ray.data._internal.util import RetryingPyFileSystem
from ray.data.block import BlockMetadata
from ray.data.datasource.partitioning import Partitioning
from ray.util.annotations import DeveloperAPI
if TYPE_CHECKING:
    import pyarrow
logger = logging.getLogger(__name__)
def _handle_read_os_error(error: OSError, paths: Union[str, List[str]]) -> str:
    # NOTE: this is not comprehensive yet, and should be extended as more errors arise.
    # NOTE: The latter patterns are raised in Arrow 10+, while the former is raised in
    # Arrow < 10.
    aws_error_pattern = (
        r"^(?:(.*)AWS Error \[code \d+\]: No response body\.(.*))|"
        r"(?:(.*)AWS Error UNKNOWN \(HTTP status 400\) during HeadObject operation: "
        r"No response body\.(.*))|"
        r"(?:(.*)AWS Error ACCESS_DENIED during HeadObject operation: No response "
        r"body\.(.*))$"
    )
    if re.match(aws_error_pattern, str(error)):
        # Specially handle AWS error when reading files, to give a clearer error
        # message to avoid confusing users. The real issue is most likely that the AWS
        # S3 file credentials have not been properly configured yet.
        if isinstance(paths, str):
            # Quote to highlight single file path in error message for better
            # readability. List of file paths will be shown up as ['foo', 'boo'],
            # so only quote single file path here.
            paths = f'"{paths}"'
        raise OSError(
            (
                f"Failing to read AWS S3 file(s): {paths}. "
                "Please check that file exists and has properly configured access. "
                "You can also run AWS CLI command to get more detailed error message "
                "(e.g., aws s3 ls <file-name>). "
                "See https://awscli.amazonaws.com/v2/documentation/api/latest/reference/s3/index.html "  # noqa
                "and https://docs.ray.io/en/latest/data/creating-datasets.html#reading-from-remote-storage "  # noqa
                "for more information."
            )
        )
    else:
        raise error
def _expand_paths(
    paths: List[str],
    filesystem: "RetryingPyFileSystem",
    partitioning: Optional[Partitioning],
    ignore_missing_paths: bool = False,
) -> Iterator[Tuple[str, int]]:
    """Get the file sizes for all provided file paths."""
    from pyarrow.fs import LocalFileSystem
    from ray.data.datasource.file_based_datasource import (
        FILE_SIZE_FETCH_PARALLELIZATION_THRESHOLD,
    )
    from ray.data.datasource.path_util import _is_http_url, _unwrap_protocol
    # We break down our processing paths into a few key cases:
    # 1. If len(paths) < threshold, fetch the file info for the individual files/paths
    #    serially.
    # 2. If all paths are contained under the same parent directory (or base directory,
    #    if using partitioning), fetch all file infos at this prefix and filter to the
    #    provided paths on the client; this should be a single file info request.
    # 3. If more than threshold requests required, parallelize them via Ray tasks.
    # 1. Small # of paths case.
    is_local = isinstance(filesystem, LocalFileSystem)
    if isinstance(filesystem, RetryingPyFileSystem):
        is_local = isinstance(filesystem.unwrap(), LocalFileSystem)
    if (
        len(paths) < FILE_SIZE_FETCH_PARALLELIZATION_THRESHOLD
        # Local file systems are very fast to hit.
        or is_local
    ):
        yield from _get_file_infos_serial(paths, filesystem, ignore_missing_paths)
    else:
        # 2. Common path prefix case.
        # Get longest common path of all paths.
        common_path = os.path.commonpath(paths)
        # If parent directory (or base directory, if using partitioning) is common to
        # all paths, fetch all file infos at that prefix and filter the response to the
        # provided paths.
        if not _is_http_url(common_path) and (
            (
                partitioning is not None
                and common_path == _unwrap_protocol(partitioning.base_dir)
            )
            or all(str(pathlib.Path(path).parent) == common_path for path in paths)
        ):
            yield from _get_file_infos_common_path_prefix(
                paths, common_path, filesystem, ignore_missing_paths
            )
        # 3. Parallelization case.
        else:
            # Parallelize requests via Ray tasks.
            yield from _get_file_infos_parallel(paths, filesystem, ignore_missing_paths)
def _get_file_infos_serial(
    paths: List[str],
    filesystem: "RetryingPyFileSystem",
    ignore_missing_paths: bool = False,
) -> Iterator[Tuple[str, int]]:
    for path in paths:
        yield from _get_file_infos(path, filesystem, ignore_missing_paths)
def _get_file_infos_common_path_prefix(
    paths: List[str],
    common_path: str,
    filesystem: "pyarrow.fs.FileSystem",
    ignore_missing_paths: bool = False,
) -> Iterator[Tuple[str, int]]:
    path_to_size = {path: None for path in paths}
    for path, file_size in _get_file_infos(
        common_path, filesystem, ignore_missing_paths
    ):
        if path in path_to_size:
            path_to_size[path] = file_size
    # Check if all `paths` have file size metadata.
    # If any of paths has no file size, fall back to get files metadata in parallel.
    # This can happen when path is a directory, but not a file.
    have_missing_path = False
    for path in paths:
        if path_to_size[path] is None:
            logger.debug(
                f"Finding path {path} not have file size metadata. "
                "Fall back to get files metadata in parallel for all paths."
            )
            have_missing_path = True
            break
    if have_missing_path:
        # Parallelize requests via Ray tasks.
        yield from _get_file_infos_parallel(paths, filesystem, ignore_missing_paths)
    else:
        # Iterate over `paths` to yield each path in original order.
        # NOTE: do not iterate over `path_to_size` because the dictionary skips
        # duplicated path, while `paths` might contain duplicated path if one wants
        # to read same file multiple times.
        for path in paths:
            yield path, path_to_size[path]
def _get_file_infos_parallel(
    paths: List[str],
    filesystem: "RetryingPyFileSystem",
    ignore_missing_paths: bool = False,
) -> Iterator[Tuple[str, int]]:
    from ray.data.datasource.file_based_datasource import (
        PATHS_PER_FILE_SIZE_FETCH_TASK,
        _unwrap_s3_serialization_workaround,
        _wrap_s3_serialization_workaround,
    )
    logger.warning(
        f"Expanding {len(paths)} path(s). This may be a HIGH LATENCY "
        f"operation on some cloud storage services. Moving all the "
        "paths to a common parent directory will lead to faster "
        "metadata fetching."
    )
    # Capture the filesystem in the fetcher func closure, but wrap it in our
    # serialization workaround to make sure that the pickle roundtrip works as expected.
    filesystem = _wrap_s3_serialization_workaround(filesystem)
    def _file_infos_fetcher(paths: List[str]) -> List[Tuple[str, int]]:
        fs = _unwrap_s3_serialization_workaround(filesystem)
        return list(
            itertools.chain.from_iterable(
                _get_file_infos(path, fs, ignore_missing_paths) for path in paths
            )
        )
    yield from _fetch_metadata_parallel(
        paths, _file_infos_fetcher, PATHS_PER_FILE_SIZE_FETCH_TASK
    )
Uri = TypeVar("Uri")
Meta = TypeVar("Meta")
def _fetch_metadata_parallel(
    uris: List[Uri],
    fetch_func: Callable[[List[Uri]], List[Meta]],
    desired_uris_per_task: int,
    **ray_remote_args,
) -> Iterator[Meta]:
    """Fetch file metadata in parallel using Ray tasks."""
    remote_fetch_func = cached_remote_fn(fetch_func)
    if ray_remote_args:
        remote_fetch_func = remote_fetch_func.options(**ray_remote_args)
    # Choose a parallelism that results in a # of metadata fetches per task that
    # dominates the Ray task overhead while ensuring good parallelism.
    # Always launch at least 2 parallel fetch tasks.
    parallelism = max(len(uris) // desired_uris_per_task, 2)
    metadata_fetch_bar = ProgressBar(
        "Metadata Fetch Progress", total=parallelism, unit="task"
    )
    fetch_tasks = []
    for uri_chunk in np.array_split(uris, parallelism):
        if len(uri_chunk) == 0:
            continue
        fetch_tasks.append(remote_fetch_func.remote(uri_chunk))
    results = metadata_fetch_bar.fetch_until_complete(fetch_tasks)
    yield from itertools.chain.from_iterable(results)
def _get_file_infos(
    path: str, filesystem: "RetryingPyFileSystem", ignore_missing_path: bool = False
) -> List[Tuple[str, int]]:
    """Get the file info for all files at or under the provided path."""
    from pyarrow.fs import FileType
    file_infos = []
    try:
        file_info = filesystem.get_file_info(path)
    except OSError as e:
        _handle_read_os_error(e, path)
    if file_info.type == FileType.Directory:
        for file_path, file_size in _expand_directory(path, filesystem):
            file_infos.append((file_path, file_size))
    elif file_info.type == FileType.File:
        file_infos.append((path, file_info.size))
    elif file_info.type == FileType.NotFound and ignore_missing_path:
        pass
    else:
        raise FileNotFoundError(path)
    return file_infos
def _expand_directory(
    path: str,
    filesystem: "RetryingPyFileSystem",
    exclude_prefixes: Optional[List[str]] = None,
    ignore_missing_path: bool = False,
) -> List[Tuple[str, int]]:
    """
    Expand the provided directory path to a list of file paths.
    Args:
        path: The directory path to expand.
        filesystem: The filesystem implementation that should be used for
            reading these files.
        exclude_prefixes: The file relative path prefixes that should be
            excluded from the returned file set. Default excluded prefixes are
            "." and "_".
    Returns:
        An iterator of (file_path, file_size) tuples.
    """
    if exclude_prefixes is None:
        exclude_prefixes = [".", "_"]
    from pyarrow.fs import FileSelector
    selector = FileSelector(path, recursive=True, allow_not_found=ignore_missing_path)
    files = filesystem.get_file_info(selector)
    base_path = selector.base_dir
    out = []
    for file_ in files:
        if not file_.is_file:
            continue
        file_path = file_.path
        if not file_path.startswith(base_path):
            continue
        relative = file_path[len(base_path) :]
        if any(relative.startswith(prefix) for prefix in exclude_prefixes):
            continue
        out.append((file_path, file_.size))
    # We sort the paths to guarantee a stable order.
    return sorted(out)