
129 lines
7.3 KiB

import math
from typing import Optional, Union, Tuple
from tinygrad.tensor import Tensor
from tinygrad.helpers import prod, all_int
class BatchNorm2d:
def __init__(self, sz, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1):
self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum
if affine: self.weight, self.bias = Tensor.ones(sz), Tensor.zeros(sz)
else: self.weight, self.bias = None, None
self.running_mean, self.running_var = Tensor.zeros(sz, requires_grad=False), Tensor.ones(sz, requires_grad=False)
self.num_batches_tracked = Tensor.zeros(1, requires_grad=False)
def __call__(self, x:Tensor):
# This requires two full memory accesses to x
# There's "online" algorithms that fix this, like's_Online_algorithm
batch_mean = x.mean(axis=(0,2,3))
y = (x - batch_mean.reshape(shape=[1, -1, 1, 1]))
batch_var = (y*y).mean(axis=(0,2,3))
batch_invstd = batch_var.add(self.eps).pow(-0.5)
# NOTE: wow, this is done all throughout training in most PyTorch models
if self.track_running_stats:
self.running_mean.assign((1 - self.momentum) * self.running_mean + self.momentum * batch_mean.detach())
self.running_var.assign((1 - self.momentum) * self.running_var + self.momentum * prod(y.shape)/(prod(y.shape) - y.shape[1]) * batch_var.detach() )
self.num_batches_tracked += 1
batch_mean = self.running_mean
# NOTE: this can be precomputed for static inference. we expand it here so it fuses
batch_invstd = self.running_var.reshape(1, -1, 1, 1).expand(x.shape).add(self.eps).rsqrt()
return x.batchnorm(self.weight, self.bias, batch_mean, batch_invstd)
# TODO: these Conv lines are terrible
def Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
return Conv2d(in_channels, out_channels, (kernel_size,), stride, padding, dilation, groups, bias)
class Conv2d:
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
self.stride, self.padding, self.dilation, self.groups = stride, padding, dilation, groups
self.weight = self.initialize_weight(out_channels, in_channels, groups)
assert all_int(self.weight.shape), "does not support symbolic shape"
bound = 1 / math.sqrt(prod(self.weight.shape[1:]))
self.bias = Tensor.uniform(out_channels, low=-bound, high=bound) if bias else None
def __call__(self, x:Tensor):
return x.conv2d(self.weight, self.bias, padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
def initialize_weight(self, out_channels, in_channels, groups): return Tensor.kaiming_uniform(out_channels, in_channels//groups, *self.kernel_size, a=math.sqrt(5))
def ConvTranspose1d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
return ConvTranspose2d(in_channels, out_channels, (kernel_size,), stride, padding, output_padding, dilation, groups, bias)
class ConvTranspose2d(Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
self.output_padding = output_padding
def __call__(self, x:Tensor):
return x.conv_transpose2d(self.weight, self.bias, padding=self.padding, output_padding=self.output_padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
def initialize_weight(self, out_channels, in_channels, groups): return Tensor.kaiming_uniform(in_channels, out_channels//groups, *self.kernel_size, a=math.sqrt(5))
class Linear:
def __init__(self, in_features, out_features, bias=True):
self.weight = Tensor.kaiming_uniform(out_features, in_features, a=math.sqrt(5))
# TODO: remove this once we can represent Tensor with int shape in typing
assert isinstance(self.weight.shape[1], int), "does not support symbolic shape"
bound = 1 / math.sqrt(self.weight.shape[1])
self.bias = Tensor.uniform(out_features, low=-bound, high=bound) if bias else None
def __call__(self, x:Tensor):
return x.linear(self.weight.transpose(), self.bias)
class GroupNorm:
def __init__(self, num_groups:int, num_channels:int, eps:float=1e-5, affine:bool=True):
self.num_groups, self.num_channels, self.eps = num_groups, num_channels, eps
self.weight: Optional[Tensor] = Tensor.ones(num_channels) if affine else None
self.bias: Optional[Tensor] = Tensor.zeros(num_channels) if affine else None
def __call__(self, x:Tensor):
# reshape for layernorm to work as group norm
# subtract mean and divide stddev
x = x.reshape(x.shape[0], self.num_groups, -1).layernorm(eps=self.eps).reshape(x.shape)
if self.weight is None or self.bias is None: return x
# elementwise_affine on channels
return x * self.weight.reshape(1, -1, *[1] * (len(x.shape)-2)) + self.bias.reshape(1, -1, *[1] * (len(x.shape)-2))
class InstanceNorm:
def __init__(self, num_features:int, eps:float=1e-5, affine:bool=True):
self.num_features, self.eps = num_features, eps
self.weight: Optional[Tensor] = Tensor.ones(num_features) if affine else None
self.bias: Optional[Tensor] = Tensor.zeros(num_features) if affine else None
def __call__(self, x:Tensor):
x = x.reshape(x.shape[0], self.num_features, -1).layernorm(eps=self.eps).reshape(x.shape)
if self.weight is None or self.bias is None: return x
return x * self.weight.reshape(1, -1, *[1] * (len(x.shape)-2)) + self.bias.reshape(1, -1, *[1] * (len(x.shape)-2))
class LayerNorm:
def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps:float=1e-5, elementwise_affine:bool=True):
self.normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(self.normalized_shape))), eps, elementwise_affine
self.weight, self.bias = (Tensor.ones(*self.normalized_shape), Tensor.zeros(*self.normalized_shape)) if elementwise_affine else (None, None)
def __call__(self, x:Tensor):
assert self.normalized_shape == x.shape[-len(self.normalized_shape):], f"last dimensions of {x.shape} must match {self.normalized_shape}"
x = x.layernorm(eps=self.eps, axis=self.axis)
if not self.elementwise_affine: return x
return x * self.weight + self.bias
class LayerNorm2d(LayerNorm):
def __call__(self, x): return super().__call__(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
class Embedding:
def __init__(self, vocab_size:int, embed_size:int):
self.vocab_size = vocab_size
self.weight = Tensor.glorot_uniform(vocab_size, embed_size)
def __call__(self, idx:Tensor) -> Tensor:
if not hasattr(self, 'vocab_counter'): self.vocab_counter = Tensor.arange(self.vocab_size, requires_grad=False).reshape(1, 1, self.vocab_size)
return (self.vocab_counter == idx.unsqueeze(2)).expand(*idx.shape, self.vocab_size) @ self.weight