from __future__ import print_function
import collections
import datetime
import numbers
import sys
import textwrap
import time
import warnings
from pathlib import Path
from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, Union
import numpy as np
import pandas as pd
import ray
from ray._private.dict import flatten_dict
from ray._private.thirdparty.tabulate.tabulate import tabulate
from ray.air.constants import EXPR_ERROR_FILE, TRAINING_ITERATION
from ray.air.util.node import _force_on_current_node
from ray.experimental.tqdm_ray import safe_print
from ray.tune.callback import Callback
from ray.tune.experiment.trial import DEBUG_PRINT_INTERVAL, Trial, _Location
from ray.tune.logger import pretty_print
from ray.tune.result import (
    AUTO_RESULT_KEYS,
    DEFAULT_METRIC,
    DONE,
    EPISODE_REWARD_MEAN,
    EXPERIMENT_TAG,
    MEAN_ACCURACY,
    MEAN_LOSS,
    NODE_IP,
    PID,
    TIME_TOTAL_S,
    TIMESTEPS_TOTAL,
    TRIAL_ID,
)
from ray.tune.trainable import Trainable
from ray.tune.utils import unflattened_lookup
from ray.tune.utils.log import Verbosity, has_verbosity, set_verbosity
from ray.util.annotations import DeveloperAPI, PublicAPI
from ray.util.queue import Empty, Queue
from ray.widgets import Template
try:
    from collections.abc import Mapping, MutableMapping
except ImportError:
    from collections import Mapping, MutableMapping
IS_NOTEBOOK = ray.widgets.util.in_notebook()
SKIP_RESULTS_IN_REPORT = {"config", TRIAL_ID, EXPERIMENT_TAG, DONE}
[docs]
@PublicAPI
class ProgressReporter:
    """Abstract class for experiment progress reporting.
    `should_report()` is called to determine whether or not `report()` should
    be called. Tune will call these functions after trial state transitions,
    receiving training results, and so on.
    """
[docs]
    def setup(
        self,
        start_time: Optional[float] = None,
        total_samples: Optional[int] = None,
        metric: Optional[str] = None,
        mode: Optional[str] = None,
        **kwargs,
    ):
        """Setup progress reporter for a new Ray Tune run.
        This function is used to initialize parameters that are set on runtime.
        It will be called before any of the other methods.
        Defaults to no-op.
        Args:
            start_time: Timestamp when the Ray Tune run is started.
            total_samples: Number of samples the Ray Tune run will run.
            metric: Metric to optimize.
            mode: Must be one of [min, max]. Determines whether objective is
                minimizing or maximizing the metric attribute.
            **kwargs: Keyword arguments for forward-compatibility.
        """
        pass 
[docs]
    def should_report(self, trials: List[Trial], done: bool = False):
        """Returns whether or not progress should be reported.
        Args:
            trials: Trials to report on.
            done: Whether this is the last progress report attempt.
        """
        raise NotImplementedError 
[docs]
    def report(self, trials: List[Trial], done: bool, *sys_info: Dict):
        """Reports progress across trials.
        Args:
            trials: Trials to report on.
            done: Whether this is the last progress report attempt.
            sys_info: System info.
        """
        raise NotImplementedError 
 
