import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Union
import numpy as np
from ray.air.constants import TRAINING_ITERATION
from ray.tune.logger.logger import LoggerCallback
from ray.tune.result import TIME_TOTAL_S, TIMESTEPS_TOTAL
from ray.tune.utils import flatten_dict
from ray.util.annotations import PublicAPI
if TYPE_CHECKING:
    from ray.tune.experiment.trial import Trial
try:
    from aim.sdk import Repo, Run
except ImportError:
    Repo, Run = None, None
logger = logging.getLogger(__name__)
VALID_SUMMARY_TYPES = [int, float, np.float32, np.float64, np.int32, np.int64]
[docs]
@PublicAPI
class AimLoggerCallback(LoggerCallback):
    """Aim Logger: logs metrics in Aim format.
    Aim is an open-source, self-hosted ML experiment tracking tool.
    It's good at tracking lots (thousands) of training runs, and it allows you to
    compare them with a performant and well-designed UI.
    Source: https://github.com/aimhubio/aim
    Args:
        repo: Aim repository directory or a `Repo` object that the Run object will
            log results to. If not provided, a default repo will be set up in the
            experiment directory (one level above trial directories).
        experiment: Sets the `experiment` property of each Run object, which is the
            experiment name associated with it. Can be used later to query
            runs/sequences.
            If not provided, the default will be the Tune experiment name set
            by `RunConfig(name=...)`.
        metrics: List of metric names (out of the metrics reported by Tune) to
            track in Aim. If no metric are specified, log everything that
            is reported.
        aim_run_kwargs: Additional arguments that will be passed when creating the
            individual `Run` objects for each trial. For the full list of arguments,
            please see the Aim documentation:
            https://aimstack.readthedocs.io/en/latest/refs/sdk.html
    """
    VALID_HPARAMS = (str, bool, int, float, list, type(None))
    VALID_NP_HPARAMS = (np.bool_, np.float32, np.float64, np.int32, np.int64)
[docs]
    def __init__(
        self,
        repo: Optional[Union[str, "Repo"]] = None,
        experiment_name: Optional[str] = None,
        metrics: Optional[List[str]] = None,
        **aim_run_kwargs,
    ):
        """
        See help(AimLoggerCallback) for more information about parameters.
        """
        assert Run is not None, (
            "aim must be installed!. You can install aim with"
            " the command: `pip install aim`."
        )
        self._repo_path = repo
        self._experiment_name = experiment_name
        if not (bool(metrics) or metrics is None):
            raise ValueError(
                "`metrics` must either contain at least one metric name, or be None, "
                "in which case all reported metrics will be logged to the aim repo."
            )
        self._metrics = metrics
        self._aim_run_kwargs = aim_run_kwargs
        self._trial_to_run: Dict["Trial", Run] = {} 
    def _create_run(self, trial: "Trial") -> Run:
        """Initializes an Aim Run object for a given trial.
        Args:
            trial: The Tune trial that aim will track as a Run.
        Returns:
            Run: The created aim run for a specific trial.
        """
        experiment_dir = trial.local_experiment_path
        run = Run(
            repo=self._repo_path or experiment_dir,
            experiment=self._experiment_name or trial.experiment_dir_name,
            **self._aim_run_kwargs,
        )
        # Attach a few useful trial properties
        run["trial_id"] = trial.trial_id
        run["trial_log_dir"] = trial.path
        trial_ip = trial.get_ray_actor_ip()
        if trial_ip:
            run["trial_ip"] = trial_ip
        return run
    def log_trial_start(self, trial: "Trial"):
        if trial in self._trial_to_run:
            # Cleanup an existing run if the trial has been restarted
            self._trial_to_run[trial].close()
        trial.init_local_path()
        self._trial_to_run[trial] = self._create_run(trial)
        if trial.evaluated_params:
            self._log_trial_hparams(trial)
    def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
        tmp_result = result.copy()
        step = result.get(TIMESTEPS_TOTAL, None) or result[TRAINING_ITERATION]
        for k in ["config", "pid", "timestamp", TIME_TOTAL_S, TRAINING_ITERATION]:
            tmp_result.pop(k, None)  # not useful to log these
        # `context` and `epoch` are special keys that users can report,
        # which are treated as special aim metrics/configurations.
        context = tmp_result.pop("context", None)
        epoch = tmp_result.pop("epoch", None)
        trial_run = self._trial_to_run[trial]
        path = ["ray", "tune"]
        flat_result = flatten_dict(tmp_result, delimiter="/")
        valid_result = {}
        for attr, value in flat_result.items():
            if self._metrics and attr not in self._metrics:
                continue
            full_attr = "/".join(path + [attr])
            if isinstance(value, tuple(VALID_SUMMARY_TYPES)) and not (
                np.isnan(value) or np.isinf(value)
            ):
                valid_result[attr] = value
                trial_run.track(
                    value=value,
                    name=full_attr,
                    epoch=epoch,
                    step=step,
                    context=context,
                )
            elif (isinstance(value, (list, tuple, set)) and len(value) > 0) or (
                isinstance(value, np.ndarray) and value.size > 0
            ):
                valid_result[attr] = value
    def log_trial_end(self, trial: "Trial", failed: bool = False):
        trial_run = self._trial_to_run.pop(trial)
        trial_run.close()
    def _log_trial_hparams(self, trial: "Trial"):
        params = flatten_dict(trial.evaluated_params, delimiter="/")
        flat_params = flatten_dict(params)
        scrubbed_params = {
            k: v for k, v in flat_params.items() if isinstance(v, self.VALID_HPARAMS)
        }
        np_params = {
            k: v.tolist()
            for k, v in flat_params.items()
            if isinstance(v, self.VALID_NP_HPARAMS)
        }
        scrubbed_params.update(np_params)
        removed = {
            k: v
            for k, v in flat_params.items()
            if not isinstance(v, self.VALID_HPARAMS + self.VALID_NP_HPARAMS)
        }
        if removed:
            logger.info(
                "Removed the following hyperparameter values when "
                "logging to aim: %s",
                str(removed),
            )
        run = self._trial_to_run[trial]
        run["hparams"] = scrubbed_params