mirror of https://github.com/commaai/tinygrad.git
minor onnx_op cleanups to prep dtype changes (#2758)
read through it and clean some minor stuff
This commit is contained in:
parent
d8952fc575
commit
38da001b64
|
@ -1,17 +1,15 @@
|
|||
import functools, io, math, os
|
||||
from typing import Union, Tuple, Optional, List, Any
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import prod, dtypes, ImageDType, flatten
|
||||
from tinygrad.nn import Embedding
|
||||
from extra.onnx import safe_numpy
|
||||
from onnx.helper import tensor_dtype_to_np_dtype
|
||||
from onnx import TensorProto
|
||||
import io
|
||||
import os
|
||||
import numpy as np
|
||||
import functools
|
||||
from typing import Union, Tuple, Optional, List, Any
|
||||
import math
|
||||
|
||||
tensor_methods = {"Neg", "Reciprocal", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", "Relu", "Sigmoid", "Tanh", "MatMul",
|
||||
"Floor", "Ceil", "Tanh", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", "Softsign", "Asinh", "Acosh", "Atanh"}
|
||||
tensor_methods = {"Neg", "Reciprocal", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", "Relu", "Sigmoid", "MatMul",
|
||||
"Floor", "Ceil", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", "Tanh", "Softsign", "Asinh", "Acosh", "Atanh", "Elu"}
|
||||
|
||||
# **************** Free Ops ****************
|
||||
|
||||
|
@ -47,9 +45,9 @@ def Celu(x:Tensor, alpha=1.0): return x.celu(alpha)
|
|||
def Selu(X: Tensor, alpha=1.67326319217681884765625, gamma=1.05070102214813232421875): return gamma * (X.relu() - (-alpha*X.exp()+alpha).relu())
|
||||
def PRelu(X:Tensor, slope:Tensor):
|
||||
slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope # HACK OnnxBackendPyTorchConvertedModelTest HAS WEIRD SLOPE WHERE IT'S [0.25, 0.25, 0.25] FOR ANY X.SHAPE
|
||||
return X.clip(0, float("inf")) + X.clip(float("-inf"), 0) * slope
|
||||
return (X > 0).where(X, X * slope)
|
||||
def LeakyRelu(X: Tensor, alpha=0.01): return X.leakyrelu(alpha)
|
||||
def ThresholdedRelu(X: Tensor, alpha=1.0): return (X-alpha).relu() + (X-alpha).relu().sign() * alpha
|
||||
def ThresholdedRelu(X: Tensor, alpha=1.0): return (X > alpha).where(X, 0)
|
||||
def Softmax_1(input: Tensor, axis=1): return input.softmax(axis)
|
||||
def Softmax_13(input: Tensor, axis=-1): return input.softmax(axis)
|
||||
Softmax = {1: Softmax_1, 13: Softmax_13} # Softmax default axis changed
|
||||
|
@ -70,8 +68,8 @@ def ReduceLogSumExp(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0)
|
|||
|
||||
def GlobalAveragePool(X: Tensor): return X.mean(axis=tuple(range(2, len(X.shape))), keepdim=True)
|
||||
def GlobalMaxPool(X: Tensor): return X.max(axis=tuple(range(2, len(X.shape))), keepdim=True)
|
||||
def OptionalHasElement(x: Tensor=None): return Tensor(x is not None and x.numel() > 0, dtype=dtypes.bool)
|
||||
def OptionalGetElement(x: Tensor=None): return x if x is not None else Tensor([], dtype=dtypes.float32)
|
||||
def OptionalHasElement(x: Optional[Tensor]=None): return Tensor(x is not None and x.numel() > 0, dtype=dtypes.bool)
|
||||
def OptionalGetElement(x: Optional[Tensor]=None): return x if x is not None else Tensor([], dtype=dtypes.float32)
|
||||
|
||||
def Tile(input: Tensor, repeats): return input.repeat([int(x) for x in safe_numpy(repeats)])
|
||||
def Range(start: Tensor, limit, delta): return Tensor.arange(start=int(safe_numpy(start).item()), stop=int(safe_numpy(limit).item()), step=int(safe_numpy(delta).item())).cast(dtype=start.dtype)
|
||||
|
@ -80,10 +78,10 @@ def Size(data: Tensor): return prod(data if isinstance(data, list) else data.sha
|
|||
def Flatten(input: Tensor, axis=1): return input.reshape(prod((1,) + input.shape[0:axis]), -1)
|
||||
def Reshape(data: Tensor, shape: Tensor, allowzero=None): return data.reshape([int(x) if x != 0 else data.shape[i] for i,x in enumerate(safe_numpy(shape))])
|
||||
def Shrink(input: Tensor, bias=0.0, lambd=0.5): return (input < -lambd)*(input+bias) + (input > lambd)*(input-bias)
|
||||
def And(x:Tensor, y:Tensor): return (x==y).where(x, Tensor.zeros(*x.shape)).cast(dtypes.bool)
|
||||
def Or(x:Tensor, y:Tensor): return (x==y).where(x, Tensor.ones(*x.shape)).cast(dtypes.bool)
|
||||
def Xor(x:Tensor, y:Tensor): return (x==y).where(Tensor.zeros(*x.shape), Tensor.ones(*x.shape)).cast(dtypes.bool)
|
||||
def Not(x:Tensor): return (x==1).where(Tensor.zeros(*x.shape), Tensor.ones(*x.shape)).cast(dtypes.bool)
|
||||
def And(x:Tensor, y:Tensor): return (x==y).where(x, 0).cast(dtypes.bool)
|
||||
def Or(x:Tensor, y:Tensor): return (x==y).where(x, 1).cast(dtypes.bool)
|
||||
def Xor(x:Tensor, y:Tensor): return (x==y).where(0, 1).cast(dtypes.bool)
|
||||
def Not(x:Tensor): return x.where(0, 1).cast(dtypes.bool)
|
||||
|
||||
def Asin(x): return Atan(x / Tensor.sqrt(1 - x * x))
|
||||
def Acos(x: Tensor):
|
||||
|
@ -106,7 +104,7 @@ def Atan(y: Tensor):
|
|||
return (y < 0).where(-t3, t3)
|
||||
|
||||
def Trilu(x: Tensor, k: Union[Tensor, int]=0, upper=1):
|
||||
k = int(k.numpy().item()) if k != 0 else 0 # onnx passes k as a tensor int64 with one element, default is 0
|
||||
k = int(k.numpy().item()) if isinstance(k, Tensor) else 0 # onnx passes k as a tensor int64 with one element, default is 0
|
||||
return x.triu(k) if upper else x.tril(k)
|
||||
|
||||
def Squeeze(input: Tensor, axes):
|
||||
|
@ -131,11 +129,11 @@ def ArgMax(x: Tensor, axis=0, keepdims=1, select_last_index=0):
|
|||
return c.max(axis=axis,keepdim=keepdims).cast(dtypes.int64)
|
||||
def ArgMin(x, axis=0, keepdims=1, select_last_index=0): return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index)
|
||||
|
||||
def Elu(input: Tensor, alpha=1.0): return input.elu(alpha=alpha)
|
||||
def Concat(*inputs: List[Tensor], axis): return inputs[0].cat(*inputs[1:], dim=axis)
|
||||
def Transpose(input: Tensor, perm=None): return input.permute(order=list(range(len(input.shape))[::-1]) if perm is None else perm)
|
||||
|
||||
# NOTE: since we only have one type, this is valid!
|
||||
# TODO: fix this with dtypes
|
||||
def CastLike(input, target_type):
|
||||
assert isinstance(target_type, Tensor), "can only CastLike Tensor"
|
||||
return input
|
||||
|
@ -403,10 +401,8 @@ def _round(x:Tensor, n:float, equidistant_case = "round_down") -> Tensor:
|
|||
assert n <= 1, f"n:{n} shouldn't be larger than 1"
|
||||
b = x.cast(dtypes.int32).contiguous().cast(x.dtype)
|
||||
b = (b >= 0).where(b+n, b-n)
|
||||
if equidistant_case == "round_down":
|
||||
return (x > b).where(b+1-n, b-n)
|
||||
elif equidistant_case == "round_up":
|
||||
return (x >= b).where(b+1-n, b-n)
|
||||
if equidistant_case == "round_down": return (x > b).where(b+1-n, b-n)
|
||||
elif equidistant_case == "round_up": return (x >= b).where(b+1-n, b-n)
|
||||
elif equidistant_case == "round_to_even":
|
||||
x_ceil_fraction = x.ceil()/2
|
||||
cond_ceil_even = x_ceil_fraction.ceil() == x_ceil_fraction
|
||||
|
@ -649,16 +645,12 @@ def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Optional[Tensor]=None
|
|||
compute_seg_emb = (segment_embedding is not None and segment_ids is not None)
|
||||
vocab_size, max_position_embeddings, type_vocab_size = word_embedding.shape[0], position_embedding.shape[0], (segment_embedding.shape[0] if compute_seg_emb else None)
|
||||
|
||||
def embedding(x:Tensor, vocab_size, weight:Tensor)->Tensor: # TODO from nn.Embedding. Could probably upstream this to Tensor
|
||||
vocab_counter = Tensor.arange(vocab_size, dtype=x.dtype, requires_grad=False).reshape(1, 1, vocab_size).expand(*x.shape, vocab_size)
|
||||
return (vocab_counter == x.unsqueeze(2).expand(*x.shape, vocab_size)) @ weight
|
||||
|
||||
# bert embedding layer
|
||||
if epsilon is None: epsilon = 1e-12
|
||||
if position_ids is None: position_ids = Tensor.arange(seq_length, requires_grad=False).unsqueeze(0).expand(*input_shape)
|
||||
wrd_embedding_res = embedding(input_ids, vocab_size, word_embedding)
|
||||
pos_embedding_res = embedding(position_ids, max_position_embeddings, position_embedding)
|
||||
seg_embedding_res = embedding(segment_ids, type_vocab_size, segment_embedding) if compute_seg_emb else None
|
||||
wrd_embedding_res = Embedding(vocab_size, word_embedding)(input_ids)
|
||||
pos_embedding_res = Embedding(max_position_embeddings, position_embedding)(position_ids)
|
||||
seg_embedding_res = Embedding(type_vocab_size, segment_embedding)(segment_ids) if compute_seg_emb else None
|
||||
|
||||
embedding_sum = wrd_embedding_res + pos_embedding_res + seg_embedding_res
|
||||
out = embedding_sum.layernorm(eps=epsilon) * gamma + beta
|
||||
|
|
Loading…
Reference in New Issue