import atexit
import faulthandler
import functools
import inspect
import io
import json
import logging
import os
import sys
import threading
import time
import traceback
import urllib
import warnings
from abc import ABCMeta, abstractmethod
from collections.abc import Mapping
from contextlib import contextmanager
from dataclasses import dataclass
from typing import (
    IO,
    Any,
    AnyStr,
    Callable,
    Dict,
    Generic,
    Iterator,
    List,
    Optional,
    Sequence,
    Tuple,
    TypeVar,
    Union,
    overload,
)
from urllib.parse import urlparse
import colorama
import setproctitle
from typing import Literal, Protocol
import ray
import ray._private.worker
import ray._private.node
import ray._private.parameter
import ray._private.profiling as profiling
import ray._private.ray_constants as ray_constants
import ray._private.serialization as serialization
import ray._private.services as services
import ray._private.state
import ray._private.storage as storage
from ray._private.ray_logging.logging_config import LoggingConfig
# Ray modules
import ray.actor
import ray.cloudpickle as pickle  # noqa
import ray.job_config
import ray.remote_function
from ray import ActorID, JobID, Language, ObjectRef
from ray._raylet import raise_sys_exit_with_custom_error_message
from ray._raylet import ObjectRefGenerator, TaskID
from ray.runtime_env.runtime_env import _merge_runtime_env
from ray._private import ray_option_utils
from ray._private.client_mode_hook import client_mode_hook
from ray._private.function_manager import FunctionActorManager
from ray._private.inspect_util import is_cython
from ray._private.ray_logging import (
    global_worker_stdstream_dispatcher,
    stdout_deduplicator,
    stderr_deduplicator,
    setup_logger,
)
from ray._private.runtime_env.constants import RAY_JOB_CONFIG_JSON_ENV_VAR
from ray._private.runtime_env.py_modules import upload_py_modules_if_needed
from ray._private.runtime_env.working_dir import upload_working_dir_if_needed
from ray._private.runtime_env.setup_hook import (
    upload_worker_process_setup_hook_if_needed,
)
from ray._private.storage import _load_class
from ray._private.utils import get_ray_doc_version
from ray.exceptions import ObjectStoreFullError, RayError, RaySystemError, RayTaskError
from ray.experimental.internal_kv import (
    _initialize_internal_kv,
    _internal_kv_get,
    _internal_kv_initialized,
    _internal_kv_reset,
)
from ray.experimental import tqdm_ray
from ray.experimental.compiled_dag_ref import CompiledDAGRef
from ray.experimental.tqdm_ray import RAY_TQDM_MAGIC
from ray.util.annotations import Deprecated, DeveloperAPI, PublicAPI
from ray.util.debug import log_once
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from ray.util.tracing.tracing_helper import _import_from_string
from ray.widgets import Template
from ray.widgets.util import repr_with_fallback
SCRIPT_MODE = 0
WORKER_MODE = 1
LOCAL_MODE = 2
SPILL_WORKER_MODE = 3
RESTORE_WORKER_MODE = 4
# Logger for this module. It should be configured at the entry point
# into the program using Ray. Ray provides a default configuration at
# entry/init points.
logger = logging.getLogger(__name__)
T = TypeVar("T")
T0 = TypeVar("T0")
T1 = TypeVar("T1")
T2 = TypeVar("T2")
T3 = TypeVar("T3")
T4 = TypeVar("T4")
T5 = TypeVar("T5")
T6 = TypeVar("T6")
T7 = TypeVar("T7")
T8 = TypeVar("T8")
T9 = TypeVar("T9")
R = TypeVar("R")
DAGNode = TypeVar("DAGNode")
# Only used for type annotations as a placeholder
Undefined: Any = object()
# TypeVar for self-referential generics in `RemoteFunction[N]`.
RF = TypeVar("RF", bound="HasOptions")
class HasOptions(Protocol):
    def options(self: RF, **task_options) -> RF:
        ...
class RemoteFunctionNoArgs(HasOptions, Generic[R]):
    def __init__(self, function: Callable[[], R]) -> None:
        pass
    def remote(
        self,
    ) -> "ObjectRef[R]":
        ...
    def bind(
        self,
    ) -> "DAGNode[R]":
        ...
class RemoteFunction0(HasOptions, Generic[R, T0]):
    def __init__(self, function: Callable[[T0], R]) -> None:
        pass
    def remote(
        self,
        __arg0: "Union[T0, ObjectRef[T0]]",
    ) -> "ObjectRef[R]":
        ...
    def bind(
        self,
        __arg0: "Union[T0, DAGNode[T0]]",
    ) -> "DAGNode[R]":
        ...
class RemoteFunction1(HasOptions, Generic[R, T0, T1]):
    def __init__(self, function: Callable[[T0, T1], R]) -> None:
        pass
    def remote(
        self,
        __arg0: "Union[T0, ObjectRef[T0]]",
        __arg1: "Union[T1, ObjectRef[T1]]",
    ) -> "ObjectRef[R]":
        ...
    def bind(
        self,
        __arg0: "Union[T0, DAGNode[T0]]",
        __arg1: "Union[T1, DAGNode[T1]]",
    ) -> "DAGNode[R]":
        ...
class RemoteFunction2(HasOptions, Generic[R, T0, T1, T2]):
    def __init__(self, function: Callable[[T0, T1, T2], R]) -> None:
        pass
    def remote(
        self,
        __arg0: "Union[T0, ObjectRef[T0]]",
        __arg1: "Union[T1, ObjectRef[T1]]",
        __arg2: "Union[T2, ObjectRef[T2]]",
    ) -> "ObjectRef[R]":
        ...
    def bind(
        self,
        __arg0: "Union[T0, DAGNode[T0]]",
        __arg1: "Union[T1, DAGNode[T1]]",
        __arg2: "Union[T2, DAGNode[T2]]",
    ) -> "DAGNode[R]":
        ...
class RemoteFunction3(HasOptions, Generic[R, T0, T1, T2, T3]):
    def __init__(self, function: Callable[[T0, T1, T2, T3], R]) -> None:
        pass
    def remote(
        self,
        __arg0: "Union[T0, ObjectRef[T0]]",
        __arg1: "Union[T1, ObjectRef[T1]]",
        __arg2: "Union[T2, ObjectRef[T2]]",
        __arg3: "Union[T3, ObjectRef[T3]]",
    ) -> "ObjectRef[R]":
        ...
    def bind(
        self,
        __arg0: "Union[T0, DAGNode[T0]]",
        __arg1: "Union[T1, DAGNode[T1]]",
        __arg2: "Union[T2, DAGNode[T2]]",
        __arg3: "Union[T3, DAGNode[T3]]",
    ) -> "DAGNode[R]":
        ...
class RemoteFunction4(HasOptions, Generic[R, T0, T1, T2, T3, T4]):
    def __init__(self, function: Callable[[T0, T1, T2, T3, T4], R]) -> None:
        pass
    def remote(
        self,
        __arg0: "Union[T0, ObjectRef[T0]]",
        __arg1: "Union[T1, ObjectRef[T1]]",
        __arg2: "Union[T2, ObjectRef[T2]]",
        __arg3: "Union[T3, ObjectRef[T3]]",
        __arg4: "Union[T4, ObjectRef[T4]]",
    ) -> "ObjectRef[R]":
        ...
    def bind(
        self,
        __arg0: "Union[T0, DAGNode[T0]]",
        __arg1: "Union[T1, DAGNode[T1]]",
        __arg2: "Union[T2, DAGNode[T2]]",
        __arg3: "Union[T3, DAGNode[T3]]",
        __arg4: "Union[T4, DAGNode[T4]]",
    ) -> "DAGNode[R]":
        ...
class RemoteFunction5(HasOptions, Generic[R, T0, T1, T2, T3, T4, T5]):
    def __init__(self, function: Callable[[T0, T1, T2, T3, T4, T5], R]) -> None:
        pass
    def remote(
        self,
        __arg0: "Union[T0, ObjectRef[T0]]",
        __arg1: "Union[T1, ObjectRef[T1]]",
        __arg2: "Union[T2, ObjectRef[T2]]",
        __arg3: "Union[T3, ObjectRef[T3]]",
        __arg4: "Union[T4, ObjectRef[T4]]",
        __arg5: "Union[T5, ObjectRef[T5]]",
    ) -> "ObjectRef[R]":
        ...
    def bind(
        self,
        __arg0: "Union[T0, DAGNode[T0]]",
        __arg1: "Union[T1, DAGNode[T1]]",
        __arg2: "Union[T2, DAGNode[T2]]",
        __arg3: "Union[T3, DAGNode[T3]]",
        __arg4: "Union[T4, DAGNode[T4]]",
        __arg5: "Union[T5, DAGNode[T5]]",
    ) -> "DAGNode[R]":
        ...
class RemoteFunction6(HasOptions, Generic[R, T0, T1, T2, T3, T4, T5, T6]):
    def __init__(self, function: Callable[[T0, T1, T2, T3, T4, T5, T6], R]) -> None:
        pass
    def remote(
        self,
        __arg0: "Union[T0, ObjectRef[T0]]",
        __arg1: "Union[T1, ObjectRef[T1]]",
        __arg2: "Union[T2, ObjectRef[T2]]",
        __arg3: "Union[T3, ObjectRef[T3]]",
        __arg4: "Union[T4, ObjectRef[T4]]",
        __arg5: "Union[T5, ObjectRef[T5]]",
        __arg6: "Union[T6, ObjectRef[T6]]",
    ) -> "ObjectRef[R]":
        ...
    def bind(
        self,
        __arg0: "Union[T0, DAGNode[T0]]",
        __arg1: "Union[T1, DAGNode[T1]]",
        __arg2: "Union[T2, DAGNode[T2]]",
        __arg3: "Union[T3, DAGNode[T3]]",
        __arg4: "Union[T4, DAGNode[T4]]",
        __arg5: "Union[T5, DAGNode[T5]]",
        __arg6: "Union[T6, DAGNode[T6]]",
    ) -> "DAGNode[R]":
        ...
class RemoteFunction7(HasOptions, Generic[R, T0, T1, T2, T3, T4, T5, T6, T7]):
    def __init__(self, function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7], R]) -> None:
        pass
    def remote(
        self,
        __arg0: "Union[T0, ObjectRef[T0]]",
        __arg1: "Union[T1, ObjectRef[T1]]",
        __arg2: "Union[T2, ObjectRef[T2]]",
        __arg3: "Union[T3, ObjectRef[T3]]",
        __arg4: "Union[T4, ObjectRef[T4]]",
        __arg5: "Union[T5, ObjectRef[T5]]",
        __arg6: "Union[T6, ObjectRef[T6]]",
        __arg7: "Union[T7, ObjectRef[T7]]",
    ) -> "ObjectRef[R]":
        ...
    def bind(
        self,
        __arg0: "Union[T0, DAGNode[T0]]",
        __arg1: "Union[T1, DAGNode[T1]]",
        __arg2: "Union[T2, DAGNode[T2]]",
        __arg3: "Union[T3, DAGNode[T3]]",
        __arg4: "Union[T4, DAGNode[T4]]",
        __arg5: "Union[T5, DAGNode[T5]]",
        __arg6: "Union[T6, DAGNode[T6]]",
        __arg7: "Union[T7, DAGNode[T7]]",
    ) -> "DAGNode[R]":
        ...
class RemoteFunction8(HasOptions, Generic[R, T0, T1, T2, T3, T4, T5, T6, T7, T8]):
    def __init__(
        self, function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8], R]
    ) -> None:
        pass
    def remote(
        self,
        __arg0: "Union[T0, ObjectRef[T0]]",
        __arg1: "Union[T1, ObjectRef[T1]]",
        __arg2: "Union[T2, ObjectRef[T2]]",
        __arg3: "Union[T3, ObjectRef[T3]]",
        __arg4: "Union[T4, ObjectRef[T4]]",
        __arg5: "Union[T5, ObjectRef[T5]]",
        __arg6: "Union[T6, ObjectRef[T6]]",
        __arg7: "Union[T7, ObjectRef[T7]]",
        __arg8: "Union[T8, ObjectRef[T8]]",
    ) -> "ObjectRef[R]":
        ...
    def bind(
        self,
        __arg0: "Union[T0, DAGNode[T0]]",
        __arg1: "Union[T1, DAGNode[T1]]",
        __arg2: "Union[T2, DAGNode[T2]]",
        __arg3: "Union[T3, DAGNode[T3]]",
        __arg4: "Union[T4, DAGNode[T4]]",
        __arg5: "Union[T5, DAGNode[T5]]",
        __arg6: "Union[T6, DAGNode[T6]]",
        __arg7: "Union[T7, DAGNode[T7]]",
        __arg8: "Union[T8, DAGNode[T8]]",
    ) -> "DAGNode[R]":
        ...
class RemoteFunction9(HasOptions, Generic[R, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9]):
    def __init__(
        self, function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9], R]
    ) -> None:
        pass
    def remote(
        self,
        __arg0: "Union[T0, ObjectRef[T0]]",
        __arg1: "Union[T1, ObjectRef[T1]]",
        __arg2: "Union[T2, ObjectRef[T2]]",
        __arg3: "Union[T3, ObjectRef[T3]]",
        __arg4: "Union[T4, ObjectRef[T4]]",
        __arg5: "Union[T5, ObjectRef[T5]]",
        __arg6: "Union[T6, ObjectRef[T6]]",
        __arg7: "Union[T7, ObjectRef[T7]]",
        __arg8: "Union[T8, ObjectRef[T8]]",
        __arg9: "Union[T9, ObjectRef[T9]]",
    ) -> "ObjectRef[R]":
        ...
    def bind(
        self,
        __arg0: "Union[T0, DAGNode[T0]]",
        __arg1: "Union[T1, DAGNode[T1]]",
        __arg2: "Union[T2, DAGNode[T2]]",
        __arg3: "Union[T3, DAGNode[T3]]",
        __arg4: "Union[T4, DAGNode[T4]]",
        __arg5: "Union[T5, DAGNode[T5]]",
        __arg6: "Union[T6, DAGNode[T6]]",
        __arg7: "Union[T7, DAGNode[T7]]",
        __arg8: "Union[T8, DAGNode[T8]]",
        __arg9: "Union[T9, DAGNode[T9]]",
    ) -> "DAGNode[R]":
        ...
# Visible for testing.
def _unhandled_error_handler(e: Exception):
    logger.error(
        f"Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): {e}"
    )