@DeveloperAPI
class TuneReporterBase(ProgressReporter):
    """Abstract base class for the default Tune reporters.
    If metric_columns is not overridden, Tune will attempt to automatically
    infer the metrics being outputted, up to 'infer_limit' number of
    metrics.
    Args:
        metric_columns: Names of metrics to
            include in progress table. If this is a dict, the keys should
            be metric names and the values should be the displayed names.
            If this is a list, the metric name is used directly.
        parameter_columns: Names of parameters to
            include in progress table. If this is a dict, the keys should
            be parameter names and the values should be the displayed names.
            If this is a list, the parameter name is used directly. If empty,
            defaults to all available parameters.
        max_progress_rows: Maximum number of rows to print
            in the progress table. The progress table describes the
            progress of each trial. Defaults to 20.
        max_error_rows: Maximum number of rows to print in the
            error table. The error table lists the error file, if any,
            corresponding to each trial. Defaults to 20.
        max_column_length: Maximum column length (in characters). Column
            headers and values longer than this will be abbreviated.
        max_report_frequency: Maximum report frequency in seconds.
            Defaults to 5s.
        infer_limit: Maximum number of metrics to automatically infer
            from tune results.
        print_intermediate_tables: Print intermediate result
            tables. If None (default), will be set to True for verbosity
            levels above 3, otherwise False. If True, intermediate tables
            will be printed with experiment progress. If False, tables
            will only be printed at then end of the tuning run for verbosity
            levels greater than 2.
        metric: Metric used to determine best current trial.
        mode: One of [min, max]. Determines whether objective is
            minimizing or maximizing the metric attribute.
        sort_by_metric: Sort terminated trials by metric in the
            intermediate table. Defaults to False.
    """
    # Truncated representations of column names (to accommodate small screens).
    DEFAULT_COLUMNS = collections.OrderedDict(
        {
            MEAN_ACCURACY: "acc",
            MEAN_LOSS: "loss",
            TRAINING_ITERATION: "iter",
            TIME_TOTAL_S: "total time (s)",
            TIMESTEPS_TOTAL: "ts",
            EPISODE_REWARD_MEAN: "reward",
        }
    )
    VALID_SUMMARY_TYPES = {
        int,
        float,
        np.float32,
        np.float64,
        np.int32,
        np.int64,
        type(None),
    }
    def __init__(
        self,
        *,
        metric_columns: Optional[Union[List[str], Dict[str, str]]] = None,
        parameter_columns: Optional[Union[List[str], Dict[str, str]]] = None,
        total_samples: Optional[int] = None,
        max_progress_rows: int = 20,
        max_error_rows: int = 20,
        max_column_length: int = 20,
        max_report_frequency: int = 5,
        infer_limit: int = 3,
        print_intermediate_tables: Optional[bool] = None,
        metric: Optional[str] = None,
        mode: Optional[str] = None,
        sort_by_metric: bool = False,
    ):
        self._total_samples = total_samples
        self._metrics_override = metric_columns is not None
        self._inferred_metrics = {}
        self._metric_columns = metric_columns or self.DEFAULT_COLUMNS.copy()
        self._parameter_columns = parameter_columns or []
        self._max_progress_rows = max_progress_rows
        self._max_error_rows = max_error_rows
        self._max_column_length = max_column_length
        self._infer_limit = infer_limit
        if print_intermediate_tables is None:
            self._print_intermediate_tables = has_verbosity(Verbosity.V3_TRIAL_DETAILS)
        else:
            self._print_intermediate_tables = print_intermediate_tables
        self._max_report_freqency = max_report_frequency
        self._last_report_time = 0
        self._start_time = time.time()
        self._metric = metric
        self._mode = mode
        self._sort_by_metric = sort_by_metric
    def setup(
        self,
        start_time: Optional[float] = None,
        total_samples: Optional[int] = None,
        metric: Optional[str] = None,
        mode: Optional[str] = None,
        **kwargs,
    ):
        self.set_start_time(start_time)
        self.set_total_samples(total_samples)
        self.set_search_properties(metric=metric, mode=mode)
    def set_search_properties(self, metric: Optional[str], mode: Optional[str]):
        if (self._metric and metric) or (self._mode and mode):
            raise ValueError(
                "You passed a `metric` or `mode` argument to `tune.TuneConfig()`, but "
                "the reporter you are using was already instantiated with their "
                "own `metric` and `mode` parameters. Either remove the arguments "
                "from your reporter or from your call to `tune.TuneConfig()`"
            )
        if metric:
            self._metric = metric
        if mode:
            self._mode = mode
        if self._metric is None and self._mode:
            # If only a mode was passed, use anonymous metric
            self._metric = DEFAULT_METRIC
        return True
    def set_total_samples(self, total_samples: int):
        self._total_samples = total_samples
    def set_start_time(self, timestamp: Optional[float] = None):
        if timestamp is not None:
            self._start_time = time.time()
        else:
            self._start_time = timestamp
    def should_report(self, trials: List[Trial], done: bool = False):
        if time.time() - self._last_report_time > self._max_report_freqency:
            self._last_report_time = time.time()
            return True
        return done
    def add_metric_column(self, metric: str, representation: Optional[str] = None):
        """Adds a metric to the existing columns.
        Args:
            metric: Metric to add. This must be a metric being returned
                in training step results.
            representation: Representation to use in table. Defaults to
                `metric`.
        """
        self._metrics_override = True
        if metric in self._metric_columns:
            raise ValueError("Column {} already exists.".format(metric))
        if isinstance(self._metric_columns, MutableMapping):
            representation = representation or metric
            self._metric_columns[metric] = representation
        else:
            if representation is not None and representation != metric:
                raise ValueError(
                    "`representation` cannot differ from `metric` "
                    "if this reporter was initialized with a list "
                    "of metric columns."
                )
            self._metric_columns.append(metric)
    def add_parameter_column(
        self, parameter: str, representation: Optional[str] = None
    ):
        """Adds a parameter to the existing columns.
        Args:
            parameter: Parameter to add. This must be a parameter
                specified in the configuration.
            representation: Representation to use in table. Defaults to
                `parameter`.
        """
        if parameter in self._parameter_columns:
            raise ValueError("Column {} already exists.".format(parameter))
        if isinstance(self._parameter_columns, MutableMapping):
            representation = representation or parameter
            self._parameter_columns[parameter] = representation
        else:
            if representation is not None and representation != parameter:
                raise ValueError(
                    "`representation` cannot differ from `parameter` "
                    "if this reporter was initialized with a list "
                    "of metric columns."
                )
            self._parameter_columns.append(parameter)
    def _progress_str(
        self,
        trials: List[Trial],
        done: bool,
        *sys_info: Dict,
        fmt: str = "psql",
        delim: str = "\n",
    ):
        """Returns full progress string.
        This string contains a progress table and error table. The progress
        table describes the progress of each trial. The error table lists
        the error file, if any, corresponding to each trial. The latter only
        exists if errors have occurred.
        Args:
            trials: Trials to report on.
            done: Whether this is the last progress report attempt.
            fmt: Table format. See `tablefmt` in tabulate API.
            delim: Delimiter between messages.
        """
        if self._sort_by_metric and (self._metric is None or self._mode is None):
            self._sort_by_metric = False
            warnings.warn(
                "Both 'metric' and 'mode' must be set to be able "
                "to sort by metric. No sorting is performed."
            )
        if not self._metrics_override:
            user_metrics = self._infer_user_metrics(trials, self._infer_limit)
            self._metric_columns.update(user_metrics)
        messages = [
            "== Status ==",
            _time_passed_str(self._start_time, time.time()),
            *sys_info,
        ]
        if done:
            max_progress = None
            max_error = None
        else:
            max_progress = self._max_progress_rows
            max_error = self._max_error_rows
        current_best_trial, metric = self._current_best_trial(trials)
        if current_best_trial:
            messages.append(
                _best_trial_str(current_best_trial, metric, self._parameter_columns)
            )
        if has_verbosity(Verbosity.V1_EXPERIMENT):
            # Will filter the table in `trial_progress_str`
            messages.append(
                _trial_progress_str(
                    trials,
                    metric_columns=self._metric_columns,
                    parameter_columns=self._parameter_columns,
                    total_samples=self._total_samples,
                    force_table=self._print_intermediate_tables,
                    fmt=fmt,
                    max_rows=max_progress,
                    max_column_length=self._max_column_length,
                    done=done,
                    metric=self._metric,
                    mode=self._mode,
                    sort_by_metric=self._sort_by_metric,
                )
            )
            messages.append(_trial_errors_str(trials, fmt=fmt, max_rows=max_error))
        return delim.join(messages) + delim
    def _infer_user_metrics(self, trials: List[Trial], limit: int = 4):
        """Try to infer the metrics to print out."""
        if len(self._inferred_metrics) >= limit:
            return self._inferred_metrics
        self._inferred_metrics = {}
        for t in trials:
            if not t.last_result:
                continue
            for metric, value in t.last_result.items():
                if metric not in self.DEFAULT_COLUMNS:
                    if metric not in AUTO_RESULT_KEYS:
                        if type(value) in self.VALID_SUMMARY_TYPES:
                            self._inferred_metrics[metric] = metric
                if len(self._inferred_metrics) >= limit:
                    return self._inferred_metrics
        return self._inferred_metrics
    def _current_best_trial(self, trials: List[Trial]):
        if not trials:
            return None, None
        metric, mode = self._metric, self._mode
        # If no metric has been set, see if exactly one has been reported
        # and use that one. `mode` must still be set.
        if not metric:
            if len(self._inferred_metrics) == 1:
                metric = list(self._inferred_metrics.keys())[0]
        if not metric or not mode:
            return None, metric
        metric_op = 1.0 if mode == "max" else -1.0
        best_metric = float("-inf")
        best_trial = None
        for t in trials:
            if not t.last_result:
                continue
            metric_value = unflattened_lookup(metric, t.last_result, default=None)
            if pd.isnull(metric_value):
                continue
            if not best_trial or metric_value * metric_op > best_metric:
                best_metric = metric_value * metric_op
                best_trial = t
        return best_trial, metric
