uniform init to match torch (#5494)

This commit is contained in:
George Hotz 2024-07-15 12:07:44 -07:00 committed by GitHub
parent 338b7590b9
commit aab1e8c6dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 7 additions and 12 deletions

View File

@ -1,5 +1,5 @@
import math
from typing import Optional, Union, Tuple, cast
from typing import Optional, Union, Tuple
from tinygrad.tensor import Tensor
from tinygrad.helpers import prod
from tinygrad.nn import optim, state, datasets # noqa: F401
@ -98,16 +98,13 @@ class Conv2d:
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
self.stride, self.padding, self.dilation, self.groups = stride, padding, dilation, groups
self.weight = self.initialize_weight(out_channels, in_channels, groups)
bound = 1 / math.sqrt(cast(int, prod(self.weight.shape[1:]))) # weight shape is always ints but mypy cannot tell
self.bias = Tensor.uniform(out_channels, low=-bound, high=bound) if bias else None
scale = 1 / math.sqrt(in_channels * prod(self.kernel_size))
self.weight = Tensor.uniform(out_channels, in_channels//groups, *self.kernel_size, low=-scale, high=scale)
self.bias = Tensor.uniform(out_channels, low=-scale, high=scale) if bias else None
def __call__(self, x:Tensor):
return x.conv2d(self.weight, self.bias, padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
def initialize_weight(self, out_channels, in_channels, groups):
return Tensor.kaiming_uniform(out_channels, in_channels//groups, *self.kernel_size, a=math.sqrt(5))
def ConvTranspose1d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
"""
Applies a 1D transposed convolution operator over an input signal composed of several input planes.
@ -144,15 +141,14 @@ class ConvTranspose2d(Conv2d):
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
scale = 1 / math.sqrt(in_channels * prod(self.kernel_size))
self.weight = Tensor.uniform(in_channels, out_channels//groups, *self.kernel_size, low=-scale, high=scale)
self.output_padding = output_padding
def __call__(self, x:Tensor):
return x.conv_transpose2d(self.weight, self.bias, padding=self.padding, output_padding=self.output_padding, stride=self.stride,
dilation=self.dilation, groups=self.groups)
def initialize_weight(self, out_channels, in_channels, groups):
return Tensor.kaiming_uniform(in_channels, out_channels//groups, *self.kernel_size, a=math.sqrt(5))
class Linear:
"""
Applies a linear transformation to the incoming data.
@ -170,9 +166,8 @@ class Linear:
```
"""
def __init__(self, in_features, out_features, bias=True):
# TODO: is this init good? torch inits to uniform(-1/sqrt(in_features), 1/sqrt(in_features))
self.weight = Tensor.kaiming_uniform(out_features, in_features, a=math.sqrt(5))
bound = 1 / math.sqrt(in_features)
self.weight = Tensor.uniform(out_features, in_features, low=-bound, high=bound)
self.bias = Tensor.uniform(out_features, low=-bound, high=bound) if bias else None
def __call__(self, x:Tensor):