class Worker:
    """A class used to define the control flow of a worker process.
    Note:
        The methods in this class are considered unexposed to the user. The
        functions outside of this class are considered exposed.
    Attributes:
        node (ray._private.node.Node): The node this worker is attached to.
        mode: The mode of the worker. One of SCRIPT_MODE, LOCAL_MODE, and
            WORKER_MODE.
    """
    def __init__(self):
        """Initialize a Worker object."""
        self.node = None
        self.mode = None
        self.actors = {}
        # When the worker is constructed. Record the original value of the
        # (CUDA_VISIBLE_DEVICES, ONEAPI_DEVICE_SELECTOR, ROCR_VISIBLE_DEVICES,
        # NEURON_RT_VISIBLE_CORES, TPU_VISIBLE_CHIPS, ..) environment variables.
        self.original_visible_accelerator_ids = (
            ray._private.utils.get_visible_accelerator_ids()
        )
        # A dictionary that maps from driver id to SerializationContext
        # TODO: clean up the SerializationContext once the job finished.
        self.serialization_context_map = {}
        self.function_actor_manager = FunctionActorManager(self)
        # This event is checked regularly by all of the threads so that they
        # know when to exit.
        self.threads_stopped = threading.Event()
        # If this is set, the next .remote call should drop into the
        # debugger, at the specified breakpoint ID.
        self.debugger_breakpoint = b""
        # If this is set, ray.get calls invoked on the object ID returned
        # by the worker should drop into the debugger at the specified
        # breakpoint ID.
        self.debugger_get_breakpoint = b""
        # If True, make the debugger external to the node this worker is
        # running on.
        self.ray_debugger_external = False
        self._load_code_from_local = False
        # Opened file descriptor to stdout/stderr for this python worker.
        self._enable_record_actor_task_log = (
            ray_constants.RAY_ENABLE_RECORD_ACTOR_TASK_LOGGING
        )
        self._out_file = None
        self._err_file = None
        # Create the lock here because the serializer will use it before
        # initializing Ray.
        self.lock = threading.RLock()
        # By default, don't show logs from other drivers. This is set to true by Serve
        # in order to stream logs from the controller and replica actors across
        # different drivers that connect to the same Serve instance.
        # See https://github.com/ray-project/ray/pull/35070.
        self._filter_logs_by_job = True
        # the debugger port for this worker
        self._debugger_port = None
        # Cache the job id from initialize_job_config() to optimize lookups.
        # This is on the critical path of ray.get()/put() calls.
        self._cached_job_id = None
        # Indicates whether the worker is connected to the Ray cluster.
        # It should be set to True in `connect` and False in `disconnect`.
        self._is_connected: bool = False
    @property
    def connected(self):
        """bool: True if Ray has been started and False otherwise."""
        return self._is_connected
    def set_is_connected(self, is_connected: bool):
        self._is_connected = is_connected
    @property
    def node_ip_address(self):
        self.check_connected()
        return self.node.node_ip_address
    @property
    def load_code_from_local(self):
        self.check_connected()
        return self._load_code_from_local
    @property
    def current_job_id(self):
        if self._cached_job_id is not None:
            return self._cached_job_id
        elif hasattr(self, "core_worker"):
            return self.core_worker.get_current_job_id()
        return JobID.nil()
    @property
    def actor_id(self):
        if hasattr(self, "core_worker"):
            return self.core_worker.get_actor_id()
        return ActorID.nil()
    @property
    def actor_name(self):
        if hasattr(self, "core_worker"):
            return self.core_worker.get_actor_name().decode("utf-8")
        return None
    @property
    def current_task_id(self):
        return self.core_worker.get_current_task_id()
    @property
    def current_task_name(self):
        return self.core_worker.get_current_task_name()
    @property
    def current_task_function_name(self):
        return self.core_worker.get_current_task_function_name()
    @property
    def current_node_id(self):
        return self.core_worker.get_current_node_id()
    @property
    def task_depth(self):
        return self.core_worker.get_task_depth()
    @property
    def namespace(self):
        return self.core_worker.get_job_config().ray_namespace
    @property
    def placement_group_id(self):
        return self.core_worker.get_placement_group_id()
    @property
    def worker_id(self):
        return self.core_worker.get_worker_id().binary()
    @property
    def should_capture_child_tasks_in_placement_group(self):
        return self.core_worker.should_capture_child_tasks_in_placement_group()
    @property
    def current_cluster_and_job(self):
        """Get the current session index and job id as pair."""
        assert isinstance(self.node.cluster_id, ray.ClusterID)
        assert isinstance(self.current_job_id, ray.JobID)
        return self.node.cluster_id, self.current_job_id
    @property
    def current_virtual_cluster_id(self):
        return os.environ.get(ray_constants.RAY_VIRTUAL_CLUSTER_ID_ENV_VAR, "")
    @property
    def runtime_env(self):
        """Get the runtime env in json format"""
        return self.core_worker.get_current_runtime_env()
    @property
    def debugger_port(self):
        """Get the debugger port for this worker"""
        worker_id = self.core_worker.get_worker_id()
        return ray._private.state.get_worker_debugger_port(worker_id)
    @property
    def job_logging_config(self):
        """Get the job's logging config for this worker"""
        if not hasattr(self, "core_worker"):
            return None
        job_config = self.core_worker.get_job_config()
        if not job_config.serialized_py_logging_config:
            return None
        logging_config = pickle.loads(job_config.serialized_py_logging_config)
        return logging_config
    def set_debugger_port(self, port):
        worker_id = self.core_worker.get_worker_id()
        ray._private.state.update_worker_debugger_port(worker_id, port)
    def set_cached_job_id(self, job_id):
        """Set the cached job id to speed `current_job_id()`."""
        self._cached_job_id = job_id
    @contextmanager
    def task_paused_by_debugger(self):
        """Use while the task is paused by debugger"""
        try:
            self.core_worker.update_task_is_debugger_paused(
                ray.get_runtime_context()._get_current_task_id(), True
            )
            yield
        finally:
            self.core_worker.update_task_is_debugger_paused(
                ray.get_runtime_context()._get_current_task_id(), False
            )
    @contextmanager
    def worker_paused_by_debugger(self):
        """
        Updates the worker num paused threads when the worker is paused by debugger
        """
        try:
            worker_id = self.core_worker.get_worker_id()
            ray._private.state.update_worker_num_paused_threads(worker_id, 1)
            yield
        finally:
            ray._private.state.update_worker_num_paused_threads(worker_id, -1)
    def set_err_file(self, err_file=Optional[IO[AnyStr]]) -> None:
        """Set the worker's err file where stderr is redirected to"""
        self._err_file = err_file
    def set_out_file(self, out_file=Optional[IO[AnyStr]]) -> None:
        """Set the worker's out file where stdout is redirected to"""
        self._out_file = out_file
    def record_task_log_start(self, task_id: TaskID, attempt_number: int):
        """Record the task log info when task starts executing for
        non concurrent actor tasks."""
        if not self._enable_record_actor_task_log and not self.actor_id.is_nil():
            # We are not recording actor task log if not enabled explicitly.
            # Recording actor task log is expensive and should be enabled only
            # when needed.
            # https://github.com/ray-project/ray/issues/35598
            return
        if not hasattr(self, "core_worker"):
            return
        self.core_worker.record_task_log_start(
            task_id,
            attempt_number,
            self.get_out_file_path(),
            self.get_err_file_path(),
            self.get_current_out_offset(),
            self.get_current_err_offset(),
        )
    def record_task_log_end(self, task_id: TaskID, attempt_number: int):
        """Record the task log info when task finishes executing for
        non concurrent actor tasks."""
        if not self._enable_record_actor_task_log and not self.actor_id.is_nil():
            # We are not recording actor task log if not enabled explicitly.
            # Recording actor task log is expensive and should be enabled only
            # when needed.
            # https://github.com/ray-project/ray/issues/35598
            return
        if not hasattr(self, "core_worker"):
            return
        self.core_worker.record_task_log_end(
            task_id,
            attempt_number,
            self.get_current_out_offset(),
            self.get_current_err_offset(),
        )
    def get_err_file_path(self) -> str:
        """Get the err log file path"""
        return self._err_file.name if self._err_file is not None else ""
    def get_out_file_path(self) -> str:
        """Get the out log file path"""
        return self._out_file.name if self._out_file is not None else ""
    def get_current_out_offset(self) -> int:
        """Get the current offset of the out file if seekable, else 0"""
        if self._out_file is not None and self._out_file.seekable():
            return self._out_file.tell()
        return 0
    def get_current_err_offset(self) -> int:
        """Get the current offset of the err file if seekable, else 0"""
        if self._err_file is not None and self._err_file.seekable():
            return self._err_file.tell()
        return 0
    def get_serialization_context(self):
        """Get the SerializationContext of the job that this worker is processing.
        Returns:
            The serialization context of the given job.
        """
        # This function needs to be protected by a lock, because it will be
        # called by`register_class_for_serialization`, as well as the import
        # thread, from different threads. Also, this function will recursively
        # call itself, so we use RLock here.
        job_id = self.current_job_id
        context_map = self.serialization_context_map
        with self.lock:
            if job_id not in context_map:
                # The job ID is nil before initializing Ray.
                if JobID.nil() in context_map:
                    # Transfer the serializer context used before initializing Ray.
                    context_map[job_id] = context_map.pop(JobID.nil())
                else:
                    context_map[job_id] = serialization.SerializationContext(self)
            return context_map[job_id]
    def check_connected(self):
        """Check if the worker is connected.
        Raises:
          Exception: An exception is raised if the worker is not connected.
        """
        if not self.connected:
            raise RaySystemError(
                "Ray has not been started yet. You can start Ray with 'ray.init()'."
            )
    def set_mode(self, mode):
        """Set the mode of the worker.
        The mode SCRIPT_MODE should be used if this Worker is a driver that is
        being run as a Python script or interactively in a shell. It will print
        information about task failures.
        The mode WORKER_MODE should be used if this Worker is not a driver. It
        will not print information about tasks.
        The mode LOCAL_MODE should be used if this Worker is a driver and if
        you want to run the driver in a manner equivalent to serial Python for
        debugging purposes. It will not send remote function calls to the
        scheduler and will instead execute them in a blocking fashion.
        Args:
            mode: One of SCRIPT_MODE, WORKER_MODE, and LOCAL_MODE.
        """
        self.mode = mode
    def set_load_code_from_local(self, load_code_from_local):
        self._load_code_from_local = load_code_from_local
    def put_object(
        self,
        value: Any,
        object_ref: Optional["ray.ObjectRef"] = None,
        owner_address: Optional[str] = None,
        _is_experimental_channel: bool = False,
    ):
        """Put value in the local object store with object reference `object_ref`.
        This assumes that the value for `object_ref` has not yet been placed in
        the local object store. If the plasma store is full, the worker will
        automatically retry up to DEFAULT_PUT_OBJECT_RETRIES times. Each
        retry will delay for an exponentially doubling amount of time,
        starting with DEFAULT_PUT_OBJECT_DELAY. After this, exception
        will be raised.
        Args:
            value: The value to put in the object store.
            object_ref: The object ref of the value to be
                put. If None, one will be generated.
            owner_address: The serialized address of object's owner.
            _is_experimental_channel: An experimental flag for mutable
                objects. If True, then the returned object will not have a
                valid value. The object must be written to using the
                ray.experimental.channel API before readers can read.
        Returns:
            ObjectRef: The object ref the object was put under.
        Raises:
            ray.exceptions.ObjectStoreFullError: This is raised if the attempt
                to store the object fails because the object store is full even
                after multiple retries.
        """
        # Make sure that the value is not an object ref.
        if isinstance(value, ObjectRef):
            raise TypeError(
                "Calling 'put' on an ray.ObjectRef is not allowed. "
                "If you really want to do this, you can wrap the "
                "ray.ObjectRef in a list and call 'put' on it."
            )
        if self.mode == LOCAL_MODE:
            assert (
                object_ref is None
            ), "Local Mode does not support inserting with an ObjectRef"
        try:
            serialized_value = self.get_serialization_context().serialize(value)
        except TypeError as e:
            sio = io.StringIO()
            ray.util.inspect_serializability(value, print_file=sio)
            msg = (
                "Could not serialize the put value "
                f"{repr(value)}:\n"
                f"{sio.getvalue()}"
            )
            raise TypeError(msg) from e
        # If the object is mutable, then the raylet should never read the
        # object. Instead, clients will keep the object pinned.
        pin_object = not _is_experimental_channel
        # This *must* be the first place that we construct this python
        # ObjectRef because an entry with 0 local references is created when
        # the object is Put() in the core worker, expecting that this python
        # reference will be created. If another reference is created and
        # removed before this one, it will corrupt the state in the
        # reference counter.
        return ray.ObjectRef(
            self.core_worker.put_serialized_object_and_increment_local_ref(
                serialized_value,
                object_ref=object_ref,
                pin_object=pin_object,
                owner_address=owner_address,
                _is_experimental_channel=_is_experimental_channel,
            ),
            # The initial local reference is already acquired internally.
            skip_adding_local_ref=True,
        )
    def raise_errors(self, data_metadata_pairs, object_refs):
        out = self.deserialize_objects(data_metadata_pairs, object_refs)
        if "RAY_IGNORE_UNHANDLED_ERRORS" in os.environ:
            return
        for e in out:
            _unhandled_error_handler(e)
    def deserialize_objects(self, data_metadata_pairs, object_refs):
        # Function actor manager or the import thread may call pickle.loads
        # at the same time which can lead to failed imports
        # TODO: We may be better off locking on all imports or injecting a lock
        # into pickle.loads (https://github.com/ray-project/ray/issues/16304)
        with self.function_actor_manager.lock:
            context = self.get_serialization_context()
            return context.deserialize_objects(data_metadata_pairs, object_refs)
    def get_objects(
        self,
        object_refs: list,
        timeout: Optional[float] = None,
        return_exceptions: bool = False,
        skip_deserialization: bool = False,
    ):
        """Get the values in the object store associated with the IDs.
        Return the values from the local object store for object_refs. This
        will block until all the values for object_refs have been written to
        the local object store.
        Args:
            object_refs: A list of the object refs
                whose values should be retrieved.
            timeout: The maximum amount of time in
                seconds to wait before returning.
            return_exceptions: If any of the objects deserialize to an
                Exception object, whether to return them as values in the
                returned list. If False, then the first found exception will be
                raised.
            skip_deserialization: If true, only the buffer will be released and
                the object associated with the buffer will not be deserailized.
        Returns:
            list: List of deserialized objects or None if skip_deserialization is True.
            bytes: UUID of the debugger breakpoint we should drop
                into or b"" if there is no breakpoint.
        """
        # Make sure that the values are object refs.
        for object_ref in object_refs:
            if not isinstance(object_ref, ObjectRef):
                raise TypeError(
                    f"Attempting to call `get` on the value {object_ref}, "
                    "which is not an ray.ObjectRef."
                )
        timeout_ms = (
            int(timeout * 1000) if timeout is not None and timeout != -1 else -1
        )
        data_metadata_pairs: List[
            Tuple[ray._raylet.Buffer, bytes]
        ] = self.core_worker.get_objects(
            object_refs,
            timeout_ms,
        )
        debugger_breakpoint = b""
        for data, metadata in data_metadata_pairs:
            if metadata:
                metadata_fields = metadata.split(b",")
                if len(metadata_fields) >= 2 and metadata_fields[1].startswith(
                    ray_constants.OBJECT_METADATA_DEBUG_PREFIX
                ):
                    debugger_breakpoint = metadata_fields[1][
                        len(ray_constants.OBJECT_METADATA_DEBUG_PREFIX) :
                    ]
        if skip_deserialization:
            return None, debugger_breakpoint
        values = self.deserialize_objects(data_metadata_pairs, object_refs)
        if not return_exceptions:
            # Raise exceptions instead of returning them to the user.
            for i, value in enumerate(values):
                if isinstance(value, RayError):
                    if isinstance(value, ray.exceptions.ObjectLostError):
                        global_worker.core_worker.dump_object_store_memory_usage()
                    if isinstance(value, RayTaskError):
                        raise value.as_instanceof_cause()
                    else:
                        raise value
        return values, debugger_breakpoint
    def main_loop(self):
        """The main loop a worker runs to receive and execute tasks."""
        def sigterm_handler(signum, frame):
            raise_sys_exit_with_custom_error_message(
                "The process receives a SIGTERM.", exit_code=1
            )
            # Note: shutdown() function is called from atexit handler.
        ray._private.utils.set_sigterm_handler(sigterm_handler)
        self.core_worker.run_task_loop()
        sys.exit(0)
    def print_logs(self):
        """Prints log messages from workers on all nodes in the same job."""
        subscriber = self.gcs_log_subscriber
        subscriber.subscribe()
        exception_type = ray.exceptions.RpcError
        localhost = services.get_node_ip_address()
        try:
            # Number of messages received from the last polling. When the batch
            # size exceeds 100 and keeps increasing, the worker and the user
            # probably will not be able to consume the log messages as rapidly
            # as they are coming in.
            # This is meaningful only for GCS subscriber.
            last_polling_batch_size = 0
            job_id_hex = self.current_job_id.hex()
            while True:
                # Exit if we received a signal that we should stop.
                if self.threads_stopped.is_set():
                    return
                data = subscriber.poll()
                # GCS subscriber only returns None on unavailability.
                if data is None:
                    last_polling_batch_size = 0
                    continue
                if (
                    self._filter_logs_by_job
                    and data["job"]
                    and data["job"] != job_id_hex
                ):
                    last_polling_batch_size = 0
                    continue
                data["localhost"] = localhost
                global_worker_stdstream_dispatcher.emit(data)
                lagging = 100 <= last_polling_batch_size < subscriber.last_batch_size
                if lagging:
                    logger.warning(
                        "The driver may not be able to keep up with the "
                        "stdout/stderr of the workers. To avoid forwarding "
                        "logs to the driver, use "
                        "'ray.init(log_to_driver=False)'."
                    )
                last_polling_batch_size = subscriber.last_batch_size
        except (OSError, exception_type) as e:
            logger.error(f"print_logs: {e}")
        finally:
            # Close the pubsub client to avoid leaking file descriptors.
            subscriber.close()
    def get_accelerator_ids_for_accelerator_resource(
        self, resource_name: str, resource_regex: str
    ) -> Union[List[str], List[int]]:
        """Get the accelerator IDs that are assigned to the given accelerator resource.
        Args:
            resource_name: The name of the resource.
            resource_regex: The regex of the resource.
        Returns:
            (List[str]) The IDs that are assigned to the given resource pre-configured.
            (List[int]) The IDs that are assigned to the given resource.
        """
        resource_ids = self.core_worker.resource_ids()
        assigned_ids = set()
        # Handle both normal and placement group accelerator resources.
        # Note: We should only get the accelerator ids from the placement
        # group resource that does not contain the bundle index!
        import re
        for resource, assignment in resource_ids.items():
            if resource == resource_name or re.match(resource_regex, resource):
                for resource_id, _ in assignment:
                    assigned_ids.add(resource_id)
        # If the user had already set the environment variables
        # (CUDA_VISIBLE_DEVICES, ONEAPI_DEVICE_SELECTOR, NEURON_RT_VISIBLE_CORES,
        # TPU_VISIBLE_CHIPS, ..) then respect that in the sense that only IDs
        # that appear in (CUDA_VISIBLE_DEVICES, ONEAPI_DEVICE_SELECTOR,
        # ROCR_VISIBLE_DEVICES, NEURON_RT_VISIBLE_CORES, TPU_VISIBLE_CHIPS, ..)
        # should be returned.
        if self.original_visible_accelerator_ids.get(resource_name, None) is not None:
            original_ids = self.original_visible_accelerator_ids[resource_name]
            assigned_ids = {str(original_ids[i]) for i in assigned_ids}
            # Give all accelerator ids in local_mode.
            if self.mode == LOCAL_MODE:
                if resource_name == ray_constants.GPU:
                    max_accelerators = self.node.get_resource_spec().num_gpus
                else:
                    max_accelerators = self.node.get_resource_spec().resources.get(
                        resource_name, None
                    )
                if max_accelerators:
                    assigned_ids = original_ids[:max_accelerators]
        return list(assigned_ids)