@DeveloperAPI
class RemoteReporterMixin:
    """Remote reporter abstract mixin class.
    Subclasses of this class will use a Ray Queue to display output
    on the driver side when running Ray Client."""
    @property
    def output_queue(self) -> Queue:
        return getattr(self, "_output_queue", None)
    @output_queue.setter
    def output_queue(self, value: Queue):
        self._output_queue = value
    def display(self, string: str) -> None:
        """Display the progress string.
        Args:
            string: String to display.
        """
        raise NotImplementedError
[docs]
@PublicAPI
class JupyterNotebookReporter(TuneReporterBase, RemoteReporterMixin):
    """Jupyter notebook-friendly Reporter that can update display in-place.
    Args:
        overwrite: Flag for overwriting the cell contents before initialization.
        metric_columns: Names of metrics to
            include in progress table. If this is a dict, the keys should
            be metric names and the values should be the displayed names.
            If this is a list, the metric name is used directly.
        parameter_columns: Names of parameters to
            include in progress table. If this is a dict, the keys should
            be parameter names and the values should be the displayed names.
            If this is a list, the parameter name is used directly. If empty,
            defaults to all available parameters.
        max_progress_rows: Maximum number of rows to print
            in the progress table. The progress table describes the
            progress of each trial. Defaults to 20.
        max_error_rows: Maximum number of rows to print in the
            error table. The error table lists the error file, if any,
            corresponding to each trial. Defaults to 20.
        max_column_length: Maximum column length (in characters). Column
            headers and values longer than this will be abbreviated.
        max_report_frequency: Maximum report frequency in seconds.
            Defaults to 5s.
        infer_limit: Maximum number of metrics to automatically infer
            from tune results.
        print_intermediate_tables: Print intermediate result
            tables. If None (default), will be set to True for verbosity
            levels above 3, otherwise False. If True, intermediate tables
            will be printed with experiment progress. If False, tables
            will only be printed at then end of the tuning run for verbosity
            levels greater than 2.
        metric: Metric used to determine best current trial.
        mode: One of [min, max]. Determines whether objective is
            minimizing or maximizing the metric attribute.
        sort_by_metric: Sort terminated trials by metric in the
            intermediate table. Defaults to False.
    """
    def __init__(
        self,
        *,
        overwrite: bool = True,
        metric_columns: Optional[Union[List[str], Dict[str, str]]] = None,
        parameter_columns: Optional[Union[List[str], Dict[str, str]]] = None,
        total_samples: Optional[int] = None,
        max_progress_rows: int = 20,
        max_error_rows: int = 20,
        max_column_length: int = 20,
        max_report_frequency: int = 5,
        infer_limit: int = 3,
        print_intermediate_tables: Optional[bool] = None,
        metric: Optional[str] = None,
        mode: Optional[str] = None,
        sort_by_metric: bool = False,
    ):
        super(JupyterNotebookReporter, self).__init__(
            metric_columns=metric_columns,
            parameter_columns=parameter_columns,
            total_samples=total_samples,
            max_progress_rows=max_progress_rows,
            max_error_rows=max_error_rows,
            max_column_length=max_column_length,
            max_report_frequency=max_report_frequency,
            infer_limit=infer_limit,
            print_intermediate_tables=print_intermediate_tables,
            metric=metric,
            mode=mode,
            sort_by_metric=sort_by_metric,
        )
        if not IS_NOTEBOOK:
            warnings.warn(
                "You are using the `JupyterNotebookReporter`, but not "
                "IPython/Jupyter-compatible environment was detected. "
                "If this leads to unformatted output (e.g. like "
                "<IPython.core.display.HTML object>), consider passing "
                "a `CLIReporter` as the `progress_reporter` argument "
                "to `train.RunConfig()` instead."
            )
        self._overwrite = overwrite
        self._display_handle = None
        self.display("")  # initialize empty display to update later
    def report(self, trials: List[Trial], done: bool, *sys_info: Dict):
        progress = self._progress_html(trials, done, *sys_info)
        if self.output_queue is not None:
            # If an output queue is set, send string
            self.output_queue.put(progress)
        else:
            # Else, output directly
            self.display(progress)
    def display(self, string: str) -> None:
        from IPython.display import HTML, clear_output, display
        if not self._display_handle:
            if self._overwrite:
                clear_output(wait=True)
            self._display_handle = display(HTML(string), display_id=True)
        else:
            self._display_handle.update(HTML(string))
    def _progress_html(self, trials: List[Trial], done: bool, *sys_info) -> str:
        """Generate an HTML-formatted progress update.
        Args:
            trials: List of trials for which progress should be
                displayed
            done: True if the trials are finished, False otherwise
            *sys_info: System information to be displayed
        Returns:
            Progress update to be rendered in a notebook, including HTML
                tables and formatted error messages. Includes
                - Duration of the tune job
                - Memory consumption
                - Trial progress table, with information about each experiment
        """
        if not self._metrics_override:
            user_metrics = self._infer_user_metrics(trials, self._infer_limit)
            self._metric_columns.update(user_metrics)
        current_time, running_for = _get_time_str(self._start_time, time.time())
        used_gb, total_gb, memory_message = _get_memory_usage()
        status_table = tabulate(
            [
                ("Current time:", current_time),
                ("Running for:", running_for),
                ("Memory:", f"{used_gb}/{total_gb} GiB"),
            ],
            tablefmt="html",
        )
        trial_progress_data = _trial_progress_table(
            trials=trials,
            metric_columns=self._metric_columns,
            parameter_columns=self._parameter_columns,
            fmt="html",
            max_rows=None if done else self._max_progress_rows,
            metric=self._metric,
            mode=self._mode,
            sort_by_metric=self._sort_by_metric,
            max_column_length=self._max_column_length,
        )
        trial_progress = trial_progress_data[0]
        trial_progress_messages = trial_progress_data[1:]
        trial_errors = _trial_errors_str(
            trials, fmt="html", max_rows=None if done else self._max_error_rows
        )
        if any([memory_message, trial_progress_messages, trial_errors]):
            msg = Template("tune_status_messages.html.j2").render(
                memory_message=memory_message,
                trial_progress_messages=trial_progress_messages,
                trial_errors=trial_errors,
            )
        else:
            msg = None
        return Template("tune_status.html.j2").render(
            status_table=status_table,
            sys_info_message=_generate_sys_info_str(*sys_info),
            trial_progress=trial_progress,
            messages=msg,
        ) 
