Source code for ray.dag.input_node
from typing import Any, Dict, List, Union, Optional
from ray.dag import DAGNode
from ray.dag.format_utils import get_dag_node_str
from ray.experimental.gradio_utils import type_to_string
from ray.util.annotations import DeveloperAPI
IN_CONTEXT_MANAGER = "__in_context_manager__"
[docs]
@DeveloperAPI
class InputNode(DAGNode):
    r"""Ray dag node used in DAG building API to mark entrypoints of a DAG.
    Should only be function or class method. A DAG can have multiple
    entrypoints, but only one instance of InputNode exists per DAG, shared
    among all DAGNodes.
    Example:
    .. code-block::
                m1.forward
                /       \
        dag_input     ensemble -> dag_output
                \       /
                m2.forward
    In this pipeline, each user input is broadcasted to both m1.forward and
    m2.forward as first stop of the DAG, and authored like
    .. code-block:: python
        import ray
        @ray.remote
        class Model:
            def __init__(self, val):
                self.val = val
            def forward(self, input):
                return self.val * input
        @ray.remote
        def combine(a, b):
            return a + b
        with InputNode() as dag_input:
            m1 = Model.bind(1)
            m2 = Model.bind(2)
            m1_output = m1.forward.bind(dag_input[0])
            m2_output = m2.forward.bind(dag_input.x)
            ray_dag = combine.bind(m1_output, m2_output)
        # Pass mix of args and kwargs as input.
        ray_dag.execute(1, x=2) # 1 sent to m1, 2 sent to m2
        # Alternatively user can also pass single data object, list or dict
        # and access them via list index, object attribute or dict key str.
        ray_dag.execute(UserDataObject(m1=1, m2=2))
        # dag_input.m1, dag_input.m2
        ray_dag.execute([1, 2])
        # dag_input[0], dag_input[1]
        ray_dag.execute({"m1": 1, "m2": 2})
        # dag_input["m1"], dag_input["m2"]
    """
[docs]
    def __init__(
        self,
        *args,
        input_type: Optional[Union[type, Dict[Union[int, str], type]]] = None,
        _other_args_to_resolve=None,
        **kwargs,
    ):
        """InputNode should only take attributes of validating and converting
        input data rather than the input data itself. User input should be
        provided via `ray_dag.execute(user_input)`.
        Args:
            input_type: Describes the data type of inputs user will be giving.
                - if given through singular InputNode: type of InputNode
                - if given through InputAttributeNodes: map of key -> type
                Used when deciding what Gradio block to represent the input nodes with.
            _other_args_to_resolve: Internal only to keep InputNode's execution
                context throughput pickling, replacement and serialization.
                User should not use or pass this field.
        """
        if len(args) != 0 or len(kwargs) != 0:
            raise ValueError("InputNode should not take any args or kwargs.")
        self.input_attribute_nodes = {}
        self.input_type = input_type
        if input_type is not None and isinstance(input_type, type):
            if _other_args_to_resolve is None:
                _other_args_to_resolve = {}
            _other_args_to_resolve["result_type_string"] = type_to_string(input_type)
        super().__init__([], {}, {}, other_args_to_resolve=_other_args_to_resolve)
    def _copy_impl(
        self,
        new_args: List[Any],
        new_kwargs: Dict[str, Any],
        new_options: Dict[str, Any],
        new_other_args_to_resolve: Dict[str, Any],
    ):
        return InputNode(_other_args_to_resolve=new_other_args_to_resolve)
    def _execute_impl(self, *args, **kwargs):
        """Executor of InputNode."""
        # Catch and assert singleton context at dag execution time.
        assert self._in_context_manager(), (
            "InputNode is a singleton instance that should be only used in "
            "context manager for dag building and execution. See the docstring "
            "of class InputNode for examples."
        )
        # If user only passed in one value, for simplicity we just return it.
        if len(args) == 1 and len(kwargs) == 0:
            return args[0]
        return DAGInputData(*args, **kwargs)
    def _in_context_manager(self) -> bool:
        """Return if InputNode is created in context manager."""
        if (
            not self._bound_other_args_to_resolve
            or IN_CONTEXT_MANAGER not in self._bound_other_args_to_resolve
        ):
            return False
        else:
            return self._bound_other_args_to_resolve[IN_CONTEXT_MANAGER]
