import argparse
import collections
import datetime
import logging
import math
import numbers
import os
import sys
import textwrap
import time
from dataclasses import dataclass
from enum import IntEnum
from typing import Any, Collection, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
import pandas as pd
import ray
from ray._private.dict import flatten_dict, unflattened_lookup
from ray._private.thirdparty.tabulate.tabulate import (
    DataRow,
    Line,
    TableFormat,
    tabulate,
)
from ray.air._internal.usage import AirEntrypoint
from ray.air.constants import TRAINING_ITERATION
from ray.train import Checkpoint
from ray.tune.callback import Callback
from ray.tune.experiment.trial import Trial
from ray.tune.result import (
    AUTO_RESULT_KEYS,
    EPISODE_REWARD_MEAN,
    MEAN_ACCURACY,
    MEAN_LOSS,
    TIME_TOTAL_S,
    TIMESTEPS_TOTAL,
)
from ray.tune.search.sample import Domain
from ray.tune.utils.log import Verbosity
try:
    import rich
    import rich.layout
    import rich.live
except ImportError:
    rich = None
logger = logging.getLogger(__name__)
# defines the mapping of the key in result and the key to be printed in table.
# Note this is ordered!
DEFAULT_COLUMNS = collections.OrderedDict(
    {
        MEAN_ACCURACY: "acc",
        MEAN_LOSS: "loss",
        TRAINING_ITERATION: "iter",
        TIME_TOTAL_S: "total time (s)",
        TIMESTEPS_TOTAL: "ts",
        EPISODE_REWARD_MEAN: "reward",
    }
)
# These keys are blacklisted for printing out training/tuning intermediate/final result!
BLACKLISTED_KEYS = {
    "config",
    "date",
    "done",
    "hostname",
    "iterations_since_restore",
    "node_ip",
    "pid",
    "time_since_restore",
    "timestamp",
    "trial_id",
    "experiment_tag",
    "should_checkpoint",
    "_report_on",  # LIGHTNING_REPORT_STAGE_KEY
}
VALID_SUMMARY_TYPES = {
    int,
    float,
    np.float32,
    np.float64,
    np.int32,
    np.int64,
    type(None),
}
# The order of summarizing trials.
ORDER = [
    Trial.RUNNING,
    Trial.TERMINATED,
    Trial.PAUSED,
    Trial.PENDING,
    Trial.ERROR,
]
class AirVerbosity(IntEnum):
    SILENT = 0
    DEFAULT = 1
    VERBOSE = 2
    def __repr__(self):
        return str(self.value)
IS_NOTEBOOK = ray.widgets.util.in_notebook()
def get_air_verbosity(
    verbose: Union[int, AirVerbosity, Verbosity]
) -> Optional[AirVerbosity]:
    if os.environ.get("RAY_AIR_NEW_OUTPUT", "1") == "0":
        return None
    if isinstance(verbose, AirVerbosity):
        return verbose
    verbose_int = verbose if isinstance(verbose, int) else verbose.value
    # Verbosity 2 and 3 both map to AirVerbosity 2
    verbose_int = min(2, verbose_int)
    return AirVerbosity(verbose_int)
def _infer_params(config: Dict[str, Any]) -> List[str]:
    params = []
    flat_config = flatten_dict(config)
    for key, val in flat_config.items():
        if isinstance(val, Domain):
            params.append(key)
        # Grid search is a special named field. Because we flattened
        # the whole config, we look it up per string
        if key.endswith("/grid_search"):
            # Truncate `/grid_search`
            params.append(key[:-12])
    return params
