import logging
import os
import warnings
from typing import Dict, List, Optional, TYPE_CHECKING, Union
import gymnasium as gym
from gymnasium.spaces import Discrete, MultiDiscrete
import numpy as np
from packaging import version
import tree  # pip install dm_tree
from ray.rllib.models.repeated_values import RepeatedValues
from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI, OldAPIStack
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.numpy import SMALL_NUMBER
from ray.rllib.utils.typing import (
    LocalOptimizer,
    NetworkType,
    SpaceStruct,
    TensorStructType,
    TensorType,
)
if TYPE_CHECKING:
    from ray.rllib.core.learner.learner import ParamDict, ParamList
    from ray.rllib.policy.torch_policy import TorchPolicy
    from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
logger = logging.getLogger(__name__)
torch, nn = try_import_torch()
# Limit values suitable for use as close to a -inf logit. These are useful
# since -inf / inf cause NaNs during backprop.
FLOAT_MIN = -3.4e38
FLOAT_MAX = 3.4e38
if torch:
    TORCH_COMPILE_REQUIRED_VERSION = version.parse("2.0.0")
else:
    TORCH_COMPILE_REQUIRED_VERSION = ValueError(
        "torch is not installed. " "TORCH_COMPILE_REQUIRED_VERSION is " "not defined."
    )
@OldAPIStack
def apply_grad_clipping(
    policy: "TorchPolicy", optimizer: LocalOptimizer, loss: TensorType
) -> Dict[str, TensorType]:
    """Applies gradient clipping to already computed grads inside `optimizer`.
    Note: This function does NOT perform an analogous operation as
    tf.clip_by_global_norm. It merely clips by norm (per gradient tensor) and
    then computes the global norm across all given tensors (but without clipping
    by that global norm).
    Args:
        policy: The TorchPolicy, which calculated `loss`.
        optimizer: A local torch optimizer object.
        loss: The torch loss tensor.
    Returns:
        An info dict containing the "grad_norm" key and the resulting clipped
        gradients.
    """
    grad_gnorm = 0
    if policy.config["grad_clip"] is not None:
        clip_value = policy.config["grad_clip"]
    else:
        clip_value = np.inf
    num_none_grads = 0
    for param_group in optimizer.param_groups:
        # Make sure we only pass params with grad != None into torch
        # clip_grad_norm_. Would fail otherwise.
        params = list(filter(lambda p: p.grad is not None, param_group["params"]))
        if params:
            # PyTorch clips gradients inplace and returns the norm before clipping
            # We therefore need to compute grad_gnorm further down (fixes #4965)
            global_norm = nn.utils.clip_grad_norm_(params, clip_value)
            if isinstance(global_norm, torch.Tensor):
                global_norm = global_norm.cpu().numpy()
            grad_gnorm += min(global_norm, clip_value)
        else:
            num_none_grads += 1
    # Note (Kourosh): grads could indeed be zero. This method should still return
    # grad_gnorm in that case.
    if num_none_grads == len(optimizer.param_groups):
        # No grads available
        return {}
    return {"grad_gnorm": grad_gnorm}