[docs]
@PublicAPI
@client_mode_hook
def get_gpu_ids() -> Union[List[int], List[str]]:
    """Get the IDs of the GPUs that are available to the worker.
    This method should only be called inside of a task or actor, and not a driver.
    If the CUDA_VISIBLE_DEVICES environment variable was set when the worker
    started up, then the IDs returned by this method will be a subset of the
    IDs in CUDA_VISIBLE_DEVICES. If not, the IDs will fall in the range
    [0, NUM_GPUS - 1], where NUM_GPUS is the number of GPUs that the node has.
    Returns:
        A list of GPU IDs.
    """
    worker = global_worker
    worker.check_connected()
    return worker.get_accelerator_ids_for_accelerator_resource(
        ray_constants.GPU, f"^{ray_constants.GPU}_group_[0-9A-Za-z]+$"
    ) 
@Deprecated(
    message="Use ray.get_runtime_context().get_assigned_resources() instead.",
    warning=True,
)
def get_resource_ids():
    """Get the IDs of the resources that are available to the worker.
    Returns:
        A dictionary mapping the name of a resource to a list of pairs, where
        each pair consists of the ID of a resource and the fraction of that
        resource reserved for this worker.
    """
    worker = global_worker
    worker.check_connected()
    if _mode() == LOCAL_MODE:
        raise RuntimeError(
            "ray._private.worker.get_resource_ids() does not work in local_mode."
        )
    return global_worker.core_worker.resource_ids()
@Deprecated(message="Use ray.init().address_info['webui_url'] instead.")
def get_dashboard_url():
    """Get the URL to access the Ray dashboard.
    Note that the URL does not specify which node the dashboard is on.
    Returns:
        The URL of the dashboard as a string.
    """
    if ray_constants.RAY_OVERRIDE_DASHBOARD_URL in os.environ:
        return _remove_protocol_from_url(
            os.environ.get(ray_constants.RAY_OVERRIDE_DASHBOARD_URL)
        )
    else:
        worker = global_worker
        worker.check_connected()
        return _global_node.webui_url
def _remove_protocol_from_url(url: Optional[str]) -> str:
    """
    Helper function to remove protocol from URL if it exists.
    """
    if not url:
        return url
    parsed_url = urllib.parse.urlparse(url)
    if parsed_url.scheme:
        # Construct URL without protocol
        scheme = f"{parsed_url.scheme}://"
        return parsed_url.geturl().replace(scheme, "", 1)
    return url
class BaseContext(metaclass=ABCMeta):
    """
    Base class for RayContext and ClientContext
    """
    dashboard_url: Optional[str]
    python_version: str
    ray_version: str
    @abstractmethod
    def disconnect(self):
        """
        If this context is for directly attaching to a cluster, disconnect
        will call ray.shutdown(). Otherwise, if the context is for a ray
        client connection, the client will be disconnected.
        """
        pass
    @abstractmethod
    def __enter__(self):
        pass
    @abstractmethod
    def __exit__(self):
        pass
    def _context_table_template(self):
        if self.dashboard_url:
            dashboard_row = Template("context_dashrow.html.j2").render(
                dashboard_url="http://" + self.dashboard_url
            )
        else:
            dashboard_row = None
        return Template("context_table.html.j2").render(
            python_version=self.python_version,
            ray_version=self.ray_version,
            dashboard_row=dashboard_row,
        )
    def _repr_html_(self):
        return Template("context.html.j2").render(
            context_logo=Template("context_logo.html.j2").render(),
            context_table=self._context_table_template(),
        )
    @repr_with_fallback(["ipywidgets", "8"])
    def _get_widget_bundle(self, **kwargs) -> Dict[str, Any]:
        """Get the mimebundle for the widget representation of the context.
        Args:
            **kwargs: Passed to the _repr_mimebundle_() function for the widget
        Returns:
            Dictionary ("mimebundle") of the widget representation of the context.
        """
        import ipywidgets
        disconnect_button = ipywidgets.Button(
            description="Disconnect",
            disabled=False,
            button_style="",
            tooltip="Disconnect from the Ray cluster",
            layout=ipywidgets.Layout(margin="auto 0px 0px 0px"),
        )
        def disconnect_callback(button):
            button.disabled = True
            button.description = "Disconnecting..."
            self.disconnect()
            button.description = "Disconnected"
        disconnect_button.on_click(disconnect_callback)
        left_content = ipywidgets.VBox(
            [
                ipywidgets.HTML(Template("context_logo.html.j2").render()),
                disconnect_button,
            ],
            layout=ipywidgets.Layout(),
        )
        right_content = ipywidgets.HTML(self._context_table_template())
        widget = ipywidgets.HBox(
            [left_content, right_content], layout=ipywidgets.Layout(width="100%")
        )
        return widget._repr_mimebundle_(**kwargs)
    def _repr_mimebundle_(self, **kwargs):
        bundle = self._get_widget_bundle(**kwargs)
        # Overwrite the widget html repr and default repr with those of the BaseContext
        bundle.update({"text/html": self._repr_html_(), "text/plain": repr(self)})
        return bundle
@dataclass
class RayContext(BaseContext, Mapping):
    """
    Context manager for attached drivers.
    """
    dashboard_url: Optional[str]
    python_version: str
    ray_version: str
    ray_commit: str
    def __init__(self, address_info: Dict[str, Optional[str]]):
        super().__init__()
        self.dashboard_url = get_dashboard_url()
        self.python_version = "{}.{}.{}".format(*sys.version_info[:3])
        self.ray_version = ray.__version__
        self.ray_commit = ray.__commit__
        self.address_info = address_info
    def __getitem__(self, key):
        if log_once("ray_context_getitem"):
            warnings.warn(
                f'Accessing values through ctx["{key}"] is deprecated. '
                f'Use ctx.address_info["{key}"] instead.',
                DeprecationWarning,
                stacklevel=2,
            )
        return self.address_info[key]
    def __len__(self):
        if log_once("ray_context_len"):
            warnings.warn("len(ctx) is deprecated. Use len(ctx.address_info) instead.")
        return len(self.address_info)
    def __iter__(self):
        if log_once("ray_context_len"):
            warnings.warn(
                "iter(ctx) is deprecated. Use iter(ctx.address_info) instead."
            )
        return iter(self.address_info)
    def __enter__(self) -> "RayContext":
        return self
    def __exit__(self, *exc):
        ray.shutdown()
    def disconnect(self):
        # Include disconnect() to stay consistent with ClientContext
        ray.shutdown()
