mirror of https://github.com/commaai/tinygrad.git
tinytqdm.set_description and tinytrange (#5101)
This commit is contained in:
parent
8080298739
commit
e356807696
|
@ -2,7 +2,7 @@ from typing import Tuple
|
|||
import time
|
||||
from tinygrad import Tensor, TinyJit, nn
|
||||
import gymnasium as gym
|
||||
from tqdm import trange
|
||||
from tinygrad.helpers import trange
|
||||
import numpy as np # TODO: remove numpy import
|
||||
|
||||
ENVIRONMENT_NAME = 'CartPole-v1'
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
# model based off https://towardsdatascience.com/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
|
||||
from typing import List, Callable
|
||||
from tinygrad import Tensor, TinyJit, nn, GlobalCounters
|
||||
from tinygrad.helpers import getenv, colored
|
||||
from tinygrad.helpers import getenv, colored, trange
|
||||
from tinygrad.nn.datasets import mnist
|
||||
from tqdm import trange
|
||||
|
||||
class Model:
|
||||
def __init__(self):
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
# model based off https://towardsdatascience.com/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
|
||||
from typing import List, Callable
|
||||
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, Device
|
||||
from tinygrad.helpers import getenv, colored
|
||||
from tinygrad.helpers import getenv, colored, trange
|
||||
from extra.datasets import fetch_mnist
|
||||
from tqdm import trange
|
||||
|
||||
GPUS = [f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 2))]
|
||||
|
||||
|
|
|
@ -1,11 +1,10 @@
|
|||
#!/usr/bin/env python3
|
||||
from typing import Optional, Union
|
||||
import argparse
|
||||
from tqdm import trange
|
||||
import numpy as np
|
||||
import tiktoken
|
||||
from tinygrad import Tensor, TinyJit, Device, GlobalCounters, Variable
|
||||
from tinygrad.helpers import Timing, DEBUG, getenv, fetch, colored
|
||||
from tinygrad.helpers import Timing, DEBUG, getenv, fetch, colored, trange
|
||||
from tinygrad.nn import Embedding, Linear, LayerNorm
|
||||
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import functools, argparse, pathlib
|
||||
from tqdm import tqdm
|
||||
from tinygrad import Tensor, nn, Device, GlobalCounters, Variable
|
||||
from tinygrad.helpers import Timing, Profiling, CI
|
||||
from tinygrad.helpers import Timing, Profiling, CI, tqdm
|
||||
from tinygrad.nn.state import torch_load, get_state_dict
|
||||
from extra.models.llama import FeedForward, Transformer
|
||||
|
||||
|
|
|
@ -1,11 +1,10 @@
|
|||
from pathlib import Path
|
||||
import numpy as np
|
||||
from tqdm import trange
|
||||
import torch
|
||||
from torchvision.utils import make_grid, save_image
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.helpers import trange
|
||||
from tinygrad.nn import optim
|
||||
from extra.datasets import fetch_mnist
|
||||
|
||||
|
|
|
@ -8,9 +8,8 @@ from collections import namedtuple
|
|||
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from tinygrad import Device, GlobalCounters, dtypes, Tensor, TinyJit
|
||||
from tinygrad.helpers import Timing, Context, getenv, fetch, colored
|
||||
from tinygrad.helpers import Timing, Context, getenv, fetch, colored, tqdm
|
||||
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
|
||||
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
|
||||
|
||||
|
|
|
@ -2,10 +2,9 @@ import traceback
|
|||
import time
|
||||
from multiprocessing import Process, Queue
|
||||
import numpy as np
|
||||
from tqdm import trange
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.nn import optim
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.helpers import getenv, trange
|
||||
from tinygrad.tensor import Tensor
|
||||
from extra.datasets import fetch_cifar
|
||||
from extra.models.efficientnet import EfficientNet
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
import pathlib, json
|
||||
from tqdm import trange
|
||||
from tinygrad.helpers import trange
|
||||
from extra.datasets import fetch_mnist
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import numpy as np
|
||||
from tqdm import trange
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import CI
|
||||
from tinygrad.helpers import CI, trange
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import random
|
||||
from typing import Tuple
|
||||
from tqdm import trange
|
||||
from tinygrad.helpers import getenv, DEBUG, colored
|
||||
from tinygrad.helpers import getenv, DEBUG, colored, trange
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from test.external.fuzz_shapetracker import shapetracker_ops
|
||||
from test.external.fuzz_shapetracker import do_permute, do_reshape_split_one, do_reshape_combine_two, do_flip, do_pad
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
import time, random, unittest
|
||||
from tqdm import tqdm
|
||||
from unittest.mock import patch
|
||||
from io import StringIO
|
||||
from tinygrad.helpers import tqdm as tinytqdm
|
||||
from collections import namedtuple
|
||||
from tqdm import tqdm
|
||||
from tinygrad.helpers import tqdm as tinytqdm, trange as tinytrange
|
||||
|
||||
class TestProgressBar(unittest.TestCase):
|
||||
def _compare_bars(self, bar1, bar2, cmp_prog=False):
|
||||
|
@ -31,6 +31,7 @@ class TestProgressBar(unittest.TestCase):
|
|||
|
||||
diff = sum([1 for c1, c2 in zip(prog1, prog2) if c1 == c2]) # allow 1 char diff (due to tqdm special chars)
|
||||
self.assertTrue(not cmp_prog or diff <= 1)
|
||||
|
||||
@patch('sys.stderr', new_callable=StringIO)
|
||||
@patch('shutil.get_terminal_size')
|
||||
def test_tqdm_output_iter(self, mock_terminal_size, mock_stderr):
|
||||
|
@ -56,6 +57,31 @@ class TestProgressBar(unittest.TestCase):
|
|||
tqdm_output = tqdm.format_meter(n=total, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
|
||||
self._compare_bars(tinytqdm_output, tqdm_output)
|
||||
|
||||
@patch('sys.stderr', new_callable=StringIO)
|
||||
@patch('shutil.get_terminal_size')
|
||||
def test_trange_output_iter(self, mock_terminal_size, mock_stderr):
|
||||
for _ in range(5):
|
||||
total, ncols = random.randint(5, 30), random.randint(80, 240)
|
||||
mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
|
||||
mock_stderr.truncate(0)
|
||||
|
||||
# compare bars at each iteration (only when tinytqdm bar has been updated)
|
||||
for n in (bar := tinytrange(total, desc="Test: ")):
|
||||
time.sleep(0.01)
|
||||
if bar.i % bar.skip != 0: continue
|
||||
tiny_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
|
||||
iters_per_sec = float(tiny_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
|
||||
elapsed = n/iters_per_sec if n>0 else 0
|
||||
tqdm_output = tqdm.format_meter(n=n, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
|
||||
self._compare_bars(tiny_output, tqdm_output)
|
||||
|
||||
# compare final bars
|
||||
tiny_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
|
||||
iters_per_sec = float(tiny_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
|
||||
elapsed = total/iters_per_sec if n>0 else 0
|
||||
tqdm_output = tqdm.format_meter(n=total, total=total, elapsed=elapsed, ncols=ncols, prefix="Test")
|
||||
self._compare_bars(tiny_output, tqdm_output)
|
||||
|
||||
@patch('sys.stderr', new_callable=StringIO)
|
||||
@patch('shutil.get_terminal_size')
|
||||
def test_tqdm_output_custom(self, mock_terminal_size, mock_stderr):
|
||||
|
|
|
@ -260,6 +260,7 @@ class tqdm:
|
|||
yield item
|
||||
self.update(1)
|
||||
finally: self.update(close=True)
|
||||
def set_description(self, desc:str): self.desc = desc
|
||||
def update(self, n:int=0, close:bool=False):
|
||||
self.n, self.i = self.n+n, self.i+1
|
||||
if (self.i % self.skip != 0 and not close) or self.dis: return
|
||||
|
@ -276,3 +277,6 @@ class tqdm:
|
|||
sz = max(term-5-len(suf)-len(self.desc), 1)
|
||||
bar = f'\r{self.desc}{round(100*prog):3}%|{"█"*round(sz*prog)}{" "*(sz-round(sz*prog))}{suf}' if self.t else f'\r{self.desc}{suf}{" "*term}'
|
||||
print(bar[:term+1],flush=True,end='\n'*close,file=sys.stderr)
|
||||
|
||||
class trange(tqdm):
|
||||
def __init__(self, n:int, **kwargs): super().__init__(iterable=range(n), total=n, **kwargs)
|
Loading…
Reference in New Issue