Report

from torch_snippets.torch_loader import Report
import numpy as np
import time
n_epochs = 3
report = Report(n_epochs)
random_walker1 = 0
random_walker2 = 0

for epoch in range(n_epochs):
    for ix in range(1000):
        report.record(
            pos=epoch + (ix + 1) / 1000,
            loss=random_walker1,
            val_loss=random_walker2,
            end="\r",
        )
        random_walker1 += np.random.normal()
        random_walker2 += np.random.normal()
        time.sleep(0.001)
    report.report_avgs(epoch + 1)

report.plot()
EPOCH: 1.000    loss: -6.503    val_loss: -3.093    (1.19s - 2.38s remaining)))
EPOCH: 2.000    loss: 48.754    val_loss: -6.265    (2.37s - 1.18s remaining))
EPOCH: 3.000    loss: 38.115    val_loss: -29.732   (3.54s - 0.00s remaining)

n_epochs = 5
report = Report(n_epochs, old_report=report)

for epoch in range(n_epochs):
    for ix in range(1000):
        report.record(
            pos=epoch + (ix + 1) / 1000,
            loss=random_walker1,
            val_loss=random_walker2,
            end="\r",
        )
        random_walker1 += np.random.normal()
        random_walker2 += np.random.normal()
        time.sleep(0.001)
    report.report_avgs(epoch + 1)
EPOCH: 1.000    loss: 29.338    val_loss: -74.955   (1.17s - 4.70s remaining))
EPOCH: 2.000    loss: 0.340 val_loss: -110.763  (2.35s - 3.52s remaining)))
EPOCH: 3.000    loss: 30.617    val_loss: -84.599   (3.51s - 2.34s remaining))
EPOCH: 4.000    loss: 34.309    val_loss: -27.520   (4.68s - 1.17s remaining)
EPOCH: 5.000    loss: 15.252    val_loss: -46.033   (5.85s - 0.00s remaining)
import matplotlib.pyplot as plt

fig, ax = plt.subplots()
ax.vlines(0, -100, 100, colors=["red"])
report.plot(ax=ax)