global_worker = Worker()
"""Worker: The global Worker object for this worker process.
We use a global Worker object to ensure that there is a single worker object
per worker process.
"""
_global_node = None
"""ray._private.node.Node: The global node object that is created by ray.init()."""
[docs]
@PublicAPI
@client_mode_hook
def init(
    address: Optional[str] = None,
    *,
    num_cpus: Optional[int] = None,
    num_gpus: Optional[int] = None,
    resources: Optional[Dict[str, float]] = None,
    labels: Optional[Dict[str, str]] = None,
    object_store_memory: Optional[int] = None,
    local_mode: bool = False,
    ignore_reinit_error: bool = False,
    include_dashboard: Optional[bool] = None,
    dashboard_host: str = ray_constants.DEFAULT_DASHBOARD_IP,
    dashboard_port: Optional[int] = None,
    job_config: "ray.job_config.JobConfig" = None,
    configure_logging: bool = True,
    logging_level: int = ray_constants.LOGGER_LEVEL,
    logging_format: Optional[str] = None,
    logging_config: Optional[LoggingConfig] = None,
    log_to_driver: Optional[bool] = None,
    namespace: Optional[str] = None,
    runtime_env: Optional[Union[Dict[str, Any], "RuntimeEnv"]] = None,  # noqa: F821
    storage: Optional[str] = None,
    **kwargs,
) -> BaseContext:
    """
    Connect to an existing Ray cluster or start one and connect to it.
    This method handles two cases; either a Ray cluster already exists and we
    just attach this driver to it or we start all of the processes associated
    with a Ray cluster and attach to the newly started cluster.
    Note: This method overwrite sigterm handler of the driver process.
    In most cases, it is enough to just call this method with no arguments.
    This will autodetect an existing Ray cluster or start a new Ray instance if
    no existing cluster is found:
    .. testcode::
        ray.init()
    To explicitly connect to an existing local cluster, use this as follows. A
    ConnectionError will be thrown if no existing local cluster is found.
    .. testcode::
        :skipif: True
        ray.init(address="auto")
    To connect to an existing remote cluster, use this as follows (substituting
    in the appropriate address). Note the addition of "ray://" at the beginning
    of the address. This requires `ray[client]`.
    .. testcode::
        :skipif: True
        ray.init(address="ray://123.45.67.89:10001")
    More details for starting and connecting to a remote cluster can be found
    here: https://docs.ray.io/en/master/cluster/getting-started.html
    You can also define an environment variable called `RAY_ADDRESS` in
    the same format as the `address` parameter to connect to an existing
    cluster with ray.init() or ray.init(address="auto").
    Args:
        address: The address of the Ray cluster to connect to. The provided
            address is resolved as follows:
            1. If a concrete address (e.g., localhost:<port>) is provided, try to
            connect to it. Concrete addresses can be prefixed with "ray://" to
            connect to a remote cluster. For example, passing in the address
            "ray://123.45.67.89:50005" will connect to the cluster at the given
            address.
            2. If no address is provided, try to find an existing Ray instance
            to connect to. This is done by first checking the environment
            variable `RAY_ADDRESS`. If this is not defined, check the address
            of the latest cluster started (found in
            /tmp/ray/ray_current_cluster) if available. If this is also empty,
            then start a new local Ray instance.
            3. If the provided address is "auto", then follow the same process
            as above. However, if there is no existing cluster found, this will
            throw a ConnectionError instead of starting a new local Ray
            instance.
            4. If the provided address is "local", start a new local Ray
            instance, even if there is already an existing local Ray instance.
        num_cpus: Number of CPUs the user wishes to assign to each
            raylet. By default, this is set based on virtual cores.
        num_gpus: Number of GPUs the user wishes to assign to each
            raylet. By default, this is set based on detected GPUs.
        resources: A dictionary mapping the names of custom resources to the
            quantities for them available.
        labels: [Experimental] The key-value labels of the node.
        object_store_memory: The amount of memory (in bytes) to start the
            object store with.
            By default, this is 30%
            (ray_constants.DEFAULT_OBJECT_STORE_MEMORY_PROPORTION)
            of available system memory capped by
            the shm size and 200G (ray_constants.DEFAULT_OBJECT_STORE_MAX_MEMORY_BYTES)
            but can be set higher.
        local_mode: Deprecated: consider using the Ray Debugger instead.
        ignore_reinit_error: If true, Ray suppresses errors from calling
            ray.init() a second time. Ray won't be restarted.
        include_dashboard: Boolean flag indicating whether or not to start the
            Ray dashboard, which displays the status of the Ray
            cluster. If this argument is None, then the UI will be started if
            the relevant dependencies are present.
        dashboard_host: The host to bind the dashboard server to. Can either be
            localhost (127.0.0.1) or 0.0.0.0 (available from all interfaces).
            By default, this is set to localhost to prevent access from
            external machines.
        dashboard_port(int, None): The port to bind the dashboard server to.
            Defaults to 8265 and Ray will automatically find a free port if
            8265 is not available.
        job_config (ray.job_config.JobConfig): The job configuration.
        configure_logging: True (default) if configuration of logging is
            allowed here. Otherwise, the user may want to configure it
            separately.
        logging_level: Logging level for the "ray" logger of the driver process,
            defaults to logging.INFO. Ignored unless "configure_logging" is true.
        logging_format: Logging format for the "ray" logger of the driver process,
            defaults to a string containing a timestamp, filename, line number, and
            message. See the source file ray_constants.py for details. Ignored unless
            "configure_logging" is true.
        logging_config: [Experimental] Logging configuration will be applied to the
            root loggers for both the driver process and all worker processes belonging
            to the current job. See :class:`~ray.LoggingConfig` for details.
        log_to_driver: If true, the output from all of the worker
            processes on all nodes will be directed to the driver.
        namespace: A namespace is a logical grouping of jobs and named actors.
        runtime_env: The runtime environment to use
            for this job (see :ref:`runtime-environments` for details).
        storage: [Experimental] Specify a URI for persistent cluster-wide storage.
            This storage path must be accessible by all nodes of the cluster, otherwise
            an error will be raised. This option can also be specified as the
            RAY_STORAGE env var.
        _enable_object_reconstruction: If True, when an object stored in
            the distributed plasma store is lost due to node failure, Ray will
            attempt to reconstruct the object by re-executing the task that
            created the object. Arguments to the task will be recursively
            reconstructed. If False, then ray.ObjectLostError will be
            thrown.
        _redis_max_memory: Redis max memory.
        _plasma_directory: Override the plasma mmap file directory.
        _node_ip_address: The IP address of the node that we are on.
        _driver_object_store_memory: Deprecated.
        _memory: Amount of reservable memory resource in bytes rounded
            down to the nearest integer.
        _redis_username: Prevents external clients without the username
            from connecting to Redis if provided.
        _redis_password: Prevents external clients without the password
            from connecting to Redis if provided.
        _temp_dir: If provided, specifies the root temporary
            directory for the Ray process. Must be an absolute path. Defaults to an
            OS-specific conventional location, e.g., "/tmp/ray".
        _metrics_export_port: Port number Ray exposes system metrics
            through a Prometheus endpoint. It is currently under active
            development, and the API is subject to change.
        _system_config: Configuration for overriding
            RayConfig defaults. For testing purposes ONLY.
        _tracing_startup_hook: If provided, turns on and sets up tracing
            for Ray. Must be the name of a function that takes no arguments and
            sets up a Tracer Provider, Remote Span Processors, and
            (optional) additional instruments. See more at
            docs.ray.io/tracing.html. It is currently under active development,
            and the API is subject to change.
        _node_name: User-provided node name or identifier. Defaults to
            the node IP address.
    Returns:
        If the provided address includes a protocol, for example by prepending
        "ray://" to the address to get "ray://1.2.3.4:10001", then a
        ClientContext is returned with information such as settings, server
        versions for ray and python, and the dashboard_url. Otherwise,
        a RayContext is returned with ray and python versions, and address
        information about the started processes.
    Raises:
        Exception: An exception is raised if an inappropriate combination of
            arguments is passed in.
    """
    if log_to_driver is None:
        log_to_driver = ray_constants.RAY_LOG_TO_DRIVER
    # Configure the "ray" logger for the driver process.
    if configure_logging:
        setup_logger(logging_level, logging_format or ray_constants.LOGGER_FORMAT)
    else:
        logging.getLogger("ray").handlers.clear()
    # Configure the logging settings for the driver process.
    if logging_config or ray_constants.RAY_LOGGING_CONFIG_ENCODING:
        logging_config = logging_config or LoggingConfig(
            encoding=ray_constants.RAY_LOGGING_CONFIG_ENCODING
        )
        logging_config._apply()
    # Parse the hidden options:
    _enable_object_reconstruction: bool = kwargs.pop(
        "_enable_object_reconstruction", False
    )
    _redis_max_memory: Optional[int] = kwargs.pop("_redis_max_memory", None)
    _plasma_directory: Optional[str] = kwargs.pop("_plasma_directory", None)
    _node_ip_address: str = kwargs.pop("_node_ip_address", None)
    _driver_object_store_memory: Optional[int] = kwargs.pop(
        "_driver_object_store_memory", None
    )
    _memory: Optional[int] = kwargs.pop("_memory", None)
    _redis_username: str = kwargs.pop(
        "_redis_username", ray_constants.REDIS_DEFAULT_USERNAME
    )
    _redis_password: str = kwargs.pop(
        "_redis_password", ray_constants.REDIS_DEFAULT_PASSWORD
    )
    _temp_dir: Optional[str] = kwargs.pop("_temp_dir", None)
    _metrics_export_port: Optional[int] = kwargs.pop("_metrics_export_port", None)
    _system_config: Optional[Dict[str, str]] = kwargs.pop("_system_config", None)
    _tracing_startup_hook: Optional[Callable] = kwargs.pop(
        "_tracing_startup_hook", None
    )
    _node_name: str = kwargs.pop("_node_name", None)
    # Fix for https://github.com/ray-project/ray/issues/26729
    _skip_env_hook: bool = kwargs.pop("_skip_env_hook", False)
    # terminate any signal before connecting driver
    def sigterm_handler(signum, frame):
        sys.exit(signum)
    if threading.current_thread() is threading.main_thread():
        ray._private.utils.set_sigterm_handler(sigterm_handler)
    else:
        logger.warning(
            "SIGTERM handler is not set because current thread "
            "is not the main thread."
        )
    # If available, use RAY_ADDRESS to override if the address was left
    # unspecified, or set to "auto" in the call to init
    address_env_var = os.environ.get(ray_constants.RAY_ADDRESS_ENVIRONMENT_VARIABLE)
    if address_env_var and (address is None or address == "auto"):
        address = address_env_var
        logger.info(
            f"Using address {address_env_var} set in the environment "
            f"variable {ray_constants.RAY_ADDRESS_ENVIRONMENT_VARIABLE}"
        )
    if address is not None and "://" in address:
        # Address specified a protocol, use ray client
        builder = ray.client(address, _deprecation_warn_enabled=False)
        # Forward any keyword arguments that were changed from their default
        # values to the builder
        init_sig = inspect.signature(init)
        passed_kwargs = {}
        for argument_name, param_obj in init_sig.parameters.items():
            if argument_name in {"kwargs", "address"}:
                # kwargs and address are handled separately
                continue
            default_value = param_obj.default
            passed_value = locals()[argument_name]
            if passed_value != default_value:
                # passed value is different than default, pass to the client
                # builder
                passed_kwargs[argument_name] = passed_value
        passed_kwargs.update(kwargs)
        builder._init_args(**passed_kwargs)
        ctx = builder.connect()
        from ray._private.usage import usage_lib
        if passed_kwargs.get("allow_multiple") is True:
            with ctx:
                usage_lib.put_pre_init_usage_stats()
        else:
            usage_lib.put_pre_init_usage_stats()
        usage_lib.record_library_usage("client")
        return ctx
    if kwargs.get("allow_multiple"):
        raise RuntimeError(
            "`allow_multiple` argument is passed to `ray.init` when the "
            "ray client is not used ("
            f"https://docs.ray.io/en/{get_ray_doc_version()}/cluster"
            "/running-applications/job-submission"
            "/ray-client.html#connect-to-multiple-ray-clusters-experimental). "
            "Do not pass the `allow_multiple` to `ray.init` to fix the issue."
        )
    if kwargs:
        # User passed in extra keyword arguments but isn't connecting through
        # ray client. Raise an error, since most likely a typo in keyword
        unknown = ", ".join(kwargs)
        raise RuntimeError(f"Unknown keyword argument(s): {unknown}")
    # Try to increase the file descriptor limit, which is too low by
    # default for Ray: https://github.com/ray-project/ray/issues/11239
    try:
        import resource
        soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
        if soft < hard:
            # https://github.com/ray-project/ray/issues/12059
            soft = max(soft, min(hard, 65536))
            logger.debug(
                f"Automatically increasing RLIMIT_NOFILE to max value of {hard}"
            )
            try:
                resource.setrlimit(resource.RLIMIT_NOFILE, (soft, hard))
            except ValueError:
                logger.debug("Failed to raise limit.")
        soft, _ = resource.getrlimit(resource.RLIMIT_NOFILE)
        if soft < 4096:
            logger.warning(
                "File descriptor limit {} is too low for production "
                "servers and may result in connection errors. "
                "At least 8192 is recommended. --- "
                "Fix with 'ulimit -n 8192'".format(soft)
            )
    except ImportError:
        logger.debug("Could not import resource module (on Windows)")
        pass
    if job_config is None:
        job_config = ray.job_config.JobConfig()
    if RAY_JOB_CONFIG_JSON_ENV_VAR in os.environ:
        injected_job_config_json = json.loads(
            os.environ.get(RAY_JOB_CONFIG_JSON_ENV_VAR)
        )
        injected_job_config: ray.job_config.JobConfig = (
            ray.job_config.JobConfig.from_json(injected_job_config_json)
        )
        driver_runtime_env = runtime_env
        runtime_env = _merge_runtime_env(
            injected_job_config.runtime_env,
            driver_runtime_env,
            override=os.getenv("RAY_OVERRIDE_JOB_RUNTIME_ENV") == "1",
        )
        if runtime_env is None:
            # None means there was a conflict.
            raise ValueError(
                "Failed to merge the Job's runtime env "
                f"{injected_job_config.runtime_env} with "
                f"a ray.init's runtime env {driver_runtime_env} because "
                "of a conflict. Specifying the same runtime_env fields "
                "or the same environment variable keys is not allowed. "
                "Use RAY_OVERRIDE_JOB_RUNTIME_ENV=1 to instruct Ray to "
                "combine Job and Driver's runtime environment in the event of "
                "a conflict."
            )
        if ray_constants.RAY_RUNTIME_ENV_HOOK in os.environ and not _skip_env_hook:
            runtime_env = _load_class(os.environ[ray_constants.RAY_RUNTIME_ENV_HOOK])(
                runtime_env
            )
        job_config.set_runtime_env(runtime_env)
        # Similarly, we prefer metadata provided via job submission API
        for key, value in injected_job_config.metadata.items():
            job_config.set_metadata(key, value)
    # RAY_JOB_CONFIG_JSON_ENV_VAR is only set at ray job manager level and has
    # higher priority in case user also provided runtime_env for ray.init()
    else:
        if ray_constants.RAY_RUNTIME_ENV_HOOK in os.environ and not _skip_env_hook:
            runtime_env = _load_class(os.environ[ray_constants.RAY_RUNTIME_ENV_HOOK])(
                runtime_env
            )
        if runtime_env:
            # Set runtime_env in job_config if passed in as part of ray.init()
            job_config.set_runtime_env(runtime_env)
    # Pass the logging_config to job_config to configure loggers of all worker
    # processes belonging to the job.
    if logging_config is not None:
        job_config.set_py_logging_config(logging_config)
    redis_address, gcs_address = None, None
    bootstrap_address = services.canonicalize_bootstrap_address(address, _temp_dir)
    if bootstrap_address is not None:
        gcs_address = bootstrap_address
        logger.info("Connecting to existing Ray cluster at address: %s...", gcs_address)
    if local_mode:
        driver_mode = LOCAL_MODE
        warnings.warn(
            "DeprecationWarning: local mode is an experimental feature that is no "
            "longer maintained and will be removed in the future."
            "For debugging consider using Ray debugger. ",
            DeprecationWarning,
            stacklevel=2,
        )
    else:
        driver_mode = SCRIPT_MODE
    global _global_node
    if global_worker.connected:
        if ignore_reinit_error:
            logger.info("Calling ray.init() again after it has already been called.")
            node_id = global_worker.core_worker.get_current_node_id()
            return RayContext(dict(_global_node.address_info, node_id=node_id.hex()))
        else:
            raise RuntimeError(
                "Maybe you called ray.init twice by accident? "
                "This error can be suppressed by passing in "
                "'ignore_reinit_error=True' or by calling "
                "'ray.shutdown()' prior to 'ray.init()'."
            )
    _system_config = _system_config or {}
    if not isinstance(_system_config, dict):
        raise TypeError("The _system_config must be a dict.")
    if bootstrap_address is None:
        # In this case, we need to start a new cluster.
        # Don't collect usage stats in ray.init() unless it's a nightly wheel.
        from ray._private.usage import usage_lib
        if usage_lib.is_nightly_wheel():
            usage_lib.show_usage_stats_prompt(cli=False)
        else:
            usage_lib.set_usage_stats_enabled_via_env_var(False)
        # Use a random port by not specifying Redis port / GCS server port.
        ray_params = ray._private.parameter.RayParams(
            node_ip_address=_node_ip_address,
            object_ref_seed=None,
            driver_mode=driver_mode,
            redirect_output=None,
            num_cpus=num_cpus,
            num_gpus=num_gpus,
            resources=resources,
            labels=labels,
            num_redis_shards=None,
            redis_max_clients=None,
            redis_username=_redis_username,
            redis_password=_redis_password,
            plasma_directory=_plasma_directory,
            huge_pages=None,
            include_dashboard=include_dashboard,
            dashboard_host=dashboard_host,
            dashboard_port=dashboard_port,
            memory=_memory,
            object_store_memory=object_store_memory,
            redis_max_memory=_redis_max_memory,
            plasma_store_socket_name=None,
            temp_dir=_temp_dir,
            storage=storage,
            _system_config=_system_config,
            enable_object_reconstruction=_enable_object_reconstruction,
            metrics_export_port=_metrics_export_port,
            tracing_startup_hook=_tracing_startup_hook,
            node_name=_node_name,
        )
        # Start the Ray processes. We set shutdown_at_exit=False because we
        # shutdown the node in the ray.shutdown call that happens in the atexit
        # handler. We still spawn a reaper process in case the atexit handler
        # isn't called.
        _global_node = ray._private.node.Node(
            ray_params=ray_params,
            head=True,
            shutdown_at_exit=False,
            spawn_reaper=True,
            ray_init_cluster=True,
        )
    else:
        # In this case, we are connecting to an existing cluster.
        if num_cpus is not None or num_gpus is not None:
            raise ValueError(
                "When connecting to an existing cluster, num_cpus "
                "and num_gpus must not be provided."
            )
        if resources is not None:
            raise ValueError(
                "When connecting to an existing cluster, "
                "resources must not be provided."
            )
        if labels is not None:
            raise ValueError(
                "When connecting to an existing cluster, "
                "labels must not be provided."
            )
        if object_store_memory is not None:
            raise ValueError(
                "When connecting to an existing cluster, "
                "object_store_memory must not be provided."
            )
        if storage is not None:
            raise ValueError(
                "When connecting to an existing cluster, "
                "storage must not be provided."
            )
        if _system_config is not None and len(_system_config) != 0:
            raise ValueError(
                "When connecting to an existing cluster, "
                "_system_config must not be provided."
            )
        if _enable_object_reconstruction:
            raise ValueError(
                "When connecting to an existing cluster, "
                "_enable_object_reconstruction must not be provided."
            )
        if _node_name is not None:
            raise ValueError(
                "_node_name cannot be configured when connecting to "
                "an existing cluster."
            )
        # In this case, we only need to connect the node.
        ray_params = ray._private.parameter.RayParams(
            node_ip_address=_node_ip_address,
            gcs_address=gcs_address,
            redis_address=redis_address,
            redis_username=_redis_username,
            redis_password=_redis_password,
            object_ref_seed=None,
            temp_dir=_temp_dir,
            _system_config=_system_config,
            enable_object_reconstruction=_enable_object_reconstruction,
            metrics_export_port=_metrics_export_port,
        )
        try:
            _global_node = ray._private.node.Node(
                ray_params,
                head=False,
                shutdown_at_exit=False,
                spawn_reaper=False,
                connect_only=True,
            )
        except (ConnectionError, RuntimeError):
            if gcs_address == ray._private.utils.read_ray_address(_temp_dir):
                logger.info(
                    "Failed to connect to the default Ray cluster address at "
                    f"{gcs_address}. This is most likely due to a previous Ray "
                    "instance that has since crashed. To reset the default "
                    "address to connect to, run `ray stop` or restart Ray with "
                    "`ray start`."
                )
            raise ConnectionError
    # Log a message to find the Ray address that we connected to and the
    # dashboard URL.
    if ray_constants.RAY_OVERRIDE_DASHBOARD_URL in os.environ:
        dashboard_url = os.environ.get(ray_constants.RAY_OVERRIDE_DASHBOARD_URL)
    else:
        dashboard_url = _global_node.webui_url
    # Add http protocol to dashboard URL if it doesn't
    # already contain a protocol.
    if dashboard_url and not urlparse(dashboard_url).scheme:
        dashboard_url = "http://" + dashboard_url
    # We logged the address before attempting the connection, so we don't need
    # to log it again.
    info_str = "Connected to Ray cluster."
    if gcs_address is None:
        info_str = "Started a local Ray instance."
    if dashboard_url:
        logger.info(
            info_str + " View the dashboard at %s%s%s %s%s",
            colorama.Style.BRIGHT,
            colorama.Fore.GREEN,
            dashboard_url,
            colorama.Fore.RESET,
            colorama.Style.NORMAL,
        )
    else:
        logger.info(info_str)
    connect(
        _global_node,
        _global_node.session_name,
        mode=driver_mode,
        log_to_driver=log_to_driver,
        worker=global_worker,
        driver_object_store_memory=_driver_object_store_memory,
        job_id=None,
        namespace=namespace,
        job_config=job_config,
        entrypoint=ray._private.utils.get_entrypoint_name(),
    )
    if job_config and job_config.code_search_path:
        global_worker.set_load_code_from_local(True)
    else:
        # Because `ray.shutdown()` doesn't reset this flag, for multiple
        # sessions in one process, the 2nd `ray.init()` will reuse the
        # flag of last session. For example:
        #     ray.init(load_code_from_local=True)
        #     ray.shutdown()
        #     ray.init()
        #     # Here the flag `load_code_from_local` is still True if we
        #     # doesn't have this `else` branch.
        #     ray.shutdown()
        global_worker.set_load_code_from_local(False)
    for hook in _post_init_hooks:
        hook()
    node_id = global_worker.core_worker.get_current_node_id()
    global_node_address_info = _global_node.address_info.copy()
    global_node_address_info["webui_url"] = _remove_protocol_from_url(dashboard_url)
    return RayContext(dict(global_node_address_info, node_id=node_id.hex())) 