def _get_time_str(start_time: float, current_time: float) -> Tuple[str, str]:
    """Get strings representing the current and elapsed time.
    Args:
        start_time: POSIX timestamp of the start of the tune run
        current_time: POSIX timestamp giving the current time
    Returns:
        Current time and elapsed time for the current run
    """
    current_time_dt = datetime.datetime.fromtimestamp(current_time)
    start_time_dt = datetime.datetime.fromtimestamp(start_time)
    delta: datetime.timedelta = current_time_dt - start_time_dt
    rest = delta.total_seconds()
    days = int(rest // (60 * 60 * 24))
    rest -= days * (60 * 60 * 24)
    hours = int(rest // (60 * 60))
    rest -= hours * (60 * 60)
    minutes = int(rest // 60)
    seconds = int(rest - minutes * 60)
    running_for_str = ""
    if days > 0:
        running_for_str += f"{days:d}d "
    if hours > 0 or running_for_str:
        running_for_str += f"{hours:d}hr "
    if minutes > 0 or running_for_str:
        running_for_str += f"{minutes:d}min "
    running_for_str += f"{seconds:d}s"
    return f"{current_time_dt:%Y-%m-%d %H:%M:%S}", running_for_str
def _get_trials_by_state(trials: List[Trial]) -> Dict[str, List[Trial]]:
    trials_by_state = collections.defaultdict(list)
    for t in trials:
        trials_by_state[t.status].append(t)
    return trials_by_state
def _get_trials_with_error(trials: List[Trial]) -> List[Trial]:
    return [t for t in trials if t.error_file]
def _infer_user_metrics(trials: List[Trial], limit: int = 4) -> List[str]:
    """Try to infer the metrics to print out.
    By default, only the first 4 meaningful metrics in `last_result` will be
    inferred as user implied metrics.
    """
    # Using OrderedDict for OrderedSet.
    result = collections.OrderedDict()
    for t in trials:
        if not t.last_result:
            continue
        for metric, value in t.last_result.items():
            if metric not in DEFAULT_COLUMNS:
                if metric not in AUTO_RESULT_KEYS:
                    if type(value) in VALID_SUMMARY_TYPES:
                        result[metric] = ""  # not important
            if len(result) >= limit:
                return list(result.keys())
    return list(result.keys())
def _current_best_trial(
    trials: List[Trial], metric: Optional[str], mode: Optional[str]
) -> Tuple[Optional[Trial], Optional[str]]:
    """
    Returns the best trial and the metric key. If anything is empty or None,
    returns a trivial result of None, None.
    Args:
        trials: List of trials.
        metric: Metric that trials are being ranked.
        mode: One of "min" or "max".
    Returns:
         Best trial and the metric key.
    """
    if not trials or not metric or not mode:
        return None, None
    metric_op = 1.0 if mode == "max" else -1.0
    best_metric = float("-inf")
    best_trial = None
    for t in trials:
        if not t.last_result:
            continue
        metric_value = unflattened_lookup(metric, t.last_result, default=None)
        if pd.isnull(metric_value):
            continue
        if not best_trial or metric_value * metric_op > best_metric:
            best_metric = metric_value * metric_op
            best_trial = t
    return best_trial, metric
@dataclass
class _PerStatusTrialTableData:
    trial_infos: List[List[str]]
    more_info: str
@dataclass
class _TrialTableData:
    header: List[str]
    data: List[_PerStatusTrialTableData]
def _max_len(value: Any, max_len: int = 20, wrap: bool = False) -> Any:
    """Abbreviate a string representation of an object to `max_len` characters.
    For numbers, booleans and None, the original value will be returned for
    correct rendering in the table formatting tool.
    Args:
        value: Object to be represented as a string.
        max_len: Maximum return string length.
    """
    if value is None or isinstance(value, (int, float, numbers.Number, bool)):
        return value
    string = str(value)
    if len(string) <= max_len:
        return string
    if wrap:
        # Maximum two rows.
        # Todo: Make this configurable in the refactor
        if len(value) > max_len * 2:
            value = "..." + string[(3 - (max_len * 2)) :]
        wrapped = textwrap.wrap(value, width=max_len)
        return "\n".join(wrapped)
    result = "..." + string[(3 - max_len) :]
    return result
def _get_trial_info(
    trial: Trial, param_keys: List[str], metric_keys: List[str]
) -> List[str]:
    """Returns the following information about a trial:
    name | status | metrics...
    Args:
        trial: Trial to get information for.
        param_keys: Names of parameters to include.
        metric_keys: Names of metrics to include.
    """
    result = trial.last_result
    trial_info = [str(trial), trial.status]
    # params
    trial_info.extend(
        [
            _max_len(
                unflattened_lookup(param, trial.config, default=None),
            )
            for param in param_keys
        ]
    )
    # metrics
    trial_info.extend(
        [
            _max_len(
                unflattened_lookup(metric, result, default=None),
            )
            for metric in metric_keys
        ]
    )
    return trial_info
def _get_trial_table_data_per_status(
    status: str,
    trials: List[Trial],
    param_keys: List[str],
    metric_keys: List[str],
    force_max_rows: bool = False,
) -> Optional[_PerStatusTrialTableData]:
    """Gather all information of trials pertained to one `status`.
    Args:
        status: The trial status of interest.
        trials: all the trials of that status.
        param_keys: *Ordered* list of parameters to be displayed in the table.
        metric_keys: *Ordered* list of metrics to be displayed in the table.
            Including both default and user defined.
        force_max_rows: Whether or not to enforce a max row number for this status.
            If True, only a max of `5` rows will be shown.
    Returns:
        All information of trials pertained to the `status`.
    """
    # TODO: configure it.
    max_row = 5 if force_max_rows else math.inf
    if not trials:
        return None
    trial_infos = list()
    more_info = None
    for t in trials:
        if len(trial_infos) >= max_row:
            remaining = len(trials) - max_row
            more_info = f"{remaining} more {status}"
            break
        trial_infos.append(_get_trial_info(t, param_keys, metric_keys))
    return _PerStatusTrialTableData(trial_infos, more_info)
def _get_trial_table_data(
    trials: List[Trial],
    param_keys: List[str],
    metric_keys: List[str],
    all_rows: bool = False,
    wrap_headers: bool = False,
) -> _TrialTableData:
    """Generate a table showing the current progress of tuning trials.
    Args:
        trials: List of trials for which progress is to be shown.
        param_keys: Ordered list of parameters to be displayed in the table.
        metric_keys: Ordered list of metrics to be displayed in the table.
            Including both default and user defined.
            Will only be shown if at least one trial is having the key.
        all_rows: Force to show all rows.
        wrap_headers: If True, header columns can be wrapped with ``\n``.
    Returns:
        Trial table data, including header and trial table per each status.
    """
    # TODO: configure
    max_trial_num_to_show = 20
    max_column_length = 20
    trials_by_state = _get_trials_by_state(trials)
    # get the right metric to show.
    metric_keys = [
        k
        for k in metric_keys
        if any(
            unflattened_lookup(k, t.last_result, default=None) is not None
            for t in trials
        )
    ]
    # get header from metric keys
    formatted_metric_columns = [
        _max_len(k, max_len=max_column_length, wrap=wrap_headers) for k in metric_keys
    ]
    formatted_param_columns = [
        _max_len(k, max_len=max_column_length, wrap=wrap_headers) for k in param_keys
    ]
    metric_header = [
        DEFAULT_COLUMNS[metric] if metric in DEFAULT_COLUMNS else formatted
        for metric, formatted in zip(metric_keys, formatted_metric_columns)
    ]
    param_header = formatted_param_columns
    # Map to the abbreviated version if necessary.
    header = ["Trial name", "status"] + param_header + metric_header
    trial_data = list()
    for t_status in ORDER:
        trial_data_per_status = _get_trial_table_data_per_status(
            t_status,
            trials_by_state[t_status],
            param_keys=param_keys,
            metric_keys=metric_keys,
            force_max_rows=not all_rows and len(trials) > max_trial_num_to_show,
        )
        if trial_data_per_status:
            trial_data.append(trial_data_per_status)
    return _TrialTableData(header, trial_data)
def _best_trial_str(
    trial: Trial,
    metric: str,
):
    """Returns a readable message stating the current best trial."""
    # returns something like
    # Current best trial: 18ae7_00005 with loss=0.5918508041056858 and params={'train_loop_config': {'lr': 0.059253447253394785}}. # noqa
    val = unflattened_lookup(metric, trial.last_result, default=None)
    config = trial.last_result.get("config", {})
    parameter_columns = list(config.keys())
    params = {p: unflattened_lookup(p, config) for p in parameter_columns}
    return (
        f"Current best trial: {trial.trial_id} with {metric}={val} and "
        f"params={params}"
    )
def _render_table_item(
    key: str, item: Any, prefix: str = ""
) -> Iterable[Tuple[str, str]]:
    key = prefix + key
    if isinstance(item, argparse.Namespace):
        item = item.__dict__
    if isinstance(item, float):
        # tabulate does not work well with mixed-type columns, so we format
        # numbers ourselves.
        yield key, f"{item:.5f}".rstrip("0")
    elif isinstance(item, dict):
        flattened = flatten_dict(item)
        for k, v in sorted(flattened.items()):
            yield key + "/" + str(k), _max_len(v)
    else:
        yield key, _max_len(item, 20)
def _get_dict_as_table_data(
    data: Dict,
    include: Optional[Collection] = None,
    exclude: Optional[Collection] = None,
    upper_keys: Optional[Collection] = None,
):
    """Get ``data`` dict as table rows.
    If specified, excluded keys are removed. Excluded keys can either be
    fully specified (e.g. ``foo/bar/baz``) or specify a top-level dictionary
    (e.g. ``foo``), but no intermediate levels (e.g. ``foo/bar``). If this is
    needed, we can revisit the logic at a later point.
    The same is true for included keys. If a top-level key is included (e.g. ``foo``)
    then all sub keys will be included, too, except if they are excluded.
    If keys are both excluded and included, exclusion takes precedence. Thus, if
    ``foo`` is excluded but ``foo/bar`` is included, it won't show up in the output.
    """
    include = include or set()
    exclude = exclude or set()
    upper_keys = upper_keys or set()
    upper = []
    lower = []
    for key, value in sorted(data.items()):
        # Exclude top-level keys
        if key in exclude:
            continue
        for k, v in _render_table_item(str(key), value):
            # k is now the full subkey, e.g. config/nested/key
            # We can exclude the full key
            if k in exclude:
                continue
            # If we specify includes, top-level includes should take precedence
            # (e.g. if `config` is in include, include config always).
            if include and key not in include and k not in include:
                continue
            if key in upper_keys:
                upper.append([k, v])
            else:
                lower.append([k, v])
    if not upper:
        return lower
    elif not lower:
        return upper
    else:
        return upper + lower
if sys.stdout and sys.stdout.encoding and sys.stdout.encoding.startswith("utf"):
    # Copied/adjusted from tabulate
    AIR_TABULATE_TABLEFMT = TableFormat(
        lineabove=Line("╭", "─", "─", "╮"),
        linebelowheader=Line("├", "─", "─", "┤"),
        linebetweenrows=None,
        linebelow=Line("╰", "─", "─", "╯"),
        headerrow=DataRow("│", " ", "│"),
        datarow=DataRow("│", " ", "│"),
        padding=1,
        with_header_hide=None,
    )
else:
    # For non-utf output, use ascii-compatible characters.
    # This prevents errors e.g. when legacy windows encoding is used.
    AIR_TABULATE_TABLEFMT = TableFormat(
        lineabove=Line("+", "-", "-", "+"),
        linebelowheader=Line("+", "-", "-", "+"),
        linebetweenrows=None,
        linebelow=Line("+", "-", "-", "+"),
        headerrow=DataRow("|", " ", "|"),
        datarow=DataRow("|", " ", "|"),
        padding=1,
        with_header_hide=None,
    )
def _print_dict_as_table(
    data: Dict,
    header: Optional[str] = None,
    include: Optional[Collection[str]] = None,
    exclude: Optional[Collection[str]] = None,
    division: Optional[Collection[str]] = None,
):
    table_data = _get_dict_as_table_data(
        data=data, include=include, exclude=exclude, upper_keys=division
    )
    headers = [header, ""] if header else []
    if not table_data:
        return
    print(
        tabulate(
            table_data,
            headers=headers,
            colalign=("left", "right"),
            tablefmt=AIR_TABULATE_TABLEFMT,
        )
    )
[docs]
class ProgressReporter(Callback):
    """Periodically prints out status update."""
    # TODO: Make this configurable
    _heartbeat_freq = 30  # every 30 sec
    # to be updated by subclasses.
    _heartbeat_threshold = None
    _start_end_verbosity = None
    _intermediate_result_verbosity = None
    _addressing_tmpl = None
    def __init__(
        self,
        verbosity: AirVerbosity,
        progress_metrics: Optional[Union[List[str], List[Dict[str, str]]]] = None,
    ):
        """
        Args:
            verbosity: AirVerbosity level.
        """
        self._verbosity = verbosity
        self._start_time = time.time()
        self._last_heartbeat_time = float("-inf")
        self._start_time = time.time()
        self._progress_metrics = progress_metrics
        self._trial_last_printed_results = {}
        self._in_block = None
    @property
    def verbosity(self) -> AirVerbosity:
        return self._verbosity
    def setup(
        self,
        start_time: Optional[float] = None,
        **kwargs,
    ):
        self._start_time = start_time
    def _start_block(self, indicator: Any):
        if self._in_block != indicator:
            self._end_block()
        self._in_block = indicator
    def _end_block(self):
        if self._in_block:
            print("")
        self._in_block = None
    def on_experiment_end(self, trials: List["Trial"], **info):
        self._end_block()
    def experiment_started(
        self,
        experiment_name: str,
        experiment_path: str,
        searcher_str: str,
        scheduler_str: str,
        total_num_samples: int,
        tensorboard_path: Optional[str] = None,
        **kwargs,
    ):
        self._start_block("exp_start")
        print(f"\nView detailed results here: {experiment_path}")
        if tensorboard_path:
            print(
                f"To visualize your results with TensorBoard, run: "
                f"`tensorboard --logdir {tensorboard_path}`"
            )
    @property
    def _time_heartbeat_str(self):
        current_time_str, running_time_str = _get_time_str(
            self._start_time, time.time()
        )
        return (
            f"Current time: {current_time_str}. Total running time: " + running_time_str
        )
    def print_heartbeat(self, trials, *args, force: bool = False):
        if self._verbosity < self._heartbeat_threshold:
            return
        if force or time.time() - self._last_heartbeat_time >= self._heartbeat_freq:
            self._print_heartbeat(trials, *args, force=force)
            self._last_heartbeat_time = time.time()
    def _print_heartbeat(self, trials, *args, force: bool = False):
        raise NotImplementedError
    def _print_result(self, trial, result: Optional[Dict] = None, force: bool = False):
        """Only print result if a different result has been reported, or force=True"""
        result = result or trial.last_result
        last_result_iter = self._trial_last_printed_results.get(trial.trial_id, -1)
        this_iter = result.get(TRAINING_ITERATION, 0)
        if this_iter != last_result_iter or force:
            _print_dict_as_table(
                result,
                header=f"{self._addressing_tmpl.format(trial)} result",
                include=self._progress_metrics,
                exclude=BLACKLISTED_KEYS,
                division=AUTO_RESULT_KEYS,
            )
            self._trial_last_printed_results[trial.trial_id] = this_iter
    def _print_config(self, trial):
        _print_dict_as_table(
            trial.config, header=f"{self._addressing_tmpl.format(trial)} config"
        )
    def on_trial_result(
        self,
        iteration: int,
        trials: List[Trial],
        trial: Trial,
        result: Dict,
        **info,
    ):
        if self.verbosity < self._intermediate_result_verbosity:
            return
        self._start_block(f"trial_{trial}_result_{result[TRAINING_ITERATION]}")
        curr_time_str, running_time_str = _get_time_str(self._start_time, time.time())
        print(
            f"{self._addressing_tmpl.format(trial)} "
            f"finished iteration {result[TRAINING_ITERATION]} "
            f"at {curr_time_str}. Total running time: " + running_time_str
        )
        self._print_result(trial, result)
    def on_trial_complete(
        self, iteration: int, trials: List[Trial], trial: Trial, **info
    ):
        if self.verbosity < self._start_end_verbosity:
            return
        curr_time_str, running_time_str = _get_time_str(self._start_time, time.time())
        finished_iter = 0
        if trial.last_result and TRAINING_ITERATION in trial.last_result:
            finished_iter = trial.last_result[TRAINING_ITERATION]
        self._start_block(f"trial_{trial}_complete")
        print(
            f"{self._addressing_tmpl.format(trial)} "
            f"completed after {finished_iter} iterations "
            f"at {curr_time_str}. Total running time: " + running_time_str
        )
        self._print_result(trial)
    def on_trial_error(
        self, iteration: int, trials: List["Trial"], trial: "Trial", **info
    ):
        curr_time_str, running_time_str = _get_time_str(self._start_time, time.time())
        finished_iter = 0
        if trial.last_result and TRAINING_ITERATION in trial.last_result:
            finished_iter = trial.last_result[TRAINING_ITERATION]
        self._start_block(f"trial_{trial}_error")
        print(
            f"{self._addressing_tmpl.format(trial)} "
            f"errored after {finished_iter} iterations "
            f"at {curr_time_str}. Total running time: {running_time_str}\n"
            f"Error file: {trial.error_file}"
        )
        self._print_result(trial)
    def on_trial_recover(
        self, iteration: int, trials: List["Trial"], trial: "Trial", **info
    ):
        self.on_trial_error(iteration=iteration, trials=trials, trial=trial, **info)
    def on_checkpoint(
        self,
        iteration: int,
        trials: List[Trial],
        trial: Trial,
        checkpoint: Checkpoint,
        **info,
    ):
        if self._verbosity < self._intermediate_result_verbosity:
            return
        # don't think this is supposed to happen but just to be safe.
        saved_iter = "?"
        if trial.last_result and TRAINING_ITERATION in trial.last_result:
            saved_iter = trial.last_result[TRAINING_ITERATION]
        self._start_block(f"trial_{trial}_result_{saved_iter}")
        loc = f"({checkpoint.filesystem.type_name}){checkpoint.path}"
        print(
            f"{self._addressing_tmpl.format(trial)} "
            f"saved a checkpoint for iteration {saved_iter} "
            f"at: {loc}"
        )
    def on_trial_start(self, iteration: int, trials: List[Trial], trial: Trial, **info):
        if self.verbosity < self._start_end_verbosity:
            return
        has_config = bool(trial.config)
        self._start_block(f"trial_{trial}_start")
        if has_config:
            print(
                f"{self._addressing_tmpl.format(trial)} " f"started with configuration:"
            )
            self._print_config(trial)
        else:
            print(
                f"{self._addressing_tmpl.format(trial)} "
                f"started without custom configuration."
            ) 
def _detect_reporter(
    verbosity: AirVerbosity,
    num_samples: int,
    entrypoint: Optional[AirEntrypoint] = None,
    metric: Optional[str] = None,
    mode: Optional[str] = None,
    config: Optional[Dict] = None,
    progress_metrics: Optional[Union[List[str], List[Dict[str, str]]]] = None,
):
    if entrypoint in {
        AirEntrypoint.TUNE_RUN,
        AirEntrypoint.TUNE_RUN_EXPERIMENTS,
        AirEntrypoint.TUNER,
    }:
        reporter = TuneTerminalReporter(
            verbosity,
            num_samples=num_samples,
            metric=metric,
            mode=mode,
            config=config,
            progress_metrics=progress_metrics,
        )
    else:
        reporter = TrainReporter(verbosity, progress_metrics=progress_metrics)
    return reporter
[docs]
class TuneReporterBase(ProgressReporter):
    _heartbeat_threshold = AirVerbosity.DEFAULT
    _wrap_headers = False
    _intermediate_result_verbosity = AirVerbosity.VERBOSE
    _start_end_verbosity = AirVerbosity.DEFAULT
    _addressing_tmpl = "Trial {}"
    def __init__(
        self,
        verbosity: AirVerbosity,
        num_samples: int = 0,
        metric: Optional[str] = None,
        mode: Optional[str] = None,
        config: Optional[Dict] = None,
        progress_metrics: Optional[Union[List[str], List[Dict[str, str]]]] = None,
    ):
        self._num_samples = num_samples
        self._metric = metric
        self._mode = mode
        # will be populated when first result comes in.
        self._inferred_metric = None
        self._inferred_params = _infer_params(config or {})
        super(TuneReporterBase, self).__init__(
            verbosity=verbosity, progress_metrics=progress_metrics
        )
    def setup(
        self,
        start_time: Optional[float] = None,
        total_samples: Optional[int] = None,
        **kwargs,
    ):
        super().setup(start_time=start_time)
        self._num_samples = total_samples
    def _get_overall_trial_progress_str(self, trials):
        result = " | ".join(
            [
                f"{len(trials)} {status}"
                for status, trials in _get_trials_by_state(trials).items()
            ]
        )
        return f"Trial status: {result}"
    # TODO: Return a more structured type to share code with Jupyter flow.
    def _get_heartbeat(
        self, trials, *sys_args, force_full_output: bool = False
    ) -> Tuple[List[str], _TrialTableData]:
        result = list()
        # Trial status: 1 RUNNING | 7 PENDING
        result.append(self._get_overall_trial_progress_str(trials))
        # Current time: 2023-02-24 12:35:39 (running for 00:00:37.40)
        result.append(self._time_heartbeat_str)
        # Logical resource usage: 8.0/64 CPUs, 0/0 GPUs
        result.extend(sys_args)
        # Current best trial: TRIAL NAME, metrics: {...}, parameters: {...}
        current_best_trial, metric = _current_best_trial(
            trials, self._metric, self._mode
        )
        if current_best_trial:
            result.append(_best_trial_str(current_best_trial, metric))
        # Now populating the trial table data.
        if not self._inferred_metric:
            # try inferring again.
            self._inferred_metric = _infer_user_metrics(trials)
        all_metrics = list(DEFAULT_COLUMNS.keys()) + self._inferred_metric
        trial_table_data = _get_trial_table_data(
            trials,
            param_keys=self._inferred_params,
            metric_keys=all_metrics,
            all_rows=force_full_output,
            wrap_headers=self._wrap_headers,
        )
        return result, trial_table_data
    def _print_heartbeat(self, trials, *sys_args, force: bool = False):
        raise NotImplementedError 
[docs]
class TuneTerminalReporter(TuneReporterBase):
    def experiment_started(
        self,
        experiment_name: str,
        experiment_path: str,
        searcher_str: str,
        scheduler_str: str,
        total_num_samples: int,
        tensorboard_path: Optional[str] = None,
        **kwargs,
    ):
        if total_num_samples > sys.maxsize:
            total_num_samples_str = "infinite"
        else:
            total_num_samples_str = str(total_num_samples)
        print(
            tabulate(
                [
                    ["Search algorithm", searcher_str],
                    ["Scheduler", scheduler_str],
                    ["Number of trials", total_num_samples_str],
                ],
                headers=["Configuration for experiment", experiment_name],
                tablefmt=AIR_TABULATE_TABLEFMT,
            )
        )
        super().experiment_started(
            experiment_name=experiment_name,
            experiment_path=experiment_path,
            searcher_str=searcher_str,
            scheduler_str=scheduler_str,
            total_num_samples=total_num_samples,
            tensorboard_path=tensorboard_path,
            **kwargs,
        )
    def _print_heartbeat(self, trials, *sys_args, force: bool = False):
        if self._verbosity < self._heartbeat_threshold and not force:
            return
        heartbeat_strs, table_data = self._get_heartbeat(
            trials, *sys_args, force_full_output=force
        )
        self._start_block("heartbeat")
        for s in heartbeat_strs:
            print(s)
        # now print the table using Tabulate
        more_infos = []
        all_data = []
        fail_header = table_data.header
        for sub_table in table_data.data:
            all_data.extend(sub_table.trial_infos)
            if sub_table.more_info:
                more_infos.append(sub_table.more_info)
        print(
            tabulate(
                all_data,
                headers=fail_header,
                tablefmt=AIR_TABULATE_TABLEFMT,
                showindex=False,
            )
        )
        if more_infos:
            print(", ".join(more_infos))
        if not force:
            # Only print error table at end of training
            return
        trials_with_error = _get_trials_with_error(trials)
        if not trials_with_error:
            return
        self._start_block("status_errored")
        print(f"Number of errored trials: {len(trials_with_error)}")
        fail_header = ["Trial name", "# failures", "error file"]
        fail_table_data = [
            [
                str(trial),
                str(trial.run_metadata.num_failures)
                + ("" if trial.status == Trial.ERROR else "*"),
                trial.error_file,
            ]
            for trial in trials_with_error
        ]
        print(
            tabulate(
                fail_table_data,
                headers=fail_header,
                tablefmt=AIR_TABULATE_TABLEFMT,
                showindex=False,
                colalign=("left", "right", "left"),
            )
        )
        if any(trial.status == Trial.TERMINATED for trial in trials_with_error):
            print("* The trial terminated successfully after retrying.") 
[docs]
class TrainReporter(ProgressReporter):
    # the minimal verbosity threshold at which heartbeat starts getting printed.
    _heartbeat_threshold = AirVerbosity.VERBOSE
    _intermediate_result_verbosity = AirVerbosity.DEFAULT
    _start_end_verbosity = AirVerbosity.DEFAULT
    _addressing_tmpl = "Training"
    def _get_heartbeat(self, trials: List[Trial], force_full_output: bool = False):
        # Training on iteration 1. Current time: 2023-03-22 15:29:25 (running for 00:00:03.24)  # noqa
        if len(trials) == 0:
            return
        trial = trials[0]
        if trial.status != Trial.RUNNING:
            return " ".join(
                [f"Training is in {trial.status} status.", self._time_heartbeat_str]
            )
        if not trial.last_result or TRAINING_ITERATION not in trial.last_result:
            iter_num = 1
        else:
            iter_num = trial.last_result[TRAINING_ITERATION] + 1
        return " ".join(
            [f"Training on iteration {iter_num}.", self._time_heartbeat_str]
        )
    def _print_heartbeat(self, trials, *args, force: bool = False):
        print(self._get_heartbeat(trials, force_full_output=force))
    def on_trial_result(
        self,
        iteration: int,
        trials: List[Trial],
        trial: Trial,
        result: Dict,
        **info,
    ):
        self._last_heartbeat_time = time.time()
        super().on_trial_result(
            iteration=iteration, trials=trials, trial=trial, result=result, **info
        )