mirror of https://github.com/commaai/tinygrad.git
uniform init to match torch (#5494)
This commit is contained in:
parent
338b7590b9
commit
aab1e8c6dc
|
@ -1,5 +1,5 @@
|
|||
import math
|
||||
from typing import Optional, Union, Tuple, cast
|
||||
from typing import Optional, Union, Tuple
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.nn import optim, state, datasets # noqa: F401
|
||||
|
@ -98,16 +98,13 @@ class Conv2d:
|
|||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
||||
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
|
||||
self.stride, self.padding, self.dilation, self.groups = stride, padding, dilation, groups
|
||||
self.weight = self.initialize_weight(out_channels, in_channels, groups)
|
||||
bound = 1 / math.sqrt(cast(int, prod(self.weight.shape[1:]))) # weight shape is always ints but mypy cannot tell
|
||||
self.bias = Tensor.uniform(out_channels, low=-bound, high=bound) if bias else None
|
||||
scale = 1 / math.sqrt(in_channels * prod(self.kernel_size))
|
||||
self.weight = Tensor.uniform(out_channels, in_channels//groups, *self.kernel_size, low=-scale, high=scale)
|
||||
self.bias = Tensor.uniform(out_channels, low=-scale, high=scale) if bias else None
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
return x.conv2d(self.weight, self.bias, padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
|
||||
|
||||
def initialize_weight(self, out_channels, in_channels, groups):
|
||||
return Tensor.kaiming_uniform(out_channels, in_channels//groups, *self.kernel_size, a=math.sqrt(5))
|
||||
|
||||
def ConvTranspose1d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
|
||||
"""
|
||||
Applies a 1D transposed convolution operator over an input signal composed of several input planes.
|
||||
|
@ -144,15 +141,14 @@ class ConvTranspose2d(Conv2d):
|
|||
"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
|
||||
super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
|
||||
scale = 1 / math.sqrt(in_channels * prod(self.kernel_size))
|
||||
self.weight = Tensor.uniform(in_channels, out_channels//groups, *self.kernel_size, low=-scale, high=scale)
|
||||
self.output_padding = output_padding
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
return x.conv_transpose2d(self.weight, self.bias, padding=self.padding, output_padding=self.output_padding, stride=self.stride,
|
||||
dilation=self.dilation, groups=self.groups)
|
||||
|
||||
def initialize_weight(self, out_channels, in_channels, groups):
|
||||
return Tensor.kaiming_uniform(in_channels, out_channels//groups, *self.kernel_size, a=math.sqrt(5))
|
||||
|
||||
class Linear:
|
||||
"""
|
||||
Applies a linear transformation to the incoming data.
|
||||
|
@ -170,9 +166,8 @@ class Linear:
|
|||
```
|
||||
"""
|
||||
def __init__(self, in_features, out_features, bias=True):
|
||||
# TODO: is this init good? torch inits to uniform(-1/sqrt(in_features), 1/sqrt(in_features))
|
||||
self.weight = Tensor.kaiming_uniform(out_features, in_features, a=math.sqrt(5))
|
||||
bound = 1 / math.sqrt(in_features)
|
||||
self.weight = Tensor.uniform(out_features, in_features, low=-bound, high=bound)
|
||||
self.bias = Tensor.uniform(out_features, low=-bound, high=bound) if bias else None
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
|
|
Loading…
Reference in New Issue