mirror of https://github.com/commaai/tinygrad.git
nn init matches torch (#901)
This commit is contained in:
parent
27845fd3a3
commit
8a928ed2f3
|
@ -85,14 +85,18 @@ def train_cifar():
|
|||
model = SpeedyResNet()
|
||||
|
||||
# init weights with torch
|
||||
# TODO: it doesn't learn with the tinygrad weights, likely since kaiming init
|
||||
if getenv("TORCHWEIGHTS"):
|
||||
from examples.hlb_cifar10_torch import SpeedyResNet as SpeedyResNetTorch
|
||||
torch_model = SpeedyResNetTorch()
|
||||
model_state_dict = optim.get_state_dict(model)
|
||||
torch_state_dict = torch_model.state_dict()
|
||||
for k,v in torch_state_dict.items():
|
||||
print(f"initting {k} from torch")
|
||||
old_mean_std = model_state_dict[k].mean().numpy(), model_state_dict[k].std().numpy()
|
||||
model_state_dict[k].assign(Tensor(v.detach().numpy())).realize()
|
||||
new_mean_std = model_state_dict[k].mean().numpy(), model_state_dict[k].std().numpy()
|
||||
print(f"initted {k:40s} {str(model_state_dict[k].shape):20s} from torch mean:{old_mean_std[0]:8.5f} -> {new_mean_std[0]:8.5f} std:{old_mean_std[1]:8.5f} -> {new_mean_std[1]:8.5f}")
|
||||
exit(0)
|
||||
|
||||
if getenv("ADAM"):
|
||||
optimizer = optim.Adam(optim.get_parameters(model), lr=Tensor([0.001]).realize())
|
||||
|
|
|
@ -8,6 +8,9 @@ from torch import optim
|
|||
from datasets import fetch_cifar
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
# allow TF32
|
||||
torch.set_float32_matmul_precision('high')
|
||||
|
||||
OSX = platform.system() == "Darwin"
|
||||
device = 'mps' if OSX else 'cuda'
|
||||
|
||||
|
|
|
@ -43,14 +43,14 @@ def normal_test(func, shape=(20, 23), alpha=0.05):
|
|||
y = np.random.randn(*shape).flatten()
|
||||
return kstest(x, y) >= alpha
|
||||
|
||||
def equal_distribution(tiny_func, torch_func, numpy_func, shape=(20, 23), alpha=0.05):
|
||||
def equal_distribution(tiny_func, torch_func, numpy_func=None, shape=(20, 23), alpha=0.05):
|
||||
Tensor.manual_seed(1337)
|
||||
torch.manual_seed(1337)
|
||||
np.random.seed(1337)
|
||||
x = tiny_func(*shape).cpu().numpy().flatten()
|
||||
y = numpy_func(shape).flatten()
|
||||
if numpy_func is not None: y = numpy_func(shape).flatten()
|
||||
z = torch_func(shape).numpy().flatten()
|
||||
return kstest(x, y) >= alpha and kstest(x, z) >= alpha
|
||||
return (numpy_func is None or kstest(x, y) >= alpha) and kstest(x, z) >= alpha
|
||||
|
||||
class TestRandomness(unittest.TestCase):
|
||||
def test_rand(self):
|
||||
|
@ -73,13 +73,12 @@ class TestRandomness(unittest.TestCase):
|
|||
self.assertFalse(normal_test(Tensor.glorot_uniform))
|
||||
self.assertTrue(equal_distribution(Tensor.glorot_uniform, lambda x: torch.nn.init.xavier_uniform_(torch.empty(x)), lambda x: (np.random.rand(*x) * 2 - 1) * math.sqrt(6 / (x[0] + math.prod(x[1:])))))
|
||||
|
||||
def test_kaiming_uniform(self, shape=(20, 23), a=0.01):
|
||||
def test_kaiming_uniform(self):
|
||||
Tensor.manual_seed(1337)
|
||||
torch.manual_seed(1337)
|
||||
np.random.seed(1337)
|
||||
|
||||
bound = (math.sqrt(3.0) * (math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(shape[1] * np.prod(shape[2:]))))
|
||||
self.assertTrue(equal_distribution(Tensor.kaiming_uniform, lambda x: torch.nn.init.kaiming_uniform_(torch.empty(x)), lambda x: np.random.uniform(low=-bound, high=bound, size=shape)))
|
||||
for shape in [(128, 64, 3, 3), (20, 24)]:
|
||||
self.assertTrue(equal_distribution(Tensor.kaiming_uniform, lambda x: torch.nn.init.kaiming_uniform_(torch.empty(x)), shape=shape))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import math
|
||||
from typing import Optional, Union, Tuple
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import prod
|
||||
|
@ -35,11 +36,12 @@ class BatchNorm2d:
|
|||
return x.batchnorm(self.weight, self.bias, batch_mean, batch_invstd)
|
||||
|
||||
class Conv2d:
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, initialization: str='kaiming_uniform'):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
||||
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
|
||||
self.stride, self.padding, self.dilation, self.groups = stride, padding, dilation, groups
|
||||
self.weight = getattr(Tensor, initialization)(out_channels, in_channels//groups, *self.kernel_size)
|
||||
self.bias = Tensor.zeros(out_channels) if bias else None
|
||||
self.weight = Tensor.kaiming_uniform(out_channels, in_channels//groups, *self.kernel_size, a=math.sqrt(5))
|
||||
bound = 1 / math.sqrt(prod(self.weight.shape[1:]))
|
||||
self.bias = Tensor.uniform(out_channels, low=-bound, high=bound) if bias else None
|
||||
|
||||
def __call__(self, x):
|
||||
return x.conv2d(self.weight, self.bias, padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
|
||||
|
@ -48,16 +50,18 @@ class ConvTranspose2d:
|
|||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
|
||||
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
|
||||
self.stride, self.padding, self.output_padding, self.dilation, self.groups = stride, padding, output_padding, dilation, groups
|
||||
self.weight = Tensor.glorot_uniform(in_channels, out_channels//groups, *self.kernel_size)
|
||||
self.bias = Tensor.zeros(out_channels) if bias else None
|
||||
self.weight = Tensor.kaiming_uniform(in_channels, out_channels//groups, *self.kernel_size, a=math.sqrt(5))
|
||||
bound = 1 / math.sqrt(prod(self.weight.shape[1:]))
|
||||
self.bias = Tensor.uniform(out_channels, low=-bound, high=bound) if bias else None
|
||||
|
||||
def __call__(self, x):
|
||||
return x.conv_transpose2d(self.weight, self.bias, padding=self.padding, output_padding=self.output_padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
|
||||
|
||||
class Linear:
|
||||
def __init__(self, in_features, out_features, bias=True, initialization: str='kaiming_uniform'):
|
||||
self.weight = getattr(Tensor, initialization)(out_features, in_features)
|
||||
self.bias = Tensor.zeros(out_features) if bias else None
|
||||
def __init__(self, in_features, out_features, bias=True):
|
||||
self.weight = Tensor.kaiming_uniform(out_features, in_features, a=math.sqrt(5))
|
||||
bound = 1 / math.sqrt(self.weight.shape[1])
|
||||
self.bias = Tensor.uniform(out_features, low=-bound, high=bound) if bias else None
|
||||
|
||||
def __call__(self, x):
|
||||
return x.linear(self.weight.transpose(), self.bias)
|
||||
|
|
|
@ -185,7 +185,7 @@ class Tensor:
|
|||
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_
|
||||
@staticmethod
|
||||
def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor:
|
||||
bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(shape[1] * prod(shape[2:]))
|
||||
bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(shape[1:]))
|
||||
return Tensor.uniform(*shape, low=-bound, high=bound)
|
||||
|
||||
# ***** toposort and backward pass *****
|
||||
|
@ -251,7 +251,7 @@ class Tensor:
|
|||
# - Strides > 1 and < 0 are now allowed!:
|
||||
# - This works by applying Shrink -> [[Flip -> ] Pad -> Reshape -> Shrink] -> Reshape (ops in brackets are optional)
|
||||
# - Idea of stride < 0 support:
|
||||
# - Do the slice first, flip the axes were slice.step is negative, do slice.step -> -slice.step. Go to steps below.
|
||||
# - Do the slice first, flip the axes were slice.step is negative, do slice.step -> -slice.step. Go to steps below.
|
||||
# - Idea of stride `s` > 1 support (Pad -> Reshape -> Shrink):
|
||||
# - Instead of doing [::s] on axis [dim_sz], do [:, 0] on axes [dim_sz_padded // s, s].
|
||||
# - So pad dim_sz with as many zeros as needed (dim_sz -> dim_sz_padded) so that reshape to [dim_sz_padded // s, s]
|
||||
|
@ -368,7 +368,7 @@ class Tensor:
|
|||
out = self.sum(axis=axis, keepdim=keepdim)
|
||||
return out * (prod(out.shape)/prod(self.shape))
|
||||
def std(self, axis=None, keepdim=False, correction=1):
|
||||
square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim)
|
||||
square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim)
|
||||
return (square_sum / (prod(self.shape)/prod(square_sum.shape)-correction)).sqrt()
|
||||
def _softmax(self, axis):
|
||||
m = self - self.max(axis=axis, keepdim=True)
|
||||
|
|
Loading…
Reference in New Issue