cleanup onnx, pass one more reshape test and remove some casts (#2806)

This commit is contained in:
chenyu 2023-12-16 20:40:43 -05:00 committed by GitHub
parent baa94d6142
commit 157c0be509
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 51 additions and 48 deletions

View File

@ -170,7 +170,7 @@ def get_run_onnx(onnx_model: ModelProto):
axes, ends, starts, steps = list(opt.get("axes", range(inp[0].ndim))), list(opt["ends"]), list(opt["starts"]), [1]*inp[0].ndim
else:
starts, ends = inp[1:3]
axes = safe_numpy(Tensor.arange(inp[0].ndim, dtype=dtypes.int32) if len(inp) <= 3 else inp[3]).tolist()
axes = safe_numpy(Tensor.arange(inp[0].ndim) if len(inp) <= 3 else inp[3]).tolist()
steps = safe_numpy(inp[4]) if len(inp) > 4 else [1]*inp[0].ndim
starts, ends = safe_numpy(starts.ceil().cast(dtypes.int32)).tolist(), safe_numpy(ends.ceil().cast(dtypes.int32)).tolist()
arg = [(0,x,1) for x in inp[0].shape]

View File

@ -1,15 +1,15 @@
import functools, io, math, os
import functools, io, math
from typing import Union, Tuple, Optional, List, Any
from tinygrad.tensor import Tensor
from tinygrad.helpers import prod, dtypes, ImageDType, flatten
from tinygrad.nn import Embedding
from extra.onnx import safe_numpy
from onnx.helper import tensor_dtype_to_np_dtype
from onnx import TensorProto
import numpy as np
tensor_methods = {"Neg", "Reciprocal", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", "Relu", "Sigmoid", "MatMul",
"Floor", "Ceil", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", "Tanh", "Softsign", "Asinh", "Acosh", "Atanh", "Elu"}
"Floor", "Ceil", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", "Tanh", "Softsign", "Asinh", "Acosh", "Atanh",
"Elu", "Celu"}
# **************** Free Ops ****************
@ -17,7 +17,7 @@ def Identity(input: Tensor): return input
def Add(input: Tensor, other: Tensor, broadcast=None): return input + other if input.dtype == dtypes.float or isinstance(input.dtype, ImageDType) else (input + other).cast(input.dtype)
def Sub(input: Union[Tensor, Any], other: Tensor): return input - other # some test has input as int
def Div(input: Tensor, other: Tensor): return input / other if input.dtype == dtypes.float or isinstance(input.dtype, ImageDType) else input.div(other).floor() # TODO: this has dtype issues
def Pow(input: Tensor, other: Tensor): return (input.float() ** other.float()).cast(input.dtype) # TODO: this has dtype issues
def Pow(input: Tensor, other: Tensor): return input.float() ** other.float()
def Less(x:Tensor,y:Tensor): return (x<y).cast(dtypes.bool)
def LessOrEqual(x:Tensor,y:Tensor): return (x<=y).cast(dtypes.bool)
def Greater(x:Tensor,y:Tensor): return (x>y).cast(dtypes.bool)
@ -41,7 +41,6 @@ def Constant(value: Tensor=None, value_float=None, value_floats=None, value_int=
def HardSigmoid(input: Tensor, alpha=0.2, beta=0.5): return (alpha*input + beta).clip(0, 1)
def Gelu(x:Tensor, approximate=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + Erf(x/math.sqrt(2)))
def Celu(x:Tensor, alpha=1.0): return x.celu(alpha)
def Selu(X: Tensor, alpha=1.67326319217681884765625, gamma=1.05070102214813232421875): return gamma * (X.relu() - (-alpha*X.exp()+alpha).relu())
def PRelu(X:Tensor, slope:Tensor):
slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope # HACK OnnxBackendPyTorchConvertedModelTest HAS WEIRD SLOPE WHERE IT'S [0.25, 0.25, 0.25] FOR ANY X.SHAPE
@ -55,7 +54,9 @@ def LogSoftmax(input: Tensor, axis=-1): return input.log_softmax(axis)
def Clip(input: Tensor, min=None, max=None): return input.clip(float('-inf') if min is None else min, float('inf') if max is None else max)
# NOTE ReduceProd would require a new llop
def _axes(axes, noop_with_empty_axes): return [int(x) for x in safe_numpy(axes)] if axes is not None and not (isinstance(axes, Tensor) and axes.shape == (0,)) else ([] if noop_with_empty_axes else None)
def _axes(axes, noop_with_empty_axes):
if axes is not None and not (isinstance(axes, Tensor) and axes.shape == (0,)): return [int(x) for x in safe_numpy(axes)]
return [] if noop_with_empty_axes else None
def ReduceMax(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.max(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceMin(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.min(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceSum(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
@ -68,15 +69,16 @@ def ReduceLogSumExp(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0)
def GlobalAveragePool(X: Tensor): return X.mean(axis=tuple(range(2, len(X.shape))), keepdim=True)
def GlobalMaxPool(X: Tensor): return X.max(axis=tuple(range(2, len(X.shape))), keepdim=True)
def OptionalHasElement(x: Optional[Tensor]=None): return Tensor(x is not None and x.numel() > 0, dtype=dtypes.bool)
def OptionalGetElement(x: Optional[Tensor]=None): return x if x is not None else Tensor([], dtype=dtypes.float32)
def OptionalHasElement(x: Optional[Tensor]=None): return Tensor(x is not None and x.numel() > 0)
def OptionalGetElement(x: Optional[Tensor]=None): return x if x is not None else Tensor([])
def Tile(input: Tensor, repeats): return input.repeat([int(x) for x in safe_numpy(repeats)])
def Range(start: Tensor, limit, delta): return Tensor.arange(start=int(safe_numpy(start).item()), stop=int(safe_numpy(limit).item()), step=int(safe_numpy(delta).item())).cast(dtype=start.dtype)
def Shape(data: Tensor, end=None, start=0): return Tensor(list(data.shape)[start:end], dtype=dtypes.int32 if os.path.isfile("/TICI") else dtypes.int64) # TODO: really?
def Range(start: Tensor, limit, delta): return Tensor.arange(start=safe_numpy(start).item(), stop=safe_numpy(limit).item(), step=safe_numpy(delta).item())
def Shape(data: Tensor, end=None, start=0): return Tensor(list(data.shape)[start:end], dtype=dtypes.int64)
def Size(data: Tensor): return prod(data if isinstance(data, list) else data.shape)
def Flatten(input: Tensor, axis=1): return input.reshape(prod((1,) + input.shape[0:axis]), -1)
def Reshape(data: Tensor, shape: Tensor, allowzero=None): return data.reshape([int(x) if x != 0 else data.shape[i] for i,x in enumerate(safe_numpy(shape))])
def Reshape(data: Tensor, shape: Tensor, allowzero=0):
return data.reshape([int(x) if x != 0 else (0 if allowzero else data.shape[i]) for i,x in enumerate(safe_numpy(shape))])
def Shrink(input: Tensor, bias=0.0, lambd=0.5): return (input < -lambd)*(input+bias) + (input > lambd)*(input-bias)
def And(x:Tensor, y:Tensor): return (x==y).where(x, 0).cast(dtypes.bool)
def Or(x:Tensor, y:Tensor): return (x==y).where(x, 1).cast(dtypes.bool)
@ -107,13 +109,13 @@ def Trilu(x: Tensor, k: Union[Tensor, int]=0, upper=1):
k = int(k.numpy().item()) if isinstance(k, Tensor) else 0 # onnx passes k as a tensor int64 with one element, default is 0
return x.triu(k) if upper else x.tril(k)
def Squeeze(input: Tensor, axes):
def Squeeze(data: Tensor, axes):
if isinstance(axes, Tensor): axes = safe_numpy(axes)
axes = [int(x) if x >= 0 else int(x+input.ndim) for x in axes]
return input.reshape([s for i,s in enumerate(input.shape) if i not in axes])
axes = [int(x) + data.ndim if x < 0 else int(x) for x in axes]
return data.reshape([s for i,s in enumerate(data.shape) if i not in axes])
def Unsqueeze(data: Tensor, axes):
axes = [len(data.shape) + int(x) if x < 0 else int(x) for x in safe_numpy(axes)]
new_shape = [1] * (len(data.shape) + len(axes))
axes = [int(x) + data.ndim if x < 0 else int(x) for x in safe_numpy(axes)]
new_shape = [1] * (data.ndim + len(axes))
ptr = iter(data.shape)
for i in range(len(new_shape)):
if i not in axes:
@ -155,7 +157,7 @@ def Expand(input: Tensor, shape):
def Gemm(A: Tensor, B: Tensor, C: Tensor=None, alpha=1.0, beta=1.0, transA=0, transB=0, broadcast=0):
ret = alpha * (A.transpose(transA) @ B.transpose(transB))
if C is not None: ret += beta * (C if broadcast == 0 else C.reshape([-1 if i < len(C.shape) else 1 for i in range(len(ret.shape))][::-1]))
if C is not None: ret += beta * (C if broadcast == 0 else C.reshape([-1 if i < len(C.shape) else 1 for i in range(ret.ndim)][::-1]))
return ret
# works with Tensors.ndim != 4
@ -174,7 +176,7 @@ def BatchNormalization(X: Tensor, scale, B, input_mean, input_var, epsilon=1e-05
current_mean = x_detached.mean(axis=(0,2,3))
y = (x_detached - current_mean.reshape(shape=[1, -1, 1, 1]))
current_var = (y*y).mean(axis=(0,2,3))
current_invstd = current_var.add(epsilon).pow(-0.5)
current_invstd = current_var.add(epsilon).rsqrt()
running_mean = input_mean * momentum + current_mean * (1 - momentum)
running_var = input_var * momentum + current_var * (1 - momentum)
@ -187,14 +189,14 @@ def BatchNormalization(X: Tensor, scale, B, input_mean, input_var, epsilon=1e-05
def InstanceNormalization(x: Tensor, scale: Tensor, bias: Tensor, epsilon=1e-05):
axis = tuple(range(2, len(x.shape)))
mean = x.mean(axis=axis, keepdim=True)
invstd = x.sub(mean).pow(2).mean(axis=axis, keepdim=True).add(epsilon).pow(-0.5)
invstd = x.sub(mean).pow(2).mean(axis=axis, keepdim=True).add(epsilon).rsqrt()
return x.sub(mean).mul(scale.reshape(shape=[-1, 1, 1])).mul(invstd).add(bias.reshape(shape=[-1, 1, 1]))
def LayerNormalization(x: Tensor, scale, bias, axis=-1, epsilon=1e-05, stash_type=1):
assert stash_type == 1, "only float32 is supported"
axis = tuple(i for i in range(axis if axis >= 0 else len(x.shape) + axis, len(x.shape)))
axis = tuple(i for i in range(axis if axis >= 0 else x.ndim + axis, x.ndim))
mean = x.mean(axis=axis, keepdim=True)
return x.layernorm(axis, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).pow(2).mean(axis=axis, keepdim=True).add(epsilon).sqrt().reciprocal()
return x.layernorm(axis, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).pow(2).mean(axis=axis, keepdim=True).add(epsilon).rsqrt()
def GroupNormalization(x: Tensor, scale: Tensor, bias: Tensor, num_groups, epsilon=1e-05):
return x.reshape(x.shape[0], num_groups, -1).layernorm(axis=-1, eps=epsilon).mul(scale.unsqueeze(-1)).add(bias.unsqueeze(-1)).reshape(x.shape)
@ -380,7 +382,8 @@ def SoftmaxCrossEntropyLoss(scores: Tensor, labels: Tensor, weights=None, ignore
elif reduction == "sum": loss = loss.sum()
return loss, y
def ArrayFeatureExtractor(input: Tensor, indices: Tensor): return input.__getitem__(tuple([slice(None) if i != (input.ndim-1) else indices for i in range(input.ndim)]))
def ArrayFeatureExtractor(input: Tensor, indices: Tensor):
return input.__getitem__(tuple([slice(None) if i != (input.ndim-1) else indices for i in range(input.ndim)]))
def Gather(input: Tensor, indices: Tensor, axis=0):
if indices.numel() < 9: # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices
input_sh = list(input.shape)
@ -393,17 +396,17 @@ def Gather(input: Tensor, indices: Tensor, axis=0):
return input.__getitem__(tuple([slice(None) if i != axis else indices for i in range(input.ndim)]))
def GatherElements(input: Tensor, indices: Tensor, axis):
indices = indices.sign().contiguous().__neg__().contiguous().relu() * input.shape[axis] + indices
indices = (indices < 0).where(input.shape[axis], 0) + indices
return input.gather(indices, axis)
def _round(x:Tensor, n:float, equidistant_case = "round_down") -> Tensor:
def _and(cond1, cond2): return ((cond1 + cond2) == 2).where(1, 0)
assert n <= 1, f"n:{n} shouldn't be larger than 1"
b = x.cast(dtypes.int32).contiguous().cast(x.dtype)
b = (b >= 0).where(b+n, b-n)
if equidistant_case == "round_down": return (x > b).where(b+1-n, b-n)
elif equidistant_case == "round_up": return (x >= b).where(b+1-n, b-n)
elif equidistant_case == "round_to_even":
def _and(cond1, cond2): return ((cond1 + cond2) == 2).where(1, 0)
x_ceil_fraction = x.ceil()/2
cond_ceil_even = x_ceil_fraction.ceil() == x_ceil_fraction
x = (_and(x == b, cond_ceil_even)).where(x+1-n, x)
@ -413,7 +416,9 @@ def _round(x:Tensor, n:float, equidistant_case = "round_down") -> Tensor:
def Round(X:Tensor): return _round(X, 0.5, "round_to_even")
# TODO clean this up, it's taking the longest in CI
def Resize(X:Tensor, roi=None, scales=None, sizes=None, antialias=0, axes=None, coordinate_transformation_mode='half_pixel', cubic_coeff_a=-0.75, exclude_outside=0, extrapolation_value=0.0, keep_aspect_ratio_policy='stretch', mode='nearest', nearest_mode='round_prefer_floor'):
def Resize(X:Tensor, roi=None, scales=None, sizes=None, antialias=0, axes=None, coordinate_transformation_mode='half_pixel',
cubic_coeff_a=-0.75, exclude_outside=0, extrapolation_value=0.0, keep_aspect_ratio_policy='stretch',
mode='nearest', nearest_mode='round_prefer_floor'):
def _nearest_gather(X: Tensor, x_out, y_out): return X[:,:,y_out,:][:,:,:,x_out]
def _nearest_mode(x_resized: Tensor, nearest_mode: str, x_len):
if nearest_mode == "round_prefer_floor": ret = _round(x_resized, 0.5, "round_down")
@ -421,22 +426,22 @@ def Resize(X:Tensor, roi=None, scales=None, sizes=None, antialias=0, axes=None,
elif nearest_mode == "floor": ret = x_resized.floor()
elif nearest_mode == "ceil": ret = x_resized.ceil()
return ret.clip(0, x_len-1)
def _coordinate_transformation(x_out, y_out, output_shape, scales_lol, roi=None):
def _coordinate_transformation(x_out, y_out, output_shape, scales_, roi=None):
if coordinate_transformation_mode == "half_pixel":
x_out = (x_out + 0.5)/Tensor(scales_lol[-1]) - 0.5 # TODO Tensor() because try (((Tensor([0,1,2,3,4,5])+0.5)/3.5 - 0.5)) with LLVM or METAL, inaccuacy.
y_out = (y_out + 0.5)/Tensor(scales_lol[-2]) - 0.5
x_out = (x_out + 0.5)/Tensor(scales_[-1]) - 0.5 # TODO Tensor() because try (((Tensor([0,1,2,3,4,5])+0.5)/3.5 - 0.5)) with LLVM or METAL, inaccuacy.
y_out = (y_out + 0.5)/Tensor(scales_[-2]) - 0.5
elif coordinate_transformation_mode == "align_corners":
x_out = x_out * (X.shape[-1] - 1) / (output_shape[-1] - 1)
y_out = y_out * (X.shape[-2] - 1) / (output_shape[-2] - 1)
elif coordinate_transformation_mode == "asymmetric":
x_out = x_out/scales_lol[-1]
y_out = y_out/scales_lol[-2]
x_out = x_out/scales_[-1]
y_out = y_out/scales_[-2]
elif coordinate_transformation_mode == "half_pixel_symmetric":
x_out = X.shape[-1] / 2 * (1 - int(output_shape[-1]) / output_shape[-1]) + (x_out + 0.5) / scales_lol[-1] - 0.5
y_out = X.shape[-2] / 2 * (1 - int(output_shape[-2]) / output_shape[-2]) + (y_out + 0.5) / scales_lol[-2] - 0.5
x_out = X.shape[-1] / 2 * (1 - int(output_shape[-1]) / output_shape[-1]) + (x_out + 0.5) / scales_[-1] - 0.5
y_out = X.shape[-2] / 2 * (1 - int(output_shape[-2]) / output_shape[-2]) + (y_out + 0.5) / scales_[-2] - 0.5
elif coordinate_transformation_mode == "pytorch_half_pixel":
x_out = (x_out + 0.5)/scales_lol[-1] - 0.5 if output_shape[-1] > 1 else Tensor([0])
y_out = (y_out + 0.5)/scales_lol[-2] - 0.5 if output_shape[-2] > 1 else Tensor([0])
x_out = (x_out + 0.5)/scales_[-1] - 0.5 if output_shape[-1] > 1 else Tensor([0])
y_out = (y_out + 0.5)/scales_[-2] - 0.5 if output_shape[-2] > 1 else Tensor([0])
elif coordinate_transformation_mode == "tf_crop_and_resize":
x_out = roi[-1][0] * (X.shape[-1] - 1) + x_out * ((roi[-1][1] - roi[-1][0]) * (X.shape[-1] - 1) / (output_shape[-1] - 1)) if output_shape[-1] > 1 else Tensor([0.5 * (roi[-1][0] + roi[-1][1]) * (X.shape[-1] - 1)])
y_out = roi[-2][0] * (X.shape[-2] - 1) + y_out * ((roi[-2][1] - roi[-2][0]) * (X.shape[-2] - 1) / (output_shape[-2] - 1)) if output_shape[-2] > 1 else Tensor([0.5 * (roi[-2][0] + roi[-2][1]) * (X.shape[-2] - 1)])
@ -476,11 +481,11 @@ def Resize(X:Tensor, roi=None, scales=None, sizes=None, antialias=0, axes=None,
sizes = list(X.shape[:-2]) + [int(i) for i in safe_numpy(sizes)]
output_shape = sizes if sizes else [math.floor(x*s) for x,s in zip(X.shape, scales)]
output_shape_ = sizes if sizes else [x*s for x,s in zip(X.shape, scales)]
scales_lol = [os/xs for xs, os in zip(X.shape, output_shape)]
scales_ = [os/xs for xs, os in zip(X.shape, output_shape)]
x_out = Tensor.arange(output_shape[-1]).cast(Tensor.default_type)
y_out = Tensor.arange(output_shape[-2]).cast(Tensor.default_type)
if mode == "nearest":
x_out, y_out = _coordinate_transformation(x_out, y_out, output_shape, scales_lol, roi)
x_out, y_out = _coordinate_transformation(x_out, y_out, output_shape, scales_, roi)
x_out = _nearest_mode(x_out, nearest_mode, X.shape[-1])
y_out = _nearest_mode(y_out, nearest_mode, X.shape[-1])
return _nearest_gather(X, x_out, y_out)
@ -510,7 +515,7 @@ def Resize(X:Tensor, roi=None, scales=None, sizes=None, antialias=0, axes=None,
def CenterCropPad(input: Tensor, shape: Tensor, axes=None):
if not axes: axes = list(range(input.ndim))
shrink_arg = [(0,i) for i in input.shape]
pad_arg = [(0,0) for _ in range(input.ndim)]
pad_arg = [(0,0)] * input.ndim
shape = safe_numpy(shape).tolist()
for s, x in zip(shape, axes):
if s < input.shape[x]: shrink_arg[x] = (input.shape[x]//2 - s//2, input.shape[x]//2 + s//2) if s%2 == 0 else (input.shape[x]//2 - s//2 - 1, input.shape[x]//2 + s//2)
@ -519,23 +524,22 @@ def CenterCropPad(input: Tensor, shape: Tensor, axes=None):
def OneHot(indices: Tensor, depth: Tensor, values: Tensor, axis=-1):
depth = int(safe_numpy(depth).item())
indices, rank = (indices < 0).where(indices+depth, indices), len(indices.shape)
indices, rank = (indices < 0).where(indices+depth, indices), indices.ndim
if axis < 0: axis += rank + 1
ls, rs = indices.shape[0:axis], indices.shape[axis: rank]
cond = indices[:,None] == Tensor.arange(depth).reshape((1,) * len(ls) + (depth,) + (1,) * len(rs))
return cond.where(values[1], values[0]).cast(values.dtype)
def Erf(x: Tensor):
sign = x.sign()
x = x.abs()
t = 1.0 / (1.0 + 0.3275911 * x)
t = 1.0 / (1.0 + 0.3275911 * x.abs())
term1 = 0.254829592 * t
term2 = -0.284496736 * t ** 2
term3 = 1.421413741 * t ** 3
term4 = -1.453152027 * t ** 4
term5 = 1.061405429 * t ** 5
y = (term1 + term2 + term3 + term4 + term5)
return sign * (1.0 - y * Tensor.exp(-x * x))
z = 1.0 - y * Tensor.exp(-x * x)
return (x > 0).where(z, -z)
def Compress(inp: Tensor, condition: Tensor, axis=None):
if axis is None:
@ -570,7 +574,7 @@ def IsInf(x: Tensor, detect_negative=1, detect_positive=1):
def DequantizeLinear(x: Tensor, x_scale: Tensor, x_zero_point: Union[Tensor, int] = 0, axis=1):
axis = axis + x.ndim if axis < 0 else axis
x = x.cast(dtypes.float)
if x_zero_point.__class__ is Tensor: x_zero_point.cast(dtypes.float)
if isinstance(x_zero_point, Tensor): x_zero_point.cast(dtypes.float)
x_sc = x_scale.reshape(*[1]*axis, *x_scale.shape, *[1]*(x.ndim - axis - x_scale.ndim))
x_zer = x_zero_point.reshape(*[1]*axis, *x_scale.shape, *[1]*(x.ndim - axis - x_scale.ndim)) if isinstance(x_zero_point, Tensor) else x_zero_point
return ((x - x_zer) * x_sc).cast(x_scale.dtype)
@ -686,7 +690,7 @@ def Attention(input:Tensor, weights, bias:Optional[Tensor]=None, mask_index:Opti
cdim = max(query_length, key_length) + 1
attn_weights = query @ key.transpose(-1, -2) / math.sqrt(value.shape[-1])
# This is where Tensor.scaled_dot_product_attention differs:
causal_mask = Tensor.ones((cdim, cdim), requires_grad=False).cast(dtypes.bool).tril(0)[key_length - query_length : key_length, :key_length].cast(dtypes.bool)
causal_mask = Tensor.ones((cdim, cdim), requires_grad=False, dtype=dtypes.bool).tril(0)[key_length - query_length : key_length, :key_length]
return (Tensor.where(causal_mask, attn_weights, -float("inf")) + attn_mask).softmax(-1) @ value
bsz, _, seq_len, _ = xq.shape

View File

@ -149,7 +149,6 @@ backend_test.exclude('test_resize_upsample_sizes_cubic_*') # unsure how to imple
# rest of the failing tests
backend_test.exclude('test_regex_*') # does not support string Tensors
backend_test.exclude('test_reshape_allowzero_reordered_cpu') # reshaping to shape with 0, also allowzero
backend_test.exclude('test_resize_downsample_scales_linear_antialias_cpu') # antialias not implemented
backend_test.exclude('test_resize_downsample_sizes_linear_antialias_cpu') # antialias not implemented
backend_test.exclude('test_resize_tf_crop_and_resize_cpu') # unsure about fill value after clip

View File

@ -515,11 +515,11 @@ class Tensor:
def argmax(self, axis=None, keepdim=False):
if axis is None:
idx = (self == self.max(axis)) * Tensor.arange(prod(self.shape)-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape) # noqa: E501
idx = (self == self.max(axis)) * Tensor.arange(prod(self.shape)-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape)
return prod(self.shape) - idx.max() - 1
axis = axis + len(self.shape) if axis < 0 else axis
m = self == self.max(axis=axis, keepdim=True)
idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1)) # noqa: E501
idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
return self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1
def argmin(self, axis=None, keepdim=False): return (-self).argmax(axis=axis, keepdim=keepdim)