# Functions to run as callback after a successful ray init.
_post_init_hooks = []
[docs]
@PublicAPI
@client_mode_hook
def shutdown(_exiting_interpreter: bool = False):
    """Disconnect the worker, and terminate processes started by ray.init().
    This will automatically run at the end when a Python process that uses Ray
    exits. It is ok to run this twice in a row. The primary use case for this
    function is to cleanup state between tests.
    Note that this will clear any remote function definitions, actor
    definitions, and existing actors, so if you wish to use any previously
    defined remote functions or actors after calling ray.shutdown(), then you
    need to redefine them. If they were defined in an imported module, then you
    will need to reload the module.
    Args:
        _exiting_interpreter: True if this is called by the atexit hook
            and false otherwise. If we are exiting the interpreter, we will
            wait a little while to print any extra error messages.
    """
    # Make sure to clean up compiled dag node if exists.
    from ray.dag.compiled_dag_node import _shutdown_all_compiled_dags
    _shutdown_all_compiled_dags()
    if _exiting_interpreter and global_worker.mode == SCRIPT_MODE:
        # This is a duration to sleep before shutting down everything in order
        # to make sure that log messages finish printing.
        time.sleep(0.5)
    disconnect(_exiting_interpreter)
    # disconnect internal kv
    if hasattr(global_worker, "gcs_client"):
        del global_worker.gcs_client
    _internal_kv_reset()
    # We need to destruct the core worker here because after this function,
    # we will tear down any processes spawned by ray.init() and the background
    # IO thread in the core worker doesn't currently handle that gracefully.
    if hasattr(global_worker, "core_worker"):
        if global_worker.mode == SCRIPT_MODE or global_worker.mode == LOCAL_MODE:
            global_worker.core_worker.shutdown_driver()
        del global_worker.core_worker
    # We need to reset function actor manager to clear the context
    global_worker.function_actor_manager = FunctionActorManager(global_worker)
    # Disconnect global state from GCS.
    ray._private.state.state.disconnect()
    # Shut down the Ray processes.
    global _global_node
    if _global_node is not None:
        if _global_node.is_head():
            _global_node.destroy_external_storage()
        _global_node.kill_all_processes(check_alive=False, allow_graceful=True)
        _global_node = None
    storage._reset()
    # TODO(rkn): Instead of manually resetting some of the worker fields, we
    # should simply set "global_worker" to equal "None" or something like that.
    global_worker.set_mode(None)
    global_worker.set_cached_job_id(None) 
atexit.register(shutdown, True)
# Define a custom excepthook so that if the driver exits with an exception, we
# can push that exception to Redis.
normal_excepthook = sys.excepthook
def custom_excepthook(type, value, tb):
    import ray.core.generated.common_pb2 as common_pb2
    # If this is a driver, push the exception to GCS worker table.
    if global_worker.mode == SCRIPT_MODE and hasattr(global_worker, "worker_id"):
        error_message = "".join(traceback.format_tb(tb))
        worker_id = global_worker.worker_id
        worker_type = common_pb2.DRIVER
        worker_info = {"exception": error_message}
        ray._private.state.state._check_connected()
        ray._private.state.state.add_worker(worker_id, worker_type, worker_info)
    # Call the normal excepthook.
    normal_excepthook(type, value, tb)
sys.excepthook = custom_excepthook
def print_to_stdstream(data, ignore_prefix: bool):
    should_dedup = data.get("pid") not in ["autoscaler"]
    if data["is_err"]:
        if should_dedup:
            batches = stderr_deduplicator.deduplicate(data)
        else:
            batches = [data]
        sink = sys.stderr
    else:
        if should_dedup:
            batches = stdout_deduplicator.deduplicate(data)
        else:
            batches = [data]
        sink = sys.stdout
    for batch in batches:
        print_worker_logs(batch, sink, ignore_prefix)
# Start time of this process, used for relative time logs.
t0 = time.time()
autoscaler_log_fyi_printed = False
def filter_autoscaler_events(lines: List[str]) -> Iterator[str]:
    """Given raw log lines from the monitor, return only autoscaler events.
    For Autoscaler V1:
        Autoscaler events are denoted by the ":event_summary:" magic token.
    For Autoscaler V2:
        Autoscaler events are published from log_monitor.py which read
        them from the `event_AUTOSCALER.log`.
    """
    if not ray_constants.AUTOSCALER_EVENTS:
        return
    AUTOSCALER_LOG_FYI = (
        "Tip: use `ray status` to view detailed "
        "cluster status. To disable these "
        "messages, set RAY_SCHEDULER_EVENTS=0."
    )
    def autoscaler_log_fyi_needed() -> bool:
        global autoscaler_log_fyi_printed
        if not autoscaler_log_fyi_printed:
            autoscaler_log_fyi_printed = True
            return True
        return False
    from ray.autoscaler.v2.utils import is_autoscaler_v2
    if is_autoscaler_v2():
        from ray._private.event.event_logger import parse_event, filter_event_by_level
        for event_line in lines:
            if autoscaler_log_fyi_needed():
                yield AUTOSCALER_LOG_FYI
            event = parse_event(event_line)
            if not event or not event.message:
                continue
            if filter_event_by_level(
                event, ray_constants.RAY_LOG_TO_DRIVER_EVENT_LEVEL
            ):
                continue
            yield event.message
    else:
        # Print out autoscaler events only, ignoring other messages.
        for line in lines:
            if ray_constants.LOG_PREFIX_EVENT_SUMMARY in line:
                if autoscaler_log_fyi_needed():
                    yield AUTOSCALER_LOG_FYI
                # The event text immediately follows the ":event_summary:"
                # magic token.
                yield line.split(ray_constants.LOG_PREFIX_EVENT_SUMMARY)[1]
def time_string() -> str:
    """Return the relative time from the start of this job.
    For example, 15m30s.
    """
    delta = time.time() - t0
    hours = 0
    minutes = 0
    while delta > 3600:
        hours += 1
        delta -= 3600
    while delta > 60:
        minutes += 1
        delta -= 60
    output = ""
    if hours:
        output += f"{hours}h"
    if minutes:
        output += f"{minutes}m"
    output += f"{int(delta)}s"
    return output
# When we enter a breakpoint, worker logs are automatically disabled via this.
_worker_logs_enabled = True
def print_worker_logs(
    data: Dict[str, str], print_file: Any, ignore_prefix: bool = False
):
    if not _worker_logs_enabled:
        return
    def prefix_for(data: Dict[str, str]) -> str:
        """The PID prefix for this log line."""
        if data.get("pid") in ["autoscaler", "raylet"]:
            return ""
        else:
            res = "pid="
            if data.get("actor_name"):
                res = f"{data['actor_name']} {res}"
            elif data.get("task_name"):
                res = f"{data['task_name']} {res}"
            return res
    def message_for(data: Dict[str, str], line: str) -> str:
        """The printed message of this log line."""
        if ray_constants.LOG_PREFIX_INFO_MESSAGE in line:
            return line.split(ray_constants.LOG_PREFIX_INFO_MESSAGE)[1]
        return line
    def color_for(data: Dict[str, str], line: str) -> str:
        """The color for this log line."""
        if (
            data.get("pid") == "raylet"
            and ray_constants.LOG_PREFIX_INFO_MESSAGE not in line
        ):
            return colorama.Fore.YELLOW
        elif data.get("pid") == "autoscaler":
            if "Error:" in line or "Warning:" in line:
                return colorama.Fore.YELLOW
            else:
                return colorama.Fore.CYAN
        elif os.getenv("RAY_COLOR_PREFIX") == "1":
            colors = [
                # colorama.Fore.BLUE, # Too dark
                colorama.Fore.MAGENTA,
                colorama.Fore.CYAN,
                colorama.Fore.GREEN,
                # colorama.Fore.WHITE, # Too light
                # colorama.Fore.RED,
                colorama.Fore.LIGHTBLACK_EX,
                colorama.Fore.LIGHTBLUE_EX,
                # colorama.Fore.LIGHTCYAN_EX, # Too light
                # colorama.Fore.LIGHTGREEN_EX, # Too light
                colorama.Fore.LIGHTMAGENTA_EX,
                # colorama.Fore.LIGHTWHITE_EX, # Too light
                # colorama.Fore.LIGHTYELLOW_EX, # Too light
            ]
            pid = data.get("pid", 0)
            try:
                i = int(pid)
            except ValueError:
                i = 0
            return colors[i % len(colors)]
        else:
            return colorama.Fore.CYAN
    if data.get("pid") == "autoscaler":
        pid = "autoscaler +{}".format(time_string())
        lines = filter_autoscaler_events(data.get("lines", []))
    else:
        pid = data.get("pid")
        lines = data.get("lines", [])
    ip = data.get("ip")
    ip_prefix = "" if ip == data.get("localhost") else f", ip={ip}"
    for line in lines:
        if RAY_TQDM_MAGIC in line:
            process_tqdm(line)
        else:
            hide_tqdm()
            # If RAY_COLOR_PREFIX=0, do not wrap with any color codes
            if os.getenv("RAY_COLOR_PREFIX") == "0":
                color_pre = ""
                color_post = ""
            else:
                color_pre = color_for(data, line)
                color_post = colorama.Style.RESET_ALL
            if ignore_prefix:
                print(
                    f"{message_for(data, line)}",
                    file=print_file,
                )
            else:
                print(
                    f"{color_pre}({prefix_for(data)}{pid}{ip_prefix}){color_post} "
                    f"{message_for(data, line)}",
                    file=print_file,
                )
    # Restore once at end of batch to avoid excess hiding/unhiding of tqdm.
    restore_tqdm()
def process_tqdm(line):
    """Experimental distributed tqdm: see ray.experimental.tqdm_ray."""
    try:
        data = json.loads(line)
        tqdm_ray.instance().process_state_update(data)
    except Exception:
        if log_once("tqdm_corruption"):
            logger.warning(
                f"[tqdm_ray] Failed to decode {line}, this may be due to "
                "logging too fast. This warning will not be printed again."
            )
def hide_tqdm():
    """Hide distributed tqdm bars temporarily to avoid conflicts with other logs."""
    tqdm_ray.instance().hide_bars()
def restore_tqdm():
    """Undo hide_tqdm()."""
    tqdm_ray.instance().unhide_bars()
def listen_error_messages(worker, threads_stopped):
    """Listen to error messages in the background on the driver.
    This runs in a separate thread on the driver and pushes (error, time)
    tuples to be published.
    Args:
        worker: The worker class that this thread belongs to.
        threads_stopped (threading.Event): A threading event used to signal to
            the thread that it should exit.
    """
    # TODO: we should just subscribe to the errors for this specific job.
    worker.gcs_error_subscriber.subscribe()
    try:
        if _internal_kv_initialized():
            # Get any autoscaler errors that occurred before the call to
            # subscribe.
            error_message = _internal_kv_get(ray_constants.DEBUG_AUTOSCALING_ERROR)
            if error_message is not None:
                logger.warning(error_message.decode())
        while True:
            # Exit if received a signal that the thread should stop.
            if threads_stopped.is_set():
                return
            _, error_data = worker.gcs_error_subscriber.poll()
            if error_data is None:
                continue
            if error_data["job_id"] not in [
                worker.current_job_id.binary(),
                JobID.nil().binary(),
            ]:
                continue
            error_message = error_data["error_message"]
            print_to_stdstream(
                {
                    "lines": [error_message],
                    "pid": "raylet",
                    "is_err": False,
                },
                ignore_prefix=False,
            )
    except (OSError, ConnectionError) as e:
        logger.error(f"listen_error_messages: {e}")