[docs]
@PublicAPI
class CLIReporter(TuneReporterBase):
    """Command-line reporter
    Args:
        metric_columns: Names of metrics to
            include in progress table. If this is a dict, the keys should
            be metric names and the values should be the displayed names.
            If this is a list, the metric name is used directly.
        parameter_columns: Names of parameters to
            include in progress table. If this is a dict, the keys should
            be parameter names and the values should be the displayed names.
            If this is a list, the parameter name is used directly. If empty,
            defaults to all available parameters.
        max_progress_rows: Maximum number of rows to print
            in the progress table. The progress table describes the
            progress of each trial. Defaults to 20.
        max_error_rows: Maximum number of rows to print in the
            error table. The error table lists the error file, if any,
            corresponding to each trial. Defaults to 20.
        max_column_length: Maximum column length (in characters). Column
            headers and values longer than this will be abbreviated.
        max_report_frequency: Maximum report frequency in seconds.
            Defaults to 5s.
        infer_limit: Maximum number of metrics to automatically infer
            from tune results.
        print_intermediate_tables: Print intermediate result
            tables. If None (default), will be set to True for verbosity
            levels above 3, otherwise False. If True, intermediate tables
            will be printed with experiment progress. If False, tables
            will only be printed at then end of the tuning run for verbosity
            levels greater than 2.
        metric: Metric used to determine best current trial.
        mode: One of [min, max]. Determines whether objective is
            minimizing or maximizing the metric attribute.
        sort_by_metric: Sort terminated trials by metric in the
            intermediate table. Defaults to False.
    """
    def __init__(
        self,
        *,
        metric_columns: Optional[Union[List[str], Dict[str, str]]] = None,
        parameter_columns: Optional[Union[List[str], Dict[str, str]]] = None,
        total_samples: Optional[int] = None,
        max_progress_rows: int = 20,
        max_error_rows: int = 20,
        max_column_length: int = 20,
        max_report_frequency: int = 5,
        infer_limit: int = 3,
        print_intermediate_tables: Optional[bool] = None,
        metric: Optional[str] = None,
        mode: Optional[str] = None,
        sort_by_metric: bool = False,
    ):
        super(CLIReporter, self).__init__(
            metric_columns=metric_columns,
            parameter_columns=parameter_columns,
            total_samples=total_samples,
            max_progress_rows=max_progress_rows,
            max_error_rows=max_error_rows,
            max_column_length=max_column_length,
            max_report_frequency=max_report_frequency,
            infer_limit=infer_limit,
            print_intermediate_tables=print_intermediate_tables,
            metric=metric,
            mode=mode,
            sort_by_metric=sort_by_metric,
        )
    def _print(self, msg: str):
        safe_print(msg)
    def report(self, trials: List[Trial], done: bool, *sys_info: Dict):
        self._print(self._progress_str(trials, done, *sys_info)) 
def _get_memory_usage() -> Tuple[float, float, Optional[str]]:
    """Get the current memory consumption.
    Returns:
        Memory used, memory available, and optionally a warning
            message to be shown to the user when memory consumption is higher
            than 90% or if `psutil` is not installed
    """
    try:
        import ray  # noqa F401
        import psutil
        total_gb = psutil.virtual_memory().total / (1024**3)
        used_gb = total_gb - psutil.virtual_memory().available / (1024**3)
        if used_gb > total_gb * 0.9:
            message = (
                ": ***LOW MEMORY*** less than 10% of the memory on "
                "this node is available for use. This can cause "
                "unexpected crashes. Consider "
                "reducing the memory used by your application "
                "or reducing the Ray object store size by setting "
                "`object_store_memory` when calling `ray.init`."
            )
        else:
            message = None
        return round(used_gb, 1), round(total_gb, 1), message
    except ImportError:
        return (
            np.nan,
            np.nan,
            "Unknown memory usage. Please run `pip install psutil` to resolve",
        )
def _get_time_str(start_time: float, current_time: float) -> Tuple[str, str]:
    """Get strings representing the current and elapsed time.
    Args:
        start_time: POSIX timestamp of the start of the tune run
        current_time: POSIX timestamp giving the current time
    Returns:
        Current time and elapsed time for the current run
    """
    current_time_dt = datetime.datetime.fromtimestamp(current_time)
    start_time_dt = datetime.datetime.fromtimestamp(start_time)
    delta: datetime.timedelta = current_time_dt - start_time_dt
    rest = delta.total_seconds()
    days = rest // (60 * 60 * 24)
    rest -= days * (60 * 60 * 24)
    hours = rest // (60 * 60)
    rest -= hours * (60 * 60)
    minutes = rest // 60
    seconds = rest - minutes * 60
    if days > 0:
        running_for_str = f"{days:.0f} days, "
    else:
        running_for_str = ""
    running_for_str += f"{hours:02.0f}:{minutes:02.0f}:{seconds:05.2f}"
    return f"{current_time_dt:%Y-%m-%d %H:%M:%S}", running_for_str
def _time_passed_str(start_time: float, current_time: float) -> str:
    """Generate a message describing the current and elapsed time in the run.
    Args:
        start_time: POSIX timestamp of the start of the tune run
        current_time: POSIX timestamp giving the current time
    Returns:
        Message with the current and elapsed time for the current tune run,
            formatted to be displayed to the user
    """
    current_time_str, running_for_str = _get_time_str(start_time, current_time)
    return f"Current time: {current_time_str} " f"(running for {running_for_str})"
