import asyncio
from typing import Any, List, Optional
import ray
from ray.exceptions import (
    GetTimeoutError,
    RayChannelError,
    RayChannelTimeoutError,
    RayTaskError,
)
from ray.util.annotations import PublicAPI
def _process_return_vals(return_vals: List[Any], return_single_output: bool):
    """
    Process return values for return to the DAG caller. Any exceptions found in
    return_vals will be raised. If return_single_output=True, it indicates that
    the original DAG did not have a MultiOutputNode, so the DAG caller expects
    a single return value instead of a list.
    """
    # Check for exceptions.
    if isinstance(return_vals, Exception):
        raise return_vals
    for val in return_vals:
        if isinstance(val, RayTaskError):
            raise val.as_instanceof_cause()
    if return_single_output:
        assert len(return_vals) == 1
        return return_vals[0]
    return return_vals
[docs]
@PublicAPI(stability="alpha")
class CompiledDAGRef:
    """
    A reference to a compiled DAG execution result.
    This is a subclass of ObjectRef and resembles ObjectRef.  For example,
    similar to ObjectRef, ray.get() can be called on it to retrieve the result.
    However, there are several major differences:
    1. ray.get() can only be called once per CompiledDAGRef.
    2. ray.wait() is not supported.
    3. CompiledDAGRef cannot be copied, deep copied, or pickled.
    4. CompiledDAGRef cannot be passed as an argument to another task.
    """
[docs]
    def __init__(
        self,
        dag: "ray.experimental.CompiledDAG",
        execution_index: int,
        channel_index: Optional[int] = None,
    ):
        """
        Args:
            dag: The compiled DAG that generated this CompiledDAGRef.
            execution_index: The index of the execution for the DAG.
                A DAG can be executed multiple times, and execution index
                indicates which execution this CompiledDAGRef corresponds to.
            actor_execution_loop_refs: The actor execution loop refs that
                are used to execute the DAG. This can be used internally to
                check the task execution errors in case of exceptions.
            channel_index: The index of the DAG's output channel to fetch
                the result from. A DAG can have multiple output channels, and
                channel index indicates which channel this CompiledDAGRef
                corresponds to. If channel index is not provided, this CompiledDAGRef
                wraps the results from all output channels.
        """
        self._dag = dag
        self._execution_index = execution_index
        self._channel_index = channel_index
        # Whether ray.get() was called on this CompiledDAGRef.
        self._ray_get_called = False
        self._dag_output_channels = dag.dag_output_channels 
    def __str__(self):
        return (
            f"CompiledDAGRef({self._dag.get_id()}, "
            f"execution_index={self._execution_index}, "
            f"channel_index={self._channel_index})"
        )
    def __copy__(self):
        raise ValueError("CompiledDAGRef cannot be copied.")
    def __deepcopy__(self, memo):
        raise ValueError("CompiledDAGRef cannot be deep copied.")
    def __reduce__(self):
        raise ValueError("CompiledDAGRef cannot be pickled.")
    def __del__(self):
        # If the dag is already teardown, it should do nothing.
        if self._dag.is_teardown:
            return
        if self._ray_get_called:
            # get() was already called, no further cleanup is needed.
            return
        self._dag._delete_execution_results(self._execution_index, self._channel_index)
    def get(self, timeout: Optional[float] = None):
        if self._ray_get_called:
            raise ValueError(
                "ray.get() can only be called once "
                "on a CompiledDAGRef, and it was already called."
            )
        self._ray_get_called = True
        try:
            self._dag._execute_until(
                self._execution_index, self._channel_index, timeout
            )
            return_vals = self._dag._get_execution_results(
                self._execution_index, self._channel_index
            )
        except RayChannelTimeoutError:
            raise
        except RayChannelError as channel_error:
            # If we get a channel error, we'd like to call ray.get() on
            # the actor execution loop refs to check if this is a result
            # of task execution error which could not be passed down
            # (e.g., when a pure NCCL channel is used, it is only
            # able to send tensors, but not the wrapped exceptions).
            # In this case, we'd like to raise the task execution error
            # (which is the actual cause of the channel error) instead
            # of the channel error itself.
            # TODO(rui): determine which error to raise if multiple
            # actor task refs have errors.
            actor_execution_loop_refs = list(self._dag.worker_task_refs.values())
            try:
                ray.get(actor_execution_loop_refs, timeout=10)
            except GetTimeoutError as timeout_error:
                raise Exception(
                    "Timed out when getting the actor execution loop exception. "
                    "This should not happen, please file a GitHub issue."
                ) from timeout_error
            except Exception as execution_error:
                # Use 'from None' to suppress the context of the original
                # channel error, which is not useful to the user.
                raise execution_error from None
            else:
                raise channel_error
        except Exception:
            raise
        return _process_return_vals(return_vals, True) 
@PublicAPI(stability="alpha")
class CompiledDAGFuture:
    """
    A reference to a compiled DAG execution result, when executed with asyncio.
    This differs from CompiledDAGRef in that `await` must be called on the
    future to get the result, instead of `ray.get()`.
    This resembles async usage of ObjectRefs. For example, similar to
    ObjectRef, `await` can be called directly on the CompiledDAGFuture to
    retrieve the result.  However, there are several major differences:
    1. `await` can only be called once per CompiledDAGFuture.
    2. ray.wait() is not supported.
    3. CompiledDAGFuture cannot be copied, deep copied, or pickled.
    4. CompiledDAGFuture cannot be passed as an argument to another task.
    """
    def __init__(
        self,
        dag: "ray.experimental.CompiledDAG",
        execution_index: int,
        fut: "asyncio.Future",
        channel_index: Optional[int] = None,
    ):
        self._dag = dag
        self._execution_index = execution_index
        self._fut = fut
        self._channel_index = channel_index
    def __str__(self):
        return (
            f"CompiledDAGFuture({self._dag.get_id()}, "
            f"execution_index={self._execution_index}, "
            f"channel_index={self._channel_index})"
        )
    def __copy__(self):
        raise ValueError("CompiledDAGFuture cannot be copied.")
    def __deepcopy__(self, memo):
        raise ValueError("CompiledDAGFuture cannot be deep copied.")
    def __reduce__(self):
        raise ValueError("CompiledDAGFuture cannot be pickled.")
    def __await__(self):
        if self._fut is None:
            raise ValueError(
                "CompiledDAGFuture can only be awaited upon once, and it has "
                "already been awaited upon."
            )
        # NOTE(swang): If the object is zero-copy deserialized, then it will
        # stay in scope as long as this future is in scope. Therefore, we
        # delete self._fut here before we return the result to the user.
        fut = self._fut
        self._fut = None
        if not self._dag._has_execution_results(self._execution_index):
            result = yield from fut.__await__()
            self._dag._max_finished_execution_index += 1
            self._dag._cache_execution_results(self._execution_index, result)
        return_vals = self._dag._get_execution_results(
            self._execution_index, self._channel_index
        )
        return _process_return_vals(return_vals, True)
    def __del__(self):
        if self._dag.is_teardown:
            return
        if self._fut is None:
            # await() was already called, no further cleanup is needed.
            return
        self._dag._delete_execution_results(self._execution_index, self._channel_index)