[docs]
@PublicAPI
@client_mode_hook
def is_initialized() -> bool:
    """Check if ray.init has been called yet.
    Returns:
        True if ray.init has already been called and false otherwise.
    """
    return ray._private.worker.global_worker.connected 
def connect(
    node,
    session_name: str,
    mode=WORKER_MODE,
    log_to_driver: bool = False,
    worker=global_worker,
    driver_object_store_memory: Optional[int] = None,
    job_id=None,
    namespace: Optional[str] = None,
    job_config=None,
    runtime_env_hash: int = 0,
    startup_token: int = 0,
    ray_debugger_external: bool = False,
    entrypoint: str = "",
    worker_launch_time_ms: int = -1,
    worker_launched_time_ms: int = -1,
):
    """Connect this worker to the raylet, to Plasma, and to GCS.
    Args:
        node (ray._private.node.Node): The node to connect.
        session_name: The session name (cluster id) of this cluster.
        mode: The mode of the worker. One of SCRIPT_MODE, WORKER_MODE, and LOCAL_MODE.
        log_to_driver: If true, then output from all of the worker
            processes on all nodes will be directed to the driver.
        worker: The ray.Worker instance.
        driver_object_store_memory: Deprecated.
        job_id: The ID of job. If it's None, then we will generate one.
        namespace: Namespace to use.
        job_config (ray.job_config.JobConfig): The job configuration.
        runtime_env_hash: The hash of the runtime env for this worker.
        startup_token: The startup token of the process assigned to
            it during startup as a command line argument.
        ray_debugger_external: If True, make the debugger external to the
            node this worker is running on.
        entrypoint: The name of the entrypoint script. Ignored if the
            mode != SCRIPT_MODE
        worker_launch_time_ms: The time when the worker process for this worker
            is launched. If the worker is not launched by raylet (e.g.,
            driver), this must be -1 (default value).
        worker_launched_time_ms: The time when the worker process for this worker
            finshes launching. If the worker is not launched by raylet (e.g.,
            driver), this must be -1 (default value).
    """
    # Do some basic checking to make sure we didn't call ray.init twice.
    error_message = "Perhaps you called ray.init twice by accident?"
    assert not worker.connected, error_message
    # Enable nice stack traces on SIGSEGV etc.
    try:
        if not faulthandler.is_enabled():
            faulthandler.enable(all_threads=False)
    except io.UnsupportedOperation:
        pass  # ignore
    worker.gcs_client = node.get_gcs_client()
    assert worker.gcs_client is not None
    _initialize_internal_kv(worker.gcs_client)
    ray._private.state.state._initialize_global_state(
        ray._raylet.GcsClientOptions.create(
            node.gcs_address,
            node.cluster_id.hex(),
            allow_cluster_id_nil=False,
            fetch_cluster_id_if_nil=False,
        )
    )
    worker.gcs_publisher = ray._raylet.GcsPublisher(address=worker.gcs_client.address)
    # Initialize some fields.
    if mode in (WORKER_MODE, RESTORE_WORKER_MODE, SPILL_WORKER_MODE):
        # We should not specify the job_id if it's `WORKER_MODE`.
        assert job_id is None
        job_id = JobID.nil()
    else:
        # This is the code path of driver mode.
        if job_id is None:
            job_id = ray._private.state.next_job_id()
    if mode is not SCRIPT_MODE and mode is not LOCAL_MODE and setproctitle:
        process_name = ray_constants.WORKER_PROCESS_TYPE_IDLE_WORKER
        if mode is SPILL_WORKER_MODE:
            process_name = ray_constants.WORKER_PROCESS_TYPE_SPILL_WORKER_IDLE
        elif mode is RESTORE_WORKER_MODE:
            process_name = ray_constants.WORKER_PROCESS_TYPE_RESTORE_WORKER_IDLE
        setproctitle.setproctitle(process_name)
    if not isinstance(job_id, JobID):
        raise TypeError("The type of given job id must be JobID.")
    # All workers start out as non-actors. A worker can be turned into an actor
    # after it is created.
    worker.node = node
    worker.set_mode(mode)
    # For driver's check that the version information matches the version
    # information that the Ray cluster was started with.
    try:
        node.check_version_info()
    except Exception as e:
        if mode == SCRIPT_MODE:
            raise e
        elif mode == WORKER_MODE:
            traceback_str = traceback.format_exc()
            ray._private.utils.publish_error_to_driver(
                ray_constants.VERSION_MISMATCH_PUSH_ERROR,
                traceback_str,
                gcs_publisher=worker.gcs_publisher,
                num_retries=1,
            )
    driver_name = ""
    log_stdout_file_path = ""
    log_stderr_file_path = ""
    interactive_mode = False
    if mode == SCRIPT_MODE:
        import __main__ as main
        if hasattr(main, "__file__"):
            driver_name = main.__file__
        else:
            interactive_mode = True
            driver_name = "INTERACTIVE MODE"
    elif not LOCAL_MODE:
        raise ValueError("Invalid worker mode. Expected DRIVER, WORKER or LOCAL.")
    gcs_options = ray._raylet.GcsClientOptions.create(
        node.gcs_address,
        node.cluster_id.hex(),
        allow_cluster_id_nil=False,
        fetch_cluster_id_if_nil=False,
    )
    if job_config is None:
        job_config = ray.job_config.JobConfig()
    if namespace is not None:
        ray._private.utils.validate_namespace(namespace)
        # The namespace field of job config may have already been set in code
        # paths such as the client.
        job_config.set_ray_namespace(namespace)
    # Make sure breakpoint() in the user's code will
    # invoke the Ray debugger if we are in a worker or actor process
    # (but not on the driver).
    if mode == WORKER_MODE:
        os.environ["PYTHONBREAKPOINT"] = "ray.util.rpdb.set_trace"
    else:
        # Add hook to suppress worker logs during breakpoint.
        os.environ["PYTHONBREAKPOINT"] = "ray.util.rpdb._driver_set_trace"
    worker.ray_debugger_external = ray_debugger_external
    # If it's a driver and it's not coming from ray client, we'll prepare the
    # environment here. If it's ray client, the environment will be prepared
    # at the server side.
    if mode == SCRIPT_MODE and not job_config._client_job and job_config.runtime_env:
        scratch_dir: str = worker.node.get_runtime_env_dir_path()
        runtime_env = job_config.runtime_env or {}
        runtime_env = upload_py_modules_if_needed(
            runtime_env, scratch_dir, logger=logger
        )
        runtime_env = upload_working_dir_if_needed(
            runtime_env, scratch_dir, logger=logger
        )
        runtime_env = upload_worker_process_setup_hook_if_needed(
            runtime_env,
            worker,
        )
        # Remove excludes, it isn't relevant after the upload step.
        runtime_env.pop("excludes", None)
        job_config.set_runtime_env(runtime_env)
    if mode == SCRIPT_MODE:
        # Add the directory containing the script that is running to the Python
        # paths of the workers. Also add the current directory. Note that this
        # assumes that the directory structures on the machines in the clusters
        # are the same.
        # When using an interactive shell, there is no script directory.
        # We also want to skip adding script directory when running from dashboard.
        code_paths = []
        if not interactive_mode and not (
            namespace and namespace == ray_constants.RAY_INTERNAL_DASHBOARD_NAMESPACE
        ):
            script_directory = os.path.dirname(os.path.realpath(sys.argv[0]))
            # If driver's sys.path doesn't include the script directory
            # (e.g driver is started via `python -m`,
            # see https://peps.python.org/pep-0338/),
            # then we shouldn't add it to the workers.
            if script_directory in sys.path:
                code_paths.append(script_directory)
        # In client mode, if we use runtime envs with "working_dir", then
        # it'll be handled automatically.  Otherwise, add the current dir.
        if not job_config._client_job and not job_config._runtime_env_has_working_dir():
            current_directory = os.path.abspath(os.path.curdir)
            code_paths.append(current_directory)
        if len(code_paths) != 0:
            job_config._py_driver_sys_path.extend(code_paths)
    serialized_job_config = job_config._serialize()
    if not node.should_redirect_logs():
        # Logging to stderr, so give core worker empty logs directory.
        logs_dir = ""
    else:
        logs_dir = node.get_logs_dir_path()
    worker.core_worker = ray._raylet.CoreWorker(
        mode,
        node.plasma_store_socket_name,
        node.raylet_socket_name,
        job_id,
        gcs_options,
        logs_dir,
        node.node_ip_address,
        node.node_manager_port,
        node.raylet_ip_address,
        (mode == LOCAL_MODE),
        driver_name,
        log_stdout_file_path,
        log_stderr_file_path,
        serialized_job_config,
        node.metrics_agent_port,
        runtime_env_hash,
        startup_token,
        session_name,
        node.cluster_id.hex(),
        "" if mode != SCRIPT_MODE else entrypoint,
        worker_launch_time_ms,
        worker_launched_time_ms,
    )
    if mode == SCRIPT_MODE:
        worker_id = worker.worker_id
        worker.gcs_error_subscriber = ray._raylet.GcsErrorSubscriber(
            worker_id=worker_id, address=worker.gcs_client.address
        )
        worker.gcs_log_subscriber = ray._raylet.GcsLogSubscriber(
            worker_id=worker_id, address=worker.gcs_client.address
        )
    if driver_object_store_memory is not None:
        logger.warning(
            "`driver_object_store_memory` is deprecated"
            " and will be removed in the future."
        )
    # If this is a driver running in SCRIPT_MODE, start a thread to print error
    # messages asynchronously in the background. Ideally the scheduler would
    # push messages to the driver's worker service, but we ran into bugs when
    # trying to properly shutdown the driver's worker service, so we are
    # temporarily using this implementation which constantly queries the
    # scheduler for new error messages.
    if mode == SCRIPT_MODE:
        worker.listener_thread = threading.Thread(
            target=listen_error_messages,
            name="ray_listen_error_messages",
            args=(worker, worker.threads_stopped),
        )
        worker.listener_thread.daemon = True
        worker.listener_thread.start()
        # If the job's logging config is set, don't add the prefix
        # (task/actor's name and its PID) to the logs.
        ignore_prefix = global_worker.job_logging_config is not None
        if log_to_driver:
            global_worker_stdstream_dispatcher.add_handler(
                "ray_print_logs",
                functools.partial(print_to_stdstream, ignore_prefix=ignore_prefix),
            )
            worker.logger_thread = threading.Thread(
                target=worker.print_logs, name="ray_print_logs"
            )
            worker.logger_thread.daemon = True
            worker.logger_thread.start()
    # Setup tracing here
    tracing_hook_val = worker.gcs_client.internal_kv_get(
        b"tracing_startup_hook", ray_constants.KV_NAMESPACE_TRACING
    )
    if tracing_hook_val is not None:
        ray.util.tracing.tracing_helper._enable_tracing()
        if not getattr(ray, "__traced__", False):
            _setup_tracing = _import_from_string(tracing_hook_val.decode("utf-8"))
            _setup_tracing()
            ray.__traced__ = True
    # Mark the worker as connected.
    worker.set_is_connected(True)
def disconnect(exiting_interpreter=False):
    """Disconnect this worker from the raylet and object store."""
    # Reset the list of cached remote functions and actors so that if more
    # remote functions or actors are defined and then connect is called again,
    # the remote functions will be exported. This is mostly relevant for the
    # tests.
    worker = global_worker
    if worker.connected:
        # Shutdown all of the threads that we've started. TODO(rkn): This
        # should be handled cleanly in the worker object's destructor and not
        # in this disconnect method.
        worker.threads_stopped.set()
        if hasattr(worker, "gcs_error_subscriber"):
            worker.gcs_error_subscriber.close()
        if hasattr(worker, "gcs_log_subscriber"):
            worker.gcs_log_subscriber.close()
        if hasattr(worker, "listener_thread"):
            worker.listener_thread.join()
        if hasattr(worker, "logger_thread"):
            worker.logger_thread.join()
        worker.threads_stopped.clear()
        # Ignore the prefix if the logging config is set.
        ignore_prefix = worker.job_logging_config is not None
        for leftover in stdout_deduplicator.flush():
            print_worker_logs(leftover, sys.stdout, ignore_prefix)
        for leftover in stderr_deduplicator.flush():
            print_worker_logs(leftover, sys.stderr, ignore_prefix)
        global_worker_stdstream_dispatcher.remove_handler("ray_print_logs")
    worker.node = None  # Disconnect the worker from the node.
    worker.serialization_context_map.clear()
    try:
        ray_actor = ray.actor
    except AttributeError:
        ray_actor = None  # This can occur during program termination
    if ray_actor is not None:
        ray_actor._ActorClassMethodMetadata.reset_cache()
    # Mark the worker as disconnected.
    worker.set_is_connected(False)
@contextmanager
def _changeproctitle(title, next_title):
    if _mode() is not LOCAL_MODE:
        setproctitle.setproctitle(title)
    try:
        yield
    finally:
        if _mode() is not LOCAL_MODE:
            setproctitle.setproctitle(next_title)
@DeveloperAPI
def show_in_dashboard(message: str, key: str = "", dtype: str = "text"):
    """Display message in dashboard.
    Display message for the current task or actor in the dashboard.
    For example, this can be used to display the status of a long-running
    computation.
    Args:
        message: Message to be displayed.
        key: The key name for the message. Multiple message under
            different keys will be displayed at the same time. Messages
            under the same key will be overridden.
        dtype: The type of message for rendering. One of the
            following: text, html.
    """
    worker = global_worker
    worker.check_connected()
    acceptable_dtypes = {"text", "html"}
    assert dtype in acceptable_dtypes, f"dtype accepts only: {acceptable_dtypes}"
    message_wrapped = {"message": message, "dtype": dtype}
    message_encoded = json.dumps(message_wrapped).encode()
    worker.core_worker.set_webui_display(key.encode(), message_encoded)
# Global variable to make sure we only send out the warning once.
blocking_get_inside_async_warned = False
@overload
def get(
    object_refs: "Sequence[ObjectRef[Any]]", *, timeout: Optional[float] = None
) -> List[Any]:
    ...
@overload
def get(
    object_refs: "Sequence[ObjectRef[R]]", *, timeout: Optional[float] = None
) -> List[R]:
    ...
@overload
def get(object_refs: "ObjectRef[R]", *, timeout: Optional[float] = None) -> R:
    ...
@overload
def get(
    object_refs: Sequence[CompiledDAGRef], *, timeout: Optional[float] = None
) -> List[Any]:
    ...
@overload
def get(object_refs: CompiledDAGRef, *, timeout: Optional[float] = None) -> Any:
    ...
