mirror of https://github.com/commaai/tinygrad.git
tinytqdm write support (#6359)
* add write support * add test * update test case to compare write outputs * assert final write output * flush when using write * update write logic * Revert "update write logic" This reverts commit 5e0e611b46cde7a22e41aa5770bc4ccad20de073. --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
parent
d1094fce5e
commit
90eff347e2
|
@ -1,10 +1,9 @@
|
|||
import os, time, math, functools
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
import multiprocessing
|
||||
|
||||
from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes
|
||||
from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear, FUSE_CONV_BW
|
||||
from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear, FUSE_CONV_BW, tqdm
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict, safe_load, safe_save
|
||||
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup
|
||||
|
||||
|
|
|
@ -80,8 +80,6 @@ class TestProgressBar(unittest.TestCase):
|
|||
iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
|
||||
elapsed = n/iters_per_sec if n>0 else 0
|
||||
tqdm_output = tqdm.format_meter(n=n, total=total, elapsed=elapsed, ncols=ncols, prefix="Test", unit_scale=unit_scale)
|
||||
# print(f"tiny: {tinytqdm_output}")
|
||||
# print(f"tqdm: {tqdm_output}")
|
||||
self._compare_bars(tinytqdm_output, tqdm_output)
|
||||
if n > 3: break
|
||||
|
||||
|
@ -213,6 +211,21 @@ class TestProgressBar(unittest.TestCase):
|
|||
self.assertEqual(tinytqdm_output, tqdm_output)
|
||||
if n > 5: break
|
||||
|
||||
@patch('sys.stderr', new_callable=StringIO)
|
||||
@patch('shutil.get_terminal_size')
|
||||
def test_tqdm_write(self, mock_terminal_size, mock_stderr):
|
||||
ncols, tqdm_fp = random.randint(80, 120), StringIO()
|
||||
mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
|
||||
mock_stderr.truncate(0)
|
||||
tqdm_fp.truncate(0)
|
||||
for i in tinytqdm(range(10)):
|
||||
time.sleep(0.01)
|
||||
tinytqdm.write(str(i))
|
||||
tqdm.write(str(i), file=tqdm_fp)
|
||||
tinytqdm_out, tqdm_out = mock_stderr.getvalue(), tqdm_fp.getvalue()
|
||||
self.assertEqual(tinytqdm_out.split("\r\033[K")[-1], tqdm_out.split(f"{i-1}\n")[-1])
|
||||
self.assertEqual(tinytqdm_out, tinytqdm_out)
|
||||
|
||||
def test_tqdm_perf(self):
|
||||
st = time.perf_counter()
|
||||
for _ in tqdm(range(100)): time.sleep(SLEEP_TIME)
|
||||
|
|
|
@ -310,6 +310,8 @@ class tqdm:
|
|||
sz = max(ncols-len(self.desc)-3-2-2-len(suf), 1)
|
||||
bar = '\r' + self.desc + (f'{100*prog:3.0f}%|{("█"*int(num:=sz*prog)+" ▏▎▍▌▋▊▉"[int(8*num)%8].strip()).ljust(sz," ")}| ' if self.t else '') + suf
|
||||
print(bar[:ncols+1], flush=True, end='\n'*close, file=sys.stderr)
|
||||
@classmethod
|
||||
def write(cls, s:str): print(f"\r\033[K{s}", flush=True, file=sys.stderr)
|
||||
|
||||
class trange(tqdm):
|
||||
def __init__(self, n:int, **kwargs): super().__init__(iterable=range(n), total=n, **kwargs)
|
||||
|
|
Loading…
Reference in New Issue