tinytqdm.set_description and tinytrange (#5101)

This commit is contained in:
chenyu 2024-06-22 14:45:06 -04:00 committed by GitHub
parent 8080298739
commit e356807696
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 43 additions and 22 deletions

View File

@ -2,7 +2,7 @@ from typing import Tuple
import time
from tinygrad import Tensor, TinyJit, nn
import gymnasium as gym
from tqdm import trange
from tinygrad.helpers import trange
import numpy as np # TODO: remove numpy import
ENVIRONMENT_NAME = 'CartPole-v1'

View File

@ -1,9 +1,8 @@
# model based off https://towardsdatascience.com/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
from typing import List, Callable
from tinygrad import Tensor, TinyJit, nn, GlobalCounters
from tinygrad.helpers import getenv, colored
from tinygrad.helpers import getenv, colored, trange
from tinygrad.nn.datasets import mnist
from tqdm import trange
class Model:
def __init__(self):

View File

@ -1,9 +1,8 @@
# model based off https://towardsdatascience.com/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
from typing import List, Callable
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, Device
from tinygrad.helpers import getenv, colored
from tinygrad.helpers import getenv, colored, trange
from extra.datasets import fetch_mnist
from tqdm import trange
GPUS = [f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 2))]

View File

@ -1,11 +1,10 @@
#!/usr/bin/env python3
from typing import Optional, Union
import argparse
from tqdm import trange
import numpy as np
import tiktoken
from tinygrad import Tensor, TinyJit, Device, GlobalCounters, Variable
from tinygrad.helpers import Timing, DEBUG, getenv, fetch, colored
from tinygrad.helpers import Timing, DEBUG, getenv, fetch, colored, trange
from tinygrad.nn import Embedding, Linear, LayerNorm
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict

View File

@ -1,7 +1,6 @@
import functools, argparse, pathlib
from tqdm import tqdm
from tinygrad import Tensor, nn, Device, GlobalCounters, Variable
from tinygrad.helpers import Timing, Profiling, CI
from tinygrad.helpers import Timing, Profiling, CI, tqdm
from tinygrad.nn.state import torch_load, get_state_dict
from extra.models.llama import FeedForward, Transformer

View File

@ -1,11 +1,10 @@
from pathlib import Path
import numpy as np
from tqdm import trange
import torch
from torchvision.utils import make_grid, save_image
from tinygrad.nn.state import get_parameters
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv
from tinygrad.helpers import trange
from tinygrad.nn import optim
from extra.datasets import fetch_mnist

View File

@ -8,9 +8,8 @@ from collections import namedtuple
from PIL import Image
import numpy as np
from tqdm import tqdm
from tinygrad import Device, GlobalCounters, dtypes, Tensor, TinyJit
from tinygrad.helpers import Timing, Context, getenv, fetch, colored
from tinygrad.helpers import Timing, Context, getenv, fetch, colored, tqdm
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict

View File

@ -2,10 +2,9 @@ import traceback
import time
from multiprocessing import Process, Queue
import numpy as np
from tqdm import trange
from tinygrad.nn.state import get_parameters
from tinygrad.nn import optim
from tinygrad.helpers import getenv
from tinygrad.helpers import getenv, trange
from tinygrad.tensor import Tensor
from extra.datasets import fetch_cifar
from extra.models.efficientnet import EfficientNet

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python3
import pathlib, json
from tqdm import trange
from tinygrad.helpers import trange
from extra.datasets import fetch_mnist
from PIL import Image
import numpy as np

View File

@ -1,7 +1,6 @@
import numpy as np
from tqdm import trange
from tinygrad.tensor import Tensor
from tinygrad.helpers import CI
from tinygrad.helpers import CI, trange
from tinygrad.engine.jit import TinyJit

View File

@ -1,7 +1,6 @@
import random
from typing import Tuple
from tqdm import trange
from tinygrad.helpers import getenv, DEBUG, colored
from tinygrad.helpers import getenv, DEBUG, colored, trange
from tinygrad.shape.shapetracker import ShapeTracker
from test.external.fuzz_shapetracker import shapetracker_ops
from test.external.fuzz_shapetracker import do_permute, do_reshape_split_one, do_reshape_combine_two, do_flip, do_pad

View File

@ -1,9 +1,9 @@
import time, random, unittest
from tqdm import tqdm
from unittest.mock import patch
from io import StringIO
from tinygrad.helpers import tqdm as tinytqdm
from collections import namedtuple
from tqdm import tqdm
from tinygrad.helpers import tqdm as tinytqdm, trange as tinytrange
class TestProgressBar(unittest.TestCase):
def _compare_bars(self, bar1, bar2, cmp_prog=False):
@ -31,6 +31,7 @@ class TestProgressBar(unittest.TestCase):
diff = sum([1 for c1, c2 in zip(prog1, prog2) if c1 == c2]) # allow 1 char diff (due to tqdm special chars)
self.assertTrue(not cmp_prog or diff <= 1)
@patch('sys.stderr', new_callable=StringIO)
@patch('shutil.get_terminal_size')
def test_tqdm_output_iter(self, mock_terminal_size, mock_stderr):
@ -56,6 +57,31 @@ class TestProgressBar(unittest.TestCase):
tqdm_output = tqdm.format_meter(n=total, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
self._compare_bars(tinytqdm_output, tqdm_output)
@patch('sys.stderr', new_callable=StringIO)
@patch('shutil.get_terminal_size')
def test_trange_output_iter(self, mock_terminal_size, mock_stderr):
for _ in range(5):
total, ncols = random.randint(5, 30), random.randint(80, 240)
mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
mock_stderr.truncate(0)
# compare bars at each iteration (only when tinytqdm bar has been updated)
for n in (bar := tinytrange(total, desc="Test: ")):
time.sleep(0.01)
if bar.i % bar.skip != 0: continue
tiny_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
iters_per_sec = float(tiny_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
elapsed = n/iters_per_sec if n>0 else 0
tqdm_output = tqdm.format_meter(n=n, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
self._compare_bars(tiny_output, tqdm_output)
# compare final bars
tiny_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
iters_per_sec = float(tiny_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
elapsed = total/iters_per_sec if n>0 else 0
tqdm_output = tqdm.format_meter(n=total, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
self._compare_bars(tiny_output, tqdm_output)
@patch('sys.stderr', new_callable=StringIO)
@patch('shutil.get_terminal_size')
def test_tqdm_output_custom(self, mock_terminal_size, mock_stderr):

View File

@ -260,6 +260,7 @@ class tqdm:
yield item
self.update(1)
finally: self.update(close=True)
def set_description(self, desc:str): self.desc = desc
def update(self, n:int=0, close:bool=False):
self.n, self.i = self.n+n, self.i+1
if (self.i % self.skip != 0 and not close) or self.dis: return
@ -276,3 +277,6 @@ class tqdm:
sz = max(term-5-len(suf)-len(self.desc), 1)
bar = f'\r{self.desc}{round(100*prog):3}%|{""*round(sz*prog)}{" "*(sz-round(sz*prog))}{suf}' if self.t else f'\r{self.desc}{suf}{" "*term}'
print(bar[:term+1],flush=True,end='\n'*close,file=sys.stderr)
class trange(tqdm):
def __init__(self, n:int, **kwargs): super().__init__(iterable=range(n), total=n, **kwargs)