Source code for draugr.writers.terminal.terminal_plot_writer

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import Union

import numpy
import torch
from PIL import Image
from tqdm import tqdm

from draugr.metrics import MetricCollection
from draugr.writers.mixins import ImageWriterMixin
from draugr.drawers.terminal import (
    terminal_render_image,
    terminalise_image,
    styled_terminal_plot_stats_shared_x,
    terminal_plot,
)
from draugr.writers.writer import Writer

__author__ = "Christian Heider Nielsen"
__doc__ = """
Created on 27/04/2019

@author: cnheider
"""
__all__ = ["TerminalWriter"]


[docs]class TerminalWriter(Writer, ImageWriterMixin): """description"""
[docs] def image( self, tag: str, data: Union[numpy.ndarray, torch.Tensor, Image.Image], step, *, dataformats: str = "NCHW", **kwargs, ) -> None: """description""" self.E.write(terminalise_image(terminal_render_image(data, scale=(28, 28))))
def _open(self): self.E = tqdm() self.values = MetricCollection() return self def _close(self, exc_type=None, exc_val=None, exc_tb=None): return self.E.close()
[docs] def __init__(self, **kwargs): super().__init__(**kwargs) self._stats_tag = "stats"
[docs] def stats(self, value: MetricCollection, step_i: int = None): """ :param step_i: :type value: object""" if step_i: if self.filter(self._stats_tag): self._stats(value, self._counter[self._stats_tag]) self._counter[self._stats_tag] = step_i else: if self.filter(self._stats_tag): self._stats(value, self._counter[self._stats_tag]) self._counter[self._stats_tag] += 1
def _stats(self, stats: MetricCollection, step_i): styled_terminal_plot_stats_shared_x(stats, printer=self.E.write) self.E.set_description( f"Epi: {step_i}, " f"Sig: {stats.signal.running_value[-1]:.3f}, " f"Dur: {stats.duration.running_value[-1]:.1f}, " f"TD Err: {stats.td_error.running_value[-1]:.3f}, " f"Eps: {stats.epsilon.running_value[-1]:.3f}" ) def _scalar(self, tag: str, value: float, step: int): self.values[tag] = value # styled_terminal_plot_stats_shared_x(self.values, printer=self.E.write) terminal_plot([value], printer=self.E.write) self.E.set_description(f"Tag:{tag} Val:{value} Step:{step}")
if __name__ == "__main__": with TerminalWriter() as w: w.scalar("What", 4) w.image("bro", numpy.random.randint(0, 255, (28, 28, 3)), 0) ''' def train_episodically_old(self, env, test_env, *, rollouts=2000, render=False, render_frequency=100, stat_frequency=10, ): E = range(1, rollouts) E = tqdm(E, f"Episode: {1}", leave=False, disable=not render) stats = draugr.StatisticCollection(stats=("signal", "duration", "entropy")) for episode_i in E: initial_state = env.reset() if episode_i % stat_frequency == 0: draugr.styled_terminal_plot_stats_shared_x(stats, printer=E.write) E.set_description( f"Epi: {episode_i}, " f"Sig: {stats.signal.running_value[-1]:.3f}, " f"Dur: {stats.duration.running_value[-1]:.1f}" ) if render and episode_i % render_frequency == 0: signal, dur, entropy, *extras = self.rollout( initial_state, env, render=render ) else: signal, dur, entropy, *extras = self.rollout(initial_state, env) stats.duration.append(dur) stats.signal.append(signal) stats.entropy.append(entropy) if self.end_training: break return NOD(model=self._distribution_parameter_regressor, stats=stats) def train_episodically_old(self, _environment, *, rollouts=10000, render=False, render_frequency=100, stat_frequency=100, **kwargs, ): """ :param _environment: :type _environment:,0 :param rollouts: :type rollouts: :param render: :type render: :param render_frequency: :type render_frequency: :param stat_frequency: :type stat_frequency: :return: :rtype: """ stats = draugr.StatisticCollection( stats=("signal", "duration", "td_error", "epsilon") ) E = range(1, rollouts) E = tqdm(E, leave=False, disable=not render) for episode_i in E: initial_state = _environment.reset() if episode_i % stat_frequency == 0: draugr.styled_terminal_plot_stats_shared_x(stats, printer=E.write) E.set_description( f"Epi: {episode_i}, " f"Sig: {stats.signal.running_value[-1]:.3f}, " f"Dur: {stats.duration.running_value[-1]:.1f}, " f"TD Err: {stats.td_error.running_value[-1]:.3f}, " f"Eps: {stats.epsilon.running_value[-1]:.3f}" ) if render and episode_i % render_frequency == 0: signal, dur, td_error, *extras = self.rollout( initial_state, _environment, render=render ) else: signal, dur, td_error, *extras = self.rollout( initial_state, _environment ) stats.append(signal, dur, td_error, self._current_eps_threshold) if self.end_training: break return NOD(model=self._value_model, stats=stats) '''