fix: hlb cifar types (#2099)

This commit is contained in:
Sean D'Souza 2023-10-17 22:23:50 -04:00 committed by GitHub
parent 9b1c3cd9ca
commit 999c95ea29
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 7 deletions

View File

@ -11,6 +11,7 @@ if __name__ == "__main__":
# https://siboehm.com/articles/22/CUDA-MMM
import random, time
import numpy as np
from typing import Any, Dict, Optional, SupportsIndex, Type, Union
from extra.datasets import fetch_cifar, cifar_mean, cifar_std
from tinygrad import nn
from tinygrad.nn.state import get_state_dict
@ -18,6 +19,7 @@ from tinygrad.nn import optim
from tinygrad.ops import Device
from tinygrad.tensor import Tensor
from tinygrad.ops import GlobalCounters
from tinygrad.shape.symbolic import Node
from extra.lr_scheduler import OneCycleLR
from tinygrad.jit import TinyJit
from extra.dist import collectives
@ -26,7 +28,7 @@ BS, EVAL_BS, STEPS = getenv("BS", 512), getenv('EVAL_BS', 500), getenv("STEPS",
if getenv("HALF", 0):
Tensor.default_type = dtypes.float16
np_dtype = np.float16
np_dtype: Type[Union[np.float16, np.float32]] = np.float16
else:
Tensor.default_type = dtypes.float32
np_dtype = np.float32
@ -85,7 +87,7 @@ def train_cifar():
# hyper-parameters were exactly the same as the original repo
bias_scaler = 58
hyp = {
hyp: Dict[str, Any] = {
'seed' : 209,
'opt': {
'bias_lr': 1.76 * bias_scaler/512,
@ -127,7 +129,8 @@ def train_cifar():
def _patches(data, patch_size=(kernel_size,kernel_size)):
h, w = patch_size
c = data.shape[1]
return np.lib.stride_tricks.sliding_window_view(data, window_shape=(h,w), axis=(2,3)).transpose((0,3,2,1,4,5)).reshape((-1,c,h,w))
axis: SupportsIndex = (2, 3) # type: ignore
return np.lib.stride_tricks.sliding_window_view(data, window_shape=(h,w), axis=axis).transpose((0,3,2,1,4,5)).reshape((-1,c,h,w))
def _eigens(patches):
n,c,h,w = patches.shape
@ -142,7 +145,9 @@ def train_cifar():
# ========== Loss ==========
def cross_entropy(x:Tensor, y:Tensor, reduction:str='mean', label_smoothing:float=0.0) -> Tensor:
y = (1 - label_smoothing)*y + label_smoothing / y.shape[1]
divisor = y.shape[1]
assert not isinstance(divisor, Node), "sint not supported as divisor"
y = (1 - label_smoothing)*y + label_smoothing / divisor
if reduction=='none': return -x.log_softmax(axis=1).mul(y).sum(axis=1)
if reduction=='sum': return -x.log_softmax(axis=1).mul(y).sum(axis=1).sum()
return -x.log_softmax(axis=1).mul(y).sum(axis=1).mean()
@ -251,6 +256,7 @@ def train_cifar():
# this import needs to be done here because this is running in a subprocess
from extra.dist import OOB
assert OOB is not None, "OOB should be initialized"
rank, world_size = getenv("RANK"), getenv("WORLD_SIZE", 1)
X_train, Y_train, X_test, Y_test = fetch_cifar()
@ -341,7 +347,7 @@ def train_cifar():
# https://www.anandtech.com/show/16727/nvidia-announces-geforce-rtx-3080-ti-3070-ti-upgraded-cards-coming-in-june
# 136 TFLOPS is the theoretical max w float16 on 3080 Ti
model_ema = None
model_ema: Optional[modelEMA] = None
projected_ema_decay_val = hyp['ema']['decay_base'] ** hyp['ema']['every_n_steps']
best_eval = -1
i = 0

View File

@ -1,6 +1,6 @@
# this file needs to be very careful with its imports as to not accidentally initialize the runtimes
from multiprocessing.connection import Connection
from typing import Any, Callable, List, Tuple
from typing import Any, Callable, List, Optional, Tuple
from tinygrad.helpers import DEBUG, getenv
import multiprocessing as mp
import os
@ -22,7 +22,7 @@ class _OOB:
# receive some data from a target rank, blocks until data is received
def recv(self, target_rank:int) -> Any:
return self.pipes[target_rank * getenv("WORLD_SIZE") + getenv("RANK")][0].recv()
OOB = None
OOB: Optional[_OOB] = None
def init_oob(world_size:int):
os.environ["WORLD_SIZE"] = str(world_size)