from collections import defaultdict
import copy
import time
from typing import (
    Any,
    Callable,
    Collection,
    DefaultDict,
    Dict,
    List,
    Optional,
    Set,
    Union,
)
import uuid
import gymnasium as gym
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
from ray.rllib.env.utils.infinite_lookback_buffer import InfiniteLookbackBuffer
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.utils import force_list
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.error import MultiAgentEnvError
from ray.rllib.utils.spaces.space_utils import batch
from ray.rllib.utils.typing import AgentID, ModuleID, MultiAgentDict
from ray.util.annotations import PublicAPI
# TODO (simon): Include cases in which the number of agents in an
#  episode are shrinking or growing during the episode itself.
[docs]
@PublicAPI(stability="alpha")
class MultiAgentEpisode:
    """Stores multi-agent episode data.
    The central attribute of the class is the timestep mapping
    `self.env_t_to_agent_t` that maps AgentIDs to their specific environment steps to
    the agent's own scale/timesteps.
    Each AgentID in the `MultiAgentEpisode` has its own `SingleAgentEpisode` object
    in which this agent's data is stored. Together with the env_t_to_agent_t mapping,
    we can extract information either on any individual agent's time scale or from
    the (global) multi-agent environment time scale.
    Extraction of data from a MultiAgentEpisode happens via the getter APIs, e.g.
    `get_observations()`, which work analogous to the ones implemented in the
    `SingleAgentEpisode` class.
    Note that recorded `terminateds`/`truncateds` come as simple
    `MultiAgentDict`s mapping AgentID to bools and thus have no assignment to a
    certain timestep (analogous to a SingleAgentEpisode's single `terminated/truncated`
    boolean flag). Instead we assign it to the last observation recorded.
    Theoretically, there could occur edge cases in some environments
    where an agent receives partial rewards and then terminates without
    a last observation. In these cases, we duplicate the last observation.
    Also, if no initial observation has been received yet for an agent, but
    some  rewards for this same agent already occurred, we delete the agent's data
    up to here, b/c there is nothing to learn from these "premature" rewards.
    """
    __slots__ = (
        "id_",
        "agent_to_module_mapping_fn",
        "_agent_to_module_mapping",
        "observation_space",
        "action_space",
        "env_t_started",
        "env_t",
        "agent_t_started",
        "env_t_to_agent_t",
        "_hanging_actions_end",
        "_hanging_extra_model_outputs_end",
        "_hanging_rewards_end",
        "_hanging_rewards_begin",
        "is_terminated",
        "is_truncated",
        "agent_episodes",
        "_last_step_time",
        "_len_lookback_buffers",
        "_start_time",
        "_temporary_timestep_data",
    )
    SKIP_ENV_TS_TAG = "S"
[docs]
    def __init__(
        self,
        id_: Optional[str] = None,
        *,
        observations: Optional[List[MultiAgentDict]] = None,
        observation_space: Optional[gym.Space] = None,
        infos: Optional[List[MultiAgentDict]] = None,
        actions: Optional[List[MultiAgentDict]] = None,
        action_space: Optional[gym.Space] = None,
        rewards: Optional[List[MultiAgentDict]] = None,
        terminateds: Union[MultiAgentDict, bool] = False,
        truncateds: Union[MultiAgentDict, bool] = False,
        extra_model_outputs: Optional[List[MultiAgentDict]] = None,
        env_t_started: Optional[int] = None,
        agent_t_started: Optional[Dict[AgentID, int]] = None,
        len_lookback_buffer: Union[int, str] = "auto",
        agent_episode_ids: Optional[Dict[AgentID, str]] = None,
        agent_module_ids: Optional[Dict[AgentID, ModuleID]] = None,
        agent_to_module_mapping_fn: Optional[
            Callable[[AgentID, "MultiAgentEpisode"], ModuleID]
        ] = None,
    ):
        """Initializes a `MultiAgentEpisode`.
        Args:
            id_: Optional. Either a string to identify an episode or None.
                If None, a hexadecimal id is created. In case of providing
                a string, make sure that it is unique, as episodes get
                concatenated via this string.
            observations: A list of dictionaries mapping agent IDs to observations.
                Can be None. If provided, should match all other episode data
                (actions, rewards, etc.) in terms of list lengths and agent IDs.
            observation_space: An optional gym.spaces.Dict mapping agent IDs to
                individual agents' spaces, which all (individual agents') observations
                should abide to. If not None and this MultiAgentEpisode is numpy'ized
                (via the `self.to_numpy()` method), and data is appended or set, the new
                data will be checked for correctness.
            infos: A list of dictionaries mapping agent IDs to info dicts.
                Can be None. If provided, should match all other episode data
                (observations, rewards, etc.) in terms of list lengths and agent IDs.
            actions: A list of dictionaries mapping agent IDs to actions.
                Can be None. If provided, should match all other episode data
                (observations, rewards, etc.) in terms of list lengths and agent IDs.
            action_space: An optional gym.spaces.Dict mapping agent IDs to
                individual agents' spaces, which all (individual agents') actions
                should abide to. If not None and this MultiAgentEpisode is numpy'ized
                (via the `self.to_numpy()` method), and data is appended or set, the new
                data will be checked for correctness.
            rewards: A list of dictionaries mapping agent IDs to rewards.
                Can be None. If provided, should match all other episode data
                (actions, rewards, etc.) in terms of list lengths and agent IDs.
            terminateds: A boolean defining if an environment has
                terminated OR a MultiAgentDict mapping individual agent ids
                to boolean flags indicating whether individual agents have terminated.
                A special __all__ key in these dicts indicates, whether the episode
                is terminated for all agents.
                The default is `False`, i.e. the episode has not been terminated.
            truncateds: A boolean defining if the environment has been
                truncated OR a MultiAgentDict mapping individual agent ids
                to boolean flags indicating whether individual agents have been
                truncated. A special __all__ key in these dicts indicates, whether the
                episode is truncated for all agents.
                The default is `False`, i.e. the episode has not been truncated.
            extra_model_outputs: A list of dictionaries mapping agent IDs to their
                corresponding extra model outputs. Each of these "outputs" is a dict
                mapping keys (str) to model output values, for example for
                `key=STATE_OUT`, the values would be the internal state outputs for
                that agent.
            env_t_started: The env timestep (int) that defines the starting point
                of the episode. This is only larger zero, if an already ongoing episode
                chunk is being created, for example by slicing an ongoing episode or
                by calling the `cut()` method on an ongoing episode.
            agent_t_started: A dict mapping AgentIDs to the respective agent's (local)
                timestep at which its SingleAgentEpisode chunk started.
            len_lookback_buffer: The size of the lookback buffers to keep in
                front of this Episode for each type of data (observations, actions,
                etc..). If larger 0, will interpret the first `len_lookback_buffer`
                items in each type of data as NOT part of this actual
                episode chunk, but instead serve as "historical" record that may be
                viewed and used to derive new data from. For example, it might be
                necessary to have a lookback buffer of four if you would like to do
                observation frame stacking and your episode has been cut and you are now
                operating on a new chunk (continuing from the cut one). Then, for the
                first 3 items, you would have to be able to look back into the old
                chunk's data.
                If `len_lookback_buffer` is "auto" (default), will interpret all
                provided data in the constructor as part of the lookback buffers.
            agent_episode_ids: An optional dict mapping AgentIDs
                to their corresponding `SingleAgentEpisode`. If None, each
                `SingleAgentEpisode` in `MultiAgentEpisode.agent_episodes`
                will generate a hexadecimal code. If a dictionary is provided,
                make sure that IDs are unique, because the agents' `SingleAgentEpisode`
                instances are concatenated or recreated by it.
            agent_module_ids: An optional dict mapping AgentIDs to their respective
                ModuleIDs (these mapping are always valid for an entire episode and
                thus won't change during the course of this episode). If a mapping from
                agent to module has already been provided via this dict, the (optional)
                `agent_to_module_mapping_fn` will NOT be used again to map the same
                agent (agents do not change their assigned module in the course of
                one episode).
            agent_to_module_mapping_fn: A callable taking an AgentID and a
                MultiAgentEpisode as args and returning a ModuleID. Used to map agents
                that have not been mapped yet (because they just entered this episode)
                to a ModuleID. The resulting ModuleID is only stored inside the agent's
                SingleAgentEpisode object.
        """
        self.id_: str = id_ or uuid.uuid4().hex
        if agent_to_module_mapping_fn is None:
            from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
            agent_to_module_mapping_fn = (
                AlgorithmConfig.DEFAULT_AGENT_TO_MODULE_MAPPING_FN
            )
        self.agent_to_module_mapping_fn = agent_to_module_mapping_fn
        # In case a user - e.g. via callbacks - already forces a mapping to happen
        # via the `module_for()` API even before the agent has entered the episode
        # (and has its SingleAgentEpisode created), we store all aldeary done mappings
        # in this dict here.
        self._agent_to_module_mapping: Dict[AgentID, ModuleID] = agent_module_ids or {}
        # Lookback buffer length is not provided. Interpret all provided data as
        # lookback buffer.
        if len_lookback_buffer == "auto":
            len_lookback_buffer = len(rewards or [])
        self._len_lookback_buffers = len_lookback_buffer
        self.observation_space = observation_space or {}
        self.action_space = action_space or {}
        terminateds = terminateds or {}
        truncateds = truncateds or {}
        # The global last timestep of the episode and the timesteps when this chunk
        # started (excluding a possible lookback buffer).
        self.env_t_started = env_t_started or 0
        self.env_t = (
            (len(rewards) if rewards is not None else 0)
            - self._len_lookback_buffers
            + self.env_t_started
        )
        self.agent_t_started = defaultdict(int, agent_t_started or {})
        # Keeps track of the correspondence between agent steps and environment steps.
        # Under each AgentID as key is a InfiniteLookbackBuffer with the following
        # data in it:
        # The indices of the items in the data represent environment timesteps,
        # starting from index=0 for the `env.reset()` and with each `env.step()` call
        # increase by 1.
        # The values behind these (env timestep) indices represent the agent timesteps
        # happening at these env timesteps and the special value of
        # `self.SKIP_ENV_TS_TAG` means that the agent did NOT step at the given env
        # timestep.
        # Thus, agents that are part of the reset obs, will start their mapping data
        # with a [0 ...], all other agents will start their mapping data with:
        # [self.SKIP_ENV_TS_TAG, ...].
        self.env_t_to_agent_t: DefaultDict[
            AgentID, InfiniteLookbackBuffer
        ] = defaultdict(InfiniteLookbackBuffer)
        # Create caches for hanging actions/rewards/extra_model_outputs.
        # When an agent gets an observation (and then sends an action), but does not
        # receive immediately a next observation, we store the "hanging" action (and
        # related rewards and extra model outputs) in the caches postfixed w/ `_end`
        # until the next observation is received.
        self._hanging_actions_end = {}
        self._hanging_extra_model_outputs_end = defaultdict(dict)
        self._hanging_rewards_end = defaultdict(float)
        # In case of a `cut()` or `slice()`, we also need to store the hanging actions,
        # rewards, and extra model outputs that were already "hanging" in preceeding
        # episode slice.
        self._hanging_rewards_begin = defaultdict(float)
        # If this is an ongoing episode than the last `__all__` should be `False`
        self.is_terminated: bool = (
            terminateds
            if isinstance(terminateds, bool)
            else terminateds.get("__all__", False)
        )
        # If this is an ongoing episode than the last `__all__` should be `False`
        self.is_truncated: bool = (
            truncateds
            if isinstance(truncateds, bool)
            else truncateds.get("__all__", False)
        )
        # The individual agent SingleAgentEpisode objects.
        self.agent_episodes: Dict[AgentID, SingleAgentEpisode] = {}
        self._init_single_agent_episodes(
            agent_module_ids=agent_module_ids,
            agent_episode_ids=agent_episode_ids,
            observations=observations,
            infos=infos,
            actions=actions,
            rewards=rewards,
            terminateds=terminateds,
            truncateds=truncateds,
            extra_model_outputs=extra_model_outputs,
        )
        # Caches for temporary per-timestep data. May be used to store custom metrics
        # from within a callback for the ongoing episode (e.g. render images).
        self._temporary_timestep_data = defaultdict(list)
        # Keep timer stats on deltas between steps.
        self._start_time = None
        self._last_step_time = None
        # Validate ourselves.
        self.validate() 