def _get_trials_by_state(trials: List[Trial]):
    trials_by_state = collections.defaultdict(list)
    for t in trials:
        trials_by_state[t.status].append(t)
    return trials_by_state
def _trial_progress_str(
    trials: List[Trial],
    metric_columns: Union[List[str], Dict[str, str]],
    parameter_columns: Optional[Union[List[str], Dict[str, str]]] = None,
    total_samples: int = 0,
    force_table: bool = False,
    fmt: str = "psql",
    max_rows: Optional[int] = None,
    max_column_length: int = 20,
    done: bool = False,
    metric: Optional[str] = None,
    mode: Optional[str] = None,
    sort_by_metric: bool = False,
):
    """Returns a human readable message for printing to the console.
    This contains a table where each row represents a trial, its parameters
    and the current values of its metrics.
    Args:
        trials: List of trials to get progress string for.
        metric_columns: Names of metrics to include.
            If this is a dict, the keys are metric names and the values are
            the names to use in the message. If this is a list, the metric
            name is used in the message directly.
        parameter_columns: Names of parameters to
            include. If this is a dict, the keys are parameter names and the
            values are the names to use in the message. If this is a list,
            the parameter name is used in the message directly. If this is
            empty, all parameters are used in the message.
        total_samples: Total number of trials that will be generated.
        force_table: Force printing a table. If False, a table will
            be printed only at the end of the training for verbosity levels
            above `Verbosity.V2_TRIAL_NORM`.
        fmt: Output format (see tablefmt in tabulate API).
        max_rows: Maximum number of rows in the trial table. Defaults to
            unlimited.
        max_column_length: Maximum column length (in characters).
        done: True indicates that the tuning run finished.
        metric: Metric used to sort trials.
        mode: One of [min, max]. Determines whether objective is
            minimizing or maximizing the metric attribute.
        sort_by_metric: Sort terminated trials by metric in the
            intermediate table. Defaults to False.
    """
    messages = []
    delim = "<br>" if fmt == "html" else "\n"
    if len(trials) < 1:
        return delim.join(messages)
    num_trials = len(trials)
    trials_by_state = _get_trials_by_state(trials)
    for local_dir in sorted({t.local_experiment_path for t in trials}):
        messages.append("Result logdir: {}".format(local_dir))
    num_trials_strs = [
        "{} {}".format(len(trials_by_state[state]), state)
        for state in sorted(trials_by_state)
    ]
    if total_samples and total_samples >= sys.maxsize:
        total_samples = "infinite"
    messages.append(
        "Number of trials: {}{} ({})".format(
            num_trials,
            f"/{total_samples}" if total_samples else "",
            ", ".join(num_trials_strs),
        )
    )
    if force_table or (has_verbosity(Verbosity.V2_TRIAL_NORM) and done):
        messages += _trial_progress_table(
            trials=trials,
            metric_columns=metric_columns,
            parameter_columns=parameter_columns,
            fmt=fmt,
            max_rows=max_rows,
            metric=metric,
            mode=mode,
            sort_by_metric=sort_by_metric,
            max_column_length=max_column_length,
        )
    return delim.join(messages)
def _max_len(
    value: Any, max_len: int = 20, add_addr: bool = False, wrap: bool = False
) -> Any:
    """Abbreviate a string representation of an object to `max_len` characters.
    For numbers, booleans and None, the original value will be returned for
    correct rendering in the table formatting tool.
    Args:
        value: Object to be represented as a string.
        max_len: Maximum return string length.
        add_addr: If True, will add part of the object address to the end of the
            string, e.g. to identify different instances of the same class. If
            False, three dots (``...``) will be used instead.
    """
    if value is None or isinstance(value, (int, float, numbers.Number, bool)):
        return value
    string = str(value)
    if len(string) <= max_len:
        return string
    if wrap:
        # Maximum two rows.
        # Todo: Make this configurable in the refactor
        if len(value) > max_len * 2:
            value = "..." + string[(3 - (max_len * 2)) :]
        wrapped = textwrap.wrap(value, width=max_len)
        return "\n".join(wrapped)
    if add_addr and not isinstance(value, (int, float, bool)):
        result = f"{string[: (max_len - 5)]}_{hex(id(value))[-4:]}"
        return result
    result = "..." + string[(3 - max_len) :]
    return result