[docs]
@PublicAPI
def clip_gradients(
    gradients_dict: "ParamDict",
    *,
    grad_clip: Optional[float] = None,
    grad_clip_by: str = "value",
) -> TensorType:
    """Performs gradient clipping on a grad-dict based on a clip value and clip mode.
    Changes the provided gradient dict in place.
    Args:
        gradients_dict: The gradients dict, mapping str to gradient tensors.
        grad_clip: The value to clip with. The way gradients are clipped is defined
            by the `grad_clip_by` arg (see below).
        grad_clip_by: One of 'value', 'norm', or 'global_norm'.
    Returns:
        If `grad_clip_by`="global_norm" and `grad_clip` is not None, returns the global
        norm of all tensors, otherwise returns None.
    """
    # No clipping, return.
    if grad_clip is None:
        return
    # Clip by value (each gradient individually).
    if grad_clip_by == "value":
        for k, v in gradients_dict.copy().items():
            gradients_dict[k] = (
                None if v is None else torch.clip(v, -grad_clip, grad_clip)
            )
    # Clip by L2-norm (per gradient tensor).
    elif grad_clip_by == "norm":
        for k, v in gradients_dict.copy().items():
            if v is not None:
                # Compute the L2-norm of the gradient tensor.
                norm = v.norm(2).nan_to_num(neginf=-10e8, posinf=10e8)
                # Clip all the gradients.
                if norm > grad_clip:
                    v.mul_(grad_clip / norm)
    # Clip by global L2-norm (across all gradient tensors).
    else:
        assert (
            grad_clip_by == "global_norm"
        ), f"`grad_clip_by` ({grad_clip_by}) must be one of [value|norm|global_norm]!"
        gradients_list = list(gradients_dict.values())
        total_norm = compute_global_norm(gradients_list)
        if len(gradients_list) == 0:
            return total_norm
        # We do want the coefficient to be in between 0.0 and 1.0, therefore
        # if the global_norm is smaller than the clip value, we use the clip value
        # as normalization constant.
        device = gradients_list[0].device
        clip_coef = grad_clip / torch.maximum(
            torch.tensor(grad_clip).to(device), total_norm + 1e-6
        )
        # Note: multiplying by the clamped coef is redundant when the coef is clamped to
        # 1, but doing so avoids a `if clip_coef < 1:` conditional which can require a
        # CPU <=> device synchronization when the gradients do not reside in CPU memory.
        clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
        for g in gradients_list:
            if g is not None:
                g.detach().mul_(clip_coef_clamped.to(g.device))
        return total_norm 
[docs]
@PublicAPI
def compute_global_norm(gradients_list: "ParamList") -> TensorType:
    """Computes the global norm for a gradients dict.
    Args:
        gradients_list: The gradients list containing parameters.
    Returns:
        Returns the global norm of all tensors in `gradients_list`.
    """
    # Define the norm type to be L2.
    norm_type = 2.0
    # If we have no grads, return zero.
    if len(gradients_list) == 0:
        return torch.tensor(0.0)
    device = gradients_list[0].device
    # Compute the global norm.
    total_norm = torch.norm(
        torch.stack(
            [
                torch.norm(g.detach(), norm_type)
                # Note, we want to avoid overflow in the norm computation, this does
                # not affect the gradients themselves as we clamp by multiplying and
                # not by overriding tensor values.
                .nan_to_num(neginf=-10e8, posinf=10e8).to(device)
                for g in gradients_list
                if g is not None
            ]
        ),
        norm_type,
    ).nan_to_num(neginf=-10e8, posinf=10e8)
    if torch.logical_or(total_norm.isnan(), total_norm.isinf()):
        raise RuntimeError(
            f"The total norm of order {norm_type} for gradients from "
            "`parameters` is non-finite, so it cannot be clipped. "
        )
    # Return the global norm.
    return total_norm 
@OldAPIStack
def concat_multi_gpu_td_errors(
    policy: Union["TorchPolicy", "TorchPolicyV2"]
) -> Dict[str, TensorType]:
    """Concatenates multi-GPU (per-tower) TD error tensors given TorchPolicy.
    TD-errors are extracted from the TorchPolicy via its tower_stats property.
    Args:
        policy: The TorchPolicy to extract the TD-error values from.
    Returns:
        A dict mapping strings "td_error" and "mean_td_error" to the
        corresponding concatenated and mean-reduced values.
    """
    td_error = torch.cat(
        [
            t.tower_stats.get("td_error", torch.tensor([0.0])).to(policy.device)
            for t in policy.model_gpu_towers
        ],
        dim=0,
    )
    policy.td_error = td_error
    return {
        "td_error": td_error,
        "mean_td_error": torch.mean(td_error),
    }