[docs]
    def add_env_reset(
        self,
        *,
        observations: MultiAgentDict,
        infos: Optional[MultiAgentDict] = None,
    ) -> None:
        """Stores initial observation.
        Args:
            observations: A dictionary mapping agent IDs to initial observations.
                Note that some agents may not have an initial observation.
            infos: A dictionary mapping agent IDs to initial info dicts.
                Note that some agents may not have an initial info dict. If not None,
                the agent IDs in `infos` must be a subset of those in `observations`
                meaning it would not be allowed to have an agent with an info dict,
                but not with an observation.
        """
        assert not self.is_done
        # Assume that this episode is completely empty and has not stepped yet.
        # Leave self.env_t (and self.env_t_started) at 0.
        assert self.env_t == self.env_t_started == 0
        infos = infos or {}
        # Note, all agents will have an initial observation, some may have an initial
        # info dict as well.
        for agent_id, agent_obs in observations.items():
            # Update env_t_to_agent_t mapping (all agents that are part of the reset
            # obs have their first mapping 0 (env_t) -> 0 (agent_t)).
            self.env_t_to_agent_t[agent_id].append(0)
            # Create SingleAgentEpisode, if necessary.
            if agent_id not in self.agent_episodes:
                self.agent_episodes[agent_id] = SingleAgentEpisode(
                    agent_id=agent_id,
                    module_id=self.module_for(agent_id),
                    multi_agent_episode_id=self.id_,
                    observation_space=self.observation_space.get(agent_id),
                    action_space=self.action_space.get(agent_id),
                )
            # Add initial observations (and infos) to the agent's episode.
            self.agent_episodes[agent_id].add_env_reset(
                observation=agent_obs,
                infos=infos.get(agent_id),
            )
        # Validate our data.
        self.validate()
        # Start the timer for this episode.
        self._start_time = time.perf_counter() 
[docs]
    def add_env_step(
        self,
        observations: MultiAgentDict,
        actions: MultiAgentDict,
        rewards: MultiAgentDict,
        infos: Optional[MultiAgentDict] = None,
        *,
        terminateds: Optional[MultiAgentDict] = None,
        truncateds: Optional[MultiAgentDict] = None,
        extra_model_outputs: Optional[MultiAgentDict] = None,
    ) -> None:
        """Adds a timestep to the episode.
        Args:
            observations: A dictionary mapping agent IDs to their corresponding
                next observations. Note that some agents may not have stepped at this
                timestep.
            actions: Mandatory. A dictionary mapping agent IDs to their
                corresponding actions. Note that some agents may not have stepped at
                this timestep.
            rewards: Mandatory. A dictionary mapping agent IDs to their
                corresponding observations. Note that some agents may not have stepped
                at this timestep.
            infos: A dictionary mapping agent IDs to their
                corresponding info. Note that some agents may not have stepped at this
                timestep.
            terminateds: A dictionary mapping agent IDs to their `terminated` flags,
                indicating, whether the environment has been terminated for them.
                A special `__all__` key indicates that the episode is terminated for
                all agent IDs.
            terminateds: A dictionary mapping agent IDs to their `truncated` flags,
                indicating, whether the environment has been truncated for them.
                A special `__all__` key indicates that the episode is `truncated` for
                all agent IDs.
            extra_model_outputs: A dictionary mapping agent IDs to their
                corresponding specific model outputs (also in a dictionary; e.g.
                `vf_preds` for PPO).
        """
        # Cannot add data to an already done episode.
        if self.is_done:
            raise MultiAgentEnvError(
                "Cannot call `add_env_step` on a MultiAgentEpisode that is already "
                "done!"
            )
        infos = infos or {}
        terminateds = terminateds or {}
        truncateds = truncateds or {}
        extra_model_outputs = extra_model_outputs or {}
        # Increase (global) env step by one.
        self.env_t += 1
        # Find out, whether this episode is terminated/truncated (for all agents).
        # Case 1: all agents are terminated or all are truncated.
        self.is_terminated = terminateds.get("__all__", False)
        self.is_truncated = truncateds.get("__all__", False)
        # Find all agents that were done at prior timesteps and add the agents that are
        # done at the present timestep.
        agents_done = set(
            [aid for aid, sa_eps in self.agent_episodes.items() if sa_eps.is_done]
            + [aid for aid in terminateds if terminateds[aid]]
            + [aid for aid in truncateds if truncateds[aid]]
        )
        # Case 2: Some agents are truncated and the others are terminated -> Declare
        # this episode as terminated.
        if all(aid in set(agents_done) for aid in self.agent_ids):
            self.is_terminated = True
        # For all agents that are not stepping in this env step, but that are not done
        # yet -> Add a skip tag to their env- to agent-step mappings.
        stepped_agent_ids = set(observations.keys())
        for agent_id, env_t_to_agent_t in self.env_t_to_agent_t.items():
            if agent_id not in stepped_agent_ids:
                env_t_to_agent_t.append(self.SKIP_ENV_TS_TAG)
        # Loop through all agent IDs that we received data for in this step:
        # Those found in observations, actions, and rewards.
        agent_ids_with_data = (
            set(observations.keys())
            | set(actions.keys())
            | set(rewards.keys())
            | set(terminateds.keys())
            | set(truncateds.keys())
            | set(
                self.agent_episodes.keys()
                if terminateds.get("__all__") or truncateds.get("__all__")
                else set()
            )
        ) - {"__all__"}
        for agent_id in agent_ids_with_data:
            if agent_id not in self.agent_episodes:
                sa_episode = SingleAgentEpisode(
                    agent_id=agent_id,
                    module_id=self.module_for(agent_id),
                    multi_agent_episode_id=self.id_,
                    observation_space=self.observation_space.get(agent_id),
                    action_space=self.action_space.get(agent_id),
                )
            else:
                sa_episode = self.agent_episodes[agent_id]
            # Collect value to be passed (at end of for-loop) into `add_env_step()`
            # call.
            _observation = observations.get(agent_id)
            _action = actions.get(agent_id)
            _reward = rewards.get(agent_id)
            _infos = infos.get(agent_id)
            _terminated = terminateds.get(agent_id, False) or self.is_terminated
            _truncated = truncateds.get(agent_id, False) or self.is_truncated
            _extra_model_outputs = extra_model_outputs.get(agent_id)
            # The value to place into the env- to agent-step map for this agent ID.
            # _agent_step = self.SKIP_ENV_TS_TAG
            # Agents, whose SingleAgentEpisode had already been done before this
            # step should NOT have received any data in this step.
            if sa_episode.is_done and any(
                v is not None
                for v in [_observation, _action, _reward, _infos, _extra_model_outputs]
            ):
                raise MultiAgentEnvError(
                    f"Agent {agent_id} already had its `SingleAgentEpisode.is_done` "
                    f"set to True, but still received data in a following step! "
                    f"obs={_observation} act={_action} rew={_reward} info={_infos} "
                    f"extra_model_outputs={_extra_model_outputs}."
                )
            _reward = _reward or 0.0
            # CASE 1: A complete agent step is available (in one env step).
            # -------------------------------------------------------------
            # We have an observation and an action for this agent ->
            # Add the agent step to the single agent episode.
            # ... action -> next obs + reward ...
            if _observation is not None and _action is not None:
                if agent_id not in rewards:
                    raise MultiAgentEnvError(
                        f"Agent {agent_id} acted (and received next obs), but did NOT "
                        f"receive any reward from the env!"
                    )
            # CASE 2: Step gets completed with a hanging action OR first observation.
            # ------------------------------------------------------------------------
            # We have an observation, but no action ->
            # a) Action (and extra model outputs) must be hanging already. Also use
            # collected hanging rewards and extra_model_outputs.
            # b) The observation is the first observation for this agent ID.
            elif _observation is not None and _action is None:
                _action = self._hanging_actions_end.pop(agent_id, None)
                # We have a hanging action (the agent had acted after the previous
                # observation, but the env had not responded - until now - with another
                # observation).
                # ...[hanging action] ... ... -> next obs + (reward)? ...
                if _action is not None:
                    # Get the extra model output if available.
                    _extra_model_outputs = self._hanging_extra_model_outputs_end.pop(
                        agent_id, None
                    )
                    _reward = self._hanging_rewards_end.pop(agent_id, 0.0) + _reward
                # First observation for this agent, we have no hanging action.
                # ... [done]? ... -> [1st obs for agent ID]
                else:
                    # The agent is already done -> The agent thus has never stepped once
                    # and we do not have to create a SingleAgentEpisode for it.
                    if _terminated or _truncated:
                        self._del_hanging(agent_id)
                        continue
                    # This must be the agent's initial observation.
                    else:
                        # Prepend n skip tags to this agent's mapping + the initial [0].
                        assert agent_id not in self.env_t_to_agent_t
                        self.env_t_to_agent_t[agent_id].extend(
                            [self.SKIP_ENV_TS_TAG] * self.env_t + [0]
                        )
                        self.env_t_to_agent_t[
                            agent_id
                        ].lookback = self._len_lookback_buffers
                        # Make `add_env_reset` call and continue with next agent.
                        sa_episode.add_env_reset(observation=_observation, infos=_infos)
                        # Add possible reward to begin cache.
                        self._hanging_rewards_begin[agent_id] += _reward
                        # Now that the SAEps is valid, add it to our dict.
                        self.agent_episodes[agent_id] = sa_episode
                        continue
            # CASE 3: Step is started (by an action), but not completed (no next obs).
            # ------------------------------------------------------------------------
            # We have no observation, but we have a hanging action (used when we receive
            # the next obs for this agent in the future).
            elif agent_id not in observations and agent_id in actions:
                # Agent got truncated -> Error b/c we would need a last (truncation)
                # observation for this (otherwise, e.g. bootstrapping would not work).
                # [previous obs] [action] (hanging) ... ... [truncated]
                if _truncated:
                    raise MultiAgentEnvError(
                        f"Agent {agent_id} acted and then got truncated, but did NOT "
                        "receive a last (truncation) observation, required for e.g. "
                        "value function bootstrapping!"
                    )
                # Agent got terminated.
                # [previous obs] [action] (hanging) ... ... [terminated]
                elif _terminated:
                    # If the agent was terminated and no observation is provided,
                    # duplicate the previous one (this is a technical "fix" to properly
                    # complete the single agent episode; this last observation is never
                    # used for learning anyway).
                    _observation = sa_episode._last_added_observation
                    _infos = sa_episode._last_added_infos
                # Agent is still alive.
                # [previous obs] [action] (hanging) ...
                else:
                    # Hanging action, reward, and extra_model_outputs.
                    assert agent_id not in self._hanging_actions_end
                    self._hanging_actions_end[agent_id] = _action
                    self._hanging_rewards_end[agent_id] = _reward
                    self._hanging_extra_model_outputs_end[
                        agent_id
                    ] = _extra_model_outputs
            # CASE 4: Step has started in the past and is still ongoing (no observation,
            # no action).
            # --------------------------------------------------------------------------
            # Record reward and terminated/truncated flags.
            else:
                _action = self._hanging_actions_end.get(agent_id)
                # Agent is done.
                if _terminated or _truncated:
                    # If the agent has NOT stepped, we treat it as not being
                    # part of this episode.
                    # ... ... [other agents doing stuff] ... ... [agent done]
                    if _action is None:
                        self._del_hanging(agent_id)
                        continue
                    # Agent got truncated -> Error b/c we would need a last (truncation)
                    # observation for this (otherwise, e.g. bootstrapping would not
                    # work).
                    if _truncated:
                        raise MultiAgentEnvError(
                            f"Agent {agent_id} acted and then got truncated, but did "
                            "NOT receive a last (truncation) observation, required "
                            "for e.g. value function bootstrapping!"
                        )
                    # [obs] ... ... [hanging action] ... ... [done]
                    # If the agent was terminated and no observation is provided,
                    # duplicate the previous one (this is a technical "fix" to properly
                    # complete the single agent episode; this last observation is never
                    # used for learning anyway).
                    _observation = sa_episode._last_added_observation
                    _infos = sa_episode._last_added_infos
                    # `_action` is already `get` above. We don't need to pop out from
                    # the cache as it gets wiped out anyway below b/c the agent is
                    # done.
                    _extra_model_outputs = self._hanging_extra_model_outputs_end.pop(
                        agent_id, None
                    )
                    _reward = self._hanging_rewards_end.pop(agent_id, 0.0) + _reward
                # The agent is still alive, just add current reward to cache.
                else:
                    # But has never stepped in this episode -> add to begin cache.
                    if agent_id not in self.agent_episodes:
                        self._hanging_rewards_begin[agent_id] += _reward
                    # Otherwise, add to end cache.
                    else:
                        self._hanging_rewards_end[agent_id] += _reward
            # If agent is stepping, add timestep to `SingleAgentEpisode`.
            if _observation is not None:
                sa_episode.add_env_step(
                    observation=_observation,
                    action=_action,
                    reward=_reward,
                    infos=_infos,
                    terminated=_terminated,
                    truncated=_truncated,
                    extra_model_outputs=_extra_model_outputs,
                )
                # Update the env- to agent-step mapping.
                self.env_t_to_agent_t[agent_id].append(
                    len(sa_episode) + sa_episode.observations.lookback
                )
            # Agent is also done. -> Erase all hanging values for this agent
            # (they should be empty at this point anyways).
            if _terminated or _truncated:
                self._del_hanging(agent_id)
        # Validate our data.
        self.validate()
        # Step time stats.
        self._last_step_time = time.perf_counter()
        if self._start_time is None:
            self._start_time = self._last_step_time 