def _get_progress_table_data(
    trials: List[Trial],
    metric_columns: Union[List[str], Dict[str, str]],
    parameter_columns: Optional[Union[List[str], Dict[str, str]]] = None,
    max_rows: Optional[int] = None,
    metric: Optional[str] = None,
    mode: Optional[str] = None,
    sort_by_metric: bool = False,
    max_column_length: int = 20,
) -> Tuple[List, List[str], Tuple[bool, str]]:
    """Generate a table showing the current progress of tuning trials.
    Args:
        trials: List of trials for which progress is to be shown.
        metric_columns: Metrics to be displayed in the table.
        parameter_columns: List of parameters to be included in the data
        max_rows: Maximum number of rows to show. If there's overflow, a
            message will be shown to the user indicating that some rows
            are not displayed
        metric: Metric which is being tuned
        mode: Sort the table in descending order if mode is "max";
            ascending otherwise
        sort_by_metric: If true, the table will be sorted by the metric
        max_column_length: Max number of characters in each column
    Returns:
        - Trial data
        - List of column names
        - Overflow tuple:
            - boolean indicating whether the table has rows which are hidden
            - string with info about the overflowing rows
    """
    num_trials = len(trials)
    trials_by_state = _get_trials_by_state(trials)
    # Sort terminated trials by metric and mode, descending if mode is "max"
    if sort_by_metric:
        trials_by_state[Trial.TERMINATED] = sorted(
            trials_by_state[Trial.TERMINATED],
            reverse=(mode == "max"),
            key=lambda t: unflattened_lookup(metric, t.last_result, default=None),
        )
    state_tbl_order = [
        Trial.RUNNING,
        Trial.PAUSED,
        Trial.PENDING,
        Trial.TERMINATED,
        Trial.ERROR,
    ]
    max_rows = max_rows or float("inf")
    if num_trials > max_rows:
        # TODO(ujvl): suggestion for users to view more rows.
        trials_by_state_trunc = _fair_filter_trials(
            trials_by_state, max_rows, sort_by_metric
        )
        trials = []
        overflow_strs = []
        for state in state_tbl_order:
            if state not in trials_by_state:
                continue
            trials += trials_by_state_trunc[state]
            num = len(trials_by_state[state]) - len(trials_by_state_trunc[state])
            if num > 0:
                overflow_strs.append("{} {}".format(num, state))
        # Build overflow string.
        overflow = num_trials - max_rows
        overflow_str = ", ".join(overflow_strs)
    else:
        overflow = False
        overflow_str = ""
        trials = []
        for state in state_tbl_order:
            if state not in trials_by_state:
                continue
            trials += trials_by_state[state]
    # Pre-process trials to figure out what columns to show.
    if isinstance(metric_columns, Mapping):
        metric_keys = list(metric_columns.keys())
    else:
        metric_keys = metric_columns
    metric_keys = [
        k
        for k in metric_keys
        if any(
            unflattened_lookup(k, t.last_result, default=None) is not None
            for t in trials
        )
    ]
    if not parameter_columns:
        parameter_keys = sorted(set().union(*[t.evaluated_params for t in trials]))
    elif isinstance(parameter_columns, Mapping):
        parameter_keys = list(parameter_columns.keys())
    else:
        parameter_keys = parameter_columns
    # Build trial rows.
    trial_table = [
        _get_trial_info(
            trial, parameter_keys, metric_keys, max_column_length=max_column_length
        )
        for trial in trials
    ]
    # Format column headings
    if isinstance(metric_columns, Mapping):
        formatted_metric_columns = [
            _max_len(
                metric_columns[k], max_len=max_column_length, add_addr=False, wrap=True
            )
            for k in metric_keys
        ]
    else:
        formatted_metric_columns = [
            _max_len(k, max_len=max_column_length, add_addr=False, wrap=True)
            for k in metric_keys
        ]
    if isinstance(parameter_columns, Mapping):
        formatted_parameter_columns = [
            _max_len(
                parameter_columns[k],
                max_len=max_column_length,
                add_addr=False,
                wrap=True,
            )
            for k in parameter_keys
        ]
    else:
        formatted_parameter_columns = [
            _max_len(k, max_len=max_column_length, add_addr=False, wrap=True)
            for k in parameter_keys
        ]
    columns = (
        ["Trial name", "status", "loc"]
        + formatted_parameter_columns
        + formatted_metric_columns
    )
    return trial_table, columns, (overflow, overflow_str)
def _trial_progress_table(
    trials: List[Trial],
    metric_columns: Union[List[str], Dict[str, str]],
    parameter_columns: Optional[Union[List[str], Dict[str, str]]] = None,
    fmt: str = "psql",
    max_rows: Optional[int] = None,
    metric: Optional[str] = None,
    mode: Optional[str] = None,
    sort_by_metric: bool = False,
    max_column_length: int = 20,
) -> List[str]:
    """Generate a list of trial progress table messages.
    Args:
        trials: List of trials for which progress is to be shown.
        metric_columns: Metrics to be displayed in the table.
        parameter_columns: List of parameters to be included in the data
        fmt: Format of the table; passed to tabulate as the fmtstr argument
        max_rows: Maximum number of rows to show. If there's overflow, a
            message will be shown to the user indicating that some rows
            are not displayed
        metric: Metric which is being tuned
        mode: Sort the table in descenting order if mode is "max";
            ascending otherwise
        sort_by_metric: If true, the table will be sorted by the metric
        max_column_length: Max number of characters in each column
    Returns:
        Messages to be shown to the user containing progress tables
    """
    data, columns, (overflow, overflow_str) = _get_progress_table_data(
        trials,
        metric_columns,
        parameter_columns,
        max_rows,
        metric,
        mode,
        sort_by_metric,
        max_column_length,
    )
    messages = [tabulate(data, headers=columns, tablefmt=fmt, showindex=False)]
    if overflow:
        messages.append(f"... {overflow} more trials not shown ({overflow_str})")
    return messages
def _generate_sys_info_str(*sys_info) -> str:
    """Format system info into a string.
        *sys_info: System info strings to be included.
    Returns:
        Formatted string containing system information.
    """
    if sys_info:
        return "<br>".join(sys_info).replace("\n", "<br>")
    return ""
def _trial_errors_str(
    trials: List[Trial], fmt: str = "psql", max_rows: Optional[int] = None
):
    """Returns a readable message regarding trial errors.
    Args:
        trials: List of trials to get progress string for.
        fmt: Output format (see tablefmt in tabulate API).
        max_rows: Maximum number of rows in the error table. Defaults to
            unlimited.
    """
    messages = []
    failed = [t for t in trials if t.error_file]
    num_failed = len(failed)
    if num_failed > 0:
        messages.append("Number of errored trials: {}".format(num_failed))
        if num_failed > (max_rows or float("inf")):
            messages.append(
                "Table truncated to {} rows ({} overflow)".format(
                    max_rows, num_failed - max_rows
                )
            )
        fail_header = ["Trial name", "# failures", "error file"]
        fail_table_data = [
            [
                str(trial),
                str(trial.run_metadata.num_failures)
                + ("" if trial.status == Trial.ERROR else "*"),
                trial.error_file,
            ]
            for trial in failed[:max_rows]
        ]
        messages.append(
            tabulate(
                fail_table_data,
                headers=fail_header,
                tablefmt=fmt,
                showindex=False,
                colalign=("left", "right", "left"),
            )
        )
        if any(trial.status == Trial.TERMINATED for trial in failed[:max_rows]):
            messages.append("* The trial terminated successfully after retrying.")
    delim = "<br>" if fmt == "html" else "\n"
    return delim.join(messages)
def _best_trial_str(
    trial: Trial,
    metric: str,
    parameter_columns: Optional[Union[List[str], Dict[str, str]]] = None,
):
    """Returns a readable message stating the current best trial."""
    val = unflattened_lookup(metric, trial.last_result, default=None)
    config = trial.last_result.get("config", {})
    parameter_columns = parameter_columns or list(config.keys())
    if isinstance(parameter_columns, Mapping):
        parameter_columns = parameter_columns.keys()
    params = {p: unflattened_lookup(p, config) for p in parameter_columns}
    return (
        f"Current best trial: {trial.trial_id} with {metric}={val} and "
        f"parameters={params}"
    )