[docs]
@PublicAPI
@client_mode_hook
def get(
    object_refs: Union[
        "ObjectRef[Any]",
        Sequence["ObjectRef[Any]"],
        CompiledDAGRef,
        Sequence[CompiledDAGRef],
    ],
    *,
    timeout: Optional[float] = None,
) -> Union[Any, List[Any]]:
    """Get a remote object or a list of remote objects from the object store.
    This method blocks until the object corresponding to the object ref is
    available in the local object store. If this object is not in the local
    object store, it will be shipped from an object store that has it (once the
    object has been created). If object_refs is a list, then the objects
    corresponding to each object in the list will be returned.
    Ordering for an input list of object refs is preserved for each object
    returned. That is, if an object ref to A precedes an object ref to B in the
    input list, then A will precede B in the returned list.
    This method will issue a warning if it's running inside async context,
    you can use ``await object_ref`` instead of ``ray.get(object_ref)``. For
    a list of object refs, you can use ``await asyncio.gather(*object_refs)``.
    Passing :class:`~ObjectRefGenerator` is not allowed.
    Related patterns and anti-patterns:
    - :doc:`/ray-core/patterns/ray-get-loop`
    - :doc:`/ray-core/patterns/unnecessary-ray-get`
    - :doc:`/ray-core/patterns/ray-get-submission-order`
    - :doc:`/ray-core/patterns/ray-get-too-many-objects`
    Args:
        object_refs: Object ref of the object to get or a list of object refs
            to get.
        timeout (Optional[float]): The maximum amount of time in seconds to
            wait before returning. Set this to None will block until the
            corresponding object becomes available. Setting ``timeout=0`` will
            return the object immediately if it's available, else raise
            GetTimeoutError in accordance with the above docstring.
    Returns:
        A Python object or a list of Python objects.
    Raises:
        GetTimeoutError: A GetTimeoutError is raised if a timeout is set and
            the get takes longer than timeout to return.
        Exception: An exception is raised immediately if any task that created
            the object or that created one of the objects raised an exception,
            without waiting for the remaining ones to finish.
    """
    worker = global_worker
    worker.check_connected()
    if hasattr(worker, "core_worker") and worker.core_worker.current_actor_is_asyncio():
        global blocking_get_inside_async_warned
        if not blocking_get_inside_async_warned:
            logger.warning(
                "Using blocking ray.get inside async actor. "
                "This blocks the event loop. Please use `await` "
                "on object ref with asyncio.gather if you want to "
                "yield execution to the event loop instead."
            )
            blocking_get_inside_async_warned = True
    with profiling.profile("ray.get"):
        # TODO(sang): Should make ObjectRefGenerator
        # compatible to ray.get for dataset.
        if isinstance(object_refs, ObjectRefGenerator):
            return object_refs
        if isinstance(object_refs, CompiledDAGRef):
            return object_refs.get(timeout=timeout)
        if isinstance(object_refs, list):
            all_compiled_dag_refs = True
            any_compiled_dag_refs = False
            for object_ref in object_refs:
                is_dag_ref = isinstance(object_ref, CompiledDAGRef)
                all_compiled_dag_refs = all_compiled_dag_refs and is_dag_ref
                any_compiled_dag_refs = any_compiled_dag_refs or is_dag_ref
            if all_compiled_dag_refs:
                return [object_ref.get(timeout=timeout) for object_ref in object_refs]
            elif any_compiled_dag_refs:
                raise ValueError(
                    "Invalid type of object refs. 'object_refs' must be a list of "
                    "CompiledDAGRefs if there is any CompiledDAGRef within it. "
                )
        is_individual_id = isinstance(object_refs, ray.ObjectRef)
        if is_individual_id:
            object_refs = [object_refs]
        if not isinstance(object_refs, list):
            raise ValueError(
                f"Invalid type of object refs, {type(object_refs)}, is given. "
                "'object_refs' must either be an ObjectRef or a list of ObjectRefs. "
            )
        # TODO(ujvl): Consider how to allow user to retrieve the ready objects.
        values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
        for i, value in enumerate(values):
            if isinstance(value, RayError):
                if isinstance(value, ray.exceptions.ObjectLostError):
                    worker.core_worker.dump_object_store_memory_usage()
                if isinstance(value, RayTaskError):
                    raise value.as_instanceof_cause()
                else:
                    raise value
        if is_individual_id:
            values = values[0]
        if debugger_breakpoint != b"":
            frame = sys._getframe().f_back
            rdb = ray.util.pdb._connect_ray_pdb(
                host=None,
                port=None,
                patch_stdstreams=False,
                quiet=None,
                breakpoint_uuid=(
                    debugger_breakpoint.decode() if debugger_breakpoint else None
                ),
                debugger_external=worker.ray_debugger_external,
            )
            rdb.set_trace(frame=frame)
        return values 
[docs]
@PublicAPI
@client_mode_hook
def put(
    value: Any,
    *,
    _owner: Optional["ray.actor.ActorHandle"] = None,
) -> "ray.ObjectRef":
    """Store an object in the object store.
    The object may not be evicted while a reference to the returned ID exists.
    Related patterns and anti-patterns:
    - :doc:`/ray-core/patterns/return-ray-put`
    - :doc:`/ray-core/patterns/pass-large-arg-by-value`
    - :doc:`/ray-core/patterns/closure-capture-large-objects`
    Args:
        value: The Python object to be stored.
        _owner [Experimental]: The actor that should own this object. This
            allows creating objects with lifetimes decoupled from that of the
            creating process. The owner actor must be passed a reference to the
            object prior to the object creator exiting, otherwise the reference
            will still be lost. *Note that this argument is an experimental API
            and should be avoided if possible.*
    Returns:
        The object ref assigned to this value.
    """
    worker = global_worker
    worker.check_connected()
    if _owner is None:
        serialize_owner_address = None
    elif isinstance(_owner, ray.actor.ActorHandle):
        # Ensure `ray._private.state.state.global_state_accessor` is not None
        ray._private.state.state._check_connected()
        serialize_owner_address = (
            ray._raylet._get_actor_serialized_owner_address_or_none(
                ray._private.state.state.global_state_accessor.get_actor_info(
                    _owner._actor_id
                )
            )
        )
        if not serialize_owner_address:
            raise RuntimeError(f"{_owner} is not alive, it's worker_id is empty!")
    else:
        raise TypeError(f"Expect an `ray.actor.ActorHandle`, but got: {type(_owner)}")
    with profiling.profile("ray.put"):
        try:
            object_ref = worker.put_object(value, owner_address=serialize_owner_address)
        except ObjectStoreFullError:
            logger.info(
                "Put failed since the value was either too large or the "
                "store was full of pinned objects."
            )
            raise
        return object_ref 
# Global variable to make sure we only send out the warning once.
blocking_wait_inside_async_warned = False
[docs]
@PublicAPI
@client_mode_hook
def wait(
    ray_waitables: List[Union[ObjectRef, ObjectRefGenerator]],
    *,
    num_returns: int = 1,
    timeout: Optional[float] = None,
    fetch_local: bool = True,
) -> Tuple[
    List[Union[ObjectRef, ObjectRefGenerator]],
    List[Union[ObjectRef, ObjectRefGenerator]],
]:
    """Return a list of IDs that are ready and a list of IDs that are not.
    If timeout is set, the function returns either when the requested number of
    IDs are ready or when the timeout is reached, whichever occurs first. If it
    is not set, the function simply waits until that number of objects is ready
    and returns that exact number of object refs.
    `ray_waitables` is a list of :class:`~ray.ObjectRef` and
    :class:`~ray.ObjectRefGenerator`.
    The method returns two lists, ready and unready `ray_waitables`.
    ObjectRef:
        object refs that correspond to objects that are available
        in the object store are in the first list.
        The rest of the object refs are in the second list.
    ObjectRefGenerator:
            Generators whose next reference (that will be obtained
            via `next(generator)`) has a corresponding object available
            in the object store are in the first list.
            All other generators are placed in the second list.
    Ordering of the input list of ray_waitables is preserved. That is, if A
    precedes B in the input list, and both are in the ready list, then A will
    precede B in the ready list. This also holds true if A and B are both in
    the remaining list.
    This method will issue a warning if it's running inside an async context.
    Instead of ``ray.wait(ray_waitables)``, you can use
    ``await asyncio.wait(ray_waitables)``.
    Related patterns and anti-patterns:
    - :doc:`/ray-core/patterns/limit-pending-tasks`
    - :doc:`/ray-core/patterns/ray-get-submission-order`
    Args:
        ray_waitables: List of :class:`~ObjectRef` or
            :class:`~ObjectRefGenerator` for objects that may or may
            not be ready. Note that these must be unique.
        num_returns: The number of ray_waitables that should be returned.
        timeout: The maximum amount of time in seconds to wait before
            returning.
        fetch_local: If True, wait for the object to be downloaded onto
            the local node before returning it as ready. If the `ray_waitable`
            is a generator, it will wait until the next object in the generator
            is downloaed. If False, ray.wait() will not trigger fetching of
            objects to the local node and will return immediately once the
            object is available anywhere in the cluster.
    Returns:
        A list of object refs that are ready and a list of the remaining object
        IDs.
    """
    worker = global_worker
    worker.check_connected()
    if (
        hasattr(worker, "core_worker")
        and worker.core_worker.current_actor_is_asyncio()
        and timeout != 0
    ):
        global blocking_wait_inside_async_warned
        if not blocking_wait_inside_async_warned:
            logger.debug(
                "Using blocking ray.wait inside async method. "
                "This blocks the event loop. Please use `await` "
                "on object ref with asyncio.wait. "
            )
            blocking_wait_inside_async_warned = True
    if isinstance(ray_waitables, ObjectRef) or isinstance(
        ray_waitables, ObjectRefGenerator
    ):
        raise TypeError(
            "wait() expected a list of ray.ObjectRef or ray.ObjectRefGenerator"
            ", got a single ray.ObjectRef or ray.ObjectRefGenerator "
            f"{ray_waitables}"
        )
    if not isinstance(ray_waitables, list):
        raise TypeError(
            "wait() expected a list of ray.ObjectRef or "
            "ray.ObjectRefGenerator, "
            f"got {type(ray_waitables)}"
        )
    if timeout is not None and timeout < 0:
        raise ValueError(
            "The 'timeout' argument must be nonnegative. " f"Received {timeout}"
        )
    for ray_waitable in ray_waitables:
        if not isinstance(ray_waitable, ObjectRef) and not isinstance(
            ray_waitable, ObjectRefGenerator
        ):
            raise TypeError(
                "wait() expected a list of ray.ObjectRef or "
                "ray.ObjectRefGenerator, "
                f"got list containing {type(ray_waitable)}"
            )
    worker.check_connected()
    # TODO(swang): Check main thread.
    with profiling.profile("ray.wait"):
        # TODO(rkn): This is a temporary workaround for
        # https://github.com/ray-project/ray/issues/997. However, it should be
        # fixed in Arrow instead of here.
        if len(ray_waitables) == 0:
            return [], []
        if len(ray_waitables) != len(set(ray_waitables)):
            raise ValueError("Wait requires a list of unique ray_waitables.")
        if num_returns <= 0:
            raise ValueError("Invalid number of objects to return %d." % num_returns)
        if num_returns > len(ray_waitables):
            raise ValueError(
                "num_returns cannot be greater than the number "
                "of ray_waitables provided to ray.wait."
            )
        timeout = timeout if timeout is not None else 10**6
        timeout_milliseconds = int(timeout * 1000)
        ready_ids, remaining_ids = worker.core_worker.wait(
            ray_waitables,
            num_returns,
            timeout_milliseconds,
            fetch_local,
        )
        return ready_ids, remaining_ids 
[docs]
@PublicAPI
@client_mode_hook
def get_actor(name: str, namespace: Optional[str] = None) -> "ray.actor.ActorHandle":
    """Get a handle to a named actor.
    Gets a handle to an actor with the given name. The actor must
    have been created with Actor.options(name="name").remote(). This
    works for both detached & non-detached actors.
    This method is a sync call and it'll timeout after 60s. This can be modified
    by setting OS env RAY_gcs_server_request_timeout_seconds before starting
    the cluster.
    Args:
        name: The name of the actor.
        namespace: The namespace of the actor, or None to specify the current
            namespace.
    Returns:
        ActorHandle to the actor.
    Raises:
        ValueError: if the named actor does not exist.
    """
    if not name:
        raise ValueError("Please supply a non-empty value to get_actor")
    if namespace is not None:
        ray._private.utils.validate_namespace(namespace)
    worker = global_worker
    worker.check_connected()
    return worker.core_worker.get_named_actor_handle(name, namespace or "") 
[docs]
@PublicAPI
@client_mode_hook
def kill(actor: "ray.actor.ActorHandle", *, no_restart: bool = True):
    """Kill an actor forcefully.
    This will interrupt any running tasks on the actor, causing them to fail
    immediately. ``atexit`` handlers installed in the actor will not be run.
    If you want to kill the actor but let pending tasks finish,
    you can call ``actor.__ray_terminate__.remote()`` instead to queue a
    termination task. Any ``atexit`` handlers installed in the actor *will*
    be run in this case.
    If the actor is a detached actor, subsequent calls to get its handle via
    ray.get_actor will fail.
    Args:
        actor: Handle to the actor to kill.
        no_restart: Whether or not this actor should be restarted if
            it's a restartable actor.
    """
    worker = global_worker
    worker.check_connected()
    if not isinstance(actor, ray.actor.ActorHandle):
        raise ValueError(
            "ray.kill() only supported for actors. For tasks, try ray.cancel(). "
            "Got: {}.".format(type(actor))
        )
    worker.core_worker.kill_actor(actor._ray_actor_id, no_restart) 
