"""A utility for debugging serialization issues."""
import inspect
from contextlib import contextmanager
from typing import Any, Optional, Set, Tuple
# Import ray first to use the bundled colorama
import ray  # noqa: F401
import colorama
import ray.cloudpickle as cp
from ray.util.annotations import DeveloperAPI
@contextmanager
def _indent(printer):
    printer.level += 1
    yield
    printer.level -= 1
class _Printer:
    def __init__(self, print_file):
        self.level = 0
        self.print_file = print_file
    def indent(self):
        return _indent(self)
    def print(self, msg):
        indent = "    " * self.level
        print(indent + msg, file=self.print_file)
@DeveloperAPI
class FailureTuple:
    """Represents the serialization 'frame'.
    Attributes:
        obj: The object that fails serialization.
        name: The variable name of the object.
        parent: The object that references the `obj`.
    """
    def __init__(self, obj: Any, name: str, parent: Any):
        self.obj = obj
        self.name = name
        self.parent = parent
    def __repr__(self):
        return f"FailTuple({self.name} [obj={self.obj}, parent={self.parent}])"
def _inspect_func_serialization(base_obj, depth, parent, failure_set, printer):
    """Adds the first-found non-serializable element to the failure_set."""
    assert inspect.isfunction(base_obj)
    closure = inspect.getclosurevars(base_obj)
    found = False
    if closure.globals:
        printer.print(
            f"Detected {len(closure.globals)} global variables. "
            "Checking serializability..."
        )
        with printer.indent():
            for name, obj in closure.globals.items():
                serializable, _ = _inspect_serializability(
                    obj,
                    name=name,
                    depth=depth - 1,
                    parent=parent,
                    failure_set=failure_set,
                    printer=printer,
                )
                found = found or not serializable
                if found:
                    break
    if closure.nonlocals:
        printer.print(
            f"Detected {len(closure.nonlocals)} nonlocal variables. "
            "Checking serializability..."
        )
        with printer.indent():
            for name, obj in closure.nonlocals.items():
                serializable, _ = _inspect_serializability(
                    obj,
                    name=name,
                    depth=depth - 1,
                    parent=parent,
                    failure_set=failure_set,
                    printer=printer,
                )
                found = found or not serializable
                if found:
                    break
    if not found:
        printer.print(
            f"WARNING: Did not find non-serializable object in {base_obj}. "
            "This may be an oversight."
        )
    return found
def _inspect_generic_serialization(base_obj, depth, parent, failure_set, printer):
    """Adds the first-found non-serializable element to the failure_set."""
    assert not inspect.isfunction(base_obj)
    functions = inspect.getmembers(base_obj, predicate=inspect.isfunction)
    found = False
    with printer.indent():
        for name, obj in functions:
            serializable, _ = _inspect_serializability(
                obj,
                name=name,
                depth=depth - 1,
                parent=parent,
                failure_set=failure_set,
                printer=printer,
            )
            found = found or not serializable
            if found:
                break
    with printer.indent():
        members = inspect.getmembers(base_obj)
        for name, obj in members:
            if name.startswith("__") and name.endswith("__") or inspect.isbuiltin(obj):
                continue
            serializable, _ = _inspect_serializability(
                obj,
                name=name,
                depth=depth - 1,
                parent=parent,
                failure_set=failure_set,
                printer=printer,
            )
            found = found or not serializable
            if found:
                break
    if not found:
        printer.print(
            f"WARNING: Did not find non-serializable object in {base_obj}. "
            "This may be an oversight."
        )
    return found
[docs]
@DeveloperAPI
def inspect_serializability(
    base_obj: Any,
    name: Optional[str] = None,
    depth: int = 3,
    print_file: Optional[Any] = None,
) -> Tuple[bool, Set[FailureTuple]]:
    """Identifies what objects are preventing serialization.
    Args:
        base_obj: Object to be serialized.
        name: Optional name of string.
        depth: Depth of the scope stack to walk through. Defaults to 3.
        print_file: file argument that will be passed to print().
    Returns:
        bool: True if serializable.
        set[FailureTuple]: Set of unserializable objects.
    .. versionadded:: 1.1.0
    """
    printer = _Printer(print_file)
    return _inspect_serializability(base_obj, name, depth, None, None, printer) 
def _inspect_serializability(
    base_obj, name, depth, parent, failure_set, printer
) -> Tuple[bool, Set[FailureTuple]]:
    colorama.init()
    top_level = False
    declaration = ""
    found = False
    if failure_set is None:
        top_level = True
        failure_set = set()
        declaration = f"Checking Serializability of {base_obj}"
        printer.print("=" * min(len(declaration), 80))
        printer.print(declaration)
        printer.print("=" * min(len(declaration), 80))
        if name is None:
            name = str(base_obj)
    else:
        printer.print(f"Serializing '{name}' {base_obj}...")
    try:
        cp.dumps(base_obj)
        return True, failure_set
    except Exception as e:
        printer.print(
            f"{colorama.Fore.RED}!!! FAIL{colorama.Fore.RESET} " f"serialization: {e}"
        )
        found = True
        try:
            if depth == 0:
                failure_set.add(FailureTuple(base_obj, name, parent))
        # Some objects may not be hashable, so we skip adding this to the set.
        except Exception:
            pass
    if depth <= 0:
        return False, failure_set
    # TODO: we only differentiate between 'function' and 'object'
    # but we should do a better job of diving into something
    # more specific like a Type, Object, etc.
    if inspect.isfunction(base_obj):
        _inspect_func_serialization(
            base_obj,
            depth=depth,
            parent=base_obj,
            failure_set=failure_set,
            printer=printer,
        )
    else:
        _inspect_generic_serialization(
            base_obj,
            depth=depth,
            parent=base_obj,
            failure_set=failure_set,
            printer=printer,
        )
    if not failure_set:
        failure_set.add(FailureTuple(base_obj, name, parent))
    if top_level:
        printer.print("=" * min(len(declaration), 80))
        if not failure_set:
            printer.print(
                "Nothing failed the inspect_serialization test, though "
                "serialization did not succeed."
            )
        else:
            fail_vars = (
                f"\n\n\t{colorama.Style.BRIGHT}"
                + "\n".join(str(k) for k in failure_set)
                + f"{colorama.Style.RESET_ALL}\n\n"
            )
            printer.print(
                f"Variable: {fail_vars}was found to be non-serializable. "
                "There may be multiple other undetected variables that were "
                "non-serializable. "
            )
            printer.print(
                "Consider either removing the "
                "instantiation/imports of these variables or moving the "
                "instantiation into the scope of the function/class. "
            )
        printer.print("=" * min(len(declaration), 80))
        printer.print(
            "Check https://docs.ray.io/en/master/ray-core/objects/serialization.html#troubleshooting for more information."  # noqa
        )
        printer.print(
            "If you have any suggestions on how to improve "
            "this error message, please reach out to the "
            "Ray developers on github.com/ray-project/ray/issues/"
        )
        printer.print("=" * min(len(declaration), 80))
    return not found, failure_set