def _fair_filter_trials(
    trials_by_state: Dict[str, List[Trial]],
    max_trials: int,
    sort_by_metric: bool = False,
):
    """Filters trials such that each state is represented fairly.
    The oldest trials are truncated if necessary.
    Args:
        trials_by_state: Maximum number of trials to return.
    Returns:
        Dict mapping state to List of fairly represented trials.
    """
    num_trials_by_state = collections.defaultdict(int)
    no_change = False
    # Determine number of trials to keep per state.
    while max_trials > 0 and not no_change:
        no_change = True
        for state in sorted(trials_by_state):
            if num_trials_by_state[state] < len(trials_by_state[state]):
                no_change = False
                max_trials -= 1
                num_trials_by_state[state] += 1
    # Sort by start time, descending if the trails is not sorted by metric.
    sorted_trials_by_state = dict()
    for state in sorted(trials_by_state):
        if state == Trial.TERMINATED and sort_by_metric:
            sorted_trials_by_state[state] = trials_by_state[state]
        else:
            sorted_trials_by_state[state] = sorted(
                trials_by_state[state], reverse=False, key=lambda t: t.trial_id
            )
    # Truncate oldest trials.
    filtered_trials = {
        state: sorted_trials_by_state[state][: num_trials_by_state[state]]
        for state in sorted(trials_by_state)
    }
    return filtered_trials
def _get_trial_location(trial: Trial, result: dict) -> _Location:
    # we get the location from the result, as the one in trial will be
    # reset when trial terminates
    node_ip, pid = result.get(NODE_IP, None), result.get(PID, None)
    if node_ip and pid:
        location = _Location(node_ip, pid)
    else:
        # fallback to trial location if there hasn't been a report yet
        location = trial.temporary_state.location
    return location
def _get_trial_info(
    trial: Trial, parameters: List[str], metrics: List[str], max_column_length: int = 20
):
    """Returns the following information about a trial:
    name | status | loc | params... | metrics...
    Args:
        trial: Trial to get information for.
        parameters: Names of trial parameters to include.
        metrics: Names of metrics to include.
        max_column_length: Maximum column length (in characters).
    """
    result = trial.last_result
    config = trial.config
    location = _get_trial_location(trial, result)
    trial_info = [str(trial), trial.status, str(location)]
    trial_info += [
        _max_len(
            unflattened_lookup(param, config, default=None),
            max_len=max_column_length,
            add_addr=True,
        )
        for param in parameters
    ]
    trial_info += [
        _max_len(
            unflattened_lookup(metric, result, default=None),
            max_len=max_column_length,
            add_addr=True,
        )
        for metric in metrics
    ]
    return trial_info
