nn init matches torch (#901)

This commit is contained in:
George Hotz 2023-06-01 21:24:11 -07:00 committed by GitHub
parent 27845fd3a3
commit 8a928ed2f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 29 additions and 19 deletions

View File

@ -85,14 +85,18 @@ def train_cifar():
model = SpeedyResNet()
# init weights with torch
# TODO: it doesn't learn with the tinygrad weights, likely since kaiming init
if getenv("TORCHWEIGHTS"):
from examples.hlb_cifar10_torch import SpeedyResNet as SpeedyResNetTorch
torch_model = SpeedyResNetTorch()
model_state_dict = optim.get_state_dict(model)
torch_state_dict = torch_model.state_dict()
for k,v in torch_state_dict.items():
print(f"initting {k} from torch")
old_mean_std = model_state_dict[k].mean().numpy(), model_state_dict[k].std().numpy()
model_state_dict[k].assign(Tensor(v.detach().numpy())).realize()
new_mean_std = model_state_dict[k].mean().numpy(), model_state_dict[k].std().numpy()
print(f"initted {k:40s} {str(model_state_dict[k].shape):20s} from torch mean:{old_mean_std[0]:8.5f} -> {new_mean_std[0]:8.5f} std:{old_mean_std[1]:8.5f} -> {new_mean_std[1]:8.5f}")
exit(0)
if getenv("ADAM"):
optimizer = optim.Adam(optim.get_parameters(model), lr=Tensor([0.001]).realize())

View File

@ -8,6 +8,9 @@ from torch import optim
from datasets import fetch_cifar
from tinygrad.helpers import getenv
# allow TF32
torch.set_float32_matmul_precision('high')
OSX = platform.system() == "Darwin"
device = 'mps' if OSX else 'cuda'

View File

@ -43,14 +43,14 @@ def normal_test(func, shape=(20, 23), alpha=0.05):
y = np.random.randn(*shape).flatten()
return kstest(x, y) >= alpha
def equal_distribution(tiny_func, torch_func, numpy_func, shape=(20, 23), alpha=0.05):
def equal_distribution(tiny_func, torch_func, numpy_func=None, shape=(20, 23), alpha=0.05):
Tensor.manual_seed(1337)
torch.manual_seed(1337)
np.random.seed(1337)
x = tiny_func(*shape).cpu().numpy().flatten()
y = numpy_func(shape).flatten()
if numpy_func is not None: y = numpy_func(shape).flatten()
z = torch_func(shape).numpy().flatten()
return kstest(x, y) >= alpha and kstest(x, z) >= alpha
return (numpy_func is None or kstest(x, y) >= alpha) and kstest(x, z) >= alpha
class TestRandomness(unittest.TestCase):
def test_rand(self):
@ -73,13 +73,12 @@ class TestRandomness(unittest.TestCase):
self.assertFalse(normal_test(Tensor.glorot_uniform))
self.assertTrue(equal_distribution(Tensor.glorot_uniform, lambda x: torch.nn.init.xavier_uniform_(torch.empty(x)), lambda x: (np.random.rand(*x) * 2 - 1) * math.sqrt(6 / (x[0] + math.prod(x[1:])))))
def test_kaiming_uniform(self, shape=(20, 23), a=0.01):
def test_kaiming_uniform(self):
Tensor.manual_seed(1337)
torch.manual_seed(1337)
np.random.seed(1337)
bound = (math.sqrt(3.0) * (math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(shape[1] * np.prod(shape[2:]))))
self.assertTrue(equal_distribution(Tensor.kaiming_uniform, lambda x: torch.nn.init.kaiming_uniform_(torch.empty(x)), lambda x: np.random.uniform(low=-bound, high=bound, size=shape)))
for shape in [(128, 64, 3, 3), (20, 24)]:
self.assertTrue(equal_distribution(Tensor.kaiming_uniform, lambda x: torch.nn.init.kaiming_uniform_(torch.empty(x)), shape=shape))
if __name__ == "__main__":
unittest.main()

View File

@ -1,3 +1,4 @@
import math
from typing import Optional, Union, Tuple
from tinygrad.tensor import Tensor
from tinygrad.helpers import prod
@ -35,11 +36,12 @@ class BatchNorm2d:
return x.batchnorm(self.weight, self.bias, batch_mean, batch_invstd)
class Conv2d:
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, initialization: str='kaiming_uniform'):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
self.stride, self.padding, self.dilation, self.groups = stride, padding, dilation, groups
self.weight = getattr(Tensor, initialization)(out_channels, in_channels//groups, *self.kernel_size)
self.bias = Tensor.zeros(out_channels) if bias else None
self.weight = Tensor.kaiming_uniform(out_channels, in_channels//groups, *self.kernel_size, a=math.sqrt(5))
bound = 1 / math.sqrt(prod(self.weight.shape[1:]))
self.bias = Tensor.uniform(out_channels, low=-bound, high=bound) if bias else None
def __call__(self, x):
return x.conv2d(self.weight, self.bias, padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
@ -48,16 +50,18 @@ class ConvTranspose2d:
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
self.stride, self.padding, self.output_padding, self.dilation, self.groups = stride, padding, output_padding, dilation, groups
self.weight = Tensor.glorot_uniform(in_channels, out_channels//groups, *self.kernel_size)
self.bias = Tensor.zeros(out_channels) if bias else None
self.weight = Tensor.kaiming_uniform(in_channels, out_channels//groups, *self.kernel_size, a=math.sqrt(5))
bound = 1 / math.sqrt(prod(self.weight.shape[1:]))
self.bias = Tensor.uniform(out_channels, low=-bound, high=bound) if bias else None
def __call__(self, x):
return x.conv_transpose2d(self.weight, self.bias, padding=self.padding, output_padding=self.output_padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
class Linear:
def __init__(self, in_features, out_features, bias=True, initialization: str='kaiming_uniform'):
self.weight = getattr(Tensor, initialization)(out_features, in_features)
self.bias = Tensor.zeros(out_features) if bias else None
def __init__(self, in_features, out_features, bias=True):
self.weight = Tensor.kaiming_uniform(out_features, in_features, a=math.sqrt(5))
bound = 1 / math.sqrt(self.weight.shape[1])
self.bias = Tensor.uniform(out_features, low=-bound, high=bound) if bias else None
def __call__(self, x):
return x.linear(self.weight.transpose(), self.bias)

View File

@ -185,7 +185,7 @@ class Tensor:
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_
@staticmethod
def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor:
bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(shape[1] * prod(shape[2:]))
bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(shape[1:]))
return Tensor.uniform(*shape, low=-bound, high=bound)
# ***** toposort and backward pass *****
@ -251,7 +251,7 @@ class Tensor:
# - Strides > 1 and < 0 are now allowed!:
# - This works by applying Shrink -> [[Flip -> ] Pad -> Reshape -> Shrink] -> Reshape (ops in brackets are optional)
# - Idea of stride < 0 support:
# - Do the slice first, flip the axes were slice.step is negative, do slice.step -> -slice.step. Go to steps below.
# - Do the slice first, flip the axes were slice.step is negative, do slice.step -> -slice.step. Go to steps below.
# - Idea of stride `s` > 1 support (Pad -> Reshape -> Shrink):
# - Instead of doing [::s] on axis [dim_sz], do [:, 0] on axes [dim_sz_padded // s, s].
# - So pad dim_sz with as many zeros as needed (dim_sz -> dim_sz_padded) so that reshape to [dim_sz_padded // s, s]
@ -368,7 +368,7 @@ class Tensor:
out = self.sum(axis=axis, keepdim=keepdim)
return out * (prod(out.shape)/prod(self.shape))
def std(self, axis=None, keepdim=False, correction=1):
square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim)
square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim)
return (square_sum / (prod(self.shape)/prod(square_sum.shape)-correction)).sqrt()
def _softmax(self, axis):
m = self - self.max(axis=axis, keepdim=True)