[docs]
@PublicAPI
@client_mode_hook
def cancel(
    ray_waitable: Union["ObjectRef[R]", "ObjectRefGenerator[R]"],
    *,
    force: bool = False,
    recursive: bool = True,
) -> None:
    """Cancels a task.
    Cancel API has a different behavior depending on if it is a remote function
    (Task) or a remote Actor method (Actor Task).
    Task:
        If the specified Task is pending execution, it is cancelled and not
        executed. If the Task is currently executing, the behavior depends
        on the `force` flag. When `force=False`, a KeyboardInterrupt is
        raised in Python and when `force=True`, the executing Task
        immediately exits. If the Task is already finished, nothing happens.
        Cancelled Tasks aren't retried. `max_task_retries` aren't respected.
        Calling ray.get on a cancelled Task raises a TaskCancelledError
        if the Task has been scheduled or interrupted.
        It raises a WorkerCrashedError if `force=True`.
        If `recursive=True`, all the child Tasks and Actor Tasks
        are cancelled. If `force=True` and `recursive=True`, `force=True`
        is ignored for child Actor Tasks.
    Actor Task:
        If the specified Task is pending execution, it is cancelled and not
        executed. If the Task is currently executing, the behavior depends
        on the execution model of an Actor. If it is a regular Actor
        or a threaded Actor, the execution isn't cancelled.
        Actor Tasks cannot be interrupted because Actors have
        states. If it is an async Actor, Ray cancels a `asyncio.Task`.
        The semantic of cancellation is equivalent to asyncio's cancellation.
        https://docs.python.org/3/library/asyncio-task.html#task-cancellation
        If the Task has finished, nothing happens.
        Only `force=False` is allowed for an Actor Task. Otherwise, it raises
        `ValueError`. Use `ray.kill(actor)` instead to kill an Actor.
        Cancelled Tasks aren't retried. `max_task_retries` aren't respected.
        Calling ray.get on a cancelled Task raises a TaskCancelledError
        if the Task has been scheduled or interrupted. Also note that
        only async actor tasks can be interrupted.
        If `recursive=True`, all the child Tasks and actor Tasks
        are cancelled.
    Args:
        ray_waitable: :class:`~ObjectRef` and
            :class:`~ObjectRefGenerator`
            returned by the task that should be canceled.
        force: Whether to force-kill a running task by killing
            the worker that is running the task.
        recursive: Whether to try to cancel tasks submitted by the
            task specified.
    """
    worker = ray._private.worker.global_worker
    worker.check_connected()
    if isinstance(ray_waitable, ray._raylet.ObjectRefGenerator):
        assert hasattr(ray_waitable, "_generator_ref")
        ray_waitable = ray_waitable._generator_ref
    if not isinstance(ray_waitable, ray.ObjectRef):
        raise TypeError(
            "ray.cancel() only supported for object refs. "
            f"For actors, try ray.kill(). Got: {type(ray_waitable)}."
        )
    return worker.core_worker.cancel_task(ray_waitable, force, recursive) 
def _mode(worker=global_worker):
    """This is a wrapper around worker.mode.
    We use this wrapper so that in the remote decorator, we can call _mode()
    instead of worker.mode. The difference is that when we attempt to
    serialize remote functions, we don't attempt to serialize the worker
    object, which cannot be serialized.
    """
    return worker.mode
def _make_remote(function_or_class, options):
    if not function_or_class.__module__:
        function_or_class.__module__ = "global"
    if inspect.isfunction(function_or_class) or is_cython(function_or_class):
        ray_option_utils.validate_task_options(options, in_options=False)
        return ray.remote_function.RemoteFunction(
            Language.PYTHON,
            function_or_class,
            None,
            options,
        )
    if inspect.isclass(function_or_class):
        ray_option_utils.validate_actor_options(options, in_options=False)
        return ray.actor._make_actor(function_or_class, options)
    raise TypeError(
        "The @ray.remote decorator must be applied to either a function or a class."
    )
class RemoteDecorator(Protocol):
    @overload
    def __call__(self, __function: Callable[[], R]) -> RemoteFunctionNoArgs[R]:
        ...
    @overload
    def __call__(self, __function: Callable[[T0], R]) -> RemoteFunction0[R, T0]:
        ...
    @overload
    def __call__(self, __function: Callable[[T0, T1], R]) -> RemoteFunction1[R, T0, T1]:
        ...
    @overload
    def __call__(
        self, __function: Callable[[T0, T1, T2], R]
    ) -> RemoteFunction2[R, T0, T1, T2]:
        ...
    @overload
    def __call__(
        self, __function: Callable[[T0, T1, T2, T3], R]
    ) -> RemoteFunction3[R, T0, T1, T2, T3]:
        ...
    @overload
    def __call__(
        self, __function: Callable[[T0, T1, T2, T3, T4], R]
    ) -> RemoteFunction4[R, T0, T1, T2, T3, T4]:
        ...
    @overload
    def __call__(
        self, __function: Callable[[T0, T1, T2, T3, T4, T5], R]
    ) -> RemoteFunction5[R, T0, T1, T2, T3, T4, T5]:
        ...
    @overload
    def __call__(
        self, __function: Callable[[T0, T1, T2, T3, T4, T5, T6], R]
    ) -> RemoteFunction6[R, T0, T1, T2, T3, T4, T5, T6]:
        ...
    @overload
    def __call__(
        self, __function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7], R]
    ) -> RemoteFunction7[R, T0, T1, T2, T3, T4, T5, T6, T7]:
        ...
    @overload
    def __call__(
        self, __function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8], R]
    ) -> RemoteFunction8[R, T0, T1, T2, T3, T4, T5, T6, T7, T8]:
        ...
    @overload
    def __call__(
        self, __function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9], R]
    ) -> RemoteFunction9[R, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9]:
        ...
    # Pass on typing actors for now. The following makes it so no type errors
    # are generated for actors.
    @overload
    def __call__(self, __t: type) -> Any:
        ...
@overload
def remote(__function: Callable[[], R]) -> RemoteFunctionNoArgs[R]:
    ...
@overload
def remote(__function: Callable[[T0], R]) -> RemoteFunction0[R, T0]:
    ...
@overload
def remote(__function: Callable[[T0, T1], R]) -> RemoteFunction1[R, T0, T1]:
    ...
@overload
def remote(__function: Callable[[T0, T1, T2], R]) -> RemoteFunction2[R, T0, T1, T2]:
    ...
@overload
def remote(
    __function: Callable[[T0, T1, T2, T3], R]
) -> RemoteFunction3[R, T0, T1, T2, T3]:
    ...
@overload
def remote(
    __function: Callable[[T0, T1, T2, T3, T4], R]
) -> RemoteFunction4[R, T0, T1, T2, T3, T4]:
    ...
@overload
def remote(
    __function: Callable[[T0, T1, T2, T3, T4, T5], R]
) -> RemoteFunction5[R, T0, T1, T2, T3, T4, T5]:
    ...
@overload
def remote(
    __function: Callable[[T0, T1, T2, T3, T4, T5, T6], R]
) -> RemoteFunction6[R, T0, T1, T2, T3, T4, T5, T6]:
    ...
@overload
def remote(
    __function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7], R]
) -> RemoteFunction7[R, T0, T1, T2, T3, T4, T5, T6, T7]:
    ...
@overload
def remote(
    __function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8], R]
) -> RemoteFunction8[R, T0, T1, T2, T3, T4, T5, T6, T7, T8]:
    ...
@overload
def remote(
    __function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9], R]
) -> RemoteFunction9[R, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9]:
    ...
# Pass on typing actors for now. The following makes it so no type errors
# are generated for actors.
@overload
def remote(__t: type) -> Any:
    ...
# Passing options
@overload
def remote(
    *,
    num_returns: Union[int, Literal["streaming"]] = Undefined,
    num_cpus: Union[int, float] = Undefined,
    num_gpus: Union[int, float] = Undefined,
    resources: Dict[str, float] = Undefined,
    accelerator_type: str = Undefined,
    memory: Union[int, float] = Undefined,
    max_calls: int = Undefined,
    max_restarts: int = Undefined,
    max_task_retries: int = Undefined,
    max_retries: int = Undefined,
    runtime_env: Dict[str, Any] = Undefined,
    retry_exceptions: bool = Undefined,
    scheduling_strategy: Union[
        None, Literal["DEFAULT"], Literal["SPREAD"], PlacementGroupSchedulingStrategy
    ] = Undefined,
) -> RemoteDecorator:
    ...
[docs]
@PublicAPI
def remote(
    *args, **kwargs
) -> Union[ray.remote_function.RemoteFunction, ray.actor.ActorClass]:
    """Defines a remote function or an actor class.
    This function can be used as a decorator with no arguments
    to define a remote function or actor as follows:
    .. testcode::
        import ray
        @ray.remote
        def f(a, b, c):
            return a + b + c
        object_ref = f.remote(1, 2, 3)
        result = ray.get(object_ref)
        assert result == (1 + 2 + 3)
        @ray.remote
        class Foo:
            def __init__(self, arg):
                self.x = arg
            def method(self, a):
                return self.x + a
        actor_handle = Foo.remote(123)
        object_ref = actor_handle.method.remote(321)
        result = ray.get(object_ref)
        assert result == (123 + 321)
    Equivalently, use a function call to create a remote function or actor.
    .. testcode::
        def g(a, b, c):
            return a + b + c
        remote_g = ray.remote(g)
        object_ref = remote_g.remote(1, 2, 3)
        assert ray.get(object_ref) == (1 + 2 + 3)
        class Bar:
            def __init__(self, arg):
                self.x = arg
            def method(self, a):
                return self.x + a
        RemoteBar = ray.remote(Bar)
        actor_handle = RemoteBar.remote(123)
        object_ref = actor_handle.method.remote(321)
        result = ray.get(object_ref)
        assert result == (123 + 321)
    It can also be used with specific keyword arguments as follows:
    .. testcode::
        @ray.remote(num_gpus=1, max_calls=1, num_returns=2)
        def f():
            return 1, 2
        @ray.remote(num_cpus=2, resources={"CustomResource": 1})
        class Foo:
            def method(self):
                return 1
    Remote task and actor objects returned by @ray.remote can also be
    dynamically modified with the same arguments as above using
    ``.options()`` as follows:
    .. testcode::
        :hide:
        ray.shutdown()
        ray.init(num_cpus=5, num_gpus=5)
    .. testcode::
        @ray.remote(num_gpus=1, max_calls=1, num_returns=2)
        def f():
            return 1, 2
        f_with_2_gpus = f.options(num_gpus=2)
        object_refs = f_with_2_gpus.remote()
        assert ray.get(object_refs) == [1, 2]
        @ray.remote(num_cpus=2, resources={"CustomResource": 1})
        class Foo:
            def method(self):
                return 1
        Foo_with_no_resources = Foo.options(num_cpus=1, resources=None)
        foo_actor = Foo_with_no_resources.remote()
        assert ray.get(foo_actor.method.remote()) == 1
    A remote actor will be terminated when all actor handle to it
    in Python is deleted, which will cause them to complete any outstanding
    work and then shut down. If you only have 1 reference to an actor handle,
    calling ``del actor`` *could* trigger actor deletion. Note that your program
    may have multiple references to the same ActorHandle, and actor termination
    will not occur until the reference count goes to 0. See the Python
    documentation for more context about object deletion.
    https://docs.python.org/3.9/reference/datamodel.html#object.__del__
    If you want to kill actors immediately, you can also call ``ray.kill(actor)``.
    .. tip::
        Avoid repeatedly passing in large arguments to remote task or method calls.
        Instead, use ray.put to create a copy of the object in the object store.
        See :ref:`more info here <ray-pass-large-arg-by-value>`.
    Args:
        num_returns: This is only for *remote functions*. It specifies
            the number of object refs returned by the remote function
            invocation. The default value is 1.
            Pass "dynamic" to allow the task to decide how many
            return values to return during execution, and the caller will
            receive an ObjectRef[DynamicObjectRefGenerator].
            See :ref:`dynamic generators <dynamic-generators>` for more details.
        num_cpus: The quantity of CPU resources to reserve
            for this task or for the lifetime of the actor.
            By default, tasks use 1 CPU resource and actors use 1 CPU
            for scheduling and 0 CPU for running
            (This means, by default, actors cannot get scheduled on a zero-cpu node,
            but an infinite number of them can run on any non-zero cpu node.
            The default value for actors was chosen for historical reasons.
            It’s recommended to always explicitly set num_cpus for actors
            to avoid any surprises.
            If resources are specified explicitly,
            they are required for both scheduling and running.)
            See :ref:`specifying resource requirements <resource-requirements>`
            for more details.
        num_gpus: The quantity of GPU resources to reserve
            for this task or for the lifetime of the actor.
            The default value is 0.
            See :ref:`Ray GPU support <gpu-support>` for more details.
        resources (Dict[str, float]): The quantity of various
            :ref:`custom resources <custom-resources>`
            to reserve for this task or for the lifetime of the actor.
            This is a dictionary mapping strings (resource names) to floats.
            By default it is empty.
        accelerator_type: If specified, requires that the task or actor run
            on a node with the specified type of accelerator.
            See :ref:`accelerator types <accelerator_types>`.
        memory: The heap memory request in bytes for this task/actor,
            rounded down to the nearest integer.
        max_calls: Only for *remote functions*. This specifies the
            maximum number of times that a given worker can execute
            the given remote function before it must exit
            (this can be used to address :ref:`memory leaks <gpu-leak>` in third-party
            libraries or to reclaim resources that cannot easily be
            released, e.g., GPU memory that was acquired by TensorFlow).
            By default this is infinite for CPU tasks and 1 for GPU tasks
            (to force GPU tasks to release resources after finishing).
        max_restarts: Only for *actors*. This specifies the maximum
            number of times that the actor should be restarted when it dies
            unexpectedly. The minimum valid value is 0 (default),
            which indicates that the actor doesn't need to be restarted.
            A value of -1 indicates that an actor should be restarted
            indefinitely.
            See :ref:`actor fault tolerance <fault-tolerance-actors>` for more details.
        max_task_retries: Only for *actors*. How many times to
            retry an actor task if the task fails due to a system error,
            e.g., the actor has died. If set to -1, the system will
            retry the failed task until the task succeeds, or the actor
            has reached its max_restarts limit. If set to `n > 0`, the
            system will retry the failed task up to n times, after which the
            task will throw a `RayActorError` exception upon :obj:`ray.get`.
            Note that Python exceptions are not considered system errors
            and will not trigger retries.
            The default value is 0.
            See :ref:`actor fault tolerance <fault-tolerance-actors>` for more details.
        max_retries: Only for *remote functions*. This specifies
            the maximum number of times that the remote function
            should be rerun when the worker process executing it
            crashes unexpectedly. The minimum valid value is 0,
            the default value is 3, and a value of -1 indicates
            infinite retries.
            See :ref:`task fault tolerance <fault-tolerance-tasks>` for more details.
        runtime_env (Dict[str, Any]): Specifies the runtime environment for
            this actor or task and its children. See
            :ref:`runtime-environments` for detailed documentation.
        retry_exceptions: Only for *remote functions*. This specifies whether
            application-level errors should be retried up to max_retries times.
            This can be a boolean or a list of exceptions that should be retried.
            See :ref:`task fault tolerance <fault-tolerance-tasks>` for more details.
        scheduling_strategy: Strategy about how to
            schedule a remote function or actor. Possible values are
            None: ray will figure out the scheduling strategy to use, it
            will either be the PlacementGroupSchedulingStrategy using parent's
            placement group if parent has one and has
            placement_group_capture_child_tasks set to true,
            or "DEFAULT";
            "DEFAULT": default hybrid scheduling;
            "SPREAD": best effort spread scheduling;
            `PlacementGroupSchedulingStrategy`:
            placement group based scheduling;
            `NodeAffinitySchedulingStrategy`:
            node id based affinity scheduling.
            See :ref:`Ray scheduling strategies <ray-scheduling-strategies>`
            for more details.
        _metadata: Extended options for Ray libraries. For example,
            _metadata={"workflows.io/options": <workflow options>} for Ray workflows.
        _labels: The key-value labels of a task or actor.
    """
    # "callable" returns true for both function and class.
    if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
        # This is the case where the decorator is just @ray.remote.
        # "args[0]" is the class or function under the decorator.
        return _make_remote(args[0], {})
    assert len(args) == 0 and len(kwargs) > 0, ray_option_utils.remote_args_error_string
    return functools.partial(_make_remote, options=kwargs)