mirror of https://github.com/commaai/tinygrad.git
fix various examples (#4691)
* fix examples that used ax1 and ax2 for transpose * fix that * update those
This commit is contained in:
parent
30b07f3c5d
commit
792a494eb8
|
@ -2,6 +2,7 @@
|
|||
|
||||
import os, sys
|
||||
os.environ["CLANG"] = '1'
|
||||
os.environ["JIT"] = '2'
|
||||
|
||||
import numpy as np
|
||||
import subprocess
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
# to start thinking about the $2,000 norm fusion bounty
|
||||
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import Conv2d, BatchNorm2d
|
||||
from tinygrad.nn.state import get_parameters
|
||||
|
|
|
@ -223,7 +223,7 @@ class ConvFeatureExtractionModel:
|
|||
return conv
|
||||
assert (is_layer_norm and is_group_norm) == False, "layer norm and group norm are exclusive"
|
||||
if is_layer_norm:
|
||||
return [make_conv(), partial(Tensor.dropout, p=dropout),[partial(Tensor.transpose, ax1=-2, ax2=-1), nn.LayerNorm(dim, elementwise_affine=True), partial(Tensor.transpose, ax1=-2, ax2=-1)], Tensor.gelu]
|
||||
return [make_conv(), partial(Tensor.dropout, p=dropout),[partial(Tensor.transpose, dim0=-2, dim1=-1), nn.LayerNorm(dim, elementwise_affine=True), partial(Tensor.transpose, dim0=-2, dim1=-1)], Tensor.gelu]
|
||||
elif is_group_norm and mode == "default":
|
||||
return [make_conv(), partial(Tensor.dropout, p=dropout), nn.GroupNorm(dim, dim, affine=True), Tensor.gelu]
|
||||
elif is_group_norm and mode == "group_norm_masked":
|
||||
|
|
|
@ -489,14 +489,14 @@ def split(tensor, split_sizes, dim=0): # if split_sizes is an integer, convert
|
|||
start += size
|
||||
return [tensor.slice(s) for s in slices]
|
||||
def gather(x, indices, axis):
|
||||
indices = (indices < 0).where(indices + x.shape[axis], indices).transpose(ax1=axis, ax2=0)
|
||||
indices = (indices < 0).where(indices + x.shape[axis], indices).transpose(0, axis)
|
||||
permute_args = list(range(x.ndim))
|
||||
permute_args[0], permute_args[axis] = permute_args[axis], permute_args[0]
|
||||
permute_args.append(permute_args.pop(0))
|
||||
x = x.permute(*permute_args)
|
||||
reshape_arg = [1] * x.ndim + [x.shape[-1]]
|
||||
return ((indices.unsqueeze(indices.ndim).expand(*indices.shape, x.shape[-1]) ==
|
||||
Tensor.arange(x.shape[-1]).reshape(*reshape_arg).expand(*indices.shape, x.shape[-1])) * x).sum(indices.ndim).transpose(ax1=0, ax2=axis)
|
||||
Tensor.arange(x.shape[-1]).reshape(*reshape_arg).expand(*indices.shape, x.shape[-1])) * x).sum(indices.ndim).transpose(0, axis)
|
||||
|
||||
def norm_except_dim(v, dim):
|
||||
if dim == -1: return np.linalg.norm(v)
|
||||
|
|
Loading…
Reference in New Issue