[docs]
    def set_context(self, key: str, val: Any):
        """Set field in parent DAGNode attribute that can be resolved in both
        pickle and JSON serialization
        """
        self._bound_other_args_to_resolve[key] = val
    def __str__(self) -> str:
        return get_dag_node_str(self, "__InputNode__")
    def __getattr__(self, key: str):
        assert isinstance(
            key, str
        ), "Please only access dag input attributes with str key."
        if key not in self.input_attribute_nodes:
            self.input_attribute_nodes[key] = InputAttributeNode(
                self, key, "__getattr__"
            )
        return self.input_attribute_nodes[key]
    def __getitem__(self, key: Union[int, str]) -> Any:
        assert isinstance(key, (str, int)), (
            "Please only use int index or str as first-level key to "
            "access fields of dag input."
        )
        input_type = None
        if self.input_type is not None and key in self.input_type:
            input_type = type_to_string(self.input_type[key])
        if key not in self.input_attribute_nodes:
            self.input_attribute_nodes[key] = InputAttributeNode(
                self, key, "__getitem__", input_type
            )
        return self.input_attribute_nodes[key]
    def __enter__(self):
        self.set_context(IN_CONTEXT_MANAGER, True)
        return self
    def __exit__(self, *args):
        pass
[docs]
    def get_result_type(self) -> str:
        """Get type of the output of this DAGNode.
        Generated by ray.experimental.gradio_utils.type_to_string().
        """
        if "result_type_string" in self._bound_other_args_to_resolve:
            return self._bound_other_args_to_resolve["result_type_string"]
@DeveloperAPI
class InputAttributeNode(DAGNode):
    """Represents partial access of user input based on an index (int),
     object attribute or dict key (str).
    Examples:
        .. code-block:: python
            with InputNode() as dag_input:
                a = dag_input[0]
                b = dag_input.x
                ray_dag = add.bind(a, b)
            # This makes a = 1 and b = 2
            ray_dag.execute(1, x=2)
            with InputNode() as dag_input:
                a = dag_input[0]
                b = dag_input[1]
                ray_dag = add.bind(a, b)
            # This makes a = 2 and b = 3
            ray_dag.execute(2, 3)
            # Alternatively, you can input a single object
            # and the inputs are automatically indexed from the object:
            # This makes a = 2 and b = 3
            ray_dag.execute([2, 3])
    """
    def __init__(
        self,
        dag_input_node: InputNode,
        key: Union[int, str],
        accessor_method: str,
        input_type: str = None,
    ):
        self._dag_input_node = dag_input_node
        self._key = key
        self._accessor_method = accessor_method
        super().__init__(
            [],
            {},
            {},
            {
                "dag_input_node": dag_input_node,
                "key": key,
                "accessor_method": accessor_method,
                # Type of the input tied to this node. Used by
                # gradio_visualize_graph.GraphVisualizer to determine which Gradio
                # component should be used for this node.
                "result_type_string": input_type,
            },
        )
    def _copy_impl(
        self,
        new_args: List[Any],
        new_kwargs: Dict[str, Any],
        new_options: Dict[str, Any],
        new_other_args_to_resolve: Dict[str, Any],
    ):
        return InputAttributeNode(
            new_other_args_to_resolve["dag_input_node"],
            new_other_args_to_resolve["key"],
            new_other_args_to_resolve["accessor_method"],
            new_other_args_to_resolve["result_type_string"],
        )
    def _execute_impl(self, *args, **kwargs):
        """Executor of InputAttributeNode.
        Args and kwargs are to match base class signature, but not in the
        implementation. All args and kwargs should be resolved and replaced
        with value in bound_args and bound_kwargs via bottom-up recursion when
        current node is executed.
        """
        if isinstance(self._dag_input_node, DAGInputData):
            return self._dag_input_node[self._key]
        else:
            # dag.execute() is called with only one arg, thus when an
            # InputAttributeNode is executed, its dependent InputNode is
            # resolved with original user input python object.
            user_input_python_object = self._dag_input_node
            if isinstance(self._key, str):
                if self._accessor_method == "__getitem__":
                    return user_input_python_object[self._key]
                elif self._accessor_method == "__getattr__":
                    return getattr(user_input_python_object, self._key)
            elif isinstance(self._key, int):
                return user_input_python_object[self._key]
            else:
                raise ValueError(
                    "Please only use int index or str as first-level key to "
                    "access fields of dag input."
                )
    def __str__(self) -> str:
        return get_dag_node_str(self, f'["{self._key}"]')
    def get_result_type(self) -> str:
        """Get type of the output of this DAGNode.
        Generated by ray.experimental.gradio_utils.type_to_string().
        """
        if "result_type_string" in self._bound_other_args_to_resolve:
            return self._bound_other_args_to_resolve["result_type_string"]
    @property
    def key(self) -> Union[int, str]:
        return self._key
@DeveloperAPI
class DAGInputData:
    """If user passed multiple args and kwargs directly to dag.execute(), we
    generate this wrapper for all user inputs as one object, accessible via
    list index or object attribute key.
    """
    def __init__(self, *args, **kwargs):
        self._args = list(args)
        self._kwargs = kwargs
    def __getitem__(self, key: Union[int, str]) -> Any:
        if isinstance(key, int):
            # Access list args by index.
            return self._args[key]
        elif isinstance(key, str):
            # Access kwarg by key.
            return self._kwargs[key]
        else:
            raise ValueError(
                "Please only use int index or str as first-level key to "
                "access fields of dag input."
            )