mirror of https://github.com/commaai/tinygrad.git
cleanup onnx, pass one more reshape test and remove some casts (#2806)
This commit is contained in:
parent
baa94d6142
commit
157c0be509
|
@ -170,7 +170,7 @@ def get_run_onnx(onnx_model: ModelProto):
|
|||
axes, ends, starts, steps = list(opt.get("axes", range(inp[0].ndim))), list(opt["ends"]), list(opt["starts"]), [1]*inp[0].ndim
|
||||
else:
|
||||
starts, ends = inp[1:3]
|
||||
axes = safe_numpy(Tensor.arange(inp[0].ndim, dtype=dtypes.int32) if len(inp) <= 3 else inp[3]).tolist()
|
||||
axes = safe_numpy(Tensor.arange(inp[0].ndim) if len(inp) <= 3 else inp[3]).tolist()
|
||||
steps = safe_numpy(inp[4]) if len(inp) > 4 else [1]*inp[0].ndim
|
||||
starts, ends = safe_numpy(starts.ceil().cast(dtypes.int32)).tolist(), safe_numpy(ends.ceil().cast(dtypes.int32)).tolist()
|
||||
arg = [(0,x,1) for x in inp[0].shape]
|
||||
|
|
|
@ -1,15 +1,15 @@
|
|||
import functools, io, math, os
|
||||
import functools, io, math
|
||||
from typing import Union, Tuple, Optional, List, Any
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import prod, dtypes, ImageDType, flatten
|
||||
from tinygrad.nn import Embedding
|
||||
from extra.onnx import safe_numpy
|
||||
from onnx.helper import tensor_dtype_to_np_dtype
|
||||
from onnx import TensorProto
|
||||
import numpy as np
|
||||
|
||||
tensor_methods = {"Neg", "Reciprocal", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", "Relu", "Sigmoid", "MatMul",
|
||||
"Floor", "Ceil", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", "Tanh", "Softsign", "Asinh", "Acosh", "Atanh", "Elu"}
|
||||
"Floor", "Ceil", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", "Tanh", "Softsign", "Asinh", "Acosh", "Atanh",
|
||||
"Elu", "Celu"}
|
||||
|
||||
# **************** Free Ops ****************
|
||||
|
||||
|
@ -17,7 +17,7 @@ def Identity(input: Tensor): return input
|
|||
def Add(input: Tensor, other: Tensor, broadcast=None): return input + other if input.dtype == dtypes.float or isinstance(input.dtype, ImageDType) else (input + other).cast(input.dtype)
|
||||
def Sub(input: Union[Tensor, Any], other: Tensor): return input - other # some test has input as int
|
||||
def Div(input: Tensor, other: Tensor): return input / other if input.dtype == dtypes.float or isinstance(input.dtype, ImageDType) else input.div(other).floor() # TODO: this has dtype issues
|
||||
def Pow(input: Tensor, other: Tensor): return (input.float() ** other.float()).cast(input.dtype) # TODO: this has dtype issues
|
||||
def Pow(input: Tensor, other: Tensor): return input.float() ** other.float()
|
||||
def Less(x:Tensor,y:Tensor): return (x<y).cast(dtypes.bool)
|
||||
def LessOrEqual(x:Tensor,y:Tensor): return (x<=y).cast(dtypes.bool)
|
||||
def Greater(x:Tensor,y:Tensor): return (x>y).cast(dtypes.bool)
|
||||
|
@ -41,7 +41,6 @@ def Constant(value: Tensor=None, value_float=None, value_floats=None, value_int=
|
|||
|
||||
def HardSigmoid(input: Tensor, alpha=0.2, beta=0.5): return (alpha*input + beta).clip(0, 1)
|
||||
def Gelu(x:Tensor, approximate=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + Erf(x/math.sqrt(2)))
|
||||
def Celu(x:Tensor, alpha=1.0): return x.celu(alpha)
|
||||
def Selu(X: Tensor, alpha=1.67326319217681884765625, gamma=1.05070102214813232421875): return gamma * (X.relu() - (-alpha*X.exp()+alpha).relu())
|
||||
def PRelu(X:Tensor, slope:Tensor):
|
||||
slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope # HACK OnnxBackendPyTorchConvertedModelTest HAS WEIRD SLOPE WHERE IT'S [0.25, 0.25, 0.25] FOR ANY X.SHAPE
|
||||
|
@ -55,7 +54,9 @@ def LogSoftmax(input: Tensor, axis=-1): return input.log_softmax(axis)
|
|||
def Clip(input: Tensor, min=None, max=None): return input.clip(float('-inf') if min is None else min, float('inf') if max is None else max)
|
||||
|
||||
# NOTE ReduceProd would require a new llop
|
||||
def _axes(axes, noop_with_empty_axes): return [int(x) for x in safe_numpy(axes)] if axes is not None and not (isinstance(axes, Tensor) and axes.shape == (0,)) else ([] if noop_with_empty_axes else None)
|
||||
def _axes(axes, noop_with_empty_axes):
|
||||
if axes is not None and not (isinstance(axes, Tensor) and axes.shape == (0,)): return [int(x) for x in safe_numpy(axes)]
|
||||
return [] if noop_with_empty_axes else None
|
||||
def ReduceMax(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.max(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
||||
def ReduceMin(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.min(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
||||
def ReduceSum(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
||||
|
@ -68,15 +69,16 @@ def ReduceLogSumExp(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0)
|
|||
|
||||
def GlobalAveragePool(X: Tensor): return X.mean(axis=tuple(range(2, len(X.shape))), keepdim=True)
|
||||
def GlobalMaxPool(X: Tensor): return X.max(axis=tuple(range(2, len(X.shape))), keepdim=True)
|
||||
def OptionalHasElement(x: Optional[Tensor]=None): return Tensor(x is not None and x.numel() > 0, dtype=dtypes.bool)
|
||||
def OptionalGetElement(x: Optional[Tensor]=None): return x if x is not None else Tensor([], dtype=dtypes.float32)
|
||||
def OptionalHasElement(x: Optional[Tensor]=None): return Tensor(x is not None and x.numel() > 0)
|
||||
def OptionalGetElement(x: Optional[Tensor]=None): return x if x is not None else Tensor([])
|
||||
|
||||
def Tile(input: Tensor, repeats): return input.repeat([int(x) for x in safe_numpy(repeats)])
|
||||
def Range(start: Tensor, limit, delta): return Tensor.arange(start=int(safe_numpy(start).item()), stop=int(safe_numpy(limit).item()), step=int(safe_numpy(delta).item())).cast(dtype=start.dtype)
|
||||
def Shape(data: Tensor, end=None, start=0): return Tensor(list(data.shape)[start:end], dtype=dtypes.int32 if os.path.isfile("/TICI") else dtypes.int64) # TODO: really?
|
||||
def Range(start: Tensor, limit, delta): return Tensor.arange(start=safe_numpy(start).item(), stop=safe_numpy(limit).item(), step=safe_numpy(delta).item())
|
||||
def Shape(data: Tensor, end=None, start=0): return Tensor(list(data.shape)[start:end], dtype=dtypes.int64)
|
||||
def Size(data: Tensor): return prod(data if isinstance(data, list) else data.shape)
|
||||
def Flatten(input: Tensor, axis=1): return input.reshape(prod((1,) + input.shape[0:axis]), -1)
|
||||
def Reshape(data: Tensor, shape: Tensor, allowzero=None): return data.reshape([int(x) if x != 0 else data.shape[i] for i,x in enumerate(safe_numpy(shape))])
|
||||
def Reshape(data: Tensor, shape: Tensor, allowzero=0):
|
||||
return data.reshape([int(x) if x != 0 else (0 if allowzero else data.shape[i]) for i,x in enumerate(safe_numpy(shape))])
|
||||
def Shrink(input: Tensor, bias=0.0, lambd=0.5): return (input < -lambd)*(input+bias) + (input > lambd)*(input-bias)
|
||||
def And(x:Tensor, y:Tensor): return (x==y).where(x, 0).cast(dtypes.bool)
|
||||
def Or(x:Tensor, y:Tensor): return (x==y).where(x, 1).cast(dtypes.bool)
|
||||
|
@ -107,13 +109,13 @@ def Trilu(x: Tensor, k: Union[Tensor, int]=0, upper=1):
|
|||
k = int(k.numpy().item()) if isinstance(k, Tensor) else 0 # onnx passes k as a tensor int64 with one element, default is 0
|
||||
return x.triu(k) if upper else x.tril(k)
|
||||
|
||||
def Squeeze(input: Tensor, axes):
|
||||
def Squeeze(data: Tensor, axes):
|
||||
if isinstance(axes, Tensor): axes = safe_numpy(axes)
|
||||
axes = [int(x) if x >= 0 else int(x+input.ndim) for x in axes]
|
||||
return input.reshape([s for i,s in enumerate(input.shape) if i not in axes])
|
||||
axes = [int(x) + data.ndim if x < 0 else int(x) for x in axes]
|
||||
return data.reshape([s for i,s in enumerate(data.shape) if i not in axes])
|
||||
def Unsqueeze(data: Tensor, axes):
|
||||
axes = [len(data.shape) + int(x) if x < 0 else int(x) for x in safe_numpy(axes)]
|
||||
new_shape = [1] * (len(data.shape) + len(axes))
|
||||
axes = [int(x) + data.ndim if x < 0 else int(x) for x in safe_numpy(axes)]
|
||||
new_shape = [1] * (data.ndim + len(axes))
|
||||
ptr = iter(data.shape)
|
||||
for i in range(len(new_shape)):
|
||||
if i not in axes:
|
||||
|
@ -155,7 +157,7 @@ def Expand(input: Tensor, shape):
|
|||
|
||||
def Gemm(A: Tensor, B: Tensor, C: Tensor=None, alpha=1.0, beta=1.0, transA=0, transB=0, broadcast=0):
|
||||
ret = alpha * (A.transpose(transA) @ B.transpose(transB))
|
||||
if C is not None: ret += beta * (C if broadcast == 0 else C.reshape([-1 if i < len(C.shape) else 1 for i in range(len(ret.shape))][::-1]))
|
||||
if C is not None: ret += beta * (C if broadcast == 0 else C.reshape([-1 if i < len(C.shape) else 1 for i in range(ret.ndim)][::-1]))
|
||||
return ret
|
||||
|
||||
# works with Tensors.ndim != 4
|
||||
|
@ -174,7 +176,7 @@ def BatchNormalization(X: Tensor, scale, B, input_mean, input_var, epsilon=1e-05
|
|||
current_mean = x_detached.mean(axis=(0,2,3))
|
||||
y = (x_detached - current_mean.reshape(shape=[1, -1, 1, 1]))
|
||||
current_var = (y*y).mean(axis=(0,2,3))
|
||||
current_invstd = current_var.add(epsilon).pow(-0.5)
|
||||
current_invstd = current_var.add(epsilon).rsqrt()
|
||||
|
||||
running_mean = input_mean * momentum + current_mean * (1 - momentum)
|
||||
running_var = input_var * momentum + current_var * (1 - momentum)
|
||||
|
@ -187,14 +189,14 @@ def BatchNormalization(X: Tensor, scale, B, input_mean, input_var, epsilon=1e-05
|
|||
def InstanceNormalization(x: Tensor, scale: Tensor, bias: Tensor, epsilon=1e-05):
|
||||
axis = tuple(range(2, len(x.shape)))
|
||||
mean = x.mean(axis=axis, keepdim=True)
|
||||
invstd = x.sub(mean).pow(2).mean(axis=axis, keepdim=True).add(epsilon).pow(-0.5)
|
||||
invstd = x.sub(mean).pow(2).mean(axis=axis, keepdim=True).add(epsilon).rsqrt()
|
||||
return x.sub(mean).mul(scale.reshape(shape=[-1, 1, 1])).mul(invstd).add(bias.reshape(shape=[-1, 1, 1]))
|
||||
|
||||
def LayerNormalization(x: Tensor, scale, bias, axis=-1, epsilon=1e-05, stash_type=1):
|
||||
assert stash_type == 1, "only float32 is supported"
|
||||
axis = tuple(i for i in range(axis if axis >= 0 else len(x.shape) + axis, len(x.shape)))
|
||||
axis = tuple(i for i in range(axis if axis >= 0 else x.ndim + axis, x.ndim))
|
||||
mean = x.mean(axis=axis, keepdim=True)
|
||||
return x.layernorm(axis, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).pow(2).mean(axis=axis, keepdim=True).add(epsilon).sqrt().reciprocal()
|
||||
return x.layernorm(axis, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).pow(2).mean(axis=axis, keepdim=True).add(epsilon).rsqrt()
|
||||
|
||||
def GroupNormalization(x: Tensor, scale: Tensor, bias: Tensor, num_groups, epsilon=1e-05):
|
||||
return x.reshape(x.shape[0], num_groups, -1).layernorm(axis=-1, eps=epsilon).mul(scale.unsqueeze(-1)).add(bias.unsqueeze(-1)).reshape(x.shape)
|
||||
|
@ -380,7 +382,8 @@ def SoftmaxCrossEntropyLoss(scores: Tensor, labels: Tensor, weights=None, ignore
|
|||
elif reduction == "sum": loss = loss.sum()
|
||||
return loss, y
|
||||
|
||||
def ArrayFeatureExtractor(input: Tensor, indices: Tensor): return input.__getitem__(tuple([slice(None) if i != (input.ndim-1) else indices for i in range(input.ndim)]))
|
||||
def ArrayFeatureExtractor(input: Tensor, indices: Tensor):
|
||||
return input.__getitem__(tuple([slice(None) if i != (input.ndim-1) else indices for i in range(input.ndim)]))
|
||||
def Gather(input: Tensor, indices: Tensor, axis=0):
|
||||
if indices.numel() < 9: # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices
|
||||
input_sh = list(input.shape)
|
||||
|
@ -393,17 +396,17 @@ def Gather(input: Tensor, indices: Tensor, axis=0):
|
|||
return input.__getitem__(tuple([slice(None) if i != axis else indices for i in range(input.ndim)]))
|
||||
|
||||
def GatherElements(input: Tensor, indices: Tensor, axis):
|
||||
indices = indices.sign().contiguous().__neg__().contiguous().relu() * input.shape[axis] + indices
|
||||
indices = (indices < 0).where(input.shape[axis], 0) + indices
|
||||
return input.gather(indices, axis)
|
||||
|
||||
def _round(x:Tensor, n:float, equidistant_case = "round_down") -> Tensor:
|
||||
def _and(cond1, cond2): return ((cond1 + cond2) == 2).where(1, 0)
|
||||
assert n <= 1, f"n:{n} shouldn't be larger than 1"
|
||||
b = x.cast(dtypes.int32).contiguous().cast(x.dtype)
|
||||
b = (b >= 0).where(b+n, b-n)
|
||||
if equidistant_case == "round_down": return (x > b).where(b+1-n, b-n)
|
||||
elif equidistant_case == "round_up": return (x >= b).where(b+1-n, b-n)
|
||||
elif equidistant_case == "round_to_even":
|
||||
def _and(cond1, cond2): return ((cond1 + cond2) == 2).where(1, 0)
|
||||
x_ceil_fraction = x.ceil()/2
|
||||
cond_ceil_even = x_ceil_fraction.ceil() == x_ceil_fraction
|
||||
x = (_and(x == b, cond_ceil_even)).where(x+1-n, x)
|
||||
|
@ -413,7 +416,9 @@ def _round(x:Tensor, n:float, equidistant_case = "round_down") -> Tensor:
|
|||
def Round(X:Tensor): return _round(X, 0.5, "round_to_even")
|
||||
|
||||
# TODO clean this up, it's taking the longest in CI
|
||||
def Resize(X:Tensor, roi=None, scales=None, sizes=None, antialias=0, axes=None, coordinate_transformation_mode='half_pixel', cubic_coeff_a=-0.75, exclude_outside=0, extrapolation_value=0.0, keep_aspect_ratio_policy='stretch', mode='nearest', nearest_mode='round_prefer_floor'):
|
||||
def Resize(X:Tensor, roi=None, scales=None, sizes=None, antialias=0, axes=None, coordinate_transformation_mode='half_pixel',
|
||||
cubic_coeff_a=-0.75, exclude_outside=0, extrapolation_value=0.0, keep_aspect_ratio_policy='stretch',
|
||||
mode='nearest', nearest_mode='round_prefer_floor'):
|
||||
def _nearest_gather(X: Tensor, x_out, y_out): return X[:,:,y_out,:][:,:,:,x_out]
|
||||
def _nearest_mode(x_resized: Tensor, nearest_mode: str, x_len):
|
||||
if nearest_mode == "round_prefer_floor": ret = _round(x_resized, 0.5, "round_down")
|
||||
|
@ -421,22 +426,22 @@ def Resize(X:Tensor, roi=None, scales=None, sizes=None, antialias=0, axes=None,
|
|||
elif nearest_mode == "floor": ret = x_resized.floor()
|
||||
elif nearest_mode == "ceil": ret = x_resized.ceil()
|
||||
return ret.clip(0, x_len-1)
|
||||
def _coordinate_transformation(x_out, y_out, output_shape, scales_lol, roi=None):
|
||||
def _coordinate_transformation(x_out, y_out, output_shape, scales_, roi=None):
|
||||
if coordinate_transformation_mode == "half_pixel":
|
||||
x_out = (x_out + 0.5)/Tensor(scales_lol[-1]) - 0.5 # TODO Tensor() because try (((Tensor([0,1,2,3,4,5])+0.5)/3.5 - 0.5)) with LLVM or METAL, inaccuacy.
|
||||
y_out = (y_out + 0.5)/Tensor(scales_lol[-2]) - 0.5
|
||||
x_out = (x_out + 0.5)/Tensor(scales_[-1]) - 0.5 # TODO Tensor() because try (((Tensor([0,1,2,3,4,5])+0.5)/3.5 - 0.5)) with LLVM or METAL, inaccuacy.
|
||||
y_out = (y_out + 0.5)/Tensor(scales_[-2]) - 0.5
|
||||
elif coordinate_transformation_mode == "align_corners":
|
||||
x_out = x_out * (X.shape[-1] - 1) / (output_shape[-1] - 1)
|
||||
y_out = y_out * (X.shape[-2] - 1) / (output_shape[-2] - 1)
|
||||
elif coordinate_transformation_mode == "asymmetric":
|
||||
x_out = x_out/scales_lol[-1]
|
||||
y_out = y_out/scales_lol[-2]
|
||||
x_out = x_out/scales_[-1]
|
||||
y_out = y_out/scales_[-2]
|
||||
elif coordinate_transformation_mode == "half_pixel_symmetric":
|
||||
x_out = X.shape[-1] / 2 * (1 - int(output_shape[-1]) / output_shape[-1]) + (x_out + 0.5) / scales_lol[-1] - 0.5
|
||||
y_out = X.shape[-2] / 2 * (1 - int(output_shape[-2]) / output_shape[-2]) + (y_out + 0.5) / scales_lol[-2] - 0.5
|
||||
x_out = X.shape[-1] / 2 * (1 - int(output_shape[-1]) / output_shape[-1]) + (x_out + 0.5) / scales_[-1] - 0.5
|
||||
y_out = X.shape[-2] / 2 * (1 - int(output_shape[-2]) / output_shape[-2]) + (y_out + 0.5) / scales_[-2] - 0.5
|
||||
elif coordinate_transformation_mode == "pytorch_half_pixel":
|
||||
x_out = (x_out + 0.5)/scales_lol[-1] - 0.5 if output_shape[-1] > 1 else Tensor([0])
|
||||
y_out = (y_out + 0.5)/scales_lol[-2] - 0.5 if output_shape[-2] > 1 else Tensor([0])
|
||||
x_out = (x_out + 0.5)/scales_[-1] - 0.5 if output_shape[-1] > 1 else Tensor([0])
|
||||
y_out = (y_out + 0.5)/scales_[-2] - 0.5 if output_shape[-2] > 1 else Tensor([0])
|
||||
elif coordinate_transformation_mode == "tf_crop_and_resize":
|
||||
x_out = roi[-1][0] * (X.shape[-1] - 1) + x_out * ((roi[-1][1] - roi[-1][0]) * (X.shape[-1] - 1) / (output_shape[-1] - 1)) if output_shape[-1] > 1 else Tensor([0.5 * (roi[-1][0] + roi[-1][1]) * (X.shape[-1] - 1)])
|
||||
y_out = roi[-2][0] * (X.shape[-2] - 1) + y_out * ((roi[-2][1] - roi[-2][0]) * (X.shape[-2] - 1) / (output_shape[-2] - 1)) if output_shape[-2] > 1 else Tensor([0.5 * (roi[-2][0] + roi[-2][1]) * (X.shape[-2] - 1)])
|
||||
|
@ -476,11 +481,11 @@ def Resize(X:Tensor, roi=None, scales=None, sizes=None, antialias=0, axes=None,
|
|||
sizes = list(X.shape[:-2]) + [int(i) for i in safe_numpy(sizes)]
|
||||
output_shape = sizes if sizes else [math.floor(x*s) for x,s in zip(X.shape, scales)]
|
||||
output_shape_ = sizes if sizes else [x*s for x,s in zip(X.shape, scales)]
|
||||
scales_lol = [os/xs for xs, os in zip(X.shape, output_shape)]
|
||||
scales_ = [os/xs for xs, os in zip(X.shape, output_shape)]
|
||||
x_out = Tensor.arange(output_shape[-1]).cast(Tensor.default_type)
|
||||
y_out = Tensor.arange(output_shape[-2]).cast(Tensor.default_type)
|
||||
if mode == "nearest":
|
||||
x_out, y_out = _coordinate_transformation(x_out, y_out, output_shape, scales_lol, roi)
|
||||
x_out, y_out = _coordinate_transformation(x_out, y_out, output_shape, scales_, roi)
|
||||
x_out = _nearest_mode(x_out, nearest_mode, X.shape[-1])
|
||||
y_out = _nearest_mode(y_out, nearest_mode, X.shape[-1])
|
||||
return _nearest_gather(X, x_out, y_out)
|
||||
|
@ -510,7 +515,7 @@ def Resize(X:Tensor, roi=None, scales=None, sizes=None, antialias=0, axes=None,
|
|||
def CenterCropPad(input: Tensor, shape: Tensor, axes=None):
|
||||
if not axes: axes = list(range(input.ndim))
|
||||
shrink_arg = [(0,i) for i in input.shape]
|
||||
pad_arg = [(0,0) for _ in range(input.ndim)]
|
||||
pad_arg = [(0,0)] * input.ndim
|
||||
shape = safe_numpy(shape).tolist()
|
||||
for s, x in zip(shape, axes):
|
||||
if s < input.shape[x]: shrink_arg[x] = (input.shape[x]//2 - s//2, input.shape[x]//2 + s//2) if s%2 == 0 else (input.shape[x]//2 - s//2 - 1, input.shape[x]//2 + s//2)
|
||||
|
@ -519,23 +524,22 @@ def CenterCropPad(input: Tensor, shape: Tensor, axes=None):
|
|||
|
||||
def OneHot(indices: Tensor, depth: Tensor, values: Tensor, axis=-1):
|
||||
depth = int(safe_numpy(depth).item())
|
||||
indices, rank = (indices < 0).where(indices+depth, indices), len(indices.shape)
|
||||
indices, rank = (indices < 0).where(indices+depth, indices), indices.ndim
|
||||
if axis < 0: axis += rank + 1
|
||||
ls, rs = indices.shape[0:axis], indices.shape[axis: rank]
|
||||
cond = indices[:,None] == Tensor.arange(depth).reshape((1,) * len(ls) + (depth,) + (1,) * len(rs))
|
||||
return cond.where(values[1], values[0]).cast(values.dtype)
|
||||
|
||||
def Erf(x: Tensor):
|
||||
sign = x.sign()
|
||||
x = x.abs()
|
||||
t = 1.0 / (1.0 + 0.3275911 * x)
|
||||
t = 1.0 / (1.0 + 0.3275911 * x.abs())
|
||||
term1 = 0.254829592 * t
|
||||
term2 = -0.284496736 * t ** 2
|
||||
term3 = 1.421413741 * t ** 3
|
||||
term4 = -1.453152027 * t ** 4
|
||||
term5 = 1.061405429 * t ** 5
|
||||
y = (term1 + term2 + term3 + term4 + term5)
|
||||
return sign * (1.0 - y * Tensor.exp(-x * x))
|
||||
z = 1.0 - y * Tensor.exp(-x * x)
|
||||
return (x > 0).where(z, -z)
|
||||
|
||||
def Compress(inp: Tensor, condition: Tensor, axis=None):
|
||||
if axis is None:
|
||||
|
@ -570,7 +574,7 @@ def IsInf(x: Tensor, detect_negative=1, detect_positive=1):
|
|||
def DequantizeLinear(x: Tensor, x_scale: Tensor, x_zero_point: Union[Tensor, int] = 0, axis=1):
|
||||
axis = axis + x.ndim if axis < 0 else axis
|
||||
x = x.cast(dtypes.float)
|
||||
if x_zero_point.__class__ is Tensor: x_zero_point.cast(dtypes.float)
|
||||
if isinstance(x_zero_point, Tensor): x_zero_point.cast(dtypes.float)
|
||||
x_sc = x_scale.reshape(*[1]*axis, *x_scale.shape, *[1]*(x.ndim - axis - x_scale.ndim))
|
||||
x_zer = x_zero_point.reshape(*[1]*axis, *x_scale.shape, *[1]*(x.ndim - axis - x_scale.ndim)) if isinstance(x_zero_point, Tensor) else x_zero_point
|
||||
return ((x - x_zer) * x_sc).cast(x_scale.dtype)
|
||||
|
@ -686,7 +690,7 @@ def Attention(input:Tensor, weights, bias:Optional[Tensor]=None, mask_index:Opti
|
|||
cdim = max(query_length, key_length) + 1
|
||||
attn_weights = query @ key.transpose(-1, -2) / math.sqrt(value.shape[-1])
|
||||
# This is where Tensor.scaled_dot_product_attention differs:
|
||||
causal_mask = Tensor.ones((cdim, cdim), requires_grad=False).cast(dtypes.bool).tril(0)[key_length - query_length : key_length, :key_length].cast(dtypes.bool)
|
||||
causal_mask = Tensor.ones((cdim, cdim), requires_grad=False, dtype=dtypes.bool).tril(0)[key_length - query_length : key_length, :key_length]
|
||||
return (Tensor.where(causal_mask, attn_weights, -float("inf")) + attn_mask).softmax(-1) @ value
|
||||
|
||||
bsz, _, seq_len, _ = xq.shape
|
||||
|
|
|
@ -149,7 +149,6 @@ backend_test.exclude('test_resize_upsample_sizes_cubic_*') # unsure how to imple
|
|||
|
||||
# rest of the failing tests
|
||||
backend_test.exclude('test_regex_*') # does not support string Tensors
|
||||
backend_test.exclude('test_reshape_allowzero_reordered_cpu') # reshaping to shape with 0, also allowzero
|
||||
backend_test.exclude('test_resize_downsample_scales_linear_antialias_cpu') # antialias not implemented
|
||||
backend_test.exclude('test_resize_downsample_sizes_linear_antialias_cpu') # antialias not implemented
|
||||
backend_test.exclude('test_resize_tf_crop_and_resize_cpu') # unsure about fill value after clip
|
||||
|
|
|
@ -515,11 +515,11 @@ class Tensor:
|
|||
|
||||
def argmax(self, axis=None, keepdim=False):
|
||||
if axis is None:
|
||||
idx = (self == self.max(axis)) * Tensor.arange(prod(self.shape)-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape) # noqa: E501
|
||||
idx = (self == self.max(axis)) * Tensor.arange(prod(self.shape)-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape)
|
||||
return prod(self.shape) - idx.max() - 1
|
||||
axis = axis + len(self.shape) if axis < 0 else axis
|
||||
m = self == self.max(axis=axis, keepdim=True)
|
||||
idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1)) # noqa: E501
|
||||
idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
|
||||
return self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1
|
||||
def argmin(self, axis=None, keepdim=False): return (-self).argmax(axis=axis, keepdim=keepdim)
|
||||
|
||||
|
|
Loading…
Reference in New Issue