removed redundant dtype hacks in onnx_ops (#2939)

* updated most dtype hacks in onnx_ops

* temporarily revert dequantizelinear change

* I think this is right...

* MORE FIXES WOOOO NEW DTYPE IS AWESOME

* ok

* oops missed a print

* half -> float32 for CI

* is npdtype

* some more

* fix if ordering

* more clean ups

* final cleanups

* casting to half not allowed

* k nvm

* revert ArgMax change

* only GPU

* llvm begone

* teeny tiny change

* fix: attempt to add cast tests

* try this

* fix dequantizelinear

* revert some stuff

* tests pass pls

* less lines in onnx_tests

* oops missed string tensor tests

* clean up

* try: revert default behavior changes

* fix: disabled Cast and Castlike tests

* docs: small changes

* fix: fixed isNaN op and enabled associated tests

* fix: forgot about float16

* done

* update disabled test

* gah missed another float16

* disable rest of failing tests

* rm extra line

* try...

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
geohotstan 2024-01-04 14:45:24 +08:00 committed by GitHub
parent 9f39165188
commit 57817028bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 69 additions and 68 deletions

View File

@ -2,7 +2,7 @@ from __future__ import annotations
from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
import importlib
import numpy as np
from tinygrad import Tensor, dtypes
from tinygrad import Tensor, dtypes, Device
from tinygrad.helpers import getenv, DEBUG
from typing import List, Dict
from onnx import AttributeProto, ModelProto, TensorProto, TypeProto # onnx 1.50 uses serialized file (see onnx/onnx-ml.proto) as descriptors
@ -24,6 +24,14 @@ def safe_numpy(t) -> np.ndarray:
numpy_cache[t] = tmp
return numpy_cache[t]
# src: onnx/mapping.py
# not supported: STRING = 8 COMPLEX64 = 14, COMPLEX128 = 15
# NOTE: 17, 18, 19, 20 are float8, 10 is half
DTYPE_MAP = {1:dtypes.float, 2:dtypes.uint8, 3:dtypes.int8, 4:dtypes.uint16, 5:dtypes.int16, 6:dtypes.int32, 7:dtypes.int64,
9:dtypes.bool, 10:dtypes.float, 11:dtypes.double, 12:dtypes.uint32, 13:dtypes.uint64, 16:dtypes.bfloat16,
17:dtypes.float, 18:dtypes.float, 19:dtypes.float, 20:dtypes.float}
# TODO: fix buffer_parse to use this and fix get_weight_and_biases to only use buffer_parse
onnx_ops = importlib.import_module('extra.onnx_ops')
ONNXLIMIT = getenv("ONNXLIMIT", -1)
@ -34,11 +42,11 @@ def get_run_onnx(onnx_model: ModelProto):
while True:
attr = type_proto.WhichOneof('value')
if attr == 'tensor_type':
if "dim_value" not in getattr(type_proto, attr).shape.dim.__dir__(): return () # variable type, unable to determine shape
if "dim_value" not in type_proto.tensor_type.shape.dim.__dir__(): return () # variable type, unable to determine shape
elif not ret:
return tuple([x.dim_value for x in getattr(type_proto, attr).shape.dim])
return tuple([x.dim_value for x in type_proto.tensor_type.shape.dim])
else:
ret.extend([(x.dim_value,) for x in getattr(type_proto, attr).shape.dim])
ret.extend([(x.dim_value,) for x in type_proto.tensor_type.shape.dim])
return tuple(ret)
elif attr == 'sequence_type':
type_proto = getattr(type_proto, attr).elem_type
@ -50,7 +58,7 @@ def get_run_onnx(onnx_model: ModelProto):
else: raise Exception(f"unknown attr: {attr}, {type_proto}")
def buffer_parse(inp: TensorProto) -> Tensor:
if inp.data_type in (1,10,6,7,5):
if inp.data_type in (1,10,6,7,5,11):
# TODO: this is shared with below
if len(inp.float_data) > 0:
ret = Tensor(np.array(inp.float_data, dtype=np.float32).reshape(inp.dims), requires_grad=False)

View File

@ -3,42 +3,43 @@ from typing import Union, Tuple, Optional, List, Any
from tinygrad import Tensor, dtypes
from tinygrad.dtype import ImageDType
from tinygrad.helpers import prod, flatten
from extra.onnx import safe_numpy
from extra.onnx import safe_numpy, DTYPE_MAP
from onnx.helper import tensor_dtype_to_np_dtype
from onnx import TensorProto
import numpy as np
tensor_methods = {"Neg", "Reciprocal", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", "Relu", "Sigmoid", "MatMul",
"Floor", "Ceil", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", "Tanh", "Softsign", "Asinh", "Acosh", "Atanh",
tensor_methods = {"Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", "Relu", "Sigmoid", "MatMul",
"Floor", "Ceil", "Softplus", "HardSwish", "Where", "Mul", "Div", "Sinh", "Cosh", "Tanh", "Softsign", "Asinh", "Acosh", "Atanh",
"Elu", "Celu"}
# **************** Free Ops ****************
def Identity(x: Tensor): return x
def Add(x: Tensor, other: Tensor, broadcast=None): return x + other if x.dtype == dtypes.float or isinstance(x.dtype, ImageDType) else (x + other).cast(x.dtype)
# TODO: fix buffer_parse
def Add(x: Tensor, other: Tensor, broadcast=None, axis=None): return x + other if x.dtype == dtypes.float or isinstance(x.dtype, ImageDType) else (x + other).cast(x.dtype)
def Sub(x: Union[Tensor, Any], other: Tensor): return x - other # some test has input as int
def Div(x: Tensor, other: Tensor): return x / other if x.dtype == dtypes.float or isinstance(x.dtype, ImageDType) else x.div(other).floor() # TODO: this has dtype issues
def Pow(x: Tensor, other: Tensor): return x.float() ** other.float()
def Less(x:Tensor,y:Tensor): return (x<y).cast(dtypes.bool)
def LessOrEqual(x:Tensor,y:Tensor): return (x<=y).cast(dtypes.bool)
def Greater(x:Tensor,y:Tensor): return (x>y).cast(dtypes.bool)
def GreaterOrEqual(x:Tensor,y:Tensor): return (x>=y).cast(dtypes.bool)
def Equal(x:Tensor,y:Tensor): return (x==y).cast(dtypes.bool)
def Less(x:Tensor,y:Tensor): return x < y
def LessOrEqual(x:Tensor,y:Tensor): return x <= y
def Greater(x:Tensor,y:Tensor): return x > y
def GreaterOrEqual(x:Tensor,y:Tensor): return x >= y
def Equal(x:Tensor,y:Tensor): return x == y
def Max(*data_0): return functools.reduce(Tensor.maximum, data_0)
def Min(*data_0): return functools.reduce(Tensor.minimum, data_0)
def Sum(*data_0): return functools.reduce(Tensor.__add__, data_0)
def Mean(*data_0): return functools.reduce(Tensor.__add__, data_0) / len(data_0)
def Cast(x: Tensor, to): return x.cast(dtypes.from_np(tensor_dtype_to_np_dtype(to)))
# NOTE: does not support saturate
def Cast(x: Tensor, to: int, saturate=1): return x.cast(DTYPE_MAP[to])
def CastLike(x: Tensor, target_type: Tensor, saturate=1): return x.cast(target_type.dtype)
# **************** Simple Ops ****************
def Constant(value: Tensor=None, value_float=None, value_floats=None, value_int=None, value_ints=None, value_string=None, value_strings=None):
if value: return value
if value_float: return Tensor(value_float, dtype=dtypes.float32, requires_grad=False)
if value_floats: return Tensor(list(value_floats), dtype=dtypes.float32, requires_grad=False)
if value_int: return Tensor(value_int, dtype=dtypes.int64, requires_grad=False)
if value_ints: return Tensor(list(value_ints), dtype=dtypes.int64, requires_grad=False)
if value_string or value_strings: raise NotImplementedError('value_string or value_strings not implemented for Constant op')
if value is not None: return value
if value_float is not None: return Tensor(value_float, dtype=dtypes.float32, requires_grad=False)
if value_floats is not None: return Tensor(list(value_floats), dtype=dtypes.float32, requires_grad=False)
if value_int is not None: return Tensor(value_int, dtype=dtypes.int64, requires_grad=False)
if value_ints is not None: return Tensor(list(value_ints), dtype=dtypes.int64, requires_grad=False)
if value_string is not None or value_strings is not None: raise NotImplementedError('value_string or value_strings not implemented for Constant op')
def HardSigmoid(x: Tensor, alpha=0.2, beta=0.5): return (alpha*x + beta).clip(0, 1)
def Gelu(x:Tensor, approximate=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + Erf(x/math.sqrt(2)))
@ -81,10 +82,10 @@ def Flatten(x: Tensor, axis=1): return x.reshape(prod((1,) + x.shape[0:axis]), -
def Reshape(data: Tensor, shape: Tensor, allowzero=0):
return data.reshape([int(x) if x != 0 else (0 if allowzero else data.shape[i]) for i,x in enumerate(safe_numpy(shape))])
def Shrink(x: Tensor, bias=0.0, lambd=0.5): return (x < -lambd)*(x+bias) + (x > lambd)*(x-bias)
def And(x:Tensor, y:Tensor): return (x==y).where(x, 0).cast(dtypes.bool)
def Or(x:Tensor, y:Tensor): return (x==y).where(x, 1).cast(dtypes.bool)
def Xor(x:Tensor, y:Tensor): return (x==y).where(0, 1).cast(dtypes.bool)
def Not(x:Tensor): return (x==1).where(0, 1).cast(dtypes.bool)
def And(x:Tensor, y:Tensor): return (x==y).where(x, False)
def Or(x:Tensor, y:Tensor): return (x==y).where(x, True)
def Xor(x:Tensor, y:Tensor): return (x==y).where(False, True)
def Not(x:Tensor): return (x==1).where(False, True)
def Asin(x): return Atan(x / Tensor.sqrt(1 - x * x))
def Acos(x: Tensor):
@ -107,8 +108,8 @@ def Atan(y: Tensor):
return (y < 0).where(-t3, t3)
def Trilu(x: Tensor, k: Union[Tensor, int]=0, upper=1):
k = int(k.numpy().item()) if isinstance(k, Tensor) else 0 # onnx passes k as a tensor int64 with one element, default is 0
return x.triu(k).cast(dtypes.int64) if upper else x.tril(k).cast(dtypes.int64)
k = safe_numpy(k).item() if isinstance(k, Tensor) else 0 # onnx passes k as a tensor int64 with one element, default is 0
return x.triu(k) if upper else x.tril(k)
def Squeeze(data: Tensor, axes):
if isinstance(axes, Tensor): axes = safe_numpy(axes)
@ -123,7 +124,7 @@ def Unsqueeze(data: Tensor, axes):
new_shape[i] = next(ptr)
return data.reshape(new_shape)
def Binarizer(x, threshold=0.0): return (x > threshold).cast(dtypes.float32)
def Binarizer(x, threshold=0.0): return (x > threshold).float()
def ArgMax(x: Tensor, axis=0, keepdims=1, select_last_index=0):
axis = axis + x.ndim if axis < 0 else axis
@ -135,12 +136,6 @@ def ArgMin(x, axis=0, keepdims=1, select_last_index=0): return ArgMax(-x, axis=a
def Concat(*xs: List[Tensor], axis): return xs[0].cat(*xs[1:], dim=axis)
def Transpose(x: Tensor, perm=None): return x.permute(order=list(range(len(x.shape))[::-1]) if perm is None else perm)
# NOTE: since we only have one type, this is valid!
# TODO: fix this with dtypes
def CastLike(x, target_type):
assert isinstance(target_type, Tensor), "can only CastLike Tensor"
return x
def ConstantOfShape(x, value:Tensor=None):
if value is None: value=Tensor([0.0])
shape = [int(x) for x in safe_numpy(x)]
@ -239,7 +234,7 @@ def Pad(x: Tensor, pads: Union[Tensor, Tuple[int, ...]], constant_value: Tensor=
constant_value = value if constant_value is None else float(safe_numpy(constant_value))
seq_pads = list(pads) if isinstance(pads, tuple) else safe_numpy(pads)
seq_pads = [math.ceil(i) for i in seq_pads]
seq_axes = safe_numpy(axes).astype(np.int32).tolist() if axes is not None else None
seq_axes = safe_numpy(axes).tolist() if axes is not None else None
base_shape = x.shape
pads = _format_padding(seq_pads, ndims=len(x.shape), axes=seq_axes)
if mode == "wrap":
@ -290,7 +285,7 @@ def MaxPool(X: Tensor, kernel_shape, auto_pad="NOTSET", ceil_mode=0, dilations=1
ret = _padding(X, pads, auto_pad, constant_value=float("-inf"), axes=tuple(range(len(X.shape)))[2:], strides=strides, kernel_shape=kernel_shape, dilations=dilations, ceil_mode=ceil_mode)
ret = ret.max_pool2d(kernel_shape, stride=strides, dilation=dilations)
ret_len, X_len = ret.numel(), X.numel()
indices = ((ret.flatten().unsqueeze(1).expand(ret_len, X_len) == X.flatten().reshape(1, X_len).expand(ret_len, X_len)) * Tensor.arange(X_len).reshape(1, X_len).expand(ret_len, X_len)).sum(1).reshape(ret.shape).cast(dtypes.int64)
indices = ((ret.flatten().unsqueeze(1).expand(ret_len, X_len) == X.flatten().reshape(1, X_len).expand(ret_len, X_len)) * Tensor.arange(X_len, dtype=dtypes.int64).reshape(1, X_len).expand(ret_len, X_len)).sum(1).reshape(ret.shape)
if storage_order: indices = indices.transpose(indices.ndim-2, indices.ndim-1)
return ret, indices
@ -352,7 +347,6 @@ def MeanVarianceNormalization(x: Tensor, axis=(0, 2, 3)):
return (x - data_mean) / (std + 1e-9)
def NegativeLogLikelihoodLoss(x: Tensor, target: Tensor, weight=None, ignore_index=None, reduction="mean"):
target = target.cast(dtypes.float32)
N, C, i_shape = x.shape[0], x.shape[1], x.shape
t_shape = target.shape
if len(x.shape) != 3:
@ -372,7 +366,7 @@ def NegativeLogLikelihoodLoss(x: Tensor, target: Tensor, weight=None, ignore_ind
def SoftmaxCrossEntropyLoss(scores: Tensor, labels: Tensor, weights=None, ignore_index=None, reduction="mean"):
_N, C, *s_dimensions = scores.shape
if ignore_index is not None: labels = (labels == ignore_index).where(C+1, labels).cast(dtypes.int32)
if ignore_index is not None: labels = (labels == ignore_index).where(C+1, labels)
mask = labels.unsqueeze(1) == Tensor.arange(C).reshape(1, C, *[1]*len(s_dimensions))
y = scores.log_softmax(axis=1)
if weights is not None: weights = weights.__getitem__(tuple([labels, *[slice(None)]*(weights.ndim-1)]))
@ -400,15 +394,14 @@ def GatherElements(x: Tensor, indices: Tensor, axis):
def _round(x:Tensor, n:float, equidistant_case = "round_down") -> Tensor:
assert n <= 1, f"n:{n} shouldn't be larger than 1"
b = x.cast(dtypes.int32).contiguous().cast(x.dtype)
b = x.trunc()
b = (b >= 0).where(b+n, b-n)
if equidistant_case == "round_down": return (x > b).where(b+1-n, b-n)
if equidistant_case == "round_up": return (x >= b).where(b+1-n, b-n)
if equidistant_case == "round_to_even":
def _and(cond1, cond2): return ((cond1.cast(dtypes.int) + cond2.cast(dtypes.int)) == 2).where(1, 0)
x_ceil_fraction = x.ceil()/2
cond_ceil_even = x_ceil_fraction.ceil() == x_ceil_fraction
x = (_and(x == b, cond_ceil_even)).where(x+1-n, x)
x = (And(x == b, cond_ceil_even)).where(x+1-n, x)
x = (x > b).where(b+1-n, b-n)
return x
@ -481,8 +474,8 @@ def Resize(X:Tensor, roi=None, scales=None, sizes=None, antialias=0, axes=None,
output_shape = sizes if sizes else [math.floor(x*s) for x,s in zip(X.shape, scales)]
output_shape_ = sizes if sizes else [x*s for x,s in zip(X.shape, scales)]
scales_ = [os/xs for xs, os in zip(X.shape, output_shape)]
x_out = Tensor.arange(output_shape[-1]).cast(dtypes.default_float)
y_out = Tensor.arange(output_shape[-2]).cast(dtypes.default_float)
x_out = Tensor.arange(output_shape[-1], dtype=dtypes.default_float)
y_out = Tensor.arange(output_shape[-2], dtype=dtypes.default_float)
if mode == "nearest":
x_out, y_out = _coordinate_transformation(x_out, y_out, output_shape, scales_, roi)
x_out = _nearest_mode(x_out, nearest_mode, X.shape[-1])
@ -527,7 +520,7 @@ def OneHot(indices: Tensor, depth: Tensor, values: Tensor, axis=-1):
if axis < 0: axis += rank + 1
ls, rs = indices.shape[0:axis], indices.shape[axis: rank]
cond = indices[:,None] == Tensor.arange(depth).reshape((1,) * len(ls) + (depth,) + (1,) * len(rs))
return cond.where(values[1], values[0]).cast(values.dtype)
return cond.where(values[1], values[0])
def Erf(x: Tensor):
t = 1.0 / (1.0 + 0.3275911 * x.abs())
@ -551,10 +544,9 @@ def Compress(inp: Tensor, condition: Tensor, axis=None):
con = Tensor(np.arange(condition.shape[0])[con_np]) # no boolean indexing in Tensor
return inp.__getitem__(tuple([slice(None) if i != axis else con for i in range(inp.ndim)]))
type_map = {TensorProto.DOUBLE: dtypes.double, TensorProto.FLOAT: dtypes.float32}
def EyeLike(x: Tensor, dtype=None, k=0):
if dtype is None: dtype = x.dtype
else: dtype = type_map[dtype]
else: dtype = DTYPE_MAP[int(dtype)]
shape = x.shape
dim = min(x.shape)
if shape[0] == shape[1]:
@ -565,22 +557,16 @@ def EyeLike(x: Tensor, dtype=None, k=0):
def Upsample(X, scales, mode): return Resize(X=X, scales=scales, mode=mode)
# Needs work
def IsInf(x: Tensor, detect_negative=1, detect_positive=1):
ret = (x == float("inf"))*detect_positive + (x == float("-inf"))*detect_negative + Tensor.zeros(*x.shape)
return ret.cast(dtypes.bool)
return (x == float("inf")) * bool(detect_positive) + (x == float("-inf")) * bool(detect_negative)
def DequantizeLinear(x: Tensor, x_scale: Tensor, x_zero_point: Union[Tensor, int] = 0, axis=1):
axis = axis + x.ndim if axis < 0 else axis
x = x.cast(dtypes.float)
if isinstance(x_zero_point, Tensor): x_zero_point.cast(dtypes.float)
x_sc = x_scale.reshape(*[1]*axis, *x_scale.shape, *[1]*(x.ndim - axis - x_scale.ndim))
x_zer = x_zero_point.reshape(*[1]*axis, *x_scale.shape, *[1]*(x.ndim - axis - x_scale.ndim)) if isinstance(x_zero_point, Tensor) else x_zero_point
return ((x - x_zer) * x_sc).cast(x_scale.dtype)
return ((x.float() - x_zer) * x_sc).cast(x_scale.dtype)
# Needs work
def IsNaN(x: Tensor):
return (x < float("-inf")).cast(dtypes.bool)
def IsNaN(x: Tensor): return x != x
# copied from https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_image_decoder.py
# without importing PIL we'll have to manually decode a bunch of image formats like PNG, JPEG, WebP, etc

View File

@ -155,14 +155,6 @@ backend_test.exclude('test_resize_tf_crop_and_resize_cpu') # unsure about fill v
backend_test.exclude('test_ai_onnx_ml_label_encoder_tensor_value_only_mapping_cpu') # bad data type string
backend_test.exclude('test_ai_onnx_ml_label_encoder_tensor_mapping_cpu') # bad data type string
# issue 1556 https://github.com/tinygrad/tinygrad/issues/1556
backend_test.exclude('test_isinf_cpu')
backend_test.exclude('test_isinf_negative_cpu')
backend_test.exclude('test_isinf_positive_cpu')
backend_test.exclude('test_isinf_float16_cpu')
backend_test.exclude('test_isnan_float16_cpu')
backend_test.exclude('test_isnan_cpu')
# issue 1791 fast math messes with these https://github.com/tinygrad/tinygrad/issues/1791
backend_test.exclude('test_resize_upsample_sizes_nearest_axes_2_3_cpu')
backend_test.exclude('test_resize_upsample_sizes_nearest_axes_3_2_cpu')
@ -174,14 +166,25 @@ if Device.DEFAULT in ['METAL']:
backend_test.exclude('test_maxpool_2d_same_lower_cpu')
if Device.DEFAULT in ['GPU', 'METAL']:
backend_test.exclude('test_mish_cpu') # weird inaccuracy
backend_test.exclude('test_mish_expanded_cpu') # weird inaccuracy
backend_test.exclude('test_eyelike_with_dtype_cpu') # backend does not support dtype: Double
# double not supported
backend_test.exclude('test_max_float64_cpu')
backend_test.exclude('test_min_float64_cpu')
backend_test.exclude('test_eyelike_with_dtype_cpu')
# weird inaccuracy
backend_test.exclude('test_mish_cpu')
backend_test.exclude('test_mish_expanded_cpu')
# TODO: llvm has problems with inf
if Device.DEFAULT in ['LLVM']:
backend_test.exclude('test_isinf_cpu')
backend_test.exclude('test_isinf_negative_cpu')
backend_test.exclude('test_isinf_positive_cpu')
# Segfaults in CI, GPU requires cl_khr_fp16
if Device.DEFAULT in ['LLVM', 'CUDA', 'GPU'] and CI:
backend_test.exclude('test_max_float16_cpu')
backend_test.exclude('test_min_float16_cpu')
backend_test.exclude('test_isinf_float16_cpu')
# error: casting to type 'half' is not allowed
backend_test.exclude('test_dequantizelinear_e4m3fn_float16_cpu')
@ -194,6 +197,10 @@ if isinstance(Device[Device.DEFAULT], Compiled):
if Device.DEFAULT == 'METAL':
backend_test.exclude('test_maxpool_2d_same_upper_cpu')
# TODO: problems with nan
backend_test.exclude('test_isnan_float16_cpu')
backend_test.exclude('test_isnan_cpu')
# disable model tests for now since they are slow
if not getenv("MODELTESTS"):
for x in backend_test.test_suite: