import logging
import ray
from pathlib import Path
from typing import List
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.core.columns import Columns
from ray.rllib.env.env_runner import EnvRunner
from ray.rllib.env.single_agent_env_runner import SingleAgentEnvRunner
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
from ray.rllib.utils.annotations import (
    override,
    OverrideToImplementCustomLogic_CallToSuperRecommended,
    OverrideToImplementCustomLogic,
)
from ray.rllib.utils.compression import pack_if_needed
from ray.rllib.utils.spaces.space_utils import to_jsonable_if_needed
from ray.rllib.utils.typing import EpisodeType
from ray.util.debug import log_once
from ray.util.annotations import PublicAPI
logger = logging.Logger(__file__)
# TODO (simon): This class can be agnostic to the episode type as it
#  calls only get_state.
[docs]
@PublicAPI(stability="alpha")
class OfflineSingleAgentEnvRunner(SingleAgentEnvRunner):
    """The environment runner to record the single agent case."""
    @override(SingleAgentEnvRunner)
    @OverrideToImplementCustomLogic_CallToSuperRecommended
    def __init__(self, *, config: AlgorithmConfig, **kwargs):
        # Initialize the parent.
        super().__init__(config=config, **kwargs)
        # Get the data context for this `EnvRunner`.
        data_context = ray.data.DataContext.get_current()
        # Limit the resources for Ray Data to the CPUs given to this `EnvRunner`.
        data_context.execution_options.resource_limits.cpu = (
            config.num_cpus_per_env_runner
        )
        # Set the output write method.
        self.output_write_method = self.config.output_write_method
        self.output_write_method_kwargs = self.config.output_write_method_kwargs
        # Set the filesystem.
        self.filesystem = self.config.output_filesystem
        self.filesystem_kwargs = self.config.output_filesystem_kwargs
        self.filesystem_object = None
        # Set the output base path.
        self.output_path = self.config.output
        # Set the subdir (environment specific).
        self.subdir_path = self.config.env.lower()
        # Set the worker-specific path name. Note, this is
        # specifically to enable multi-threaded writing into
        # the same directory.
        self.worker_path = "run-" + f"{self.worker_index}".zfill(6)
        # If a specific filesystem is given, set it up. Note, this could
        # be `gcsfs` for GCS, `pyarrow` for S3 or `adlfs` for Azure Blob Storage.
        # this filesystem is specifically needed, if a session has to be created
        # with the cloud provider.
        if self.filesystem == "gcs":
            import gcsfs
            self.filesystem_object = gcsfs.GCSFileSystem(**self.filesystem_kwargs)
        elif self.filesystem == "s3":
            from pyarrow import fs
            self.filesystem_object = fs.S3FileSystem(**self.filesystem_kwargs)
        elif self.filesystem == "abs":
            import adlfs
            self.filesystem_object = adlfs.AzureBlobFileSystem(**self.filesystem_kwargs)
        elif self.filesystem is not None:
            raise ValueError(
                f"Unknown filesystem: {self.filesystem}. Filesystems can be "
                "'gcs' for GCS, 's3' for S3, or 'abs'"
            )
        # Add the filesystem object to the write method kwargs.
        self.output_write_method_kwargs.update(
            {
                "filesystem": self.filesystem_object,
            }
        )
        # If we should store `SingleAgentEpisodes` or column data.
        self.output_write_episodes = self.config.output_write_episodes
        # Which columns should be compressed in the output data.
        self.output_compress_columns = self.config.output_compress_columns
        # Buffer these many rows before writing to file.
        self.output_max_rows_per_file = self.config.output_max_rows_per_file
        # If the user defines a maximum number of rows per file, set the
        # event to `False` and check during sampling.
        if self.output_max_rows_per_file:
            self.write_data_this_iter = False
        # Otherwise the event is always `True` and we write always sampled
        # data immediately to disk.
        else:
            self.write_data_this_iter = True
        # If the remaining data should be stored. Note, this is only
        # relevant in case `output_max_rows_per_file` is defined.
        self.write_remaining_data = self.config.output_write_remaining_data
        # Counts how often `sample` is called to define the output path for
        # each file.
        self._sample_counter = 0
        # Define the buffer for experiences stored until written to disk.
        self._samples = []