[docs]
    def validate(self) -> None:
        """Validates the episode's data.
        This function ensures that the data stored to a `MultiAgentEpisode` is
        in order (e.g. that the correct number of observations, actions, rewards
        are there).
        """
        for eps in self.agent_episodes.values():
            eps.validate() 
        # TODO (sven): Validate MultiAgentEpisode specifics, like the timestep mappings,
        #  action/reward caches, etc..
    @property
    def is_reset(self) -> bool:
        """Returns True if `self.add_env_reset()` has already been called."""
        return any(
            len(sa_episode.observations) > 0
            for sa_episode in self.agent_episodes.values()
        )
    @property
    def is_numpy(self) -> bool:
        """True, if the data in this episode is already stored as numpy arrays."""
        is_numpy = next(iter(self.agent_episodes.values())).is_numpy
        # Make sure that all single agent's episodes' `is_numpy` flags are the same.
        if not all(eps.is_numpy is is_numpy for eps in self.agent_episodes.values()):
            raise RuntimeError(
                f"Only some SingleAgentEpisode objects in {self} are converted to "
                f"numpy, others are not!"
            )
        return is_numpy
    @property
    def is_done(self):
        """Whether the episode is actually done (terminated or truncated).
        A done episode cannot be continued via `self.add_env_step()` or being
        concatenated on its right-side with another episode chunk or being
        succeeded via `self.cut()`.
        Note that in a multi-agent environment this does not necessarily
        correspond to single agents having terminated or being truncated.
        `self.is_terminated` should be `True`, if all agents are terminated and
        `self.is_truncated` should be `True`, if all agents are truncated. If
        only one or more (but not all!) agents are `terminated/truncated the
        `MultiAgentEpisode.is_terminated/is_truncated` should be `False`. This
        information about single agent's terminated/truncated states can always
        be retrieved from the `SingleAgentEpisode`s inside the 'MultiAgentEpisode`
        one.
        If all agents are either terminated or truncated, but in a mixed fashion,
        i.e. some are terminated and others are truncated: This is currently
        undefined and could potentially be a problem (if a user really implemented
        such a multi-agent env that behaves this way).
        Returns:
            Boolean defining if an episode has either terminated or truncated.
        """
        return self.is_terminated or self.is_truncated
[docs]
    def to_numpy(self) -> "MultiAgentEpisode":
        """Converts this Episode's list attributes to numpy arrays.
        This means in particular that this episodes' lists (per single agent) of
        (possibly complex) data (e.g. an agent having a dict obs space) will be
        converted to (possibly complex) structs, whose leafs are now numpy arrays.
        Each of these leaf numpy arrays will have the same length (batch dimension)
        as the length of the original lists.
        Note that Columns.INFOS are NEVER numpy'ized and will remain a list
        (normally, a list of the original, env-returned dicts). This is due to the
        heterogeneous nature of INFOS returned by envs, which would make it unwieldy to
        convert this information to numpy arrays.
        After calling this method, no further data may be added to this episode via
        the `self.add_env_step()` method.
        Examples:
        .. testcode::
            import numpy as np
            from ray.rllib.env.multi_agent_episode import MultiAgentEpisode
            from ray.rllib.env.tests.test_multi_agent_episode import (
                TestMultiAgentEpisode
            )
            # Create some multi-agent episode data.
            (
                observations,
                actions,
                rewards,
                terminateds,
                truncateds,
                infos,
            ) = TestMultiAgentEpisode._mock_multi_agent_records()
            # Define the agent ids.
            agent_ids = ["agent_1", "agent_2", "agent_3", "agent_4", "agent_5"]
            episode = MultiAgentEpisode(
                observations=observations,
                infos=infos,
                actions=actions,
                rewards=rewards,
                # Note: terminated/truncated have nothing to do with an episode
                # being converted `to_numpy` or not (via the `self.to_numpy()` method)!
                terminateds=terminateds,
                truncateds=truncateds,
                len_lookback_buffer=0,  # no lookback; all data is actually "in" episode
            )
            # Episode has not been numpy'ized yet.
            assert not episode.is_numpy
            # We are still operating on lists.
            assert (
                episode.get_observations(
                    indices=[1],
                    agent_ids="agent_1",
                ) == {"agent_1": [1]}
            )
            # Numpy'ized the episode.
            episode.to_numpy()
            assert episode.is_numpy
            # Everything is now numpy arrays (with 0-axis of size
            # B=[len of requested slice]).
            assert (
                isinstance(episode.get_observations(
                    indices=[1],
                    agent_ids="agent_1",
                )["agent_1"], np.ndarray)
            )
        Returns:
             This `MultiAgentEpisode` object with the converted numpy data.
        """
        for agent_id, agent_eps in self.agent_episodes.copy().items():
            agent_eps.to_numpy()
        return self 
[docs]
    def concat_episode(self, other: "MultiAgentEpisode") -> None:
        """Adds the given `other` MultiAgentEpisode to the right side of self.
        In order for this to work, both chunks (`self` and `other`) must fit
        together. This is checked by the IDs (must be identical), the time step counters
        (`self.env_t` must be the same as `episode_chunk.env_t_started`), as well as the
        observations/infos of the individual agents at the concatenation boundaries.
        Also, `self.is_done` must not be True, meaning `self.is_terminated` and
        `self.is_truncated` are both False.
        Args:
            other: The other `MultiAgentEpisode` to be concatenated to this one.
        Returns: A `MultiAgentEpisode` instance containing the concatenated data
            from both episodes (`self` and `other`).
        """
        # Make sure the IDs match.
        assert other.id_ == self.id_
        # NOTE (sven): This is what we agreed on. As the replay buffers must be
        # able to concatenate.
        assert not self.is_done
        # Make sure the timesteps match.
        assert self.env_t == other.env_t_started
        # Validate `other`.
        other.validate()
        # Concatenate the individual SingleAgentEpisodes from both chunks.
        all_agent_ids = set(self.agent_ids) | set(other.agent_ids)
        for agent_id in all_agent_ids:
            sa_episode = self.agent_episodes.get(agent_id)
            # If agent is only in the new episode chunk -> Store all the data of `other`
            # wrt agent in `self`.
            if sa_episode is None:
                self.agent_episodes[agent_id] = other.agent_episodes[agent_id]
                self.env_t_to_agent_t[agent_id] = other.env_t_to_agent_t[agent_id]
                self.agent_t_started[agent_id] = other.agent_t_started[agent_id]
                self._copy_hanging(agent_id, other)
            # If the agent was done in `self`, ignore and continue. There should not be
            # any data of that agent in `other`.
            elif sa_episode.is_done:
                continue
            # If the agent has data in both chunks, concatenate on the single-agent
            # level, thereby making sure the hanging values (begin and end) match.
            elif agent_id in other.agent_episodes:
                # If `other` has hanging (end) values -> Add these to `self`'s agent
                # SingleAgentEpisode (as a new timestep) and only then concatenate.
                # Otherwise, the concatentaion would fail b/c of missing data.
                if agent_id in self._hanging_actions_end:
                    assert agent_id in self._hanging_extra_model_outputs_end
                    sa_episode.add_env_step(
                        observation=other.agent_episodes[agent_id].get_observations(0),
                        infos=other.agent_episodes[agent_id].get_infos(0),
                        action=self._hanging_actions_end[agent_id],
                        reward=(
                            self._hanging_rewards_end[agent_id]
                            + other._hanging_rewards_begin[agent_id]
                        ),
                        extra_model_outputs=(
                            self._hanging_extra_model_outputs_end[agent_id]
                        ),
                    )
                sa_episode.concat_episode(other.agent_episodes[agent_id])
                # Override `self`'s hanging (end) values with `other`'s hanging (end).
                if agent_id in other._hanging_actions_end:
                    self._hanging_actions_end[agent_id] = copy.deepcopy(
                        other._hanging_actions_end[agent_id]
                    )
                    self._hanging_rewards_end[agent_id] = other._hanging_rewards_end[
                        agent_id
                    ]
                    self._hanging_extra_model_outputs_end[agent_id] = copy.deepcopy(
                        other._hanging_extra_model_outputs_end[agent_id]
                    )
                # Concatenate the env- to agent-timestep mappings.
                j = self.env_t
                for i, val in enumerate(other.env_t_to_agent_t[agent_id][1:]):
                    if val == self.SKIP_ENV_TS_TAG:
                        self.env_t_to_agent_t[agent_id].append(self.SKIP_ENV_TS_TAG)
                    else:
                        self.env_t_to_agent_t[agent_id].append(i + 1 + j)
            # Otherwise, the agent is only in `self` and not done. All data is stored
            # already -> skip
            # else: pass
        # Update all timestep counters.
        self.env_t = other.env_t
        # Check, if the episode is terminated or truncated.
        if other.is_terminated:
            self.is_terminated = True
        elif other.is_truncated:
            self.is_truncated = True
        # Erase all temporary timestep data caches.
        self._temporary_timestep_data.clear()
        # Validate.
        self.validate() 
