mirror of https://github.com/commaai/tinygrad.git
some onnx_ops cleanups (#5094)
This commit is contained in:
parent
f4a041af16
commit
0c857ae2d6
|
@ -24,7 +24,7 @@ def Equal(x:Tensor,y:Tensor): return x == y
|
||||||
def Max(*data_0): return functools.reduce(Tensor.maximum, data_0)
|
def Max(*data_0): return functools.reduce(Tensor.maximum, data_0)
|
||||||
def Min(*data_0): return functools.reduce(Tensor.minimum, data_0)
|
def Min(*data_0): return functools.reduce(Tensor.minimum, data_0)
|
||||||
def Sum(*data_0): return functools.reduce(Tensor.add, data_0)
|
def Sum(*data_0): return functools.reduce(Tensor.add, data_0)
|
||||||
def Mean(*data_0): return functools.reduce(Tensor.add, data_0) / len(data_0)
|
def Mean(*data_0): return Sum(*data_0) / len(data_0)
|
||||||
# NOTE: does not support saturate
|
# NOTE: does not support saturate
|
||||||
def Cast(x: Tensor, to: int, saturate=1): return x.cast(DTYPE_MAP[to])
|
def Cast(x: Tensor, to: int, saturate=1): return x.cast(DTYPE_MAP[to])
|
||||||
def CastLike(x: Tensor, target_type: Tensor, saturate=1): return x.cast(target_type.dtype)
|
def CastLike(x: Tensor, target_type: Tensor, saturate=1): return x.cast(target_type.dtype)
|
||||||
|
@ -34,7 +34,7 @@ def CastLike(x: Tensor, target_type: Tensor, saturate=1): return x.cast(target_t
|
||||||
# https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_div.py
|
# https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_div.py
|
||||||
def Div(x: Tensor, other: Tensor): return (x/other).cast(x.dtype)
|
def Div(x: Tensor, other: Tensor): return (x/other).cast(x.dtype)
|
||||||
|
|
||||||
def Constant(value: Tensor=None, value_float=None, value_floats=None, value_int=None, value_ints=None, value_string=None, value_strings=None):
|
def Constant(value:Optional[Tensor]=None, value_float=None, value_floats=None, value_int=None, value_ints=None, value_string=None, value_strings=None):
|
||||||
if value is not None: return value
|
if value is not None: return value
|
||||||
if value_float is not None: return Tensor(value_float, dtype=dtypes.float32, requires_grad=False)
|
if value_float is not None: return Tensor(value_float, dtype=dtypes.float32, requires_grad=False)
|
||||||
if value_floats is not None: return Tensor(list(value_floats), dtype=dtypes.float32, requires_grad=False)
|
if value_floats is not None: return Tensor(list(value_floats), dtype=dtypes.float32, requires_grad=False)
|
||||||
|
@ -70,8 +70,8 @@ def ReduceL2(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): retur
|
||||||
def ReduceLogSum(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return ReduceSum(data, axes, keepdims, noop_with_empty_axes).log()
|
def ReduceLogSum(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return ReduceSum(data, axes, keepdims, noop_with_empty_axes).log()
|
||||||
def ReduceLogSumExp(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return ReduceSum(data.exp(), axes, keepdims, noop_with_empty_axes).log()
|
def ReduceLogSumExp(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return ReduceSum(data.exp(), axes, keepdims, noop_with_empty_axes).log()
|
||||||
|
|
||||||
def GlobalAveragePool(X: Tensor): return X.mean(axis=tuple(range(2, len(X.shape))), keepdim=True)
|
def GlobalAveragePool(X: Tensor): return X.mean(axis=tuple(range(2, X.ndim)), keepdim=True)
|
||||||
def GlobalMaxPool(X: Tensor): return X.max(axis=tuple(range(2, len(X.shape))), keepdim=True)
|
def GlobalMaxPool(X: Tensor): return X.max(axis=tuple(range(2, X.ndim)), keepdim=True)
|
||||||
def OptionalHasElement(x: Optional[Tensor]=None): return Tensor(x is not None and x.numel() > 0)
|
def OptionalHasElement(x: Optional[Tensor]=None): return Tensor(x is not None and x.numel() > 0)
|
||||||
def OptionalGetElement(x: Optional[Tensor]=None): return x if x is not None else Tensor([])
|
def OptionalGetElement(x: Optional[Tensor]=None): return x if x is not None else Tensor([])
|
||||||
|
|
||||||
|
@ -101,8 +101,8 @@ def Atan(y: Tensor):
|
||||||
t4 = t3 * t3
|
t4 = t3 * t3
|
||||||
t0 = ((((-0.013480470 * t4 + 0.057477314) * t4 - 0.121239071) * t4 + 0.195635925) * t4 - 0.332994597) * t4 + 0.999995630
|
t0 = ((((-0.013480470 * t4 + 0.057477314) * t4 - 0.121239071) * t4 + 0.195635925) * t4 - 0.332994597) * t4 + 0.999995630
|
||||||
t3 = t0 * t3
|
t3 = t0 * t3
|
||||||
t3 = (y.abs() > 1).where(1.570796327 - t3, t3)
|
t3 = (t1 > 1).where(1.570796327 - t3, t3)
|
||||||
return (y < 0).where(-t3, t3)
|
return y.sign() * t3
|
||||||
|
|
||||||
def Trilu(x: Tensor, k: Union[Tensor, int]=0, upper=1):
|
def Trilu(x: Tensor, k: Union[Tensor, int]=0, upper=1):
|
||||||
k = to_python_const(k) if isinstance(k, Tensor) else 0 # onnx passes k as a tensor int64 with one element, default is 0
|
k = to_python_const(k) if isinstance(k, Tensor) else 0 # onnx passes k as a tensor int64 with one element, default is 0
|
||||||
|
@ -110,15 +110,12 @@ def Trilu(x: Tensor, k: Union[Tensor, int]=0, upper=1):
|
||||||
|
|
||||||
def Squeeze(data: Tensor, axes):
|
def Squeeze(data: Tensor, axes):
|
||||||
if isinstance(axes, Tensor): axes = to_python_const(axes)
|
if isinstance(axes, Tensor): axes = to_python_const(axes)
|
||||||
axes = [x + data.ndim if x < 0 else x for x in axes]
|
axes = [data._resolve_dim(x) for x in axes]
|
||||||
return data.reshape([s for i,s in enumerate(data.shape) if i not in axes])
|
return data.reshape([s for i,s in enumerate(data.shape) if i not in axes])
|
||||||
def Unsqueeze(data: Tensor, axes):
|
def Unsqueeze(data: Tensor, axes):
|
||||||
axes = [x + data.ndim if x < 0 else x for x in to_python_const(axes)]
|
axes = sorted([x + data.ndim if x < 0 else x for x in to_python_const(axes)])
|
||||||
new_shape = [1] * (data.ndim + len(axes))
|
new_shape = list(data.shape)
|
||||||
ptr = iter(data.shape)
|
for axis in axes: new_shape.insert(axis, 1)
|
||||||
for i in range(len(new_shape)):
|
|
||||||
if i not in axes:
|
|
||||||
new_shape[i] = next(ptr)
|
|
||||||
return data.reshape(new_shape)
|
return data.reshape(new_shape)
|
||||||
|
|
||||||
def Binarizer(x, threshold=0.0): return (x > threshold).float()
|
def Binarizer(x, threshold=0.0): return (x > threshold).float()
|
||||||
|
@ -128,11 +125,11 @@ def ArgMax(x: Tensor, axis=0, keepdims=1, select_last_index=0):
|
||||||
return x.argmax(axis, keepdim=keepdims).cast(dtypes.int64)
|
return x.argmax(axis, keepdim=keepdims).cast(dtypes.int64)
|
||||||
def ArgMin(x, axis=0, keepdims=1, select_last_index=0): return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index)
|
def ArgMin(x, axis=0, keepdims=1, select_last_index=0): return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index)
|
||||||
|
|
||||||
def Concat(*xs: List[Tensor], axis): return xs[0].cat(*xs[1:], dim=axis)
|
def Concat(*xs: List[Tensor], axis): return Tensor.cat(*xs, dim=axis)
|
||||||
def Transpose(x: Tensor, perm=None): return x.permute(order=list(range(len(x.shape))[::-1]) if perm is None else perm)
|
def Transpose(x: Tensor, perm=None): return x.permute(order=list(range(x.ndim)[::-1]) if perm is None else perm)
|
||||||
|
|
||||||
def ConstantOfShape(x, value:Tensor=None):
|
def ConstantOfShape(x, value:Tensor=None):
|
||||||
if value is None: value=Tensor([0.0])
|
if value is None: value = 0.0
|
||||||
shape = to_python_const(x)
|
shape = to_python_const(x)
|
||||||
return Tensor.ones(*shape, dtype=value.dtype) * (value if shape[0]!=0 else 1)
|
return Tensor.ones(*shape, dtype=value.dtype) * (value if shape[0]!=0 else 1)
|
||||||
|
|
||||||
|
@ -171,11 +168,11 @@ def BatchNormalization(X: Tensor, scale, B, input_mean, input_var, epsilon=1e-05
|
||||||
running_var = input_var * momentum + current_var * (1 - momentum)
|
running_var = input_var * momentum + current_var * (1 - momentum)
|
||||||
|
|
||||||
return X.batchnorm(scale, B, current_mean, current_invstd), running_mean, running_var
|
return X.batchnorm(scale, B, current_mean, current_invstd), running_mean, running_var
|
||||||
invstd = (input_var + epsilon)**-0.5
|
invstd = (input_var + epsilon).rsqrt()
|
||||||
return X.batchnorm(scale, B, input_mean, invstd)
|
return X.batchnorm(scale, B, input_mean, invstd)
|
||||||
|
|
||||||
def InstanceNormalization(x: Tensor, scale: Tensor, bias: Tensor, epsilon=1e-05):
|
def InstanceNormalization(x: Tensor, scale: Tensor, bias: Tensor, epsilon=1e-05):
|
||||||
axis = tuple(range(2, len(x.shape)))
|
axis = tuple(range(2, x.ndim))
|
||||||
mean = x.mean(axis=axis, keepdim=True)
|
mean = x.mean(axis=axis, keepdim=True)
|
||||||
invstd = x.sub(mean).pow(2).mean(axis=axis, keepdim=True).add(epsilon).rsqrt()
|
invstd = x.sub(mean).pow(2).mean(axis=axis, keepdim=True).add(epsilon).rsqrt()
|
||||||
return x.sub(mean).mul(scale.reshape(shape=[-1, 1, 1])).mul(invstd).add(bias.reshape(shape=[-1, 1, 1]))
|
return x.sub(mean).mul(scale.reshape(shape=[-1, 1, 1])).mul(invstd).add(bias.reshape(shape=[-1, 1, 1]))
|
||||||
|
@ -186,8 +183,6 @@ def LayerNormalization(x: Tensor, scale, bias, axis=-1, epsilon=1e-05, stash_typ
|
||||||
mean = x.mean(axis=axis, keepdim=True)
|
mean = x.mean(axis=axis, keepdim=True)
|
||||||
return x.layernorm(axis, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).pow(2).mean(axis=axis, keepdim=True).add(epsilon).rsqrt()
|
return x.layernorm(axis, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).pow(2).mean(axis=axis, keepdim=True).add(epsilon).rsqrt()
|
||||||
|
|
||||||
# TODO: current implmentation fails tests and tried copying onnx's implementation but got poor accuracy
|
|
||||||
# https://github.com/onnx/onnx/blob/main/onnx/backend/test/case/node/groupnormalization.py#L13
|
|
||||||
def GroupNormalization(x: Tensor, scale: Tensor, bias: Tensor, num_groups, epsilon=1e-05):
|
def GroupNormalization(x: Tensor, scale: Tensor, bias: Tensor, num_groups, epsilon=1e-05):
|
||||||
return x.reshape(x.shape[0], num_groups, -1).layernorm(axis=-1, eps=epsilon).mul(scale.unsqueeze(-1)).add(bias.unsqueeze(-1)).reshape(x.shape)
|
return x.reshape(x.shape[0], num_groups, -1).layernorm(axis=-1, eps=epsilon).mul(scale.unsqueeze(-1)).add(bias.unsqueeze(-1)).reshape(x.shape)
|
||||||
|
|
||||||
|
@ -342,7 +337,7 @@ def LRN(x: Tensor, size, alpha=1e-4, beta=0.75, bias=1.0):
|
||||||
return x / x.mul(x).reshape(bs,1,c,iy*ix).pad2d((0,0,(size-1)//2, size//2)).avg_pool2d((size, 1), 1).reshape(bs,c,iy,ix).mul(alpha).add(bias).pow(beta)
|
return x / x.mul(x).reshape(bs,1,c,iy*ix).pad2d((0,0,(size-1)//2, size//2)).avg_pool2d((size, 1), 1).reshape(bs,c,iy,ix).mul(alpha).add(bias).pow(beta)
|
||||||
|
|
||||||
def MeanVarianceNormalization(x: Tensor, axis=(0, 2, 3)):
|
def MeanVarianceNormalization(x: Tensor, axis=(0, 2, 3)):
|
||||||
mean = x.mean(axis=axis, keepdim=True)
|
mean = x.mean(axis, keepdim=True)
|
||||||
std = x.std(axis, keepdim=True, correction=0)
|
std = x.std(axis, keepdim=True, correction=0)
|
||||||
return (x - mean) / (std + 1e-9)
|
return (x - mean) / (std + 1e-9)
|
||||||
|
|
||||||
|
@ -358,7 +353,7 @@ def NegativeLogLikelihoodLoss(x: Tensor, target: Tensor, weight=None, ignore_ind
|
||||||
if ignore_index is not None:
|
if ignore_index is not None:
|
||||||
cond = target == ignore_index
|
cond = target == ignore_index
|
||||||
weight = cond.where(0, weight) if weight is not None else cond.where(0, 1)
|
weight = cond.where(0, weight) if weight is not None else cond.where(0, 1)
|
||||||
mask = target[:, None, :] == Tensor.arange(C).reshape([1, C] + [1]*(len(x.shape) -2))
|
mask = target[:, None, :] == Tensor.arange(C).reshape([1, C] + [1]*(x.ndim -2))
|
||||||
loss = -(mask * x).sum(axis=1) * (1 if weight is None else weight)
|
loss = -(mask * x).sum(axis=1) * (1 if weight is None else weight)
|
||||||
if reduction == "mean": return loss.mean() if weight is None else loss.sum() / weight.sum()
|
if reduction == "mean": return loss.mean() if weight is None else loss.sum() / weight.sum()
|
||||||
if reduction == "sum": return loss.sum()
|
if reduction == "sum": return loss.sum()
|
||||||
|
@ -611,8 +606,8 @@ def SkipLayerNormalization(x:Tensor, skip:Tensor, gamma, beta:Optional[Tensor]=N
|
||||||
return x.layernorm(eps=epsilon) * gamma + beta, None, None, x
|
return x.layernorm(eps=epsilon) * gamma + beta, None, None, x
|
||||||
|
|
||||||
def FastGelu(x:Tensor, bias:Optional[Tensor]=None):
|
def FastGelu(x:Tensor, bias:Optional[Tensor]=None):
|
||||||
x = x + bias
|
# this is tanh approamixated
|
||||||
return 0.5 * x * (1 + (x * 0.797885 + 0.035677 * x ** 3).tanh())
|
return (x + bias).gelu()
|
||||||
|
|
||||||
def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Optional[Tensor]=None, word_embedding:Tensor=None, position_embedding:Tensor=None, segment_embedding:Optional[Tensor]=None, gamma=None, beta=None, mask:Optional[Tensor]=None, position_ids:Optional[Tensor]=None, epsilon=None, mask_index_type=None):
|
def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Optional[Tensor]=None, word_embedding:Tensor=None, position_embedding:Tensor=None, segment_embedding:Optional[Tensor]=None, gamma=None, beta=None, mask:Optional[Tensor]=None, position_ids:Optional[Tensor]=None, epsilon=None, mask_index_type=None):
|
||||||
# https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.EmbedLayerNormalization
|
# https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.EmbedLayerNormalization
|
||||||
|
|
Loading…
Reference in New Issue