minor onnx_op cleanups to prep dtype changes (#2758)

read through it and clean some minor stuff
This commit is contained in:
chenyu 2023-12-14 03:05:59 -05:00 committed by GitHub
parent d8952fc575
commit 38da001b64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 20 additions and 28 deletions

View File

@ -1,17 +1,15 @@
import functools, io, math, os
from typing import Union, Tuple, Optional, List, Any
from tinygrad.tensor import Tensor
from tinygrad.helpers import prod, dtypes, ImageDType, flatten
from tinygrad.nn import Embedding
from extra.onnx import safe_numpy
from onnx.helper import tensor_dtype_to_np_dtype
from onnx import TensorProto
import io
import os
import numpy as np
import functools
from typing import Union, Tuple, Optional, List, Any
import math
tensor_methods = {"Neg", "Reciprocal", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", "Relu", "Sigmoid", "Tanh", "MatMul",
"Floor", "Ceil", "Tanh", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", "Softsign", "Asinh", "Acosh", "Atanh"}
tensor_methods = {"Neg", "Reciprocal", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", "Relu", "Sigmoid", "MatMul",
"Floor", "Ceil", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", "Tanh", "Softsign", "Asinh", "Acosh", "Atanh", "Elu"}
# **************** Free Ops ****************
@ -47,9 +45,9 @@ def Celu(x:Tensor, alpha=1.0): return x.celu(alpha)
def Selu(X: Tensor, alpha=1.67326319217681884765625, gamma=1.05070102214813232421875): return gamma * (X.relu() - (-alpha*X.exp()+alpha).relu())
def PRelu(X:Tensor, slope:Tensor):
slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope # HACK OnnxBackendPyTorchConvertedModelTest HAS WEIRD SLOPE WHERE IT'S [0.25, 0.25, 0.25] FOR ANY X.SHAPE
return X.clip(0, float("inf")) + X.clip(float("-inf"), 0) * slope
return (X > 0).where(X, X * slope)
def LeakyRelu(X: Tensor, alpha=0.01): return X.leakyrelu(alpha)
def ThresholdedRelu(X: Tensor, alpha=1.0): return (X-alpha).relu() + (X-alpha).relu().sign() * alpha
def ThresholdedRelu(X: Tensor, alpha=1.0): return (X > alpha).where(X, 0)
def Softmax_1(input: Tensor, axis=1): return input.softmax(axis)
def Softmax_13(input: Tensor, axis=-1): return input.softmax(axis)
Softmax = {1: Softmax_1, 13: Softmax_13} # Softmax default axis changed
@ -70,8 +68,8 @@ def ReduceLogSumExp(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0)
def GlobalAveragePool(X: Tensor): return X.mean(axis=tuple(range(2, len(X.shape))), keepdim=True)
def GlobalMaxPool(X: Tensor): return X.max(axis=tuple(range(2, len(X.shape))), keepdim=True)
def OptionalHasElement(x: Tensor=None): return Tensor(x is not None and x.numel() > 0, dtype=dtypes.bool)
def OptionalGetElement(x: Tensor=None): return x if x is not None else Tensor([], dtype=dtypes.float32)
def OptionalHasElement(x: Optional[Tensor]=None): return Tensor(x is not None and x.numel() > 0, dtype=dtypes.bool)
def OptionalGetElement(x: Optional[Tensor]=None): return x if x is not None else Tensor([], dtype=dtypes.float32)
def Tile(input: Tensor, repeats): return input.repeat([int(x) for x in safe_numpy(repeats)])
def Range(start: Tensor, limit, delta): return Tensor.arange(start=int(safe_numpy(start).item()), stop=int(safe_numpy(limit).item()), step=int(safe_numpy(delta).item())).cast(dtype=start.dtype)
@ -80,10 +78,10 @@ def Size(data: Tensor): return prod(data if isinstance(data, list) else data.sha
def Flatten(input: Tensor, axis=1): return input.reshape(prod((1,) + input.shape[0:axis]), -1)
def Reshape(data: Tensor, shape: Tensor, allowzero=None): return data.reshape([int(x) if x != 0 else data.shape[i] for i,x in enumerate(safe_numpy(shape))])
def Shrink(input: Tensor, bias=0.0, lambd=0.5): return (input < -lambd)*(input+bias) + (input > lambd)*(input-bias)
def And(x:Tensor, y:Tensor): return (x==y).where(x, Tensor.zeros(*x.shape)).cast(dtypes.bool)
def Or(x:Tensor, y:Tensor): return (x==y).where(x, Tensor.ones(*x.shape)).cast(dtypes.bool)
def Xor(x:Tensor, y:Tensor): return (x==y).where(Tensor.zeros(*x.shape), Tensor.ones(*x.shape)).cast(dtypes.bool)
def Not(x:Tensor): return (x==1).where(Tensor.zeros(*x.shape), Tensor.ones(*x.shape)).cast(dtypes.bool)
def And(x:Tensor, y:Tensor): return (x==y).where(x, 0).cast(dtypes.bool)
def Or(x:Tensor, y:Tensor): return (x==y).where(x, 1).cast(dtypes.bool)
def Xor(x:Tensor, y:Tensor): return (x==y).where(0, 1).cast(dtypes.bool)
def Not(x:Tensor): return x.where(0, 1).cast(dtypes.bool)
def Asin(x): return Atan(x / Tensor.sqrt(1 - x * x))
def Acos(x: Tensor):
@ -106,7 +104,7 @@ def Atan(y: Tensor):
return (y < 0).where(-t3, t3)
def Trilu(x: Tensor, k: Union[Tensor, int]=0, upper=1):
k = int(k.numpy().item()) if k != 0 else 0 # onnx passes k as a tensor int64 with one element, default is 0
k = int(k.numpy().item()) if isinstance(k, Tensor) else 0 # onnx passes k as a tensor int64 with one element, default is 0
return x.triu(k) if upper else x.tril(k)
def Squeeze(input: Tensor, axes):
@ -131,11 +129,11 @@ def ArgMax(x: Tensor, axis=0, keepdims=1, select_last_index=0):
return c.max(axis=axis,keepdim=keepdims).cast(dtypes.int64)
def ArgMin(x, axis=0, keepdims=1, select_last_index=0): return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index)
def Elu(input: Tensor, alpha=1.0): return input.elu(alpha=alpha)
def Concat(*inputs: List[Tensor], axis): return inputs[0].cat(*inputs[1:], dim=axis)
def Transpose(input: Tensor, perm=None): return input.permute(order=list(range(len(input.shape))[::-1]) if perm is None else perm)
# NOTE: since we only have one type, this is valid!
# TODO: fix this with dtypes
def CastLike(input, target_type):
assert isinstance(target_type, Tensor), "can only CastLike Tensor"
return input
@ -403,10 +401,8 @@ def _round(x:Tensor, n:float, equidistant_case = "round_down") -> Tensor:
assert n <= 1, f"n:{n} shouldn't be larger than 1"
b = x.cast(dtypes.int32).contiguous().cast(x.dtype)
b = (b >= 0).where(b+n, b-n)
if equidistant_case == "round_down":
return (x > b).where(b+1-n, b-n)
elif equidistant_case == "round_up":
return (x >= b).where(b+1-n, b-n)
if equidistant_case == "round_down": return (x > b).where(b+1-n, b-n)
elif equidistant_case == "round_up": return (x >= b).where(b+1-n, b-n)
elif equidistant_case == "round_to_even":
x_ceil_fraction = x.ceil()/2
cond_ceil_even = x_ceil_fraction.ceil() == x_ceil_fraction
@ -649,16 +645,12 @@ def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Optional[Tensor]=None
compute_seg_emb = (segment_embedding is not None and segment_ids is not None)
vocab_size, max_position_embeddings, type_vocab_size = word_embedding.shape[0], position_embedding.shape[0], (segment_embedding.shape[0] if compute_seg_emb else None)
def embedding(x:Tensor, vocab_size, weight:Tensor)->Tensor: # TODO from nn.Embedding. Could probably upstream this to Tensor
vocab_counter = Tensor.arange(vocab_size, dtype=x.dtype, requires_grad=False).reshape(1, 1, vocab_size).expand(*x.shape, vocab_size)
return (vocab_counter == x.unsqueeze(2).expand(*x.shape, vocab_size)) @ weight
# bert embedding layer
if epsilon is None: epsilon = 1e-12
if position_ids is None: position_ids = Tensor.arange(seq_length, requires_grad=False).unsqueeze(0).expand(*input_shape)
wrd_embedding_res = embedding(input_ids, vocab_size, word_embedding)
pos_embedding_res = embedding(position_ids, max_position_embeddings, position_embedding)
seg_embedding_res = embedding(segment_ids, type_vocab_size, segment_embedding) if compute_seg_emb else None
wrd_embedding_res = Embedding(vocab_size, word_embedding)(input_ids)
pos_embedding_res = Embedding(max_position_embeddings, position_embedding)(position_ids)
seg_embedding_res = Embedding(type_vocab_size, segment_embedding)(segment_ids) if compute_seg_emb else None
embedding_sum = wrd_embedding_res + pos_embedding_res + seg_embedding_res
out = embedding_sum.layernorm(eps=epsilon) * gamma + beta