@DeveloperAPI
class TrialProgressCallback(Callback):
    """Reports (prints) intermediate trial progress.
    This callback is automatically added to the callback stack. When a
    result is obtained, this callback will print the results according to
    the specified verbosity level.
    For ``Verbosity.V3_TRIAL_DETAILS``, a full result list is printed.
    For ``Verbosity.V2_TRIAL_NORM``, only one line is printed per received
    result.
    All other verbosity levels do not print intermediate trial progress.
    Result printing is throttled on a per-trial basis. Per default, results are
    printed only once every 30 seconds. Results are always printed when a trial
    finished or errored.
    """
    def __init__(
        self, metric: Optional[str] = None, progress_metrics: Optional[List[str]] = None
    ):
        self._last_print = collections.defaultdict(float)
        self._last_print_iteration = collections.defaultdict(int)
        self._completed_trials = set()
        self._last_result_str = {}
        self._metric = metric
        self._progress_metrics = set(progress_metrics or [])
        # Only use progress metrics if at least two metrics are in there
        if self._metric and self._progress_metrics:
            self._progress_metrics.add(self._metric)
        self._last_result = {}
        self._display_handle = None
    def _print(self, msg: str):
        safe_print(msg)
    def on_trial_result(
        self,
        iteration: int,
        trials: List["Trial"],
        trial: "Trial",
        result: Dict,
        **info,
    ):
        self.log_result(trial, result, error=False)
    def on_trial_error(
        self, iteration: int, trials: List["Trial"], trial: "Trial", **info
    ):
        self.log_result(trial, trial.last_result, error=True)
    def on_trial_complete(
        self, iteration: int, trials: List["Trial"], trial: "Trial", **info
    ):
        # Only log when we never logged that a trial was completed
        if trial not in self._completed_trials:
            self._completed_trials.add(trial)
            print_result_str = self._print_result(trial.last_result)
            last_result_str = self._last_result_str.get(trial, "")
            # If this is a new result, print full result string
            if print_result_str != last_result_str:
                self.log_result(trial, trial.last_result, error=False)
            else:
                self._print(f"Trial {trial} completed. Last result: {print_result_str}")
    def log_result(self, trial: "Trial", result: Dict, error: bool = False):
        done = result.get("done", False) is True
        last_print = self._last_print[trial]
        should_print = done or error or time.time() - last_print > DEBUG_PRINT_INTERVAL
        if done and trial not in self._completed_trials:
            self._completed_trials.add(trial)
        if should_print:
            if IS_NOTEBOOK:
                self.display_result(trial, result, error, done)
            else:
                self.print_result(trial, result, error, done)
            self._last_print[trial] = time.time()
            if TRAINING_ITERATION in result:
                self._last_print_iteration[trial] = result[TRAINING_ITERATION]
    def print_result(self, trial: Trial, result: Dict, error: bool, done: bool):
        """Print the most recent results for the given trial to stdout.
        Args:
            trial: Trial for which results are to be printed
            result: Result to be printed
            error: True if an error has occurred, False otherwise
            done: True if the trial is finished, False otherwise
        """
        last_print_iteration = self._last_print_iteration[trial]
        if has_verbosity(Verbosity.V3_TRIAL_DETAILS):
            if result.get(TRAINING_ITERATION) != last_print_iteration:
                self._print(f"Result for {trial}:")
                self._print("  {}".format(pretty_print(result).replace("\n", "\n  ")))
            if done:
                self._print(f"Trial {trial} completed.")
        elif has_verbosity(Verbosity.V2_TRIAL_NORM):
            metric_name = self._metric or "_metric"
            metric_value = result.get(metric_name, -99.0)
            error_file = Path(trial.local_path, EXPR_ERROR_FILE).as_posix()
            info = ""
            if done:
                info = " This trial completed."
            print_result_str = self._print_result(result)
            self._last_result_str[trial] = print_result_str
            if error:
                message = (
                    f"The trial {trial} errored with "
                    f"parameters={trial.config}. "
                    f"Error file: {error_file}"
                )
            elif self._metric:
                message = (
                    f"Trial {trial} reported "
                    f"{metric_name}={metric_value:.2f} "
                    f"with parameters={trial.config}.{info}"
                )
            else:
                message = (
                    f"Trial {trial} reported "
                    f"{print_result_str} "
                    f"with parameters={trial.config}.{info}"
                )
            self._print(message)
    def generate_trial_table(
        self, trials: Dict[Trial, Dict], columns: List[str]
    ) -> str:
        """Generate an HTML table of trial progress info.
        Trials (rows) are sorted by name; progress stats (columns) are sorted
        as well.
        Args:
            trials: Trials and their associated latest results
            columns: Columns to show in the table; must be a list of valid
                keys for each Trial result
        Returns:
            HTML template containing a rendered table of progress info
        """
        data = []
        columns = sorted(columns)
        sorted_trials = collections.OrderedDict(
            sorted(self._last_result.items(), key=lambda item: str(item[0]))
        )
        for trial, result in sorted_trials.items():
            data.append([str(trial)] + [result.get(col, "") for col in columns])
        return Template("trial_progress.html.j2").render(
            table=tabulate(
                data, tablefmt="html", headers=["Trial name"] + columns, showindex=False
            )
        )
    def display_result(self, trial: Trial, result: Dict, error: bool, done: bool):
        """Display a formatted HTML table of trial progress results.
        Trial progress is only shown if verbosity is set to level 2 or 3.
        Args:
            trial: Trial for which results are to be printed
            result: Result to be printed
            error: True if an error has occurred, False otherwise
            done: True if the trial is finished, False otherwise
        """
        from IPython.display import HTML, display
        self._last_result[trial] = result
        if has_verbosity(Verbosity.V3_TRIAL_DETAILS):
            ignored_keys = {
                "config",
                "hist_stats",
            }
        elif has_verbosity(Verbosity.V2_TRIAL_NORM):
            ignored_keys = {
                "config",
                "hist_stats",
                "trial_id",
                "experiment_tag",
                "done",
            } | set(AUTO_RESULT_KEYS)
        else:
            return
        table = self.generate_trial_table(
            self._last_result, set(result.keys()) - ignored_keys
        )
        if not self._display_handle:
            self._display_handle = display(HTML(table), display_id=True)
        else:
            self._display_handle.update(HTML(table))
    def _print_result(self, result: Dict):
        if self._progress_metrics:
            # If progress metrics are given, only report these
            flat_result = flatten_dict(result)
            print_result = {}
            for metric in self._progress_metrics:
                print_result[metric] = flat_result.get(metric)
        else:
            # Else, skip auto populated results
            print_result = result.copy()
            for skip_result in SKIP_RESULTS_IN_REPORT:
                print_result.pop(skip_result, None)
            for auto_result in AUTO_RESULT_KEYS:
                print_result.pop(auto_result, None)
        print_result_str = ",".join(
            [f"{k}={v}" for k, v in print_result.items() if v is not None]
        )
        return print_result_str
def _detect_reporter(_trainer_api: bool = False, **kwargs) -> TuneReporterBase:
    """Detect progress reporter class.
    Will return a :class:`JupyterNotebookReporter` if a IPython/Jupyter-like
    session was detected, and a :class:`CLIReporter` otherwise.
    Keyword arguments are passed on to the reporter class.
    """
    if IS_NOTEBOOK and not _trainer_api:
        kwargs.setdefault("overwrite", not has_verbosity(Verbosity.V2_TRIAL_NORM))
        progress_reporter = JupyterNotebookReporter(**kwargs)
    else:
        progress_reporter = CLIReporter(**kwargs)
    return progress_reporter
def _detect_progress_metrics(
    trainable: Optional[Union["Trainable", Callable]]
) -> Optional[Collection[str]]:
    """Detect progress metrics to report."""
    if not trainable:
        return None
    return getattr(trainable, "_progress_metrics", None)
def _prepare_progress_reporter_for_ray_client(
    progress_reporter: ProgressReporter,
    verbosity: Union[int, Verbosity],
    string_queue: Optional[Queue] = None,
) -> Tuple[ProgressReporter, Queue]:
    """Prepares progress reported for Ray Client by setting the string queue.
    The string queue will be created if it's None."""
    set_verbosity(verbosity)
    progress_reporter = progress_reporter or _detect_reporter()
    # JupyterNotebooks don't work with remote tune runs out of the box
    # (e.g. via Ray client) as they don't have access to the main
    # process stdout. So we introduce a queue here that accepts
    # strings, which will then be displayed on the driver side.
    if isinstance(progress_reporter, RemoteReporterMixin):
        if string_queue is None:
            string_queue = Queue(
                actor_options={"num_cpus": 0, **_force_on_current_node(None)}
            )
        progress_reporter.output_queue = string_queue
    return progress_reporter, string_queue
def _stream_client_output(
    remote_future: ray.ObjectRef,
    progress_reporter: ProgressReporter,
    string_queue: Queue,
) -> Any:
    """
    Stream items from string queue to progress_reporter until remote_future resolves
    """
    if string_queue is None:
        return
    def get_next_queue_item():
        try:
            return string_queue.get(block=False)
        except Empty:
            return None
    def _handle_string_queue():
        string_item = get_next_queue_item()
        while string_item is not None:
            # This happens on the driver side
            progress_reporter.display(string_item)
            string_item = get_next_queue_item()
    # ray.wait(...)[1] returns futures that are not ready, yet
    while ray.wait([remote_future], timeout=0.2)[1]:
        # Check if we have items to execute
        _handle_string_queue()
    # Handle queue one last time
    _handle_string_queue()