mirror of https://github.com/commaai/tinygrad.git
fix: hlb cifar types (#2099)
This commit is contained in:
parent
9b1c3cd9ca
commit
999c95ea29
|
@ -11,6 +11,7 @@ if __name__ == "__main__":
|
|||
# https://siboehm.com/articles/22/CUDA-MMM
|
||||
import random, time
|
||||
import numpy as np
|
||||
from typing import Any, Dict, Optional, SupportsIndex, Type, Union
|
||||
from extra.datasets import fetch_cifar, cifar_mean, cifar_std
|
||||
from tinygrad import nn
|
||||
from tinygrad.nn.state import get_state_dict
|
||||
|
@ -18,6 +19,7 @@ from tinygrad.nn import optim
|
|||
from tinygrad.ops import Device
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import GlobalCounters
|
||||
from tinygrad.shape.symbolic import Node
|
||||
from extra.lr_scheduler import OneCycleLR
|
||||
from tinygrad.jit import TinyJit
|
||||
from extra.dist import collectives
|
||||
|
@ -26,7 +28,7 @@ BS, EVAL_BS, STEPS = getenv("BS", 512), getenv('EVAL_BS', 500), getenv("STEPS",
|
|||
|
||||
if getenv("HALF", 0):
|
||||
Tensor.default_type = dtypes.float16
|
||||
np_dtype = np.float16
|
||||
np_dtype: Type[Union[np.float16, np.float32]] = np.float16
|
||||
else:
|
||||
Tensor.default_type = dtypes.float32
|
||||
np_dtype = np.float32
|
||||
|
@ -85,7 +87,7 @@ def train_cifar():
|
|||
|
||||
# hyper-parameters were exactly the same as the original repo
|
||||
bias_scaler = 58
|
||||
hyp = {
|
||||
hyp: Dict[str, Any] = {
|
||||
'seed' : 209,
|
||||
'opt': {
|
||||
'bias_lr': 1.76 * bias_scaler/512,
|
||||
|
@ -127,7 +129,8 @@ def train_cifar():
|
|||
def _patches(data, patch_size=(kernel_size,kernel_size)):
|
||||
h, w = patch_size
|
||||
c = data.shape[1]
|
||||
return np.lib.stride_tricks.sliding_window_view(data, window_shape=(h,w), axis=(2,3)).transpose((0,3,2,1,4,5)).reshape((-1,c,h,w))
|
||||
axis: SupportsIndex = (2, 3) # type: ignore
|
||||
return np.lib.stride_tricks.sliding_window_view(data, window_shape=(h,w), axis=axis).transpose((0,3,2,1,4,5)).reshape((-1,c,h,w))
|
||||
|
||||
def _eigens(patches):
|
||||
n,c,h,w = patches.shape
|
||||
|
@ -142,7 +145,9 @@ def train_cifar():
|
|||
|
||||
# ========== Loss ==========
|
||||
def cross_entropy(x:Tensor, y:Tensor, reduction:str='mean', label_smoothing:float=0.0) -> Tensor:
|
||||
y = (1 - label_smoothing)*y + label_smoothing / y.shape[1]
|
||||
divisor = y.shape[1]
|
||||
assert not isinstance(divisor, Node), "sint not supported as divisor"
|
||||
y = (1 - label_smoothing)*y + label_smoothing / divisor
|
||||
if reduction=='none': return -x.log_softmax(axis=1).mul(y).sum(axis=1)
|
||||
if reduction=='sum': return -x.log_softmax(axis=1).mul(y).sum(axis=1).sum()
|
||||
return -x.log_softmax(axis=1).mul(y).sum(axis=1).mean()
|
||||
|
@ -251,6 +256,7 @@ def train_cifar():
|
|||
|
||||
# this import needs to be done here because this is running in a subprocess
|
||||
from extra.dist import OOB
|
||||
assert OOB is not None, "OOB should be initialized"
|
||||
rank, world_size = getenv("RANK"), getenv("WORLD_SIZE", 1)
|
||||
|
||||
X_train, Y_train, X_test, Y_test = fetch_cifar()
|
||||
|
@ -341,7 +347,7 @@ def train_cifar():
|
|||
# https://www.anandtech.com/show/16727/nvidia-announces-geforce-rtx-3080-ti-3070-ti-upgraded-cards-coming-in-june
|
||||
# 136 TFLOPS is the theoretical max w float16 on 3080 Ti
|
||||
|
||||
model_ema = None
|
||||
model_ema: Optional[modelEMA] = None
|
||||
projected_ema_decay_val = hyp['ema']['decay_base'] ** hyp['ema']['every_n_steps']
|
||||
best_eval = -1
|
||||
i = 0
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# this file needs to be very careful with its imports as to not accidentally initialize the runtimes
|
||||
from multiprocessing.connection import Connection
|
||||
from typing import Any, Callable, List, Tuple
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
from tinygrad.helpers import DEBUG, getenv
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
|
@ -22,7 +22,7 @@ class _OOB:
|
|||
# receive some data from a target rank, blocks until data is received
|
||||
def recv(self, target_rank:int) -> Any:
|
||||
return self.pipes[target_rank * getenv("WORLD_SIZE") + getenv("RANK")][0].recv()
|
||||
OOB = None
|
||||
OOB: Optional[_OOB] = None
|
||||
|
||||
def init_oob(world_size:int):
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
|
|
Loading…
Reference in New Issue