[docs]
    @override(SingleAgentEnvRunner)
    @OverrideToImplementCustomLogic
    def sample(
        self,
        *,
        num_timesteps: int = None,
        num_episodes: int = None,
        explore: bool = None,
        random_actions: bool = False,
        force_reset: bool = False,
    ) -> List[SingleAgentEpisode]:
        """Samples from environments and writes data to disk."""
        # Call the super sample method.
        samples = super().sample(
            num_timesteps=num_timesteps,
            num_episodes=num_episodes,
            explore=explore,
            random_actions=random_actions,
            force_reset=force_reset,
        )
        self._sample_counter += 1
        # Add data to the buffers.
        if self.output_write_episodes:
            import msgpack
            import msgpack_numpy as mnp
            if log_once("msgpack"):
                logger.info(
                    "Packing episodes with `msgpack` and encode array with "
                    "`msgpack_numpy` for serialization. This is needed for "
                    "recording episodes."
                )
            # Note, we serialize episodes with `msgpack` and `msgpack_numpy` to
            # ensure version compatibility.
            self._samples.extend(
                [msgpack.packb(eps.get_state(), default=mnp.encode) for eps in samples]
            )
        else:
            self._map_episodes_to_data(samples)
        # If the user defined the maximum number of rows to write.
        if self.output_max_rows_per_file:
            # Check, if this number is reached.
            if len(self._samples) >= self.output_max_rows_per_file:
                # Start the recording of data.
                self.write_data_this_iter = True
        if self.write_data_this_iter:
            # If the user wants a maximum number of experiences per file,
            # cut the samples to write to disk from the buffer.
            if self.output_max_rows_per_file:
                # Reset the event.
                self.write_data_this_iter = False
                # Ensure that all data ready to be written is released from
                # the buffer. Note, this is important in case we have many
                # episodes sampled and a relatively small `output_max_rows_per_file`.
                while len(self._samples) >= self.output_max_rows_per_file:
                    # Extract the number of samples to be written to disk this
                    # iteration.
                    samples_to_write = self._samples[: self.output_max_rows_per_file]
                    # Reset the buffer to the remaining data. This only makes sense, if
                    # `rollout_fragment_length` is smaller `output_max_rows_per_file` or
                    # a 2 x `output_max_rows_per_file`.
                    self._samples = self._samples[self.output_max_rows_per_file :]
                    samples_ds = ray.data.from_items(samples_to_write)
            # Otherwise, write the complete data.
            else:
                samples_ds = ray.data.from_items(self._samples)
            try:
                # Setup the path for writing data. Each run will be written to
                # its own file. A run is a writing event. The path will look
                # like. 'base_path/env-name/00000<WorkerID>-00000<RunID>'.
                path = (
                    Path(self.output_path)
                    .joinpath(self.subdir_path)
                    .joinpath(self.worker_path + f"-{self._sample_counter}".zfill(6))
                )
                getattr(samples_ds, self.output_write_method)(
                    path.as_posix(), **self.output_write_method_kwargs
                )
                logger.info(f"Wrote samples to storage at {path}.")
            except Exception as e:
                logger.error(e)
        self.metrics.log_value(
            key="recording_buffer_size",
            value=len(self._samples),
        )
        # Finally return the samples as usual.
        return samples 
[docs]
    @override(EnvRunner)
    @OverrideToImplementCustomLogic
    def stop(self) -> None:
        """Writes the reamining samples to disk
        Note, if the user defined `max_rows_per_file` the
        number of rows for the remaining samples could be
        less than the defined maximum row number by the user.
        """
        # If there are samples left over we have to write htem to disk. them
        # to a dataset.
        if self._samples and self.write_remaining_data:
            # Convert them to a `ray.data.Dataset`.
            samples_ds = ray.data.from_items(self._samples)
            # Increase the sample counter for the folder/file name.
            self._sample_counter += 1
            # Try to write the dataset to disk/cloud storage.
            try:
                # Setup the path for writing data. Each run will be written to
                # its own file. A run is a writing event. The path will look
                # like. 'base_path/env-name/00000<WorkerID>-00000<RunID>'.
                path = (
                    Path(self.output_path)
                    .joinpath(self.subdir_path)
                    .joinpath(self.worker_path + f"-{self._sample_counter}".zfill(6))
                )
                getattr(samples_ds, self.output_write_method)(
                    path.as_posix(), **self.output_write_method_kwargs
                )
                logger.info(
                    f"Wrote final samples to storage at {path}. Note "
                    "Note, final samples could be smaller in size than "
                    f"`max_rows_per_file`, if defined."
                )
            except Exception as e:
                logger.error(e)
        logger.debug(f"Experience buffer length: {len(self._samples)}") 
    @OverrideToImplementCustomLogic
    def _map_episodes_to_data(self, samples: List[EpisodeType]) -> None:
        """Converts list of episodes to list of single dict experiences.
        Note, this method also appends all sampled experiences to the
        buffer.
        Args:
            samples: List of episodes to be converted.
        """
        # Loop through all sampled episodes.
        obs_space = self.env.observation_space
        action_space = self.env.action_space
        for sample in samples:
            # Loop through all items of the episode.
            for i in range(len(sample)):
                sample_data = {
                    Columns.EPS_ID: sample.id_,
                    Columns.AGENT_ID: sample.agent_id,
                    Columns.MODULE_ID: sample.module_id,
                    # Compress observations, if requested.
                    Columns.OBS: pack_if_needed(
                        to_jsonable_if_needed(sample.get_observations(i), obs_space)
                    )
                    if Columns.OBS in self.output_compress_columns
                    else to_jsonable_if_needed(sample.get_observations(i), obs_space),
                    # Compress actions, if requested.
                    Columns.ACTIONS: pack_if_needed(
                        to_jsonable_if_needed(sample.get_actions(i), action_space)
                    )
                    if Columns.ACTIONS in self.output_compress_columns
                    else to_jsonable_if_needed(sample.get_actions(i), action_space),
                    Columns.REWARDS: sample.get_rewards(i),
                    # Compress next observations, if requested.
                    Columns.NEXT_OBS: pack_if_needed(
                        to_jsonable_if_needed(sample.get_observations(i + 1), obs_space)
                    )
                    if Columns.OBS in self.output_compress_columns
                    else to_jsonable_if_needed(
                        sample.get_observations(i + 1), obs_space
                    ),
                    Columns.TERMINATEDS: False
                    if i < len(sample) - 1
                    else sample.is_terminated,
                    Columns.TRUNCATEDS: False
                    if i < len(sample) - 1
                    else sample.is_truncated,
                    **{
                        # Compress any extra model output, if requested.
                        k: pack_if_needed(sample.get_extra_model_outputs(k, i))
                        if k in self.output_compress_columns
                        else sample.get_extra_model_outputs(k, i)
                        for k in sample.extra_model_outputs.keys()
                    },
                }
                # Finally append to the data buffer.
                self._samples.append(sample_data)