tinygrad/test/test_tqdm.py

109 lines
4.5 KiB
Python

import time, random, unittest
from tqdm import tqdm
from unittest.mock import patch
from io import StringIO
from tinygrad.helpers import tinytqdm
from collections import namedtuple
class TestProgressBar(unittest.TestCase):
def _compare_bars(self, bar1, bar2, cmp_prog=False):
prefix1, prog1, suffix1 = bar1.split("|")
prefix2, prog2, suffix2 = bar2.split("|")
self.assertEqual(len(bar1), len(bar2))
self.assertEqual(prefix1, prefix2)
def parse_timer(timer): return sum([int(timer.split(":")[0])*60, int(timer.split(":")[1])])
if "?" not in suffix1 and "?" not in suffix2:
# allow for few sec diff in timers (removes flakiness)
timer1, rm1 = [parse_timer(timer) for timer in suffix1.split("[")[-1].split(",")[0].split("<")]
timer2, rm2 = [parse_timer(timer) for timer in suffix2.split("[")[-1].split(",")[0].split("<")]
self.assertTrue(abs(timer1 - timer2) <= 5)
self.assertTrue(abs(rm1 - rm2) <= 5)
# get suffix without timers
suffix1 = suffix1.split("[")[0] + suffix1.split(",")[1]
suffix2 = suffix2.split("[")[0] + suffix2.split(",")[1]
self.assertEqual(suffix1, suffix2)
else:
self.assertEqual(suffix1, suffix2)
diff = sum([1 for c1, c2 in zip(prog1, prog2) if c1 == c2]) # allow 1 char diff (due to tqdm special chars)
self.assertTrue(not cmp_prog or diff <= 1)
@patch('sys.stderr', new_callable=StringIO)
@patch('shutil.get_terminal_size')
def test_tqdm_output_iter_e2e(self, mock_terminal_size, mock_stderr):
for _ in range(10):
total, ncols = random.randint(5, 30), random.randint(80, 240)
mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
mock_stderr.truncate(0)
# compare bars at each iteration (only when tinytqdm bar has been updated)
for n in (bar := tinytqdm(range(total), desc="Test: ")):
time.sleep(0.01)
if bar.i % bar.skip != 0: continue
tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
elapsed = n/iters_per_sec if n>0 else 0
tqdm_output = tqdm.format_meter(n=n, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
self._compare_bars(tinytqdm_output, tqdm_output)
# compare final bars
tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
elapsed = total/iters_per_sec if n>0 else 0
tqdm_output = tqdm.format_meter(n=total, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
self._compare_bars(tinytqdm_output, tqdm_output)
@patch('sys.stderr', new_callable=StringIO)
@patch('shutil.get_terminal_size')
def test_tqdm_output_custom_e2e(self, mock_terminal_size, mock_stderr):
for _ in range(10):
total, ncols = random.randint(10000, 100000), random.randint(80, 120)
mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
mock_stderr.truncate(0)
# compare bars at each iteration (only when tinytqdm bar has been updated)
bar = tinytqdm(total=total, desc="Test: ")
n = 0
while n < total:
time.sleep(0.01)
incr = (total // 10) + random.randint(0, 100)
if n + incr > total: incr = total - n
bar.update(incr, close=n+incr==total)
n += incr
if bar.i % bar.skip != 0: continue
tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
elapsed = n/iters_per_sec if n>0 else 0
tqdm_output = tqdm.format_meter(n=n, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
self._compare_bars(tinytqdm_output, tqdm_output)
def test_tqdm_perf(self):
st = time.perf_counter()
for _ in tqdm(range(100)): time.sleep(0.01)
tqdm_time = time.perf_counter() - st
st = time.perf_counter()
for _ in tinytqdm(range(100)): time.sleep(0.01)
tinytqdm_time = time.perf_counter() - st
assert tinytqdm_time < 2.0 * tqdm_time
def test_tqdm_perf_high_iter(self):
st = time.perf_counter()
for _ in tqdm(range(10^7)): pass
tqdm_time = time.perf_counter() - st
st = time.perf_counter()
for _ in tinytqdm(range(10^7)): pass
tinytqdm_time = time.perf_counter() - st
assert tinytqdm_time < 5 * tqdm_time
if __name__ == '__main__':
unittest.main()