some onnx_ops cleanups (#5094)

This commit is contained in:
chenyu 2024-06-21 22:01:32 -04:00 committed by GitHub
parent f4a041af16
commit 0c857ae2d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 19 additions and 24 deletions

View File

@ -24,7 +24,7 @@ def Equal(x:Tensor,y:Tensor): return x == y
def Max(*data_0): return functools.reduce(Tensor.maximum, data_0) def Max(*data_0): return functools.reduce(Tensor.maximum, data_0)
def Min(*data_0): return functools.reduce(Tensor.minimum, data_0) def Min(*data_0): return functools.reduce(Tensor.minimum, data_0)
def Sum(*data_0): return functools.reduce(Tensor.add, data_0) def Sum(*data_0): return functools.reduce(Tensor.add, data_0)
def Mean(*data_0): return functools.reduce(Tensor.add, data_0) / len(data_0) def Mean(*data_0): return Sum(*data_0) / len(data_0)
# NOTE: does not support saturate # NOTE: does not support saturate
def Cast(x: Tensor, to: int, saturate=1): return x.cast(DTYPE_MAP[to]) def Cast(x: Tensor, to: int, saturate=1): return x.cast(DTYPE_MAP[to])
def CastLike(x: Tensor, target_type: Tensor, saturate=1): return x.cast(target_type.dtype) def CastLike(x: Tensor, target_type: Tensor, saturate=1): return x.cast(target_type.dtype)
@ -34,7 +34,7 @@ def CastLike(x: Tensor, target_type: Tensor, saturate=1): return x.cast(target_t
# https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_div.py # https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_div.py
def Div(x: Tensor, other: Tensor): return (x/other).cast(x.dtype) def Div(x: Tensor, other: Tensor): return (x/other).cast(x.dtype)
def Constant(value: Tensor=None, value_float=None, value_floats=None, value_int=None, value_ints=None, value_string=None, value_strings=None): def Constant(value:Optional[Tensor]=None, value_float=None, value_floats=None, value_int=None, value_ints=None, value_string=None, value_strings=None):
if value is not None: return value if value is not None: return value
if value_float is not None: return Tensor(value_float, dtype=dtypes.float32, requires_grad=False) if value_float is not None: return Tensor(value_float, dtype=dtypes.float32, requires_grad=False)
if value_floats is not None: return Tensor(list(value_floats), dtype=dtypes.float32, requires_grad=False) if value_floats is not None: return Tensor(list(value_floats), dtype=dtypes.float32, requires_grad=False)
@ -70,8 +70,8 @@ def ReduceL2(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): retur
def ReduceLogSum(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return ReduceSum(data, axes, keepdims, noop_with_empty_axes).log() def ReduceLogSum(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return ReduceSum(data, axes, keepdims, noop_with_empty_axes).log()
def ReduceLogSumExp(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return ReduceSum(data.exp(), axes, keepdims, noop_with_empty_axes).log() def ReduceLogSumExp(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return ReduceSum(data.exp(), axes, keepdims, noop_with_empty_axes).log()
def GlobalAveragePool(X: Tensor): return X.mean(axis=tuple(range(2, len(X.shape))), keepdim=True) def GlobalAveragePool(X: Tensor): return X.mean(axis=tuple(range(2, X.ndim)), keepdim=True)
def GlobalMaxPool(X: Tensor): return X.max(axis=tuple(range(2, len(X.shape))), keepdim=True) def GlobalMaxPool(X: Tensor): return X.max(axis=tuple(range(2, X.ndim)), keepdim=True)
def OptionalHasElement(x: Optional[Tensor]=None): return Tensor(x is not None and x.numel() > 0) def OptionalHasElement(x: Optional[Tensor]=None): return Tensor(x is not None and x.numel() > 0)
def OptionalGetElement(x: Optional[Tensor]=None): return x if x is not None else Tensor([]) def OptionalGetElement(x: Optional[Tensor]=None): return x if x is not None else Tensor([])
@ -101,8 +101,8 @@ def Atan(y: Tensor):
t4 = t3 * t3 t4 = t3 * t3
t0 = ((((-0.013480470 * t4 + 0.057477314) * t4 - 0.121239071) * t4 + 0.195635925) * t4 - 0.332994597) * t4 + 0.999995630 t0 = ((((-0.013480470 * t4 + 0.057477314) * t4 - 0.121239071) * t4 + 0.195635925) * t4 - 0.332994597) * t4 + 0.999995630
t3 = t0 * t3 t3 = t0 * t3
t3 = (y.abs() > 1).where(1.570796327 - t3, t3) t3 = (t1 > 1).where(1.570796327 - t3, t3)
return (y < 0).where(-t3, t3) return y.sign() * t3
def Trilu(x: Tensor, k: Union[Tensor, int]=0, upper=1): def Trilu(x: Tensor, k: Union[Tensor, int]=0, upper=1):
k = to_python_const(k) if isinstance(k, Tensor) else 0 # onnx passes k as a tensor int64 with one element, default is 0 k = to_python_const(k) if isinstance(k, Tensor) else 0 # onnx passes k as a tensor int64 with one element, default is 0
@ -110,15 +110,12 @@ def Trilu(x: Tensor, k: Union[Tensor, int]=0, upper=1):
def Squeeze(data: Tensor, axes): def Squeeze(data: Tensor, axes):
if isinstance(axes, Tensor): axes = to_python_const(axes) if isinstance(axes, Tensor): axes = to_python_const(axes)
axes = [x + data.ndim if x < 0 else x for x in axes] axes = [data._resolve_dim(x) for x in axes]
return data.reshape([s for i,s in enumerate(data.shape) if i not in axes]) return data.reshape([s for i,s in enumerate(data.shape) if i not in axes])
def Unsqueeze(data: Tensor, axes): def Unsqueeze(data: Tensor, axes):
axes = [x + data.ndim if x < 0 else x for x in to_python_const(axes)] axes = sorted([x + data.ndim if x < 0 else x for x in to_python_const(axes)])
new_shape = [1] * (data.ndim + len(axes)) new_shape = list(data.shape)
ptr = iter(data.shape) for axis in axes: new_shape.insert(axis, 1)
for i in range(len(new_shape)):
if i not in axes:
new_shape[i] = next(ptr)
return data.reshape(new_shape) return data.reshape(new_shape)
def Binarizer(x, threshold=0.0): return (x > threshold).float() def Binarizer(x, threshold=0.0): return (x > threshold).float()
@ -128,11 +125,11 @@ def ArgMax(x: Tensor, axis=0, keepdims=1, select_last_index=0):
return x.argmax(axis, keepdim=keepdims).cast(dtypes.int64) return x.argmax(axis, keepdim=keepdims).cast(dtypes.int64)
def ArgMin(x, axis=0, keepdims=1, select_last_index=0): return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index) def ArgMin(x, axis=0, keepdims=1, select_last_index=0): return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index)
def Concat(*xs: List[Tensor], axis): return xs[0].cat(*xs[1:], dim=axis) def Concat(*xs: List[Tensor], axis): return Tensor.cat(*xs, dim=axis)
def Transpose(x: Tensor, perm=None): return x.permute(order=list(range(len(x.shape))[::-1]) if perm is None else perm) def Transpose(x: Tensor, perm=None): return x.permute(order=list(range(x.ndim)[::-1]) if perm is None else perm)
def ConstantOfShape(x, value:Tensor=None): def ConstantOfShape(x, value:Tensor=None):
if value is None: value=Tensor([0.0]) if value is None: value = 0.0
shape = to_python_const(x) shape = to_python_const(x)
return Tensor.ones(*shape, dtype=value.dtype) * (value if shape[0]!=0 else 1) return Tensor.ones(*shape, dtype=value.dtype) * (value if shape[0]!=0 else 1)
@ -171,11 +168,11 @@ def BatchNormalization(X: Tensor, scale, B, input_mean, input_var, epsilon=1e-05
running_var = input_var * momentum + current_var * (1 - momentum) running_var = input_var * momentum + current_var * (1 - momentum)
return X.batchnorm(scale, B, current_mean, current_invstd), running_mean, running_var return X.batchnorm(scale, B, current_mean, current_invstd), running_mean, running_var
invstd = (input_var + epsilon)**-0.5 invstd = (input_var + epsilon).rsqrt()
return X.batchnorm(scale, B, input_mean, invstd) return X.batchnorm(scale, B, input_mean, invstd)
def InstanceNormalization(x: Tensor, scale: Tensor, bias: Tensor, epsilon=1e-05): def InstanceNormalization(x: Tensor, scale: Tensor, bias: Tensor, epsilon=1e-05):
axis = tuple(range(2, len(x.shape))) axis = tuple(range(2, x.ndim))
mean = x.mean(axis=axis, keepdim=True) mean = x.mean(axis=axis, keepdim=True)
invstd = x.sub(mean).pow(2).mean(axis=axis, keepdim=True).add(epsilon).rsqrt() invstd = x.sub(mean).pow(2).mean(axis=axis, keepdim=True).add(epsilon).rsqrt()
return x.sub(mean).mul(scale.reshape(shape=[-1, 1, 1])).mul(invstd).add(bias.reshape(shape=[-1, 1, 1])) return x.sub(mean).mul(scale.reshape(shape=[-1, 1, 1])).mul(invstd).add(bias.reshape(shape=[-1, 1, 1]))
@ -186,8 +183,6 @@ def LayerNormalization(x: Tensor, scale, bias, axis=-1, epsilon=1e-05, stash_typ
mean = x.mean(axis=axis, keepdim=True) mean = x.mean(axis=axis, keepdim=True)
return x.layernorm(axis, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).pow(2).mean(axis=axis, keepdim=True).add(epsilon).rsqrt() return x.layernorm(axis, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).pow(2).mean(axis=axis, keepdim=True).add(epsilon).rsqrt()
# TODO: current implmentation fails tests and tried copying onnx's implementation but got poor accuracy
# https://github.com/onnx/onnx/blob/main/onnx/backend/test/case/node/groupnormalization.py#L13
def GroupNormalization(x: Tensor, scale: Tensor, bias: Tensor, num_groups, epsilon=1e-05): def GroupNormalization(x: Tensor, scale: Tensor, bias: Tensor, num_groups, epsilon=1e-05):
return x.reshape(x.shape[0], num_groups, -1).layernorm(axis=-1, eps=epsilon).mul(scale.unsqueeze(-1)).add(bias.unsqueeze(-1)).reshape(x.shape) return x.reshape(x.shape[0], num_groups, -1).layernorm(axis=-1, eps=epsilon).mul(scale.unsqueeze(-1)).add(bias.unsqueeze(-1)).reshape(x.shape)
@ -342,7 +337,7 @@ def LRN(x: Tensor, size, alpha=1e-4, beta=0.75, bias=1.0):
return x / x.mul(x).reshape(bs,1,c,iy*ix).pad2d((0,0,(size-1)//2, size//2)).avg_pool2d((size, 1), 1).reshape(bs,c,iy,ix).mul(alpha).add(bias).pow(beta) return x / x.mul(x).reshape(bs,1,c,iy*ix).pad2d((0,0,(size-1)//2, size//2)).avg_pool2d((size, 1), 1).reshape(bs,c,iy,ix).mul(alpha).add(bias).pow(beta)
def MeanVarianceNormalization(x: Tensor, axis=(0, 2, 3)): def MeanVarianceNormalization(x: Tensor, axis=(0, 2, 3)):
mean = x.mean(axis=axis, keepdim=True) mean = x.mean(axis, keepdim=True)
std = x.std(axis, keepdim=True, correction=0) std = x.std(axis, keepdim=True, correction=0)
return (x - mean) / (std + 1e-9) return (x - mean) / (std + 1e-9)
@ -358,7 +353,7 @@ def NegativeLogLikelihoodLoss(x: Tensor, target: Tensor, weight=None, ignore_ind
if ignore_index is not None: if ignore_index is not None:
cond = target == ignore_index cond = target == ignore_index
weight = cond.where(0, weight) if weight is not None else cond.where(0, 1) weight = cond.where(0, weight) if weight is not None else cond.where(0, 1)
mask = target[:, None, :] == Tensor.arange(C).reshape([1, C] + [1]*(len(x.shape) -2)) mask = target[:, None, :] == Tensor.arange(C).reshape([1, C] + [1]*(x.ndim -2))
loss = -(mask * x).sum(axis=1) * (1 if weight is None else weight) loss = -(mask * x).sum(axis=1) * (1 if weight is None else weight)
if reduction == "mean": return loss.mean() if weight is None else loss.sum() / weight.sum() if reduction == "mean": return loss.mean() if weight is None else loss.sum() / weight.sum()
if reduction == "sum": return loss.sum() if reduction == "sum": return loss.sum()
@ -611,8 +606,8 @@ def SkipLayerNormalization(x:Tensor, skip:Tensor, gamma, beta:Optional[Tensor]=N
return x.layernorm(eps=epsilon) * gamma + beta, None, None, x return x.layernorm(eps=epsilon) * gamma + beta, None, None, x
def FastGelu(x:Tensor, bias:Optional[Tensor]=None): def FastGelu(x:Tensor, bias:Optional[Tensor]=None):
x = x + bias # this is tanh approamixated
return 0.5 * x * (1 + (x * 0.797885 + 0.035677 * x ** 3).tanh()) return (x + bias).gelu()
def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Optional[Tensor]=None, word_embedding:Tensor=None, position_embedding:Tensor=None, segment_embedding:Optional[Tensor]=None, gamma=None, beta=None, mask:Optional[Tensor]=None, position_ids:Optional[Tensor]=None, epsilon=None, mask_index_type=None): def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Optional[Tensor]=None, word_embedding:Tensor=None, position_embedding:Tensor=None, segment_embedding:Optional[Tensor]=None, gamma=None, beta=None, mask:Optional[Tensor]=None, position_ids:Optional[Tensor]=None, epsilon=None, mask_index_type=None):
# https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.EmbedLayerNormalization # https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.EmbedLayerNormalization