Fix mypy examples/beautiful_*.py (#6978)

* fix mypy examples/beautiful_*.py

* backwards

* add test

* Revert "add test"

This reverts commit 4d88845ba3f24d83621da0abf55096553abda7fa.

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
Friedrich Carl Eichenroth 2024-10-10 17:34:29 +02:00 committed by GitHub
parent 4ef5310039
commit 859d6d0407
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 6 deletions

View File

@ -1,7 +1,7 @@
import time
start_tm = time.perf_counter()
import math
from typing import Tuple
from typing import Tuple, cast
import numpy as np
from tinygrad import Tensor, nn, GlobalCounters, TinyJit, dtypes
from tinygrad.helpers import partition, trange, getenv, Context
@ -63,8 +63,8 @@ class ConvGroup:
self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding=1, bias=False)
self.norm1 = nn.BatchNorm(channels_out, track_running_stats=False, eps=1e-12, momentum=hyp['net']['batch_norm_momentum'])
self.norm2 = nn.BatchNorm(channels_out, track_running_stats=False, eps=1e-12, momentum=hyp['net']['batch_norm_momentum'])
self.norm1.weight.requires_grad = False
self.norm2.weight.requires_grad = False
cast(Tensor, self.norm1.weight).requires_grad = False
cast(Tensor, self.norm2.weight).requires_grad = False
def __call__(self, x:Tensor) -> Tensor:
x = self.norm1(self.conv1(x).max_pool2d().float()).cast(dtypes.default_float).quick_gelu()
return self.norm2(self.conv2(x).float()).cast(dtypes.default_float).quick_gelu()
@ -133,7 +133,7 @@ if __name__ == "__main__":
eval_batchsize = 2500
@TinyJit
@Tensor.test()
def val_step() -> Tensor:
def val_step() -> Tuple[Tensor, Tensor]:
# TODO with Tensor.no_grad()
Tensor.no_grad = True
loss, acc = [], []
@ -153,7 +153,7 @@ if __name__ == "__main__":
idxs = np.arange(X_train.shape[0])
np.random.shuffle(idxs)
tidxs = Tensor(idxs, dtype='int')[:num_steps_per_epoch*batchsize].reshape(num_steps_per_epoch, batchsize) # NOTE: long doesn't fold
train_loss = 0
train_loss:float = 0
for epoch_step in (t:=trange(num_steps_per_epoch)):
st = time.perf_counter()
GlobalCounters.reset()

View File

@ -4,7 +4,7 @@ from tinygrad import Tensor, TinyJit, nn, GlobalCounters, Device
from tinygrad.helpers import getenv, colored, trange
from tinygrad.nn.datasets import mnist
GPUS = [f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 2))]
GPUS = tuple(f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 2)))
class Model:
def __init__(self):