[docs]
    def cut(self, len_lookback_buffer: int = 0) -> "MultiAgentEpisode":
        """Returns a successor episode chunk (of len=0) continuing from this Episode.
        The successor will have the same ID as `self`.
        If no lookback buffer is requested (len_lookback_buffer=0), the successor's
        observations will be the last observation(s) of `self` and its length will
        therefore be 0 (no further steps taken yet). If `len_lookback_buffer` > 0,
        the returned successor will have `len_lookback_buffer` observations (and
        actions, rewards, etc..) taken from the right side (end) of `self`. For example
        if `len_lookback_buffer=2`, the returned successor's lookback buffer actions
        will be identical to teh results of `self.get_actions([-2, -1])`.
        This method is useful if you would like to discontinue building an episode
        chunk (b/c you have to return it from somewhere), but would like to have a new
        episode instance to continue building the actual gym.Env episode at a later
        time. Vie the `len_lookback_buffer` argument, the continuing chunk (successor)
        will still be able to "look back" into this predecessor episode's data (at
        least to some extend, depending on the value of `len_lookback_buffer`).
        Args:
            len_lookback_buffer: The number of environment timesteps to take along into
                the new chunk as "lookback buffer". A lookback buffer is additional data
                on the left side of the actual episode data for visibility purposes
                (but without actually being part of the new chunk). For example, if
                `self` ends in actions: agent_1=5,6,7 and agent_2=6,7, and we call
                `self.cut(len_lookback_buffer=2)`, the returned chunk will have
                actions 6 and 7 for both agents already in it, but still
                `t_started`==t==8 (not 7!) and a length of 0. If there is not enough
                data in `self` yet to fulfil the `len_lookback_buffer` request, the
                value of `len_lookback_buffer` is automatically adjusted (lowered).
        Returns:
            The successor Episode chunk of this one with the same ID and state and the
            only observation being the last observation in self.
        """
        assert len_lookback_buffer >= 0
        if self.is_done:
            raise RuntimeError(
                "Can't call `MultiAgentEpisode.cut()` when the episode is already done!"
            )
        # If there is hanging data (e.g. actions) in the agents' caches, we might have
        # to re-adjust the lookback len further into the past to make sure that these
        # agents have at least one observation to look back to. Otherwise, the timestep
        # that got cut into will be "lost" for learning from it.
        orig_len_lb = len_lookback_buffer
        for agent_id, agent_actions in self._hanging_actions_end.items():
            assert self.env_t_to_agent_t[agent_id].get(-1) == self.SKIP_ENV_TS_TAG
            for i in range(orig_len_lb, len(self.env_t_to_agent_t[agent_id].data) + 1):
                if self.env_t_to_agent_t[agent_id].get(-i) != self.SKIP_ENV_TS_TAG:
                    len_lookback_buffer = max(len_lookback_buffer, i - 1)
                    break
        # Initialize this episode chunk with the most recent observations
        # and infos (even if lookback is zero). Similar to an initial `env.reset()`
        indices_obs_and_infos = slice(-len_lookback_buffer - 1, None)
        indices_rest = (
            slice(-len_lookback_buffer, None)
            if len_lookback_buffer > 0
            else slice(None, 0)  # -> empty slice
        )
        observations = self.get_observations(
            indices=indices_obs_and_infos, return_list=True
        )
        infos = self.get_infos(indices=indices_obs_and_infos, return_list=True)
        actions = self.get_actions(indices=indices_rest, return_list=True)
        rewards = self.get_rewards(indices=indices_rest, return_list=True)
        extra_model_outputs = self.get_extra_model_outputs(
            key=None,  # all keys
            indices=indices_rest,
            return_list=True,
        )
        successor = MultiAgentEpisode(
            # Same ID.
            id_=self.id_,
            observations=observations,
            observation_space=self.observation_space,
            infos=infos,
            actions=actions,
            action_space=self.action_space,
            rewards=rewards,
            # List of MADicts, mapping agent IDs to their respective extra model output
            # dicts.
            extra_model_outputs=extra_model_outputs,
            terminateds=self.get_terminateds(),
            truncateds=self.get_truncateds(),
            # Continue with `self`'s current timesteps.
            env_t_started=self.env_t,
            agent_t_started={
                aid: self.agent_episodes[aid].t
                for aid in self.agent_ids
                if not self.agent_episodes[aid].is_done
            },
            # Same AgentIDs and SingleAgentEpisode IDs.
            agent_episode_ids=self.agent_episode_ids,
            agent_module_ids={
                aid: self.agent_episodes[aid].module_id for aid in self.agent_ids
            },
            agent_to_module_mapping_fn=self.agent_to_module_mapping_fn,
            # All data we provided to the c'tor goes into the lookback buffer.
            len_lookback_buffer="auto",
        )
        # Copy over the hanging (end) values into the hanging (begin) chaches of the
        # successor.
        successor._hanging_rewards_begin = self._hanging_rewards_end.copy()
        return successor 
    @property
    def agent_ids(self) -> Set[AgentID]:
        """Returns the agent ids."""
        return set(self.agent_episodes.keys())
    @property
    def agent_episode_ids(self) -> MultiAgentDict:
        """Returns ids from each agent's `SingleAgentEpisode`."""
        return {
            agent_id: agent_eps.id_
            for agent_id, agent_eps in self.agent_episodes.items()
        }
[docs]
    def module_for(self, agent_id: AgentID) -> Optional[ModuleID]:
        """Returns the ModuleID for a given AgentID.
        Forces the agent-to-module mapping to be performed (via
        `self.agent_to_module_mapping_fn`), if this has not been done yet.
        Note that all such mappings are stored in the `self._agent_to_module_mapping`
        property.
        Args:
            agent_id: The AgentID to get a mapped ModuleID for.
        Returns:
            The ModuleID mapped to from the given `agent_id`.
        """
        if agent_id not in self._agent_to_module_mapping:
            module_id = self._agent_to_module_mapping[
                agent_id
            ] = self.agent_to_module_mapping_fn(agent_id, self)
            return module_id
        else:
            return self._agent_to_module_mapping[agent_id] 
[docs]
    def get_observations(
        self,
        indices: Optional[Union[int, List[int], slice]] = None,
        agent_ids: Optional[Union[Collection[AgentID], AgentID]] = None,
        *,
        env_steps: bool = True,
        # global_indices: bool = False,
        neg_index_as_lookback: bool = False,
        fill: Optional[Any] = None,
        one_hot_discrete: bool = False,
        return_list: bool = False,
    ) -> Union[MultiAgentDict, List[MultiAgentDict]]:
        """Returns agents' observations or batched ranges thereof from this episode.
        Args:
            indices: A single int is interpreted as an index, from which to return the
                individual observation stored at this index.
                A list of ints is interpreted as a list of indices from which to gather
                individual observations in a batch of size len(indices).
                A slice object is interpreted as a range of observations to be returned.
                Thereby, negative indices by default are interpreted as "before the end"
                unless the `neg_index_as_lookback=True` option is used, in which case
                negative indices are interpreted as "before ts=0", meaning going back
                into the lookback buffer.
                If None, will return all observations (from ts=0 to the end).
            agent_ids: An optional collection of AgentIDs or a single AgentID to get
                observations for. If None, will return observations for all agents in
                this episode.
            env_steps: Whether `indices` should be interpreted as environment time steps
                (True) or per-agent timesteps (False).
            neg_index_as_lookback: If True, negative values in `indices` are
                interpreted as "before ts=0", meaning going back into the lookback
                buffer. For example, an episode with agent A's observations
                [4, 5, 6,  7, 8, 9], where [4, 5, 6] is the lookback buffer range
                (ts=0 item is 7), will respond to `get_observations(-1, agent_ids=[A],
                neg_index_as_lookback=True)` with {A: `6`} and to
                `get_observations(slice(-2, 1), agent_ids=[A],
                neg_index_as_lookback=True)` with {A: `[5, 6,  7]`}.
            fill: An optional value to use for filling up the returned results at
                the boundaries. This filling only happens if the requested index range's
                start/stop boundaries exceed the episode's boundaries (including the
                lookback buffer on the left side). This comes in very handy, if users
                don't want to worry about reaching such boundaries and want to zero-pad.
                For example, an episode with agent A' observations [10, 11,  12, 13, 14]
                and lookback buffer size of 2 (meaning observations `10` and `11` are
                part of the lookback buffer) will respond to
                `get_observations(slice(-7, -2), agent_ids=[A], fill=0.0)` with
                `{A: [0.0, 0.0, 10, 11, 12]}`.
            one_hot_discrete: If True, will return one-hot vectors (instead of
                int-values) for those sub-components of a (possibly complex) observation
                space that are Discrete or MultiDiscrete.  Note that if `fill=0` and the
                requested `indices` are out of the range of our data, the returned
                one-hot vectors will actually be zero-hot (all slots zero).
            return_list: Whether to return a list of multi-agent dicts (instead of
                a single multi-agent dict of lists/structs). False by default. This
                option can only be used when `env_steps` is True due to the fact the
                such a list can only be interpreted as one env step per list item
                (would not work with agent steps).
        Returns:
            A dictionary mapping agent IDs to observations (at the given
            `indices`). If `env_steps` is True, only agents that have stepped
            (were ready) at the given env step `indices` are returned (i.e. not all
            agent IDs are necessarily in the keys).
            If `return_list` is True, returns a list of MultiAgentDicts (mapping agent
            IDs to observations) instead.
        """
        return self._get(
            what="observations",
            indices=indices,
            agent_ids=agent_ids,
            env_steps=env_steps,
            neg_index_as_lookback=neg_index_as_lookback,
            fill=fill,
            one_hot_discrete=one_hot_discrete,
            return_list=return_list,
        ) 
[docs]
    def get_infos(
        self,
        indices: Optional[Union[int, List[int], slice]] = None,
        agent_ids: Optional[Union[Collection[AgentID], AgentID]] = None,
        *,
        env_steps: bool = True,
        neg_index_as_lookback: bool = False,
        fill: Optional[Any] = None,
        return_list: bool = False,
    ) -> Union[MultiAgentDict, List[MultiAgentDict]]:
        """Returns agents' info dicts or list (ranges) thereof from this episode.
        Args:
            indices: A single int is interpreted as an index, from which to return the
                individual info dict stored at this index.
                A list of ints is interpreted as a list of indices from which to gather
                individual info dicts in a list of size len(indices).
                A slice object is interpreted as a range of info dicts to be returned.
                Thereby, negative indices by default are interpreted as "before the end"
                unless the `neg_index_as_lookback=True` option is used, in which case
                negative indices are interpreted as "before ts=0", meaning going back
                into the lookback buffer.
                If None, will return all infos (from ts=0 to the end).
            agent_ids: An optional collection of AgentIDs or a single AgentID to get
                info dicts for. If None, will return info dicts for all agents in
                this episode.
            env_steps: Whether `indices` should be interpreted as environment time steps
                (True) or per-agent timesteps (False).
            neg_index_as_lookback: If True, negative values in `indices` are
                interpreted as "before ts=0", meaning going back into the lookback
                buffer. For example, an episode with agent A's info dicts
                [{"l":4}, {"l":5}, {"l":6},  {"a":7}, {"b":8}, {"c":9}], where the
                first 3 items are the lookback buffer (ts=0 item is {"a": 7}), will
                respond to `get_infos(-1, agent_ids=A, neg_index_as_lookback=True)`
                with `{A: {"l":6}}` and to
                `get_infos(slice(-2, 1), agent_ids=A, neg_index_as_lookback=True)`
                with `{A: [{"l":5}, {"l":6},  {"a":7}]}`.
            fill: An optional value to use for filling up the returned results at
                the boundaries. This filling only happens if the requested index range's
                start/stop boundaries exceed the episode's boundaries (including the
                lookback buffer on the left side). This comes in very handy, if users
                don't want to worry about reaching such boundaries and want to
                auto-fill. For example, an episode with agent A's infos being
                [{"l":10}, {"l":11},  {"a":12}, {"b":13}, {"c":14}] and lookback buffer
                size of 2 (meaning infos {"l":10}, {"l":11} are part of the lookback
                buffer) will respond to `get_infos(slice(-7, -2), agent_ids=A,
                fill={"o": 0.0})` with
                `{A: [{"o":0.0}, {"o":0.0}, {"l":10}, {"l":11}, {"a":12}]}`.
            return_list: Whether to return a list of multi-agent dicts (instead of
                a single multi-agent dict of lists/structs). False by default. This
                option can only be used when `env_steps` is True due to the fact the
                such a list can only be interpreted as one env step per list item
                (would not work with agent steps).
        Returns:
            A dictionary mapping agent IDs to observations (at the given
            `indices`). If `env_steps` is True, only agents that have stepped
            (were ready) at the given env step `indices` are returned (i.e. not all
            agent IDs are necessarily in the keys).
            If `return_list` is True, returns a list of MultiAgentDicts (mapping agent
            IDs to infos) instead.
        """
        return self._get(
            what="infos",
            indices=indices,
            agent_ids=agent_ids,
            env_steps=env_steps,
            neg_index_as_lookback=neg_index_as_lookback,
            fill=fill,
            return_list=return_list,
        ) 
