fix various examples (#4691)

* fix examples that used ax1 and ax2 for transpose

* fix that

* update those
This commit is contained in:
chenyu 2024-05-22 20:43:21 -04:00 committed by GitHub
parent 30b07f3c5d
commit 792a494eb8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 4 additions and 5 deletions

View File

@ -2,6 +2,7 @@
import os, sys
os.environ["CLANG"] = '1'
os.environ["JIT"] = '2'
import numpy as np
import subprocess

View File

@ -1,5 +1,3 @@
# to start thinking about the $2,000 norm fusion bounty
from tinygrad.tensor import Tensor
from tinygrad.nn import Conv2d, BatchNorm2d
from tinygrad.nn.state import get_parameters

View File

@ -223,7 +223,7 @@ class ConvFeatureExtractionModel:
return conv
assert (is_layer_norm and is_group_norm) == False, "layer norm and group norm are exclusive"
if is_layer_norm:
return [make_conv(), partial(Tensor.dropout, p=dropout),[partial(Tensor.transpose, ax1=-2, ax2=-1), nn.LayerNorm(dim, elementwise_affine=True), partial(Tensor.transpose, ax1=-2, ax2=-1)], Tensor.gelu]
return [make_conv(), partial(Tensor.dropout, p=dropout),[partial(Tensor.transpose, dim0=-2, dim1=-1), nn.LayerNorm(dim, elementwise_affine=True), partial(Tensor.transpose, dim0=-2, dim1=-1)], Tensor.gelu]
elif is_group_norm and mode == "default":
return [make_conv(), partial(Tensor.dropout, p=dropout), nn.GroupNorm(dim, dim, affine=True), Tensor.gelu]
elif is_group_norm and mode == "group_norm_masked":

View File

@ -489,14 +489,14 @@ def split(tensor, split_sizes, dim=0): # if split_sizes is an integer, convert
start += size
return [tensor.slice(s) for s in slices]
def gather(x, indices, axis):
indices = (indices < 0).where(indices + x.shape[axis], indices).transpose(ax1=axis, ax2=0)
indices = (indices < 0).where(indices + x.shape[axis], indices).transpose(0, axis)
permute_args = list(range(x.ndim))
permute_args[0], permute_args[axis] = permute_args[axis], permute_args[0]
permute_args.append(permute_args.pop(0))
x = x.permute(*permute_args)
reshape_arg = [1] * x.ndim + [x.shape[-1]]
return ((indices.unsqueeze(indices.ndim).expand(*indices.shape, x.shape[-1]) ==
Tensor.arange(x.shape[-1]).reshape(*reshape_arg).expand(*indices.shape, x.shape[-1])) * x).sum(indices.ndim).transpose(ax1=0, ax2=axis)
Tensor.arange(x.shape[-1]).reshape(*reshape_arg).expand(*indices.shape, x.shape[-1])) * x).sum(indices.ndim).transpose(0, axis)
def norm_except_dim(v, dim):
if dim == -1: return np.linalg.norm(v)