tinytqdm write support (#6359)

* add write support

* add test

* update test case to compare write outputs

* assert final write output

* flush when using write

* update write logic

* Revert "update write logic"

This reverts commit 5e0e611b46cde7a22e41aa5770bc4ccad20de073.

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
Francis Lata 2024-10-16 14:51:41 -04:00 committed by GitHub
parent d1094fce5e
commit 90eff347e2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 18 additions and 4 deletions

View File

@ -1,10 +1,9 @@
import os, time, math, functools
from pathlib import Path
from tqdm import tqdm
import multiprocessing
from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes
from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear, FUSE_CONV_BW
from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear, FUSE_CONV_BW, tqdm
from tinygrad.nn.state import get_parameters, get_state_dict, safe_load, safe_save
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup

View File

@ -80,8 +80,6 @@ class TestProgressBar(unittest.TestCase):
iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
elapsed = n/iters_per_sec if n>0 else 0
tqdm_output = tqdm.format_meter(n=n, total=total, elapsed=elapsed, ncols=ncols, prefix="Test", unit_scale=unit_scale)
# print(f"tiny: {tinytqdm_output}")
# print(f"tqdm: {tqdm_output}")
self._compare_bars(tinytqdm_output, tqdm_output)
if n > 3: break
@ -213,6 +211,21 @@ class TestProgressBar(unittest.TestCase):
self.assertEqual(tinytqdm_output, tqdm_output)
if n > 5: break
@patch('sys.stderr', new_callable=StringIO)
@patch('shutil.get_terminal_size')
def test_tqdm_write(self, mock_terminal_size, mock_stderr):
ncols, tqdm_fp = random.randint(80, 120), StringIO()
mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
mock_stderr.truncate(0)
tqdm_fp.truncate(0)
for i in tinytqdm(range(10)):
time.sleep(0.01)
tinytqdm.write(str(i))
tqdm.write(str(i), file=tqdm_fp)
tinytqdm_out, tqdm_out = mock_stderr.getvalue(), tqdm_fp.getvalue()
self.assertEqual(tinytqdm_out.split("\r\033[K")[-1], tqdm_out.split(f"{i-1}\n")[-1])
self.assertEqual(tinytqdm_out, tinytqdm_out)
def test_tqdm_perf(self):
st = time.perf_counter()
for _ in tqdm(range(100)): time.sleep(SLEEP_TIME)

View File

@ -310,6 +310,8 @@ class tqdm:
sz = max(ncols-len(self.desc)-3-2-2-len(suf), 1)
bar = '\r' + self.desc + (f'{100*prog:3.0f}%|{(""*int(num:=sz*prog)+" ▏▎▍▌▋▊▉"[int(8*num)%8].strip()).ljust(sz," ")}| ' if self.t else '') + suf
print(bar[:ncols+1], flush=True, end='\n'*close, file=sys.stderr)
@classmethod
def write(cls, s:str): print(f"\r\033[K{s}", flush=True, file=sys.stderr)
class trange(tqdm):
def __init__(self, n:int, **kwargs): super().__init__(iterable=range(n), total=n, **kwargs)