[docs]
    def get_actions(
        self,
        indices: Optional[Union[int, List[int], slice]] = None,
        agent_ids: Optional[Union[Collection[AgentID], AgentID]] = None,
        *,
        env_steps: bool = True,
        neg_index_as_lookback: bool = False,
        fill: Optional[Any] = None,
        one_hot_discrete: bool = False,
        return_list: bool = False,
    ) -> Union[MultiAgentDict, List[MultiAgentDict]]:
        """Returns agents' actions or batched ranges thereof from this episode.
        Args:
            indices: A single int is interpreted as an index, from which to return the
                individual actions stored at this index.
                A list of ints is interpreted as a list of indices from which to gather
                individual actions in a batch of size len(indices).
                A slice object is interpreted as a range of actions to be returned.
                Thereby, negative indices by default are interpreted as "before the end"
                unless the `neg_index_as_lookback=True` option is used, in which case
                negative indices are interpreted as "before ts=0", meaning going back
                into the lookback buffer.
                If None, will return all actions (from ts=0 to the end).
            agent_ids: An optional collection of AgentIDs or a single AgentID to get
                actions for. If None, will return actions for all agents in
                this episode.
            env_steps: Whether `indices` should be interpreted as environment time steps
                (True) or per-agent timesteps (False).
            neg_index_as_lookback: If True, negative values in `indices` are
                interpreted as "before ts=0", meaning going back into the lookback
                buffer. For example, an episode with agent A's actions
                [4, 5, 6,  7, 8, 9], where [4, 5, 6] is the lookback buffer range
                (ts=0 item is 7), will respond to `get_actions(-1, agent_ids=[A],
                neg_index_as_lookback=True)` with {A: `6`} and to
                `get_actions(slice(-2, 1), agent_ids=[A],
                neg_index_as_lookback=True)` with {A: `[5, 6,  7]`}.
            fill: An optional value to use for filling up the returned results at
                the boundaries. This filling only happens if the requested index range's
                start/stop boundaries exceed the episode's boundaries (including the
                lookback buffer on the left side). This comes in very handy, if users
                don't want to worry about reaching such boundaries and want to zero-pad.
                For example, an episode with agent A' actions [10, 11,  12, 13, 14]
                and lookback buffer size of 2 (meaning actions `10` and `11` are
                part of the lookback buffer) will respond to
                `get_actions(slice(-7, -2), agent_ids=[A], fill=0.0)` with
                `{A: [0.0, 0.0, 10, 11, 12]}`.
            one_hot_discrete: If True, will return one-hot vectors (instead of
                int-values) for those sub-components of a (possibly complex) observation
                space that are Discrete or MultiDiscrete.  Note that if `fill=0` and the
                requested `indices` are out of the range of our data, the returned
                one-hot vectors will actually be zero-hot (all slots zero).
            return_list: Whether to return a list of multi-agent dicts (instead of
                a single multi-agent dict of lists/structs). False by default. This
                option can only be used when `env_steps` is True due to the fact the
                such a list can only be interpreted as one env step per list item
                (would not work with agent steps).
        Returns:
            A dictionary mapping agent IDs to actions (at the given
            `indices`). If `env_steps` is True, only agents that have stepped
            (were ready) at the given env step `indices` are returned (i.e. not all
            agent IDs are necessarily in the keys).
            If `return_list` is True, returns a list of MultiAgentDicts (mapping agent
            IDs to actions) instead.
        """
        return self._get(
            what="actions",
            indices=indices,
            agent_ids=agent_ids,
            env_steps=env_steps,
            neg_index_as_lookback=neg_index_as_lookback,
            fill=fill,
            one_hot_discrete=one_hot_discrete,
            return_list=return_list,
        ) 
[docs]
    def get_rewards(
        self,
        indices: Optional[Union[int, List[int], slice]] = None,
        agent_ids: Optional[Union[Collection[AgentID], AgentID]] = None,
        *,
        env_steps: bool = True,
        neg_index_as_lookback: bool = False,
        fill: Optional[float] = None,
        return_list: bool = False,
    ) -> Union[MultiAgentDict, List[MultiAgentDict]]:
        """Returns agents' rewards or batched ranges thereof from this episode.
        Args:
            indices: A single int is interpreted as an index, from which to return the
                individual rewards stored at this index.
                A list of ints is interpreted as a list of indices from which to gather
                individual rewards in a batch of size len(indices).
                A slice object is interpreted as a range of rewards to be returned.
                Thereby, negative indices by default are interpreted as "before the end"
                unless the `neg_index_as_lookback=True` option is used, in which case
                negative indices are interpreted as "before ts=0", meaning going back
                into the lookback buffer.
                If None, will return all rewards (from ts=0 to the end).
            agent_ids: An optional collection of AgentIDs or a single AgentID to get
                rewards for. If None, will return rewards for all agents in
                this episode.
            env_steps: Whether `indices` should be interpreted as environment time steps
                (True) or per-agent timesteps (False).
            neg_index_as_lookback: If True, negative values in `indices` are
                interpreted as "before ts=0", meaning going back into the lookback
                buffer. For example, an episode with agent A's rewards
                [4, 5, 6,  7, 8, 9], where [4, 5, 6] is the lookback buffer range
                (ts=0 item is 7), will respond to `get_rewards(-1, agent_ids=[A],
                neg_index_as_lookback=True)` with {A: `6`} and to
                `get_rewards(slice(-2, 1), agent_ids=[A],
                neg_index_as_lookback=True)` with {A: `[5, 6,  7]`}.
            fill: An optional float value to use for filling up the returned results at
                the boundaries. This filling only happens if the requested index range's
                start/stop boundaries exceed the episode's boundaries (including the
                lookback buffer on the left side). This comes in very handy, if users
                don't want to worry about reaching such boundaries and want to zero-pad.
                For example, an episode with agent A' rewards [10, 11,  12, 13, 14]
                and lookback buffer size of 2 (meaning rewards `10` and `11` are
                part of the lookback buffer) will respond to
                `get_rewards(slice(-7, -2), agent_ids=[A], fill=0.0)` with
                `{A: [0.0, 0.0, 10, 11, 12]}`.
            return_list: Whether to return a list of multi-agent dicts (instead of
                a single multi-agent dict of lists/structs). False by default. This
                option can only be used when `env_steps` is True due to the fact the
                such a list can only be interpreted as one env step per list item
                (would not work with agent steps).
        Returns:
            A dictionary mapping agent IDs to rewards (at the given
            `indices`). If `env_steps` is True, only agents that have stepped
            (were ready) at the given env step `indices` are returned (i.e. not all
            agent IDs are necessarily in the keys).
            If `return_list` is True, returns a list of MultiAgentDicts (mapping agent
            IDs to rewards) instead.
        """
        return self._get(
            what="rewards",
            indices=indices,
            agent_ids=agent_ids,
            env_steps=env_steps,
            neg_index_as_lookback=neg_index_as_lookback,
            fill=fill,
            return_list=return_list,
        ) 
[docs]
    def get_terminateds(self) -> MultiAgentDict:
        """Gets the terminateds at given indices."""
        terminateds = {
            agent_id: self.agent_episodes[agent_id].is_terminated
            for agent_id in self.agent_ids
        }
        terminateds.update({"__all__": self.is_terminated})
        return terminateds 
[docs]
    def get_truncateds(self) -> MultiAgentDict:
        truncateds = {
            agent_id: self.agent_episodes[agent_id].is_truncated
            for agent_id in self.agent_ids
        }
        truncateds.update({"__all__": self.is_terminated})
        return truncateds 
[docs]
    def add_temporary_timestep_data(self, key: str, data: Any) -> None:
        """Temporarily adds (until `to_numpy()` called) per-timestep data to self.
        The given `data` is appended to a list (`self._temporary_timestep_data`), which
        is cleared upon calling `self.to_numpy()`. To get the thus-far accumulated
        temporary timestep data for a certain key, use the `get_temporary_timestep_data`
        API.
        Note that the size of the per timestep list is NOT checked or validated against
        the other, non-temporary data in this episode (like observations).
        Args:
            key: The key under which to find the list to append `data` to. If `data` is
                the first data to be added for this key, start a new list.
            data: The data item (representing a single timestep) to be stored.
        """
        if self.is_numpy:
            raise ValueError(
                "Cannot use the `add_temporary_timestep_data` API on an already "
                f"numpy'ized {type(self).__name__}!"
            )
        self._temporary_timestep_data[key].append(data) 
[docs]
    def get_temporary_timestep_data(self, key: str) -> List[Any]:
        """Returns all temporarily stored data items (list) under the given key.
        Note that all temporary timestep data is erased/cleared when calling
        `self.to_numpy()`.
        Returns:
            The current list storing temporary timestep data under `key`.
        """
        if self.is_numpy:
            raise ValueError(
                "Cannot use the `get_temporary_timestep_data` API on an already "
                f"numpy'ized {type(self).__name__}! All temporary data has been erased "
                f"upon `{type(self).__name__}.to_numpy()`."
            )
        try:
            return self._temporary_timestep_data[key]
        except KeyError:
            raise KeyError(f"Key {key} not found in temporary timestep data!") 
[docs]
    def slice(
        self,
        slice_: slice,
        *,
        len_lookback_buffer: Optional[int] = None,
    ) -> "MultiAgentEpisode":
        """Returns a slice of this episode with the given slice object.
        Works analogous to
        :py:meth:`~ray.rllib.env.single_agent_episode.SingleAgentEpisode.slice`
        However, the important differences are:
        - `slice_` is provided in (global) env steps, not agent steps.
        - In case `slice_` ends - for a certain agent - in an env step, where that
        particular agent does not have an observation, the previous observation will
        be included, but the next action and sum of rewards until this point will
        be stored in the agent's hanging values caches for the returned
        MultiAgentEpisode slice.
        .. testcode::
            from ray.rllib.env.multi_agent_episode import MultiAgentEpisode
            from ray.rllib.utils.test_utils import check
            # Generate a simple multi-agent episode.
            observations = [
                {"a0": 0, "a1": 0},  # 0
                {         "a1": 1},  # 1
                {         "a1": 2},  # 2
                {"a0": 3, "a1": 3},  # 3
                {"a0": 4},           # 4
            ]
            # Actions are the same as observations (except for last obs, which doesn't
            # have an action).
            actions = observations[:-1]
            # Make up a reward for each action.
            rewards = [
                {aid: r / 10 + 0.1 for aid, r in o.items()}
                for o in observations
            ]
            episode = MultiAgentEpisode(
                observations=observations,
                actions=actions,
                rewards=rewards,
                len_lookback_buffer=0,
            )
            # Slice the episode and check results.
            slice = episode[1:3]
            a0 = slice.agent_episodes["a0"]
            a1 = slice.agent_episodes["a1"]
            check((a0.observations, a1.observations), ([3], [1, 2, 3]))
            check((a0.actions, a1.actions), ([], [1, 2]))
            check((a0.rewards, a1.rewards), ([], [0.2, 0.3]))
            check((a0.is_done, a1.is_done), (False, False))
            # If a slice ends in a "gap" for an agent, expect actions and rewards to be
            # cached for this agent.
            slice = episode[:2]
            a0 = slice.agent_episodes["a0"]
            check(a0.observations, [0])
            check(a0.actions, [])
            check(a0.rewards, [])
            check(slice._hanging_actions_end["a0"], 0)
            check(slice._hanging_rewards_end["a0"], 0.1)
        Args:
            slice_: The slice object to use for slicing. This should exclude the
                lookback buffer, which will be prepended automatically to the returned
                slice.
            len_lookback_buffer: If not None, forces the returned slice to try to have
                this number of timesteps in its lookback buffer (if available). If None
                (default), tries to make the returned slice's lookback as large as the
                current lookback buffer of this episode (`self`).
        Returns:
            The new MultiAgentEpisode representing the requested slice.
        """
        if slice_.step not in [1, None]:
            raise NotImplementedError(
                "Slicing MultiAgentEnv with a step other than 1 (you used"
                f" {slice_.step}) is not supported!"
            )
        # Translate `slice_` into one that only contains 0-or-positive ints and will
        # NOT contain any None.
        start = slice_.start
        stop = slice_.stop
        # Start is None -> 0.
        if start is None:
            start = 0
        # Start is negative -> Interpret index as counting "from end".
        elif start < 0:
            start = max(len(self) + start, 0)
        # Start is larger than len(self) -> Clip to len(self).
        elif start > len(self):
            start = len(self)
        # Stop is None -> Set stop to our len (one ts past last valid index).
        if stop is None:
            stop = len(self)
        # Stop is negative -> Interpret index as counting "from end".
        elif stop < 0:
            stop = max(len(self) + stop, 0)
        # Stop is larger than len(self) -> Clip to len(self).
        elif stop > len(self):
            stop = len(self)
        ref_lookback = None
        try:
            for aid, sa_episode in self.agent_episodes.items():
                if ref_lookback is None:
                    ref_lookback = sa_episode.observations.lookback
                assert sa_episode.observations.lookback == ref_lookback
                assert sa_episode.actions.lookback == ref_lookback
                assert sa_episode.rewards.lookback == ref_lookback
                assert all(
                    ilb.lookback == ref_lookback
                    for ilb in sa_episode.extra_model_outputs.values()
                )
        except AssertionError:
            raise ValueError(
                "Can only slice a MultiAgentEpisode if all lookback buffers in this "
                "episode have the exact same size!"
            )
        # Determine terminateds/truncateds and when (in agent timesteps) the
        # single-agent episode slices start.
        terminateds = {}
        truncateds = {}
        agent_t_started = {}
        for aid, sa_episode in self.agent_episodes.items():
            mapping = self.env_t_to_agent_t[aid]
            # If the (agent) timestep directly at the slice stop boundary is equal to
            # the length of the single-agent episode of this agent -> Use the
            # single-agent episode's terminated/truncated flags.
            # If `stop` is already beyond this agent's single-agent episode, then we
            # don't have to keep track of this: The MultiAgentEpisode initializer will
            # automatically determine that this agent must be done (b/c it has no action
            # following its final observation).
            if (
                stop < len(mapping)
                and mapping[stop] != self.SKIP_ENV_TS_TAG
                and len(sa_episode) == mapping[stop]
            ):
                terminateds[aid] = sa_episode.is_terminated
                truncateds[aid] = sa_episode.is_truncated
            # Determine this agent's t_started.
            if start < len(mapping):
                for i in range(start, len(mapping)):
                    if mapping[i] != self.SKIP_ENV_TS_TAG:
                        agent_t_started[aid] = sa_episode.t_started + mapping[i]
                        break
        terminateds["__all__"] = all(
            terminateds.get(aid) for aid in self.agent_episodes
        )
        truncateds["__all__"] = all(truncateds.get(aid) for aid in self.agent_episodes)
        # Determine all other slice contents.
        _lb = len_lookback_buffer if len_lookback_buffer is not None else ref_lookback
        if start - _lb < 0 and ref_lookback < (_lb - start):
            _lb = ref_lookback + start
        observations = self.get_observations(
            slice(start - _lb, stop + 1),
            neg_index_as_lookback=True,
            return_list=True,
        )
        actions = self.get_actions(
            slice(start - _lb, stop),
            neg_index_as_lookback=True,
            return_list=True,
        )
        rewards = self.get_rewards(
            slice(start - _lb, stop),
            neg_index_as_lookback=True,
            return_list=True,
        )
        extra_model_outputs = self.get_extra_model_outputs(
            indices=slice(start - _lb, stop),
            neg_index_as_lookback=True,
            return_list=True,
        )
        # Create the actual slice to be returned.
        ma_episode = MultiAgentEpisode(
            id_=self.id_,
            # In the following, offset `start`s automatically by lookbacks.
            observations=observations,
            observation_space=self.observation_space,
            actions=actions,
            action_space=self.action_space,
            rewards=rewards,
            extra_model_outputs=extra_model_outputs,
            terminateds=terminateds,
            truncateds=truncateds,
            len_lookback_buffer=_lb,
            env_t_started=self.env_t_started + start,
            agent_episode_ids={
                aid: eid.id_ for aid, eid in self.agent_episodes.items()
            },
            agent_t_started=agent_t_started,
            agent_module_ids=self._agent_to_module_mapping,
            agent_to_module_mapping_fn=self.agent_to_module_mapping_fn,
        )
        # Numpy'ize slice if `self` is also finalized.
        if self.is_numpy:
            ma_episode.to_numpy()
        return ma_episode 