[docs]
@PublicAPI
def convert_to_torch_tensor(
    x: TensorStructType,
    device: Optional[str] = None,
    pin_memory: bool = False,
):
    """Converts any struct to torch.Tensors.
    Args:
        x: Any (possibly nested) struct, the values in which will be
            converted and returned as a new struct with all leaves converted
            to torch tensors.
        device: The device to create the tensor on.
        pin_memory: If True, will call the `pin_memory()` method on the created tensors.
    Returns:
        Any: A new struct with the same structure as `x`, but with all
        values converted to torch Tensor types. This does not convert possibly
        nested elements that are None because torch has no representation for that.
    """
    def mapping(item):
        if item is None:
            # Torch has no representation for `None`, so we return None
            return item
        # Special handling of "Repeated" values.
        if isinstance(item, RepeatedValues):
            return RepeatedValues(
                tree.map_structure(mapping, item.values), item.lengths, item.max_len
            )
        # Already torch tensor -> make sure it's on right device.
        if torch.is_tensor(item):
            tensor = item
        # Numpy arrays.
        elif isinstance(item, np.ndarray):
            # Object type (e.g. info dicts in train batch): leave as-is.
            # str type (e.g. agent_id in train batch): leave as-is.
            if item.dtype == object or item.dtype.type is np.str_:
                return item
            # Non-writable numpy-arrays will cause PyTorch warning.
            elif item.flags.writeable is False:
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore")
                    tensor = torch.from_numpy(item)
            # Already numpy: Wrap as torch tensor.
            else:
                tensor = torch.from_numpy(item)
        # Everything else: Convert to numpy, then wrap as torch tensor.
        else:
            tensor = torch.from_numpy(np.asarray(item))
        # Floatify all float64 tensors (but leave float16 as-is).
        if tensor.is_floating_point() and str(tensor.dtype) != "torch.float16":
            tensor = tensor.float()
        # Pin the tensor's memory (for faster transfer to GPU later).
        if pin_memory and torch.cuda.is_available():
            tensor.pin_memory()
        return tensor if device is None else tensor.to(device)
    return tree.map_structure(mapping, x) 
@PublicAPI
def copy_torch_tensors(x: TensorStructType, device: Optional[str] = None):
    """Creates a copy of `x` and makes deep copies torch.Tensors in x.
    Also moves the copied tensors to the specified device (if not None).
    Note if an object in x is not a torch.Tensor, it will be shallow-copied.
    Args:
        x : Any (possibly nested) struct possibly containing torch.Tensors.
        device : The device to move the tensors to.
    Returns:
        Any: A new struct with the same structure as `x`, but with all
            torch.Tensors deep-copied and moved to the specified device.
    """
    def mapping(item):
        if isinstance(item, torch.Tensor):
            return (
                torch.clone(item.detach())
                if device is None
                else item.detach().to(device)
            )
        else:
            return item
    return tree.map_structure(mapping, x)
[docs]
@PublicAPI
def explained_variance(y: TensorType, pred: TensorType) -> TensorType:
    """Computes the explained variance for a pair of labels and predictions.
    The formula used is:
    max(-1.0, 1.0 - (std(y - pred)^2 / std(y)^2))
    Args:
        y: The labels.
        pred: The predictions.
    Returns:
        The explained variance given a pair of labels and predictions.
    """
    y_var = torch.var(y, dim=[0])
    diff_var = torch.var(y - pred, dim=[0])
    min_ = torch.tensor([-1.0]).to(pred.device)
    return torch.max(min_, 1 - (diff_var / (y_var + SMALL_NUMBER)))[0] 
[docs]
@PublicAPI
def global_norm(tensors: List[TensorType]) -> TensorType:
    """Returns the global L2 norm over a list of tensors.
    output = sqrt(SUM(t ** 2 for t in tensors)),
        where SUM reduces over all tensors and over all elements in tensors.
    Args:
        tensors: The list of tensors to calculate the global norm over.
    Returns:
        The global L2 norm over the given tensor list.
    """
    # List of single tensors' L2 norms: SQRT(SUM(xi^2)) over all xi in tensor.
    single_l2s = [torch.pow(torch.sum(torch.pow(t, 2.0)), 0.5) for t in tensors]
    # Compute global norm from all single tensors' L2 norms.
    return torch.pow(sum(torch.pow(l2, 2.0) for l2 in single_l2s), 0.5) 
