From 0c857ae2d661959e977c9ad1c4d809be2fd026c7 Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 21 Jun 2024 22:01:32 -0400 Subject: [PATCH] some onnx_ops cleanups (#5094) --- extra/onnx_ops.py | 43 +++++++++++++++++++------------------------ 1 file changed, 19 insertions(+), 24 deletions(-) diff --git a/extra/onnx_ops.py b/extra/onnx_ops.py index 80469c23..e2566b00 100644 --- a/extra/onnx_ops.py +++ b/extra/onnx_ops.py @@ -24,7 +24,7 @@ def Equal(x:Tensor,y:Tensor): return x == y def Max(*data_0): return functools.reduce(Tensor.maximum, data_0) def Min(*data_0): return functools.reduce(Tensor.minimum, data_0) def Sum(*data_0): return functools.reduce(Tensor.add, data_0) -def Mean(*data_0): return functools.reduce(Tensor.add, data_0) / len(data_0) +def Mean(*data_0): return Sum(*data_0) / len(data_0) # NOTE: does not support saturate def Cast(x: Tensor, to: int, saturate=1): return x.cast(DTYPE_MAP[to]) def CastLike(x: Tensor, target_type: Tensor, saturate=1): return x.cast(target_type.dtype) @@ -34,7 +34,7 @@ def CastLike(x: Tensor, target_type: Tensor, saturate=1): return x.cast(target_t # https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_div.py def Div(x: Tensor, other: Tensor): return (x/other).cast(x.dtype) -def Constant(value: Tensor=None, value_float=None, value_floats=None, value_int=None, value_ints=None, value_string=None, value_strings=None): +def Constant(value:Optional[Tensor]=None, value_float=None, value_floats=None, value_int=None, value_ints=None, value_string=None, value_strings=None): if value is not None: return value if value_float is not None: return Tensor(value_float, dtype=dtypes.float32, requires_grad=False) if value_floats is not None: return Tensor(list(value_floats), dtype=dtypes.float32, requires_grad=False) @@ -70,8 +70,8 @@ def ReduceL2(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): retur def ReduceLogSum(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return ReduceSum(data, axes, keepdims, noop_with_empty_axes).log() def ReduceLogSumExp(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return ReduceSum(data.exp(), axes, keepdims, noop_with_empty_axes).log() -def GlobalAveragePool(X: Tensor): return X.mean(axis=tuple(range(2, len(X.shape))), keepdim=True) -def GlobalMaxPool(X: Tensor): return X.max(axis=tuple(range(2, len(X.shape))), keepdim=True) +def GlobalAveragePool(X: Tensor): return X.mean(axis=tuple(range(2, X.ndim)), keepdim=True) +def GlobalMaxPool(X: Tensor): return X.max(axis=tuple(range(2, X.ndim)), keepdim=True) def OptionalHasElement(x: Optional[Tensor]=None): return Tensor(x is not None and x.numel() > 0) def OptionalGetElement(x: Optional[Tensor]=None): return x if x is not None else Tensor([]) @@ -101,8 +101,8 @@ def Atan(y: Tensor): t4 = t3 * t3 t0 = ((((-0.013480470 * t4 + 0.057477314) * t4 - 0.121239071) * t4 + 0.195635925) * t4 - 0.332994597) * t4 + 0.999995630 t3 = t0 * t3 - t3 = (y.abs() > 1).where(1.570796327 - t3, t3) - return (y < 0).where(-t3, t3) + t3 = (t1 > 1).where(1.570796327 - t3, t3) + return y.sign() * t3 def Trilu(x: Tensor, k: Union[Tensor, int]=0, upper=1): k = to_python_const(k) if isinstance(k, Tensor) else 0 # onnx passes k as a tensor int64 with one element, default is 0 @@ -110,15 +110,12 @@ def Trilu(x: Tensor, k: Union[Tensor, int]=0, upper=1): def Squeeze(data: Tensor, axes): if isinstance(axes, Tensor): axes = to_python_const(axes) - axes = [x + data.ndim if x < 0 else x for x in axes] + axes = [data._resolve_dim(x) for x in axes] return data.reshape([s for i,s in enumerate(data.shape) if i not in axes]) def Unsqueeze(data: Tensor, axes): - axes = [x + data.ndim if x < 0 else x for x in to_python_const(axes)] - new_shape = [1] * (data.ndim + len(axes)) - ptr = iter(data.shape) - for i in range(len(new_shape)): - if i not in axes: - new_shape[i] = next(ptr) + axes = sorted([x + data.ndim if x < 0 else x for x in to_python_const(axes)]) + new_shape = list(data.shape) + for axis in axes: new_shape.insert(axis, 1) return data.reshape(new_shape) def Binarizer(x, threshold=0.0): return (x > threshold).float() @@ -128,11 +125,11 @@ def ArgMax(x: Tensor, axis=0, keepdims=1, select_last_index=0): return x.argmax(axis, keepdim=keepdims).cast(dtypes.int64) def ArgMin(x, axis=0, keepdims=1, select_last_index=0): return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index) -def Concat(*xs: List[Tensor], axis): return xs[0].cat(*xs[1:], dim=axis) -def Transpose(x: Tensor, perm=None): return x.permute(order=list(range(len(x.shape))[::-1]) if perm is None else perm) +def Concat(*xs: List[Tensor], axis): return Tensor.cat(*xs, dim=axis) +def Transpose(x: Tensor, perm=None): return x.permute(order=list(range(x.ndim)[::-1]) if perm is None else perm) def ConstantOfShape(x, value:Tensor=None): - if value is None: value=Tensor([0.0]) + if value is None: value = 0.0 shape = to_python_const(x) return Tensor.ones(*shape, dtype=value.dtype) * (value if shape[0]!=0 else 1) @@ -171,11 +168,11 @@ def BatchNormalization(X: Tensor, scale, B, input_mean, input_var, epsilon=1e-05 running_var = input_var * momentum + current_var * (1 - momentum) return X.batchnorm(scale, B, current_mean, current_invstd), running_mean, running_var - invstd = (input_var + epsilon)**-0.5 + invstd = (input_var + epsilon).rsqrt() return X.batchnorm(scale, B, input_mean, invstd) def InstanceNormalization(x: Tensor, scale: Tensor, bias: Tensor, epsilon=1e-05): - axis = tuple(range(2, len(x.shape))) + axis = tuple(range(2, x.ndim)) mean = x.mean(axis=axis, keepdim=True) invstd = x.sub(mean).pow(2).mean(axis=axis, keepdim=True).add(epsilon).rsqrt() return x.sub(mean).mul(scale.reshape(shape=[-1, 1, 1])).mul(invstd).add(bias.reshape(shape=[-1, 1, 1])) @@ -186,8 +183,6 @@ def LayerNormalization(x: Tensor, scale, bias, axis=-1, epsilon=1e-05, stash_typ mean = x.mean(axis=axis, keepdim=True) return x.layernorm(axis, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).pow(2).mean(axis=axis, keepdim=True).add(epsilon).rsqrt() -# TODO: current implmentation fails tests and tried copying onnx's implementation but got poor accuracy -# https://github.com/onnx/onnx/blob/main/onnx/backend/test/case/node/groupnormalization.py#L13 def GroupNormalization(x: Tensor, scale: Tensor, bias: Tensor, num_groups, epsilon=1e-05): return x.reshape(x.shape[0], num_groups, -1).layernorm(axis=-1, eps=epsilon).mul(scale.unsqueeze(-1)).add(bias.unsqueeze(-1)).reshape(x.shape) @@ -342,7 +337,7 @@ def LRN(x: Tensor, size, alpha=1e-4, beta=0.75, bias=1.0): return x / x.mul(x).reshape(bs,1,c,iy*ix).pad2d((0,0,(size-1)//2, size//2)).avg_pool2d((size, 1), 1).reshape(bs,c,iy,ix).mul(alpha).add(bias).pow(beta) def MeanVarianceNormalization(x: Tensor, axis=(0, 2, 3)): - mean = x.mean(axis=axis, keepdim=True) + mean = x.mean(axis, keepdim=True) std = x.std(axis, keepdim=True, correction=0) return (x - mean) / (std + 1e-9) @@ -358,7 +353,7 @@ def NegativeLogLikelihoodLoss(x: Tensor, target: Tensor, weight=None, ignore_ind if ignore_index is not None: cond = target == ignore_index weight = cond.where(0, weight) if weight is not None else cond.where(0, 1) - mask = target[:, None, :] == Tensor.arange(C).reshape([1, C] + [1]*(len(x.shape) -2)) + mask = target[:, None, :] == Tensor.arange(C).reshape([1, C] + [1]*(x.ndim -2)) loss = -(mask * x).sum(axis=1) * (1 if weight is None else weight) if reduction == "mean": return loss.mean() if weight is None else loss.sum() / weight.sum() if reduction == "sum": return loss.sum() @@ -611,8 +606,8 @@ def SkipLayerNormalization(x:Tensor, skip:Tensor, gamma, beta:Optional[Tensor]=N return x.layernorm(eps=epsilon) * gamma + beta, None, None, x def FastGelu(x:Tensor, bias:Optional[Tensor]=None): - x = x + bias - return 0.5 * x * (1 + (x * 0.797885 + 0.035677 * x ** 3).tanh()) + # this is tanh approamixated + return (x + bias).gelu() def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Optional[Tensor]=None, word_embedding:Tensor=None, position_embedding:Tensor=None, segment_embedding:Optional[Tensor]=None, gamma=None, beta=None, mask:Optional[Tensor]=None, position_ids:Optional[Tensor]=None, epsilon=None, mask_index_type=None): # https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.EmbedLayerNormalization