[docs]
    def __len__(self):
        """Returns the length of an `MultiAgentEpisode`.
        Note that the length of an episode is defined by the difference
        between its actual timestep and the starting point.
        Returns: An integer defining the length of the episode or an
            error if the episode has not yet started.
        """
        assert (
            sum(len(agent_map) for agent_map in self.env_t_to_agent_t.values()) > 0
        ), (
            "ERROR: Cannot determine length of episode that hasn't started, yet!"
            "Call `MultiAgentEpisode.add_env_reset(observations=)` "
            "first (after which `len(MultiAgentEpisode)` will be 0)."
        )
        return self.env_t - self.env_t_started 
    def __repr__(self):
        sa_eps_returns = {
            aid: sa_eps.get_return() for aid, sa_eps in self.agent_episodes.items()
        }
        return (
            f"MAEps(len={len(self)} done={self.is_done} "
            f"Rs={sa_eps_returns} id_={self.id_})"
        )
[docs]
    def print(self) -> None:
        """Prints this MultiAgentEpisode as a table of observations for the agents."""
        # Find the maximum timestep across all agents to determine the grid width.
        max_ts = max(ts.len_incl_lookback() for ts in self.env_t_to_agent_t.values())
        lookback = next(iter(self.env_t_to_agent_t.values())).lookback
        longest_agent = max(len(aid) for aid in self.agent_ids)
        # Construct the header.
        header = (
            "ts"
            + (" " * longest_agent)
            + "   ".join(str(i) for i in range(-lookback, max_ts - lookback))
            + "\n"
        )
        # Construct each agent's row.
        rows = []
        for agent, inf_buffer in self.env_t_to_agent_t.items():
            row = f"{agent}  " + (" " * (longest_agent - len(agent)))
            for t in inf_buffer.data:
                # Two spaces for alignment.
                if t == "S":
                    row += "    "
                # Mark the step with an x.
                else:
                    row += " x  "
            # Remove trailing space for alignment.
            rows.append(row.rstrip())
        # Join all components into a final string
        print(header + "\n".join(rows)) 
[docs]
    def get_state(self) -> Dict[str, Any]:
        """Returns the state of a multi-agent episode.
        Note that from an episode's state the episode itself can
        be recreated.
        Returns: A dicitonary containing pickable data for a
            `MultiAgentEpisode`.
        """
        return {
            "id_": self.id_,
            "agent_to_module_mapping_fn": self.agent_to_module_mapping_fn,
            "_agent_to_module_mapping": self._agent_to_module_mapping,
            "observation_space": self.observation_space,
            "action_space": self.action_space,
            "env_t_started": self.env_t_started,
            "env_t": self.env_t,
            "agent_t_started": self.agent_t_started,
            # TODO (simon): Check, if we can store the `InfiniteLookbackBuffer`
            "env_t_to_agent_t": self.env_t_to_agent_t,
            "_hanging_actions_end": self._hanging_actions_end,
            "_hanging_extra_model_outputs_end": self._hanging_extra_model_outputs_end,
            "_hanging_rewards_end": self._hanging_rewards_end,
            "_hanging_rewards_begin": self._hanging_rewards_begin,
            "is_terminated": self.is_terminated,
            "is_truncated": self.is_truncated,
            "agent_episodes": list(
                {
                    agent_id: agent_eps.get_state()
                    for agent_id, agent_eps in self.agent_episodes.items()
                }.items()
            ),
            "_start_time": self._start_time,
            "_last_step_time": self._last_step_time,
        } 
[docs]
    @staticmethod
    def from_state(state: Dict[str, Any]) -> "MultiAgentEpisode":
        """Creates a multi-agent episode from a state dictionary.
        See `MultiAgentEpisode.get_state()` for creating a state for
        a `MultiAgentEpisode` pickable state. For recreating a
        `MultiAgentEpisode` from a state, this state has to be complete,
        i.e. all data must have been stored in the state.
        Args:
            state: A dict containing all data required to recreate a MultiAgentEpisode`.
                See `MultiAgentEpisode.get_state()`.
        Returns:
            A `MultiAgentEpisode` instance created from the state data.
        """
        # Create an empty `MultiAgentEpisode` instance.
        episode = MultiAgentEpisode(id_=state["id_"])
        # Fill the instance with the state data.
        episode.agent_to_module_mapping_fn = state["agent_to_module_mapping_fn"]
        episode._agent_to_module_mapping = state["_agent_to_module_mapping"]
        episode.observation_space = state["observation_space"]
        episode.action_space = state["action_space"]
        episode.env_t_started = state["env_t_started"]
        episode.env_t = state["env_t"]
        episode.agent_t_started = state["agent_t_started"]
        episode.env_t_to_agent_t = state["env_t_to_agent_t"]
        episode._hanging_actions_end = state["_hanging_actions_end"]
        episode._hanging_extra_model_outputs_end = state[
            "_hanging_extra_model_outputs_end"
        ]
        episode._hanging_rewards_end = state["_hanging_rewards_end"]
        episode._hanging_rewards_begin = state["_hanging_rewards_begin"]
        episode.is_terminated = state["is_terminated"]
        episode.is_truncated = state["is_truncated"]
        episode.agent_episodes = {
            agent_id: SingleAgentEpisode.from_state(agent_state)
            for agent_id, agent_state in state["agent_episodes"]
        }
        episode._start_time = state["_start_time"]
        episode._last_step_time = state["_last_step_time"]
        # Validate the episode.
        episode.validate()
        return episode 
[docs]
    def get_sample_batch(self) -> MultiAgentBatch:
        """Converts this `MultiAgentEpisode` into a `MultiAgentBatch`.
        Each `SingleAgentEpisode` instances in `MultiAgentEpisode.agent_epiosdes`
        will be converted into a `SampleBatch` and the environment timestep will be
        passed as the returned MultiAgentBatch's `env_steps`.
        Returns:
            A MultiAgentBatch containing all of this episode's data.
        """
        # TODO (simon): Check, if timesteps should be converted into global
        # timesteps instead of agent steps.
        # Note, only agents that have stepped are included into the batch.
        return MultiAgentBatch(
            policy_batches={
                agent_id: agent_eps.get_sample_batch()
                for agent_id, agent_eps in self.agent_episodes.items()
                if agent_eps.t - agent_eps.t_started > 0
            },
            env_steps=self.env_t - self.env_t_started,
        ) 
[docs]
    def get_return(
        self,
        include_hanging_rewards: bool = False,
    ) -> float:
        """Returns all-agent return.
        Args:
            include_hanging_rewards: Whether we should also consider
                hanging rewards wehn calculating the overall return. Agents might
                have received partial rewards, i.e. rewards without an
                observation. These are stored in the "hanging" caches (begin and end)
                for each agent and added up until the next observation is received by
                that agent.
        Returns:
            The sum of all single-agents' returns (maybe including the hanging
            rewards per agent).
        """
        env_return = sum(
            agent_eps.get_return() for agent_eps in self.agent_episodes.values()
        )
        if include_hanging_rewards:
            for hanging_r in self._hanging_rewards_begin.values():
                env_return += hanging_r
            for hanging_r in self._hanging_rewards_end.values():
                env_return += hanging_r
        return env_return 
[docs]
    def get_agents_to_act(self) -> Set[AgentID]:
        """Returns a set of agent IDs required to send an action to `env.step()` next.
        Those are generally the agents that received an observation in the most recent
        `env.step()` call.
        Returns:
            A set of AgentIDs that are supposed to send actions to the next `env.step()`
            call.
        """
        return {
            aid
            for aid in self.get_observations(-1).keys()
            if not self.agent_episodes[aid].is_done
        } 
[docs]
    def get_agents_that_stepped(self) -> Set[AgentID]:
        """Returns a set of agent IDs of those agents that just finished stepping.
        These are all the agents that have an observation logged at the last env
        timestep, which may include agents, whose single agent episode just terminated
        or truncated.
        Returns:
            A set of AgentIDs of those agents that just finished stepping (that have a
            most recent observation on the env timestep scale), regardless of whether
            their single agent episodes are done or not.
        """
        return set(self.get_observations(-1).keys()) 
[docs]
    def get_duration_s(self) -> float:
        """Returns the duration of this Episode (chunk) in seconds."""
        if self._last_step_time is None:
            return 0.0
        return self._last_step_time - self._start_time 
[docs]
    def env_steps(self) -> int:
        """Returns the number of environment steps.
        Note, this episode instance could be a chunk of an actual episode.
        Returns:
            An integer that counts the number of environment steps this episode instance
            has seen.
        """
        return len(self) 
[docs]
    def agent_steps(self) -> int:
        """Number of agent steps.
        Note, there are >= 1 agent steps per environment step.
        Returns:
            An integer counting the number of agent steps executed during the time this
            episode instance records.
        """
        return sum(len(eps) for eps in self.agent_episodes.values()) 
    def __getitem__(self, item: slice) -> "MultiAgentEpisode":
        """Enable squared bracket indexing- and slicing syntax, e.g. episode[-4:]."""
        if isinstance(item, slice):
            return self.slice(slice_=item)
        else:
            raise NotImplementedError(
                f"MultiAgentEpisode does not support getting item '{item}'! "
                "Only slice objects allowed with the syntax: `episode[a:b]`."
            )
    def _init_single_agent_episodes(
        self,
        *,
        agent_module_ids: Optional[Dict[AgentID, ModuleID]] = None,
        agent_episode_ids: Optional[Dict[AgentID, str]] = None,
        observations: Optional[List[MultiAgentDict]] = None,
        actions: Optional[List[MultiAgentDict]] = None,
        rewards: Optional[List[MultiAgentDict]] = None,
        infos: Optional[List[MultiAgentDict]] = None,
        terminateds: Union[MultiAgentDict, bool] = False,
        truncateds: Union[MultiAgentDict, bool] = False,
        extra_model_outputs: Optional[List[MultiAgentDict]] = None,
    ):
        if observations is None:
            return
        if actions is None:
            assert not rewards
            assert not extra_model_outputs
            actions = []
            rewards = []
            extra_model_outputs = []
        # Infos and `extra_model_outputs` are allowed to be None -> Fill them with
        # proper dummy values, if so.
        if infos is None:
            infos = [{} for _ in range(len(observations))]
        if extra_model_outputs is None:
            extra_model_outputs = [{} for _ in range(len(actions))]
        observations_per_agent = defaultdict(list)
        infos_per_agent = defaultdict(list)
        actions_per_agent = defaultdict(list)
        rewards_per_agent = defaultdict(list)
        extra_model_outputs_per_agent = defaultdict(list)
        done_per_agent = defaultdict(bool)
        len_lookback_buffer_per_agent = defaultdict(lambda: self._len_lookback_buffers)
        all_agent_ids = set(
            agent_episode_ids.keys() if agent_episode_ids is not None else []
        )
        agent_module_ids = agent_module_ids or {}
        # Step through all observations and interpret these as the (global) env steps.
        for data_idx, (obs, inf) in enumerate(zip(observations, infos)):
            # If we do have actions/extra outs/rewards for this timestep, use the data.
            # It may be that these lists have the same length as the observations list,
            # in which case the data will be cached (agent did step/send an action,
            # but the step has not been concluded yet by the env).
            act = actions[data_idx] if len(actions) > data_idx else {}
            extra_outs = (
                extra_model_outputs[data_idx]
                if len(extra_model_outputs) > data_idx
                else {}
            )
            rew = rewards[data_idx] if len(rewards) > data_idx else {}
            for agent_id, agent_obs in obs.items():
                all_agent_ids.add(agent_id)
                observations_per_agent[agent_id].append(agent_obs)
                infos_per_agent[agent_id].append(inf.get(agent_id, {}))
                # Pull out hanging action (if not first obs for this agent) and
                # complete step for agent.
                if len(observations_per_agent[agent_id]) > 1:
                    actions_per_agent[agent_id].append(
                        self._hanging_actions_end.pop(agent_id)
                    )
                    extra_model_outputs_per_agent[agent_id].append(
                        self._hanging_extra_model_outputs_end.pop(agent_id)
                    )
                    rewards_per_agent[agent_id].append(
                        self._hanging_rewards_end.pop(agent_id)
                    )
                # First obs for this agent. Make sure the agent's mapping is
                # appropriately prepended with self.SKIP_ENV_TS_TAG tags.
                else:
                    if agent_id not in self.env_t_to_agent_t:
                        self.env_t_to_agent_t[agent_id].extend(
                            [self.SKIP_ENV_TS_TAG] * data_idx
                        )
                        len_lookback_buffer_per_agent[agent_id] -= data_idx
                # Agent is still continuing (has an action for the next step).
                if agent_id in act:
                    # Always push actions/extra outputs into cache, then remove them
                    # from there, once the next observation comes in. Same for rewards.
                    self._hanging_actions_end[agent_id] = act[agent_id]
                    self._hanging_extra_model_outputs_end[agent_id] = extra_outs.get(
                        agent_id, {}
                    )
                    self._hanging_rewards_end[agent_id] += rew.get(agent_id, 0.0)
                # Agent is done (has no action for the next step).
                elif terminateds.get(agent_id) or truncateds.get(agent_id):
                    done_per_agent[agent_id] = True
                # There is more (global) action/reward data. This agent must therefore
                # be done. Automatically add it to `done_per_agent` and `terminateds`.
                elif data_idx < len(observations) - 1:
                    done_per_agent[agent_id] = terminateds[agent_id] = True
                # Update env_t_to_agent_t mapping.
                self.env_t_to_agent_t[agent_id].append(
                    len(observations_per_agent[agent_id]) - 1
                )
            # Those agents that did NOT step:
            # - Get self.SKIP_ENV_TS_TAG added to their env_t_to_agent_t mapping.
            # - Get their reward (if any) added up.
            for agent_id in all_agent_ids:
                if agent_id not in obs and agent_id not in done_per_agent:
                    self.env_t_to_agent_t[agent_id].append(self.SKIP_ENV_TS_TAG)
                    # If we are still in the global lookback buffer segment, deduct 1
                    # from this agents' lookback buffer, b/c we don't want the agent
                    # to use this (missing) obs/data in its single-agent lookback.
                    if (
                        len(self.env_t_to_agent_t[agent_id])
                        - self._len_lookback_buffers
                        <= 0
                    ):
                        len_lookback_buffer_per_agent[agent_id] -= 1
                    self._hanging_rewards_end[agent_id] += rew.get(agent_id, 0.0)
        # - Validate per-agent data.
        # - Fix lookback buffers of env_t_to_agent_t mappings.
        for agent_id in list(self.env_t_to_agent_t.keys()):
            # Skip agent if it doesn't seem to have any data.
            if agent_id not in observations_per_agent:
                del self.env_t_to_agent_t[agent_id]
                continue
            assert (
                len(observations_per_agent[agent_id])
                == len(infos_per_agent[agent_id])
                == len(actions_per_agent[agent_id]) + 1
                == len(extra_model_outputs_per_agent[agent_id]) + 1
                == len(rewards_per_agent[agent_id]) + 1
            )
            self.env_t_to_agent_t[agent_id].lookback = self._len_lookback_buffers
        # Now create the individual episodes from the collected per-agent data.
        for agent_id, agent_obs in observations_per_agent.items():
            # If agent only has a single obs AND is already done, remove all its traces
            # from this MultiAgentEpisode.
            if len(agent_obs) == 1 and done_per_agent.get(agent_id):
                self._del_agent(agent_id)
                continue
            # Try to figure out the module ID for this agent.
            # If not provided explicitly by the user that initializes this episode
            # object, try our mapping function.
            module_id = agent_module_ids.get(
                agent_id, self.agent_to_module_mapping_fn(agent_id, self)
            )
            # Create this agent's SingleAgentEpisode.
            sa_episode = SingleAgentEpisode(
                id_=(
                    agent_episode_ids.get(agent_id)
                    if agent_episode_ids is not None
                    else None
                ),
                agent_id=agent_id,
                module_id=module_id,
                multi_agent_episode_id=self.id_,
                observations=agent_obs,
                observation_space=self.observation_space.get(agent_id),
                infos=infos_per_agent[agent_id],
                actions=actions_per_agent[agent_id],
                action_space=self.action_space.get(agent_id),
                rewards=rewards_per_agent[agent_id],
                extra_model_outputs=(
                    {
                        k: [i[k] for i in extra_model_outputs_per_agent[agent_id]]
                        for k in extra_model_outputs_per_agent[agent_id][0].keys()
                    }
                    if extra_model_outputs_per_agent[agent_id]
                    else None
                ),
                terminated=terminateds.get(agent_id, False),
                truncated=truncateds.get(agent_id, False),
                t_started=self.agent_t_started[agent_id],
                len_lookback_buffer=max(len_lookback_buffer_per_agent[agent_id], 0),
            )
            # .. and store it.
            self.agent_episodes[agent_id] = sa_episode
    def _get(
        self,
        *,
        what,
        indices,
        agent_ids=None,
        env_steps=True,
        neg_index_as_lookback=False,
        fill=None,
        one_hot_discrete=False,
        return_list=False,
        extra_model_outputs_key=None,
    ):
        agent_ids = set(force_list(agent_ids)) or self.agent_ids
        kwargs = dict(
            what=what,
            indices=indices,
            agent_ids=agent_ids,
            neg_index_as_lookback=neg_index_as_lookback,
            fill=fill,
            # Rewards and infos do not support one_hot_discrete option.
            one_hot_discrete=dict(
                {} if not one_hot_discrete else {"one_hot_discrete": one_hot_discrete}
            ),
            extra_model_outputs_key=extra_model_outputs_key,
        )
        # User specified agent timesteps (indices) -> Simply delegate everything
        # to the individual agents' SingleAgentEpisodes.
        if env_steps is False:
            if return_list:
                raise ValueError(
                    f"`MultiAgentEpisode.get_{what}()` can't be called with both "
                    "`env_steps=False` and `return_list=True`!"
                )
            return self._get_data_by_agent_steps(**kwargs)
        # User specified env timesteps (indices) -> We need to translate them for each
        # agent into agent-timesteps.
        # Return a list of individual per-env-timestep multi-agent dicts.
        elif return_list:
            return self._get_data_by_env_steps_as_list(**kwargs)
        # Return a single multi-agent dict with lists/arrays as leafs.
        else:
            return self._get_data_by_env_steps(**kwargs)
    def _get_data_by_agent_steps(
        self,
        *,
        what,
        indices,
        agent_ids,
        neg_index_as_lookback,
        fill,
        one_hot_discrete,
        extra_model_outputs_key,
    ):
        # Return requested data by agent-steps.
        ret = {}
        # For each agent, we retrieve the data through passing the given indices into
        # the SingleAgentEpisode of that agent.
        for agent_id, sa_episode in self.agent_episodes.items():
            if agent_id not in agent_ids:
                continue
            inf_lookback_buffer = getattr(sa_episode, what)
            hanging_val = self._get_hanging_value(what, agent_id)
            # User wants a specific `extra_model_outputs` key.
            if extra_model_outputs_key is not None:
                inf_lookback_buffer = inf_lookback_buffer[extra_model_outputs_key]
                hanging_val = hanging_val[extra_model_outputs_key]
            agent_value = inf_lookback_buffer.get(
                indices=indices,
                neg_index_as_lookback=neg_index_as_lookback,
                fill=fill,
                _add_last_ts_value=hanging_val,
                **one_hot_discrete,
            )
            if agent_value is None or agent_value == []:
                continue
            ret[agent_id] = agent_value
        return ret
    def _get_data_by_env_steps_as_list(
        self,
        *,
        what: str,
        indices: Union[int, slice, List[int]],
        agent_ids: Collection[AgentID],
        neg_index_as_lookback: bool,
        fill: Any,
        one_hot_discrete,
        extra_model_outputs_key: str,
    ) -> List[MultiAgentDict]:
        # Collect indices for each agent first, so we can construct the list in
        # the next step.
        agent_indices = {}
        for agent_id in self.agent_episodes.keys():
            if agent_id not in agent_ids:
                continue
            agent_indices[agent_id] = self.env_t_to_agent_t[agent_id].get(
                indices,
                neg_index_as_lookback=neg_index_as_lookback,
                fill=self.SKIP_ENV_TS_TAG,
                # For those records where there is no "hanging" last timestep (all
                # other than obs and infos), we have to ignore the last entry in
                # the env_t_to_agent_t mappings.
                _ignore_last_ts=what not in ["observations", "infos"],
            )
        if not agent_indices:
            return []
        ret = []
        for i in range(len(next(iter(agent_indices.values())))):
            ret2 = {}
            for agent_id, idxes in agent_indices.items():
                hanging_val = self._get_hanging_value(what, agent_id)
                (
                    inf_lookback_buffer,
                    indices_to_use,
                ) = self._get_inf_lookback_buffer_or_dict(
                    agent_id,
                    what,
                    extra_model_outputs_key,
                    hanging_val,
                    filter_for_skip_indices=idxes[i],
                )
                if (
                    what == "extra_model_outputs"
                    and not inf_lookback_buffer
                    and not hanging_val
                ):
                    continue
                agent_value = self._get_single_agent_data_by_index(
                    what=what,
                    inf_lookback_buffer=inf_lookback_buffer,
                    agent_id=agent_id,
                    index_incl_lookback=indices_to_use,
                    fill=fill,
                    one_hot_discrete=one_hot_discrete,
                    extra_model_outputs_key=extra_model_outputs_key,
                    hanging_val=hanging_val,
                )
                if agent_value is not None:
                    ret2[agent_id] = agent_value
            ret.append(ret2)
        return ret
    def _get_data_by_env_steps(
        self,
        *,
        what: str,
        indices: Union[int, slice, List[int]],
        agent_ids: Collection[AgentID],
        neg_index_as_lookback: bool,
        fill: Any,
        one_hot_discrete: bool,
        extra_model_outputs_key: str,
    ) -> MultiAgentDict:
        ignore_last_ts = what not in ["observations", "infos"]
        ret = {}
        for agent_id, sa_episode in self.agent_episodes.items():
            if agent_id not in agent_ids:
                continue
            hanging_val = self._get_hanging_value(what, agent_id)
            agent_indices = self.env_t_to_agent_t[agent_id].get(
                indices,
                neg_index_as_lookback=neg_index_as_lookback,
                fill=self.SKIP_ENV_TS_TAG if fill is not None else None,
                # For those records where there is no "hanging" last timestep (all
                # other than obs and infos), we have to ignore the last entry in
                # the env_t_to_agent_t mappings.
                _ignore_last_ts=ignore_last_ts,
            )
            inf_lookback_buffer, agent_indices = self._get_inf_lookback_buffer_or_dict(
                agent_id,
                what,
                extra_model_outputs_key,
                hanging_val,
                filter_for_skip_indices=agent_indices,
            )
            if isinstance(agent_indices, list):
                agent_values = self._get_single_agent_data_by_env_step_indices(
                    what=what,
                    agent_id=agent_id,
                    indices_incl_lookback=agent_indices,
                    fill=fill,
                    one_hot_discrete=one_hot_discrete,
                    hanging_val=hanging_val,
                    extra_model_outputs_key=extra_model_outputs_key,
                )
                if len(agent_values) > 0:
                    ret[agent_id] = agent_values
            else:
                agent_values = self._get_single_agent_data_by_index(
                    what=what,
                    inf_lookback_buffer=inf_lookback_buffer,
                    agent_id=agent_id,
                    index_incl_lookback=agent_indices,
                    fill=fill,
                    one_hot_discrete=one_hot_discrete,
                    extra_model_outputs_key=extra_model_outputs_key,
                    hanging_val=hanging_val,
                )
                if agent_values is not None:
                    ret[agent_id] = agent_values
        return ret
    def _get_single_agent_data_by_index(
        self,
        *,
        what: str,
        inf_lookback_buffer: InfiniteLookbackBuffer,
        agent_id: AgentID,
        index_incl_lookback: Union[int, str],
        fill: Any,
        one_hot_discrete: dict,
        extra_model_outputs_key: str,
        hanging_val: Any,
    ) -> Any:
        sa_episode = self.agent_episodes[agent_id]
        if index_incl_lookback == self.SKIP_ENV_TS_TAG:
            # We don't want to fill -> Skip this agent.
            if fill is None:
                return
            # Provide filled value for this agent.
            return getattr(sa_episode, f"get_{what}")(
                indices=1000000000000,
                neg_index_as_lookback=False,
                fill=fill,
                **dict(
                    {}
                    if extra_model_outputs_key is None
                    else {"key": extra_model_outputs_key}
                ),
                **one_hot_discrete,
            )
        # No skip timestep -> Provide value at given index for this agent.
        # Special case: extra_model_outputs and key=None (return all keys as
        # a dict). Note that `inf_lookback_buffer` is NOT an infinite lookback
        # buffer, but a dict mapping keys to individual infinite lookback
        # buffers.
        elif what == "extra_model_outputs" and extra_model_outputs_key is None:
            assert hanging_val is None or isinstance(hanging_val, dict)
            ret = {}
            if inf_lookback_buffer:
                for key, sub_buffer in inf_lookback_buffer.items():
                    ret[key] = sub_buffer.get(
                        indices=index_incl_lookback - sub_buffer.lookback,
                        neg_index_as_lookback=True,
                        fill=fill,
                        _add_last_ts_value=(
                            None if hanging_val is None else hanging_val[key]
                        ),
                        **one_hot_discrete,
                    )
            else:
                for key in hanging_val.keys():
                    ret[key] = InfiniteLookbackBuffer().get(
                        indices=index_incl_lookback,
                        neg_index_as_lookback=True,
                        fill=fill,
                        _add_last_ts_value=hanging_val[key],
                        **one_hot_discrete,
                    )
            return ret
        # Extract data directly from the infinite lookback buffer object.
        else:
            return inf_lookback_buffer.get(
                indices=index_incl_lookback - inf_lookback_buffer.lookback,
                neg_index_as_lookback=True,
                fill=fill,
                _add_last_ts_value=hanging_val,
                **one_hot_discrete,
            )
    def _get_single_agent_data_by_env_step_indices(
        self,
        *,
        what: str,
        agent_id: AgentID,
        indices_incl_lookback: Union[int, str],
        fill: Optional[Any] = None,
        one_hot_discrete: bool = False,
        extra_model_outputs_key: Optional[str] = None,
        hanging_val: Optional[Any] = None,
    ) -> Any:
        """Returns single data item from the episode based on given (env step) indices.
        The returned data item will have a batch size that matches the env timesteps
        defined via `indices_incl_lookback`.
        Args:
            what: A (str) descriptor of what data to collect. Must be one of
                "observations", "infos", "actions", "rewards", or "extra_model_outputs".
            indices_incl_lookback: A list of ints specifying, which indices
                to pull from the InfiniteLookbackBuffer defined by `agent_id` and `what`
                (and maybe `extra_model_outputs_key`). Note that these indices
                disregard the special logic of the lookback buffer. Meaning if one
                index in `indices_incl_lookback` is 0, then the first value in the
                lookback buffer should be returned, not the first value after the
                lookback buffer (which would be normal behavior for pulling items from
                an `InfiniteLookbackBuffer` object).
            agent_id: The individual agent ID to pull data for. Used to lookup the
                `SingleAgentEpisode` object for this agent in `self`.
            fill: An optional float value to use for filling up the returned results at
                the boundaries. This filling only happens if the requested index range's
                start/stop boundaries exceed the buffer's boundaries (including the
                lookback buffer on the left side). This comes in very handy, if users
                don't want to worry about reaching such boundaries and want to zero-pad.
                For example, a buffer with data [10, 11,  12, 13, 14] and lookback
                buffer size of 2 (meaning `10` and `11` are part of the lookback buffer)
                will respond to `indices_incl_lookback=[-1, -2, 0]` and `fill=0.0`
                with `[0.0, 0.0, 10]`.
            one_hot_discrete: If True, will return one-hot vectors (instead of
                int-values) for those sub-components of a (possibly complex) space
                that are Discrete or MultiDiscrete. Note that if `fill=0` and the
                requested `indices_incl_lookback` are out of the range of our data, the
                returned one-hot vectors will actually be zero-hot (all slots zero).
            extra_model_outputs_key: Only if what is "extra_model_outputs", this
                specifies the sub-key (str) inside the extra_model_outputs dict, e.g.
                STATE_OUT or ACTION_DIST_INPUTS.
            hanging_val: In case we are pulling actions, rewards, or extra_model_outputs
                data, there might be information "hanging" (cached). For example,
                if an agent receives an observation o0 and then immediately sends an
                action a0 back, but then does NOT immediately reveive a next
                observation, a0 is now cached (not fully logged yet with this
                episode). The currently cached value must be provided here to be able
                to return it in case the index is -1 (most recent timestep).
        Returns:
            A data item corresponding to the provided args.
        """
        sa_episode = self.agent_episodes[agent_id]
        inf_lookback_buffer = getattr(sa_episode, what)
        if extra_model_outputs_key is not None:
            inf_lookback_buffer = inf_lookback_buffer[extra_model_outputs_key]
        # If there are self.SKIP_ENV_TS_TAG items in `indices_incl_lookback` and user
        # wants to fill these (together with outside-episode-bounds indices) ->
        # Provide these skipped timesteps as filled values.
        if self.SKIP_ENV_TS_TAG in indices_incl_lookback and fill is not None:
            single_fill_value = inf_lookback_buffer.get(
                indices=1000000000000,
                neg_index_as_lookback=False,
                fill=fill,
                **one_hot_discrete,
            )
            ret = []
            for i in indices_incl_lookback:
                if i == self.SKIP_ENV_TS_TAG:
                    ret.append(single_fill_value)
                else:
                    ret.append(
                        inf_lookback_buffer.get(
                            indices=i - getattr(sa_episode, what).lookback,
                            neg_index_as_lookback=True,
                            fill=fill,
                            _add_last_ts_value=hanging_val,
                            **one_hot_discrete,
                        )
                    )
            if self.is_numpy:
                ret = batch(ret)
        else:
            # Filter these indices out up front.
            indices = [
                i - inf_lookback_buffer.lookback
                for i in indices_incl_lookback
                if i != self.SKIP_ENV_TS_TAG
            ]
            ret = inf_lookback_buffer.get(
                indices=indices,
                neg_index_as_lookback=True,
                fill=fill,
                _add_last_ts_value=hanging_val,
                **one_hot_discrete,
            )
        return ret
    def _get_hanging_value(self, what: str, agent_id: AgentID) -> Any:
        """Returns the hanging action/reward/extra_model_outputs for given agent."""
        if what == "actions":
            return self._hanging_actions_end.get(agent_id)
        elif what == "extra_model_outputs":
            return self._hanging_extra_model_outputs_end.get(agent_id)
        elif what == "rewards":
            return self._hanging_rewards_end.get(agent_id)
    def _copy_hanging(self, agent_id: AgentID, other: "MultiAgentEpisode") -> None:
        """Copies hanging action, reward, extra_model_outputs from `other` to `self."""
        if agent_id in other._hanging_rewards_begin:
            self._hanging_rewards_begin[agent_id] = other._hanging_rewards_begin[
                agent_id
            ]
        if agent_id in other._hanging_rewards_end:
            self._hanging_actions_end[agent_id] = copy.deepcopy(
                other._hanging_actions_end[agent_id]
            )
            self._hanging_rewards_end[agent_id] = other._hanging_rewards_end[agent_id]
            self._hanging_extra_model_outputs_end[agent_id] = copy.deepcopy(
                other._hanging_extra_model_outputs_end[agent_id]
            )
    def _del_hanging(self, agent_id: AgentID) -> None:
        """Deletes all hanging action, reward, extra_model_outputs of given agent."""
        self._hanging_rewards_begin.pop(agent_id, None)
        self._hanging_actions_end.pop(agent_id, None)
        self._hanging_extra_model_outputs_end.pop(agent_id, None)
        self._hanging_rewards_end.pop(agent_id, None)
    def _del_agent(self, agent_id: AgentID) -> None:
        """Deletes all data of given agent from this episode."""
        self._del_hanging(agent_id)
        self.agent_episodes.pop(agent_id, None)
        self.agent_ids.discard(agent_id)
        self.env_t_to_agent_t.pop(agent_id, None)
        self._agent_to_module_mapping.pop(agent_id, None)
        self.agent_t_started.pop(agent_id, None)
    def _get_inf_lookback_buffer_or_dict(
        self,
        agent_id: AgentID,
        what: str,
        extra_model_outputs_key: Optional[str] = None,
        hanging_val: Optional[Any] = None,
        filter_for_skip_indices=None,
    ):
        """Returns a single InfiniteLookbackBuffer or a dict of such.
        In case `what` is "extra_model_outputs" AND `extra_model_outputs_key` is None,
        a dict is returned. In all other cases, a single InfiniteLookbackBuffer is
        returned.
        """
        inf_lookback_buffer_or_dict = inf_lookback_buffer = getattr(
            self.agent_episodes[agent_id], what
        )
        if what == "extra_model_outputs":
            if extra_model_outputs_key is not None:
                inf_lookback_buffer = inf_lookback_buffer_or_dict[
                    extra_model_outputs_key
                ]
            elif inf_lookback_buffer_or_dict:
                inf_lookback_buffer = next(iter(inf_lookback_buffer_or_dict.values()))
            elif filter_for_skip_indices is not None:
                return inf_lookback_buffer_or_dict, filter_for_skip_indices
            else:
                return inf_lookback_buffer_or_dict
        if filter_for_skip_indices is not None:
            inf_lookback_buffer_len = (
                len(inf_lookback_buffer)
                + inf_lookback_buffer.lookback
                + (hanging_val is not None)
            )
            ignore_last_ts = what not in ["observations", "infos"]
            if isinstance(filter_for_skip_indices, list):
                filter_for_skip_indices = [
                    "S" if ignore_last_ts and i == inf_lookback_buffer_len else i
                    for i in filter_for_skip_indices
                ]
            elif ignore_last_ts and filter_for_skip_indices == inf_lookback_buffer_len:
                filter_for_skip_indices = "S"
            return inf_lookback_buffer_or_dict, filter_for_skip_indices
        else:
            return inf_lookback_buffer_or_dict
    @Deprecated(new="MultiAgentEpisode.is_numpy()", error=True)
    def is_finalized(self):
        pass
    @Deprecated(new="MultiAgentEpisode.to_numpy()", error=True)
    def finalize(self):
        pass