@OldAPIStack
def huber_loss(x: TensorType, delta: float = 1.0) -> TensorType:
    """Computes the huber loss for a given term and delta parameter.
    Reference: https://en.wikipedia.org/wiki/Huber_loss
    Note that the factor of 0.5 is implicitly included in the calculation.
    Formula:
        L = 0.5 * x^2  for small abs x (delta threshold)
        L = delta * (abs(x) - 0.5*delta)  for larger abs x (delta threshold)
    Args:
        x: The input term, e.g. a TD error.
        delta: The delta parmameter in the above formula.
    Returns:
        The Huber loss resulting from `x` and `delta`.
    """
    return torch.where(
        torch.abs(x) < delta,
        torch.pow(x, 2.0) * 0.5,
        delta * (torch.abs(x) - 0.5 * delta),
    )
@OldAPIStack
def l2_loss(x: TensorType) -> TensorType:
    """Computes half the L2 norm over a tensor's values without the sqrt.
    output = 0.5 * sum(x ** 2)
    Args:
        x: The input tensor.
    Returns:
        0.5 times the L2 norm over the given tensor's values (w/o sqrt).
    """
    return 0.5 * torch.sum(torch.pow(x, 2.0))
[docs]
@PublicAPI
def one_hot(x: TensorType, space: gym.Space) -> TensorType:
    """Returns a one-hot tensor, given and int tensor and a space.
    Handles the MultiDiscrete case as well.
    Args:
        x: The input tensor.
        space: The space to use for generating the one-hot tensor.
    Returns:
        The resulting one-hot tensor.
    Raises:
        ValueError: If the given space is not a discrete one.
    .. testcode::
        import torch
        import gymnasium as gym
        from ray.rllib.utils.torch_utils import one_hot
        x = torch.IntTensor([0, 3])  # batch-dim=2
        # Discrete space with 4 (one-hot) slots per batch item.
        s = gym.spaces.Discrete(4)
        print(one_hot(x, s))
        x = torch.IntTensor([[0, 1, 2, 3]])  # batch-dim=1
        # MultiDiscrete space with 5 + 4 + 4 + 7 = 20 (one-hot) slots
        # per batch item.
        s = gym.spaces.MultiDiscrete([5, 4, 4, 7])
        print(one_hot(x, s))
    .. testoutput::
        tensor([[1, 0, 0, 0],
                [0, 0, 0, 1]])
        tensor([[1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0]])
    """
    if isinstance(space, Discrete):
        return nn.functional.one_hot(x.long(), space.n)
    elif isinstance(space, MultiDiscrete):
        if isinstance(space.nvec[0], np.ndarray):
            nvec = np.ravel(space.nvec)
            x = x.reshape(x.shape[0], -1)
        else:
            nvec = space.nvec
        return torch.cat(
            [nn.functional.one_hot(x[:, i].long(), n) for i, n in enumerate(nvec)],
            dim=-1,
        )
    else:
        raise ValueError("Unsupported space for `one_hot`: {}".format(space)) 
[docs]
@PublicAPI
def reduce_mean_ignore_inf(x: TensorType, axis: Optional[int] = None) -> TensorType:
    """Same as torch.mean() but ignores -inf values.
    Args:
        x: The input tensor to reduce mean over.
        axis: The axis over which to reduce. None for all axes.
    Returns:
        The mean reduced inputs, ignoring inf values.
    """
    mask = torch.ne(x, float("-inf"))
    x_zeroed = torch.where(mask, x, torch.zeros_like(x))
    return torch.sum(x_zeroed, axis) / torch.sum(mask.float(), axis) 
[docs]
@PublicAPI
def sequence_mask(
    lengths: TensorType,
    maxlen: Optional[int] = None,
    dtype=None,
    time_major: bool = False,
) -> TensorType:
    """Offers same behavior as tf.sequence_mask for torch.
    Thanks to Dimitris Papatheodorou
    (https://discuss.pytorch.org/t/pytorch-equivalent-for-tf-sequence-mask/
    39036).
    Args:
        lengths: The tensor of individual lengths to mask by.
        maxlen: The maximum length to use for the time axis. If None, use
            the max of `lengths`.
        dtype: The torch dtype to use for the resulting mask.
        time_major: Whether to return the mask as [B, T] (False; default) or
            as [T, B] (True).
    Returns:
         The sequence mask resulting from the given input and parameters.
    """
    # If maxlen not given, use the longest lengths in the `lengths` tensor.
    if maxlen is None:
        maxlen = lengths.max()
    mask = torch.ones(tuple(lengths.shape) + (int(maxlen),))
    mask = ~(mask.to(lengths.device).cumsum(dim=1).t() > lengths)
    # Time major transformation.
    if not time_major:
        mask = mask.t()
    # By default, set the mask to be boolean.
    mask.type(dtype or torch.bool)
    return mask 
[docs]
@PublicAPI
def update_target_network(
    main_net: NetworkType,
    target_net: NetworkType,
    tau: float,
) -> None:
    """Updates a torch.nn.Module target network using Polyak averaging.
    .. code-block:: text
        new_target_net_weight = (
            tau * main_net_weight + (1.0 - tau) * current_target_net_weight
        )
    Args:
        main_net: The nn.Module to update from.
        target_net: The target network to update.
        tau: The tau value to use in the Polyak averaging formula.
    """
    # Get the current parameters from the Q network.
    state_dict = main_net.state_dict()
    # Use here Polyak averaging.
    new_state_dict = {
        k: tau * state_dict[k] + (1 - tau) * v
        for k, v in target_net.state_dict().items()
    }
    # Apply the new parameters to the target Q network.
    target_net.load_state_dict(new_state_dict) 
@DeveloperAPI
def warn_if_infinite_kl_divergence(
    policy: "TorchPolicy",
    kl_divergence: TensorType,
) -> None:
    if policy.loss_initialized() and kl_divergence.isinf():
        logger.warning(
            "KL divergence is non-finite, this will likely destabilize your model and"
            " the training process. Action(s) in a specific state have near-zero"
            " probability. This can happen naturally in deterministic environments"
            " where the optimal policy has zero mass for a specific action. To fix this"
            " issue, consider setting the coefficient for the KL loss term to zero or"
            " increasing policy entropy."
        )
[docs]
@PublicAPI
def set_torch_seed(seed: Optional[int] = None) -> None:
    """Sets the torch random seed to the given value.
    Args:
        seed: The seed to use or None for no seeding.
    """
    if seed is not None and torch:
        torch.manual_seed(seed)
        # See https://github.com/pytorch/pytorch/issues/47672.
        cuda_version = torch.version.cuda
        if cuda_version is not None and float(torch.version.cuda) >= 10.2:
            os.environ["CUBLAS_WORKSPACE_CONFIG"] = "4096:8"
        else:
            # Not all Operations support this.
            torch.use_deterministic_algorithms(True)
        # This is only for Convolution no problem.
        torch.backends.cudnn.deterministic = True 
[docs]
@PublicAPI
def softmax_cross_entropy_with_logits(
    logits: TensorType,
    labels: TensorType,
) -> TensorType:
    """Same behavior as tf.nn.softmax_cross_entropy_with_logits.
    Args:
        x: The input predictions.
        labels: The labels corresponding to `x`.
    Returns:
        The resulting softmax cross-entropy given predictions and labels.
    """
    return torch.sum(-labels * nn.functional.log_softmax(logits, -1), -1) 
def _dynamo_is_available():
    # This only works if torch._dynamo is available
    try:
        # TODO(Artur): Remove this once torch._dynamo is available on CI
        import torch._dynamo as dynamo  # noqa: F401
        return True
    except ImportError:
        return False