2023-07-06 09:45:50 +08:00
from tinygrad . tensor import Tensor
2023-11-21 01:51:03 +08:00
from tinygrad . helpers import prod , dtypes , ImageDType
2023-07-06 09:45:50 +08:00
from extra . onnx import safe_numpy
2023-11-21 01:51:03 +08:00
from onnx . helper import tensor_dtype_to_np_dtype
from onnx . onnx_pb import TensorProto
import os
2023-07-06 09:45:50 +08:00
import numpy as np
import functools
2023-11-21 01:51:03 +08:00
from typing import Union , Tuple , Optional , List , Any
import math
2023-07-06 09:45:50 +08:00
2023-11-21 01:51:03 +08:00
# **************** Free Ops ****************
def Identity ( input : Tensor ) : return input
def Neg ( input : Tensor ) : return - input
def Add ( input : Tensor , other : Tensor , broadcast = None ) : return input + other if input . dtype == dtypes . float or isinstance ( input . dtype , ImageDType ) else ( input + other ) . cast ( input . dtype )
def Sub ( input : Union [ Tensor , Any ] , other : Tensor ) : return input - other # some test has input as int
def Mul ( input : Tensor , other : Tensor ) : return ( input * other ) if input . dtype == dtypes . float or isinstance ( input . dtype , ImageDType ) else ( input * other ) . cast ( input . dtype )
# in openpilot, due to SHUFFLE_PAD_OPS issues, we are spending an extra kernel
def Div ( input : Tensor , other : Tensor ) : return input / other if input . dtype == dtypes . float or isinstance ( input . dtype , ImageDType ) else input . div ( other ) . floor ( )
def Pow ( input : Tensor , other : Tensor ) : return ( input . float ( ) * * other . float ( ) ) . cast ( input . dtype )
def Reciprocal ( input : Tensor ) : return input . reciprocal ( )
def Sqrt ( input : Tensor ) : return input . sqrt ( )
def Sign ( input : Tensor ) : return input . sign ( )
def Abs ( input : Tensor ) : return input . abs ( )
def Exp ( input : Tensor ) : return input . exp ( )
def Log ( input : Tensor ) : return input . log ( )
def Mish ( input : Tensor ) : return input . mish ( )
def Sin ( x : Tensor ) : return x . sin ( )
def Cos ( x : Tensor ) : return x . cos ( )
def Tan ( x : Tensor ) : return x . tan ( )
def Relu ( input : Tensor ) : return input . relu ( )
def Sigmoid ( input : Tensor ) : return input . sigmoid ( )
def Tanh ( input : Tensor ) : return input . tanh ( )
def MatMul ( input : Tensor , other : Tensor ) : return input . matmul ( other )
def Floor ( x : Tensor ) : return x . floor ( )
def Ceil ( x : Tensor ) : return x . ceil ( )
def Less ( x : Tensor , y : Tensor ) : return ( x < y ) . cast ( dtypes . bool )
def LessOrEqual ( x : Tensor , y : Tensor ) : return ( x < = y ) . cast ( dtypes . bool )
def Greater ( x : Tensor , y : Tensor ) : return ( x > y ) . cast ( dtypes . bool )
def GreaterOrEqual ( x : Tensor , y : Tensor ) : return ( x > = y ) . cast ( dtypes . bool )
def Equal ( x : Tensor , y : Tensor ) : return ( x == y ) . cast ( dtypes . bool )
def Max ( * data_0 ) : return functools . reduce ( Tensor . maximum , data_0 )
def Min ( * data_0 ) : return functools . reduce ( Tensor . minimum , data_0 )
def Sum ( * data_0 ) : return functools . reduce ( Tensor . __add__ , data_0 )
def Mean ( * data_0 ) : return functools . reduce ( Tensor . __add__ , data_0 ) / len ( data_0 )
def Where ( condition : Tensor , X : Tensor , Y : Tensor ) : return condition . where ( X , Y ) . cast ( X . dtype )
def Cast ( input : Tensor , to ) : return input . cast ( dtypes . from_np ( tensor_dtype_to_np_dtype ( to ) ) )
# **************** Simple Ops ****************
def Constant ( value : Tensor = None , value_float = None , value_floats = None , value_int = None , value_ints = None , value_string = None , value_strings = None ) :
if value : return value
elif value_float : return Tensor ( value_float , dtype = dtypes . float32 , requires_grad = False )
elif value_floats : return Tensor ( list ( value_floats ) , dtype = dtypes . float32 , requires_grad = False )
elif value_int : return Tensor ( value_int , dtype = dtypes . int64 , requires_grad = False )
elif value_ints : return Tensor ( list ( value_ints ) , dtype = dtypes . int64 , requires_grad = False )
elif value_string or value_strings : raise NotImplementedError ( f ' value_string or value_strings not implemented for Constant op ' )
def Softsign ( input : Tensor ) : return input / ( 1 + input . abs ( ) )
def Cosh ( x ) : return ( math . e * * x + math . e * * - x ) / 2
def Sinh ( x ) : return ( math . e * * x - math . e * * - x ) / 2
def Tanh ( x ) : return x . tanh ( )
def HardSigmoid ( input : Tensor , alpha = 0.2 , beta = 0.5 ) : return ( alpha * input + beta ) . clip ( 0 , 1 )
def HardSwish ( input : Tensor ) : return input * HardSigmoid ( input , 1 / 6 , 0.5 )
def Celu ( X : Tensor , alpha = 1.0 ) : return X . relu ( ) - ( - alpha * ( X / alpha ) . exp ( ) + 1 ) . relu ( )
def Selu ( X : Tensor , alpha = 1.67326319217681884765625 , gamma = 1.05070102214813232421875 ) : return gamma * ( X . relu ( ) - ( - alpha * X . exp ( ) + alpha ) . relu ( ) )
def Softplus ( X : Tensor ) : return X . softplus ( )
def PRelu ( X : Tensor , slope : Tensor ) :
slope = slope [ 0 ] if slope . shape [ - 1 ] != X . shape [ - 1 ] else slope # HACK OnnxBackendPyTorchConvertedModelTest HAS WEIRD SLOPE WHERE IT'S [0.25, 0.25, 0.25] FOR ANY X.SHAPE
return X . clip ( 0 , float ( " inf " ) ) + X . clip ( float ( " -inf " ) , 0 ) * slope
def LeakyRelu ( X : Tensor , alpha = 0.01 ) : return X . leakyrelu ( alpha )
def ThresholdedRelu ( X : Tensor , alpha = 1.0 ) : return ( X - alpha ) . relu ( ) + ( X - alpha ) . relu ( ) . sign ( ) * alpha
def Softmax_1 ( input : Tensor , axis = 1 ) : return input . softmax ( axis )
def Softmax_13 ( input : Tensor , axis = - 1 ) : return input . softmax ( axis )
Softmax = { 1 : Softmax_1 , 13 : Softmax_13 } # Softmax default axis changed
def LogSoftmax ( input : Tensor , axis = - 1 ) : return input . log_softmax ( axis )
def Clip ( input : Tensor , min = None , max = None ) : return input . clip ( float ( ' -inf ' ) if min is None else min , float ( ' inf ' ) if max is None else max )
# NOTE ReduceProd would require a new llop
def _axes ( axes , noop_with_empty_axes ) : return [ int ( x ) for x in safe_numpy ( axes ) ] if axes is not None and not ( isinstance ( axes , Tensor ) and axes . shape == ( 0 , ) ) else ( [ ] if noop_with_empty_axes else None )
def ReduceMax ( data : Tensor , axes = None , keepdims = 1 , noop_with_empty_axes = 0 ) : return data . max ( _axes ( axes , noop_with_empty_axes ) , keepdim = keepdims )
def ReduceMin ( data : Tensor , axes = None , keepdims = 1 , noop_with_empty_axes = 0 ) : return data . min ( _axes ( axes , noop_with_empty_axes ) , keepdim = keepdims )
def ReduceSum ( data : Tensor , axes = None , keepdims = 1 , noop_with_empty_axes = 0 ) : return data . sum ( _axes ( axes , noop_with_empty_axes ) , keepdim = keepdims )
def ReduceMean ( data : Tensor , axes = None , keepdims = 1 , noop_with_empty_axes = 0 ) : return data . mean ( _axes ( axes , noop_with_empty_axes ) , keepdim = keepdims )
def ReduceSumSquare ( data : Tensor , axes = None , keepdims = 1 , noop_with_empty_axes = 0 ) : return data . square ( ) . sum ( _axes ( axes , noop_with_empty_axes ) , keepdim = keepdims )
def ReduceL1 ( data : Tensor , axes = None , keepdims = 1 , noop_with_empty_axes = 0 ) : return data . abs ( ) . sum ( _axes ( axes , noop_with_empty_axes ) , keepdim = keepdims )
def ReduceL2 ( data : Tensor , axes = None , keepdims = 1 , noop_with_empty_axes = 0 ) : return data . square ( ) . sum ( _axes ( axes , noop_with_empty_axes ) , keepdim = keepdims ) . sqrt ( )
def ReduceLogSum ( data : Tensor , axes = None , keepdims = 1 , noop_with_empty_axes = 0 ) : return data . sum ( _axes ( axes , noop_with_empty_axes ) , keepdim = keepdims ) . log ( )
def ReduceLogSumExp ( data : Tensor , axes = None , keepdims = 1 , noop_with_empty_axes = 0 ) : return data . exp ( ) . sum ( _axes ( axes , noop_with_empty_axes ) , keepdim = keepdims ) . log ( )
def GlobalAveragePool ( X : Tensor ) : return X . mean ( axis = tuple ( range ( 2 , len ( X . shape ) ) ) , keepdim = True )
def GlobalMaxPool ( X : Tensor ) : return X . max ( axis = tuple ( range ( 2 , len ( X . shape ) ) ) , keepdim = True )
def OptionalHasElement ( x : Tensor = None ) : return Tensor ( x is not None and x . numel ( ) > 0 , dtype = dtypes . bool )
def OptionalGetElement ( x : Tensor = None ) : return x if x is not None else Tensor ( [ ] , dtype = dtypes . float32 )
def Tile ( input : Tensor , repeats ) : return input . repeat ( [ int ( x ) for x in safe_numpy ( repeats ) ] )
def Range ( start : Tensor , limit , delta ) : return Tensor . arange ( start = int ( safe_numpy ( start ) ) , stop = int ( safe_numpy ( limit ) ) , step = int ( safe_numpy ( delta ) ) ) . cast ( dtype = start . dtype )
def Shape ( data : Tensor , end = None , start = 0 ) : return Tensor ( list ( data . shape ) [ start : end ] , dtype = dtypes . int32 if os . path . isfile ( " /TICI " ) else dtypes . int64 ) # TODO: really?
def Size ( data : Tensor ) : return prod ( data if isinstance ( data , list ) else data . shape )
def Flatten ( input : Tensor , axis = 1 ) : return input . reshape ( prod ( ( 1 , ) + input . shape [ 0 : axis ] ) , - 1 )
def Reshape ( data : Tensor , shape : Tensor , allowzero = None ) : return data . reshape ( [ int ( x ) if x != 0 else data . shape [ i ] for i , x in enumerate ( safe_numpy ( shape ) ) ] )
def Shrink ( input : Tensor , bias = 0.0 , lambd = 0.5 ) : return ( input < - lambd ) * ( input + bias ) + ( input > lambd ) * ( input - bias )
def And ( x : Tensor , y : Tensor ) : return Where ( ( x == y ) , x , Tensor . zeros ( * x . shape ) ) . cast ( dtypes . bool )
def Or ( x : Tensor , y : Tensor ) : return Where ( ( x == y ) , x , Tensor . ones ( * x . shape ) ) . cast ( dtypes . bool )
def Xor ( x : Tensor , y : Tensor ) : return Where ( ( x == y ) , Tensor . zeros ( * x . shape ) , Tensor . ones ( * x . shape ) ) . cast ( dtypes . bool )
def Not ( x : Tensor ) : return Where ( ( x == 1 ) , Tensor . zeros ( * x . shape ) , Tensor . ones ( * x . shape ) ) . cast ( dtypes . bool )
def Asin ( x ) : return Atan ( x / Tensor . sqrt ( 1 - x * x ) )
def Asinh ( x ) : return Tensor . log ( x + Tensor . sqrt ( x * x + 1 ) )
def Acosh ( x ) : return Tensor . log ( x + Tensor . sqrt ( x * x - 1 ) )
def Atanh ( x ) : return 0.5 * Tensor . log ( ( 1 + x ) / ( 1 - x ) )
def Acos ( x : Tensor ) :
negate = ( x < 0 )
x = x . abs ( )
ret = ( ( ( ( - 0.0187293 * x ) + 0.0742610 ) * x - 0.2121144 ) * x + 1.5707288 ) * Tensor . sqrt ( 1.0 - x )
ret = ret - 2 * negate * ret
return negate * 3.14159265358979 + ret
def Atan ( y : Tensor ) :
x = Tensor . ones ( y . shape )
t3 = x
t1 = y . abs ( )
t0 = ( t3 > t1 ) . where ( t3 , t1 )
t1 = ( t3 < t1 ) . where ( t3 , t1 )
t3 = t1 / t0
t4 = t3 * t3
t0 = ( ( ( ( - 0.013480470 * t4 + 0.057477314 ) * t4 - 0.121239071 ) * t4 + 0.195635925 ) * t4 - 0.332994597 ) * t4 + 0.999995630
t3 = t0 * t3
t3 = ( y . abs ( ) > x . abs ( ) ) . where ( 1.570796327 - t3 , t3 )
return ( y < 0 ) . where ( - t3 , t3 )
def Trilu ( x : Tensor , k : Union [ Tensor , int ] = 0 , upper = 1 ) :
k = int ( k . numpy ( ) . item ( ) ) if k != 0 else 0 # onnx passes k as a tensor int64 with one element, default is 0
return x . triu ( k ) if upper else x . tril ( k )
def Squeeze ( input : Tensor , axes ) :
if isinstance ( axes , Tensor ) : axes = safe_numpy ( axes )
axes = [ int ( x ) if x > = 0 else int ( x + input . ndim ) for x in axes ]
return input . reshape ( [ s for i , s in enumerate ( input . shape ) if i not in axes ] )
def Unsqueeze ( data : Tensor , axes ) :
2023-07-06 09:45:50 +08:00
axes = [ len ( data . shape ) + int ( x ) if x < 0 else int ( x ) for x in safe_numpy ( axes ) ]
2023-11-21 01:51:03 +08:00
new_shape = [ 1 ] * ( len ( data . shape ) + len ( axes ) )
ptr = iter ( data . shape )
for i in range ( len ( new_shape ) ) :
if i not in axes :
new_shape [ i ] = next ( ptr )
2023-07-06 09:45:50 +08:00
return data . reshape ( new_shape )
2023-11-21 01:51:03 +08:00
def Binarizer ( input , threshold = 0.0 ) : return input > threshold
def ArgMax ( x : Tensor , axis = 0 , keepdims = 1 , select_last_index = 0 ) :
axis = axis + x . ndim if axis < 0 else axis
m = x == ( x . max ( axis = axis , keepdim = keepdims ) if keepdims else x . max ( axis = axis , keepdim = keepdims ) . unsqueeze ( axis ) )
c = Tensor . arange ( x . shape [ axis ] ) . reshape ( * [ 1 ] * ( axis ) , x . shape [ axis ] , * [ 1 ] * ( x . ndim - axis - 1 ) ) * m
return c . max ( axis = axis , keepdim = keepdims ) . cast ( dtypes . int64 )
def ArgMin ( x , axis = 0 , keepdims = 1 , select_last_index = 0 ) : return ArgMax ( - x , axis = axis , keepdims = keepdims , select_last_index = select_last_index )
def Elu ( input : Tensor , alpha = 1.0 ) : return input . elu ( alpha = alpha )
def Concat ( * inputs : List [ Tensor ] , axis ) : return inputs [ 0 ] . cat ( * inputs [ 1 : ] , dim = axis )
def Transpose ( input : Tensor , perm = None ) : return input . permute ( order = list ( range ( len ( input . shape ) ) [ : : - 1 ] ) if perm is None else perm )
# NOTE: since we only have one type, this is valid!
def CastLike ( input , target_type ) :
assert isinstance ( target_type , Tensor ) , " can only CastLike Tensor "
return input
def ConstantOfShape ( input , value : Tensor = None ) :
if value is None : value = Tensor ( [ 0.0 ] )
shape = [ int ( x ) for x in safe_numpy ( input ) ]
return Tensor . ones ( * shape , dtype = value . dtype ) * ( value if shape [ 0 ] != 0 else 1 )
# TODO: abstract out the broadcast logic in tensor
def Expand ( input : Tensor , shape ) :
x_shape , y_shape = input . shape , [ int ( x ) for x in safe_numpy ( shape ) ]
# copied from _broadcasted
x_shape , y_shape = [ ( [ 1 ] * ( max ( len ( x_shape ) , len ( y_shape ) ) - len ( t_shape ) ) + list ( t_shape ) ) for t_shape in [ x_shape , y_shape ] ]
shape_ret = tuple ( max ( sx , sy ) for sx , sy in zip ( x_shape , y_shape ) )
return input . reshape ( x_shape ) . expand ( shape_ret )
# **************** Complex Ops ****************
def Gemm ( A : Tensor , B : Tensor , C : Tensor = None , alpha = 1.0 , beta = 1.0 , transA = 0 , transB = 0 , broadcast = 0 ) :
ret = alpha * ( A . transpose ( transA ) @ B . transpose ( transB ) )
if C is not None : ret + = beta * ( C if broadcast == 0 else C . reshape ( [ - 1 if i < len ( C . shape ) else 1 for i in range ( len ( ret . shape ) ) ] [ : : - 1 ] ) )
2023-07-06 09:45:50 +08:00
return ret
2023-11-21 01:51:03 +08:00
# works with Tensors.ndim != 4
def _batchnorm ( self : Tensor , weight : Optional [ Tensor ] , bias : Optional [ Tensor ] , mean : Tensor , invstd : Tensor ) :
shape = [ 1 , - 1 ] + [ 1 ] * ( self . ndim - 2 )
x = ( self - mean . reshape ( shape = shape ) )
if weight : x = x * weight . reshape ( shape = shape )
ret = x . mul ( invstd . reshape ( shape = shape ) if len ( invstd . shape ) == 1 else invstd )
return ( ret + bias . reshape ( shape = shape ) ) if bias else ret
2023-07-06 09:45:50 +08:00
# TODO: this is copied from tinygrad/nn/__init__.py
# spatial is from opset 7 and has since been removed
2023-11-21 01:51:03 +08:00
def BatchNormalization ( X : Tensor , scale , B , input_mean , input_var , epsilon = 1e-05 , momentum = 0.9 , training_mode = 0 , spatial = 1 , is_test = 0 ) :
2023-07-06 09:45:50 +08:00
if training_mode :
x_detached = X . detach ( )
current_mean = x_detached . mean ( axis = ( 0 , 2 , 3 ) )
y = ( x_detached - current_mean . reshape ( shape = [ 1 , - 1 , 1 , 1 ] ) )
current_var = ( y * y ) . mean ( axis = ( 0 , 2 , 3 ) )
current_invstd = current_var . add ( epsilon ) . pow ( - 0.5 )
running_mean = input_mean * momentum + current_mean * ( 1 - momentum )
running_var = input_var * momentum + current_var * ( 1 - momentum )
2023-11-21 01:51:03 +08:00
return _batchnorm ( X , scale , B , current_mean , current_invstd ) , running_mean , running_var
2023-07-06 09:45:50 +08:00
else :
invstd = ( input_var + epsilon ) * * - 0.5
2023-11-21 01:51:03 +08:00
return _batchnorm ( X , scale , B , input_mean , invstd )
def InstanceNormalization ( x : Tensor , scale : Tensor , bias : Tensor , epsilon = 1e-05 ) :
axis = tuple ( range ( 2 , len ( x . shape ) ) )
mean = x . mean ( axis = axis , keepdim = True )
invstd = x . sub ( mean ) . pow ( 2 ) . mean ( axis = axis , keepdim = True ) . add ( epsilon ) . pow ( - 0.5 )
return x . sub ( mean ) . mul ( scale . reshape ( shape = [ - 1 , 1 , 1 ] ) ) . mul ( invstd ) . add ( bias . reshape ( shape = [ - 1 , 1 , 1 ] ) )
2023-07-06 09:45:50 +08:00
def LayerNormalization ( x : Tensor , scale , bias , axis = - 1 , epsilon = 1e-05 , stash_type = 1 ) :
assert stash_type == 1 , " only float32 is supported "
axis = tuple ( i for i in range ( axis if axis > = 0 else len ( x . shape ) + axis , len ( x . shape ) ) )
mean = x . mean ( axis = axis , keepdim = True )
return x . layernorm ( axis , epsilon ) . mul ( scale ) . add ( bias ) , mean , ( x . sub ( mean ) ) . pow ( 2 ) . mean ( axis = axis , keepdim = True ) . add ( epsilon ) . sqrt ( ) . reciprocal ( )
def GroupNormalization ( x : Tensor , scale : Tensor , bias : Tensor , num_groups , epsilon = 1e-05 ) :
return x . reshape ( x . shape [ 0 ] , num_groups , - 1 ) . layernorm ( axis = - 1 , eps = epsilon ) . mul ( scale . unsqueeze ( - 1 ) ) . add ( bias . unsqueeze ( - 1 ) ) . reshape ( x . shape )
2023-11-21 01:51:03 +08:00
# onnx: [x1_begin, x2_begin, ..., x1_end, x2_end, ...]
# numpy.pad: ((x1_begin, x1_end), (x2_begin, x2_end), ...)
def _format_padding ( onnx_pads , ndims = None , axes = None ) :
if ndims and len ( onnx_pads ) / / 2 != ndims : onnx_pads = onnx_pads * ndims # for OnnxBackendPyTorchConvertedModelTest the len(onnx_pads) == 2
if ndims is None : ndims = len ( onnx_pads ) / / 2
if axes is None : axes = list ( range ( ndims ) )
num_axes = len ( axes )
np_pads = [ ( 0 , 0 ) ] * ndims
for i in range ( num_axes ) :
np_pads [ axes [ i ] ] = ( onnx_pads [ i ] , onnx_pads [ i + num_axes ] )
return np_pads
def _padding ( X : Tensor , pads = None , auto_pad = " NOTSET " , axes = None , constant_value = 0. , strides = None , kernel_shape = None , dilations = None ) :
if auto_pad != " NOTSET " : pads = _auto_pad ( X , auto_pad , strides , kernel_shape , dilations )
if pads is None : return X
pads = _format_padding ( pads , ndims = len ( X . shape ) , axes = axes )
return X . pad ( tuple ( pads ) , value = constant_value )
def _auto_pad ( X , auto_pad , strides , kernel_shape , dilations ) :
strides = [ strides ] * len ( kernel_shape ) if isinstance ( strides , int ) else strides if strides else [ 1 ] * len ( kernel_shape )
dilations = [ 1 ] * len ( kernel_shape ) if dilations == 1 else dilations
pad_shape = [ ( math . ceil ( sh / st ) - 1 ) * st + ( ( ks - 1 ) * di + 1 ) - sh for sh , st , ks , di in zip ( X . shape [ - len ( strides ) : ] , strides , kernel_shape , dilations ) ]
if auto_pad == " SAME_UPPER " : return [ pad_shape [ 0 ] / / 2 , pad_shape [ 1 ] / / 2 , pad_shape [ 0 ] - pad_shape [ 0 ] / / 2 , pad_shape [ 1 ] - pad_shape [ 1 ] / / 2 ]
elif auto_pad == " SAME_LOWER " : return [ pad_shape [ 0 ] - pad_shape [ 0 ] / / 2 , pad_shape [ 1 ] - pad_shape [ 1 ] / / 2 , pad_shape [ 0 ] / / 2 , pad_shape [ 1 ] / / 2 ]
else : raise NotImplementedError ( f " auto_pad= { auto_pad } not implemented, yet " )
def Pad ( x : Tensor , pads : Union [ Tensor , Tuple [ int , . . . ] ] , constant_value : Tensor = None , axes : Tensor = None , mode = " constant " , value : float = 0. ) :
constant_value = value if constant_value is None else float ( safe_numpy ( constant_value ) [ 0 ] )
seq_pads = list ( pads ) if isinstance ( pads , tuple ) else safe_numpy ( pads )
seq_pads = [ math . ceil ( i ) for i in seq_pads ]
seq_axes = safe_numpy ( axes ) . astype ( np . int32 ) . tolist ( ) if axes is not None else None
base_shape = x . shape
pads = _format_padding ( seq_pads , ndims = len ( x . shape ) , axes = seq_axes )
if mode == " wrap " :
repeat_args = [ math . ceil ( dim [ 0 ] / sh ) + math . ceil ( dim [ 1 ] / sh ) + 1 for dim , sh in zip ( pads , base_shape ) ]
new_shape = [ s * r for s , r in zip ( base_shape , repeat_args ) ]
shrink_args = [ ( sh - dim [ 0 ] % sh if dim [ 0 ] % sh != 0 else 0 , nsh - ( sh - dim [ 1 ] % sh ) if dim [ 1 ] % sh != 0 else nsh ) for dim , sh , nsh in zip ( pads , base_shape , new_shape ) ]
return x . repeat ( tuple ( repeat_args ) ) . shrink ( tuple ( shrink_args ) )
elif mode == " reflect " :
for i , s in enumerate ( x . shape ) :
if pads [ i ] == ( 0 , 0 ) : continue
elif pads [ i ] [ 0 ] and not pads [ i ] [ 1 ] :
x = x . flip ( i ) . shrink ( tuple ( [ ( 0 , s_ ) if i_ != i else ( s - pads [ i ] [ 0 ] - 1 , s_ - 1 ) for i_ , s_ in enumerate ( x . shape ) ] ) ) . pad ( tuple ( [ ( 0 , 0 ) if i_ != i else ( 0 , s ) for i_ in range ( x . ndim ) ] ) ) + \
x . pad ( tuple ( [ ( 0 , 0 ) if i_ != i else pads [ i ] for i_ in range ( x . ndim ) ] ) )
elif not pads [ i ] [ 0 ] and pads [ i ] [ 1 ] :
x = x . flip ( i ) . shrink ( tuple ( [ ( 0 , s_ ) if i_ != i else ( 1 , pads [ i ] [ 1 ] + 1 ) for i_ , s_ in enumerate ( x . shape ) ] ) ) . pad ( tuple ( [ ( 0 , 0 ) if i_ != i else ( s , 0 ) for i_ in range ( x . ndim ) ] ) ) + \
x . pad ( tuple ( [ ( 0 , 0 ) if i_ != i else pads [ i ] for i_ in range ( x . ndim ) ] ) )
else :
x = x . flip ( i ) . shrink ( tuple ( [ ( 0 , s_ ) if i_ != i else ( s - pads [ i ] [ 0 ] - 1 , s_ - 1 ) for i_ , s_ in enumerate ( x . shape ) ] ) ) . pad ( tuple ( [ ( 0 , 0 ) if i_ != i else ( 0 , s + pads [ i ] [ 1 ] ) for i_ in range ( x . ndim ) ] ) ) + \
x . flip ( i ) . shrink ( tuple ( [ ( 0 , s_ ) if i_ != i else ( 1 , pads [ i ] [ 1 ] + 1 ) for i_ , s_ in enumerate ( x . shape ) ] ) ) . pad ( tuple ( [ ( 0 , 0 ) if i_ != i else ( s + pads [ i ] [ 0 ] , 0 ) for i_ in range ( x . ndim ) ] ) ) + \
x . pad ( tuple ( [ ( 0 , 0 ) if i_ != i else pads [ i ] for i_ in range ( x . ndim ) ] ) )
return x
elif mode == " edge " :
for i , s in enumerate ( x . shape ) :
if pads [ i ] == ( 0 , 0 ) : continue
elif pads [ i ] [ 0 ] and not pads [ i ] [ 1 ] :
x = x . shrink ( tuple ( [ ( 0 , s_ ) if i_ != i else ( 0 , 1 ) for i_ , s_ in enumerate ( x . shape ) ] ) ) . expand ( [ pads [ i ] [ 0 ] if i_ == i else s_ for i_ , s_ in enumerate ( x . shape ) ] ) . pad ( tuple ( [ ( 0 , 0 ) if i_ != i else ( 0 , s ) for i_ in range ( x . ndim ) ] ) ) + \
x . pad ( tuple ( [ ( 0 , 0 ) if i_ != i else pads [ i ] for i_ in range ( x . ndim ) ] ) )
elif not pads [ i ] [ 0 ] and pads [ i ] [ 1 ] :
x = x . shrink ( tuple ( [ ( 0 , s_ ) if i_ != i else ( s_ - 1 , s_ ) for i_ , s_ in enumerate ( x . shape ) ] ) ) . expand ( [ pads [ i ] [ 0 ] if i_ == i else s_ for i_ , s_ in enumerate ( x . shape ) ] ) . pad ( tuple ( [ ( 0 , 0 ) if i_ != i else ( s + pads [ i ] [ 0 ] , 0 ) for i_ in range ( x . ndim ) ] ) ) + \
x . pad ( tuple ( [ ( 0 , 0 ) if i_ != i else pads [ i ] for i_ in range ( x . ndim ) ] ) )
else :
x = x . shrink ( tuple ( [ ( 0 , s_ ) if i_ != i else ( 0 , 1 ) for i_ , s_ in enumerate ( x . shape ) ] ) ) . expand ( [ pads [ i ] [ 0 ] if i_ == i else s_ for i_ , s_ in enumerate ( x . shape ) ] ) . pad ( tuple ( [ ( 0 , 0 ) if i_ != i else ( 0 , s + pads [ i ] [ 1 ] ) for i_ in range ( x . ndim ) ] ) ) + \
x . shrink ( tuple ( [ ( 0 , s_ ) if i_ != i else ( s_ - 1 , s_ ) for i_ , s_ in enumerate ( x . shape ) ] ) ) . expand ( [ pads [ i ] [ 1 ] if i_ == i else s_ for i_ , s_ in enumerate ( x . shape ) ] ) . pad ( tuple ( [ ( 0 , 0 ) if i_ != i else ( s + pads [ i ] [ 0 ] , 0 ) for i_ in range ( x . ndim ) ] ) ) + \
x . pad ( tuple ( [ ( 0 , 0 ) if i_ != i else pads [ i ] for i_ in range ( x . ndim ) ] ) )
return x
elif mode == " constant " :
return _padding ( x , seq_pads , axes = seq_axes , constant_value = constant_value )
def AveragePool ( X : Tensor , kernel_shape , auto_pad = " NOTSET " , ceil_mode = 0 , count_include_pad = 0 , dilations = 1 , pads = None , strides = 1 ) :
if dilations != 1 : raise NotImplementedError ( f " dilations != 1 not supported, dilations: { dilations } " )
pixel_axes = tuple ( range ( len ( X . shape ) ) ) [ - 2 : ]
if ceil_mode : auto_pad = " SAME_UPPER "
padding_included = _padding ( X , pads , auto_pad , axes = pixel_axes , strides = strides , kernel_shape = kernel_shape , dilations = dilations ) . avg_pool2d ( kernel_shape , stride = strides )
2023-07-06 09:45:50 +08:00
if count_include_pad :
return padding_included
else :
2023-11-21 01:51:03 +08:00
div = _padding ( Tensor . ones ( * X . shape ) , pads , auto_pad , axes = pixel_axes , strides = strides , kernel_shape = kernel_shape , dilations = dilations ) . avg_pool2d ( kernel_shape , stride = strides )
2023-07-06 09:45:50 +08:00
return padding_included / div
2023-11-21 01:51:03 +08:00
def MaxPool ( X : Tensor , kernel_shape , auto_pad = " NOTSET " , ceil_mode = 0 , dilations = 1 , pads = None , storage_order = 0 , strides = 1 ) :
if ceil_mode : auto_pad = " SAME_UPPER "
ret = _padding ( X , pads , auto_pad , constant_value = - np . inf , axes = tuple ( range ( len ( X . shape ) ) ) [ - len ( kernel_shape ) : ] , strides = strides , kernel_shape = kernel_shape , dilations = dilations )
ret = ret . max_pool2d ( kernel_shape , stride = strides , dilation = dilations )
ret_len , X_len = ret . numel ( ) , X . numel ( )
indices = ( ( ret . flatten ( ) . unsqueeze ( 1 ) . expand ( ret_len , X_len ) == X . flatten ( ) . reshape ( 1 , X_len ) . expand ( ret_len , X_len ) ) * Tensor . arange ( X_len ) . reshape ( 1 , X_len ) . expand ( ret_len , X_len ) ) . sum ( 1 ) . reshape ( ret . shape ) . cast ( dtypes . int64 )
if storage_order : indices = indices . transpose ( indices . ndim - 2 , indices . ndim - 1 )
return ret , indices
2023-07-06 09:45:50 +08:00
2023-11-21 01:51:03 +08:00
def MaxUnpool ( xT : Tensor , xI : Tensor , outshape : Tensor = None , kernel_shape = None , pads = None , strides = None ) :
out_sh = [ ( ks / / 2 ) * 2 + st * inps for inps , st , ks in zip ( xI . shape , strides , kernel_shape ) ]
outlength = prod ( out_sh )
xI = xI . flatten ( ) . unsqueeze ( 1 ) . expand ( prod ( xT . shape ) , outlength )
arange = Tensor . arange ( outlength , requires_grad = False ) . reshape ( 1 , outlength ) . expand ( xI . shape )
xT = xT . flatten ( ) . unsqueeze ( 1 ) . expand ( prod ( xT . shape ) , outlength )
ret = ( ( xI == arange ) * xT ) . sum ( 0 ) . reshape ( [ 1 , 1 ] + out_sh )
if outshape is not None :
outshape = safe_numpy ( outshape ) . tolist ( )
if outshape != ret . shape :
diff = [ outshape [ 2 ] - ret . shape [ 2 ] , outshape [ 3 ] - ret . shape [ 3 ] ]
pad_args = [ diff [ 0 ] / / 2 , diff [ 1 ] / / 2 , diff [ 0 ] - diff [ 0 ] / / 2 , diff [ 1 ] - diff [ 1 ] / / 2 ]
ret = ret . pad2d ( ( pad_args [ 1 ] , pad_args [ 3 ] , pad_args [ 0 ] , pad_args [ 2 ] ) )
return ret
2023-07-06 09:45:50 +08:00
2023-11-21 01:51:03 +08:00
def Conv ( X : Tensor , W : Tensor , B = None , auto_pad = " NOTSET " , dilations = 1 , group = 1 , kernel_shape = None , pads = None , strides = 1 ) :
if auto_pad != " NOTSET " : padding = _auto_pad ( X , auto_pad , strides , kernel_shape , dilations )
else : padding = [ p for ps in zip ( pads [ : len ( pads ) / / 2 ] [ : : - 1 ] , pads [ len ( pads ) / / 2 : ] [ : : - 1 ] ) for p in ps ] if pads is not None else 0 # reorder padding
return X . conv2d ( W , B , stride = strides , groups = group , dilation = dilations , padding = padding )
2023-07-06 09:45:50 +08:00
2023-11-21 01:51:03 +08:00
def ConvTranspose ( X : Tensor , W : Tensor , B = None , auto_pad = " NOTSET " , dilations = 1 , group = 1 , kernel_shape = None , pads = None , output_shape = None , output_padding = 0 , strides = 1 ) :
if not kernel_shape : kernel_shape = W . shape
if pads is None and auto_pad != " NOTSET " : pads = _auto_pad ( X , auto_pad , strides , kernel_shape , dilations )
elif pads is None and auto_pad == " NOTSET " : pads = [ 0 , 0 ] * ( X . ndim - 2 )
strides_ = [ 1 ] * ( W . ndim - 1 ) + [ strides ] if isinstance ( strides , int ) else [ 1 ] * ( W . ndim - len ( strides ) ) + list ( strides )
dilations_ = [ 1 ] * ( W . ndim - 1 ) + [ dilations ] if isinstance ( dilations , int ) else [ 1 ] * ( W . ndim - len ( dilations ) ) + list ( dilations )
if output_shape and not output_padding :
out_sh = [ st * ( xs - 1 ) + ( ks - 1 ) * di + 1 if n < 2 else st * ( xs - 1 ) + ( ks - 1 ) * di + 1 - pads [ n - 2 ] - pads [ n - 1 ] for n , ( st , xs , ks , di ) in enumerate ( zip ( strides_ , X . shape , kernel_shape , dilations_ ) ) ]
output_padding = [ os - rs for os , rs in zip ( output_shape , out_sh [ - len ( output_shape ) : ] ) ]
return X . conv_transpose2d ( W , B , stride = strides , groups = group , dilation = dilations , padding = pads if pads is not None else 0 , output_padding = output_padding )
# Reimplemented here because you need legacy RNG for passing ONNX tests.
def Dropout ( data : Tensor , ratio = 0.5 , training_mode = False , seed = None ) :
if isinstance ( ratio , Tensor ) and not ratio . shape : ratio = safe_numpy ( ratio ) # ratio and tensor is passed in as Tensor with shape: ()
if isinstance ( training_mode , Tensor ) and not training_mode . shape : training_mode = safe_numpy ( training_mode )
if not training_mode : return data , Tensor . ones ( * data . shape , dtype = dtypes . bool ) # if mask is requested as output it will contain all True's.
rng = np . random . RandomState ( seed )
ratio = ratio . lazydata . realize ( ) . toCPU ( ) [ 0 ] if isinstance ( ratio , Tensor ) else ratio
mask = Tensor ( ( rng . random ( data . shape ) > = ratio ) , requires_grad = False , device = data . device )
return data * mask * ( 1 / ( 1.0 - ratio ) ) , mask
2023-07-06 09:45:50 +08:00
2023-11-21 01:51:03 +08:00
def LRN ( input : Tensor , size , alpha = 1e-4 , beta = 0.75 , bias = 1.0 ) :
bs , c , iy , ix = input . shape
2023-07-06 09:45:50 +08:00
return input / input . mul ( input ) . reshape ( bs , 1 , c , iy * ix ) . pad2d ( ( 0 , 0 , ( size - 1 ) / / 2 , size / / 2 ) ) . avg_pool2d ( ( size , 1 ) , 1 ) . reshape ( bs , c , iy , ix ) . mul ( alpha ) . add ( bias ) . pow ( beta )
2023-11-21 01:51:03 +08:00
def MeanVarianceNormalization ( input : Tensor , axis = ( 0 , 2 , 3 ) ) :
data_mean = input . mean ( axis = axis , keepdim = True )
std = ( ( input * * 2 ) . mean ( axis = axis , keepdim = True ) - data_mean * * 2 ) . sqrt ( )
return ( input - data_mean ) / ( std + 1e-9 )
2023-07-06 09:45:50 +08:00
2023-11-21 01:51:03 +08:00
def NegativeLogLikelihoodLoss ( input : Tensor , target : Tensor , weight = None , ignore_index = None , reduction = " mean " ) :
target = target . cast ( dtypes . float32 )
N , C , i_shape = input . shape [ 0 ] , input . shape [ 1 ] , input . shape
t_shape = target . shape
if len ( input . shape ) != 3 :
input = input . reshape ( ( N , C , - 1 ) )
target = target . reshape ( ( N , - 1 ) )
if weight is not None :
mask = target . unsqueeze ( - 1 ) == Tensor . arange ( C ) . repeat ( ( N , 1 , 1 ) )
weight = ( mask * weight ) . sum ( axis = - 1 )
if ignore_index is not None :
cond = target == ignore_index
weight = cond . where ( 0 , weight ) if weight is not None else cond . where ( Tensor . zeros ( * target . shape ) , 1 )
mask = target [ : , None , : ] == Tensor . arange ( C ) . reshape ( [ 1 , C ] + [ 1 ] * ( len ( input . shape ) - 2 ) )
loss = ( - mask * input ) . sum ( axis = 1 ) * ( 1 if weight is None else weight )
if reduction == " mean " : return loss . mean ( ) if weight is None else loss . sum ( ) / weight . sum ( )
elif reduction == " sum " : return loss . sum ( )
return loss . reshape ( t_shape ) if len ( i_shape ) != 3 else loss
2023-07-06 09:45:50 +08:00
2023-11-21 01:51:03 +08:00
def SoftmaxCrossEntropyLoss ( scores : Tensor , labels : Tensor , weights = None , ignore_index = None , reduction = " mean " ) :
N , C , * s_dimensions = scores . shape
if ignore_index is not None : labels = ( labels == ignore_index ) . where ( C + 1 , labels )
mask = labels . unsqueeze ( 1 ) == Tensor . arange ( C ) . reshape ( 1 , C , * [ 1 ] * len ( s_dimensions ) )
y = scores . log_softmax ( axis = 1 )
if weights is not None : weights = weights . __getitem__ ( tuple ( [ labels , * [ slice ( None ) ] * ( weights . ndim - 1 ) ] ) )
loss = ( mask * - y ) . sum ( 1 ) if weights is None else ( mask * - y ) . sum ( 1 ) * weights
if reduction == " mean " : loss = loss . sum ( ) / ( loss == 0 ) . where ( 0 , 1 ) . sum ( ) if weights is None else loss . sum ( ) / weights . sum ( )
elif reduction == " sum " : loss = loss . sum ( )
return loss , y
2023-07-06 09:45:50 +08:00
2023-11-21 01:51:03 +08:00
def ArrayFeatureExtractor ( input : Tensor , indices : Tensor ) : return input . __getitem__ ( tuple ( [ slice ( None ) if i != ( input . ndim - 1 ) else indices for i in range ( input . ndim ) ] ) )
def Gather ( input : Tensor , indices : Tensor , axis = 0 ) :
if indices . numel ( ) < 9 : # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices
input_sh = list ( input . shape )
ret_shape = input_sh [ : axis ] + list ( indices . shape ) + input_sh [ axis + 1 : ]
if indices . ndim > 1 : indices = indices . flatten ( )
indices = [ int ( safe_numpy ( indices ) ) ] if indices . shape == ( ) else [ input_sh [ axis ] + int ( x ) if x < 0 else int ( x ) for x in safe_numpy ( indices ) ]
args = [ [ ( 0 , x ) if j != axis else ( i , i + 1 ) for j , x in enumerate ( input_sh ) ] for i in indices ]
return input . shrink ( arg = tuple ( args [ 0 ] ) ) . cat ( * [ input . shrink ( arg = tuple ( arg ) ) for arg in args [ 1 : ] ] , dim = axis ) . reshape ( ret_shape )
else : # NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot
return input . __getitem__ ( tuple ( [ slice ( None ) if i != axis else indices for i in range ( input . ndim ) ] ) )
2023-07-06 09:45:50 +08:00
2023-11-21 01:51:03 +08:00
def GatherElements ( input : Tensor , indices : Tensor , axis ) :
indices = indices . sign ( ) . contiguous ( ) . __neg__ ( ) . contiguous ( ) . relu ( ) * input . shape [ axis ] + indices
return input . gather ( indices , axis )
2023-07-06 09:45:50 +08:00
2023-11-21 01:51:03 +08:00
def _round ( x : Tensor , n : float , equidistant_case = " round_down " ) - > Tensor :
def _and ( cond1 , cond2 ) : return ( ( cond1 + cond2 ) == 2 ) . where ( 1 , 0 )
assert n < = 1 , f " n: { n } shouldn ' t be larger than 1 "
b = x . cast ( dtypes . int32 ) . contiguous ( ) . cast ( x . dtype )
b = ( b > = 0 ) . where ( b + n , b - n )
if equidistant_case == " round_down " :
return ( x > b ) . where ( b + 1 - n , b - n )
elif equidistant_case == " round_up " :
return ( x > = b ) . where ( b + 1 - n , b - n )
elif equidistant_case == " round_to_even " :
x_ceil_fraction = x . ceil ( ) / 2
cond_ceil_even = x_ceil_fraction . ceil ( ) == x_ceil_fraction
x = ( _and ( x == b , cond_ceil_even ) ) . where ( x + 1 - n , x )
x = ( x > b ) . where ( b + 1 - n , b - n )
return x
def Round ( X : Tensor ) : return _round ( X , 0.5 , " round_to_even " )
def Resize ( X : Tensor , roi = None , scales = None , sizes = None , antialias = 0 , axes = None , coordinate_transformation_mode = ' half_pixel ' , cubic_coeff_a = - 0.75 , exclude_outside = 0 , extrapolation_value = 0.0 , keep_aspect_ratio_policy = ' stretch ' , mode = ' nearest ' , nearest_mode = ' round_prefer_floor ' ) :
def _nearest_gather ( X : Tensor , x_out , y_out ) : return X [ : , : , y_out , : ] [ : , : , : , x_out ]
def _nearest_mode ( x_resized : Tensor , nearest_mode : str , x_len ) :
if nearest_mode == " round_prefer_floor " : ret = _round ( x_resized , 0.5 , " round_down " )
elif nearest_mode == " round_prefer_ceil " : ret = _round ( x_resized , 0.5 , " round_up " )
elif nearest_mode == " floor " : ret = x_resized . floor ( )
elif nearest_mode == " ceil " : ret = x_resized . ceil ( )
return ret . clip ( 0 , x_len - 1 )
def _coordinate_transformation ( x_out , y_out , output_shape , scales_lol , roi = None ) :
if coordinate_transformation_mode == " half_pixel " :
x_out = ( x_out + 0.5 ) / Tensor ( scales_lol [ - 1 ] ) - 0.5 # TODO Tensor() because try (((Tensor([0,1,2,3,4,5])+0.5)/3.5 - 0.5)) with LLVM or METAL, inaccuacy.
y_out = ( y_out + 0.5 ) / Tensor ( scales_lol [ - 2 ] ) - 0.5
elif coordinate_transformation_mode == " align_corners " :
x_out = x_out * ( X . shape [ - 1 ] - 1 ) / ( output_shape [ - 1 ] - 1 )
y_out = y_out * ( X . shape [ - 2 ] - 1 ) / ( output_shape [ - 2 ] - 1 )
elif coordinate_transformation_mode == " asymmetric " :
x_out = x_out / scales_lol [ - 1 ]
y_out = y_out / scales_lol [ - 2 ]
elif coordinate_transformation_mode == " half_pixel_symmetric " :
x_out = X . shape [ - 1 ] / 2 * ( 1 - int ( output_shape [ - 1 ] ) / output_shape [ - 1 ] ) + ( x_out + 0.5 ) / scales_lol [ - 1 ] - 0.5
y_out = X . shape [ - 2 ] / 2 * ( 1 - int ( output_shape [ - 2 ] ) / output_shape [ - 2 ] ) + ( y_out + 0.5 ) / scales_lol [ - 2 ] - 0.5
elif coordinate_transformation_mode == " pytorch_half_pixel " :
x_out = ( x_out + 0.5 ) / scales_lol [ - 1 ] - 0.5 if output_shape [ - 1 ] > 1 else Tensor ( [ 0 ] )
y_out = ( y_out + 0.5 ) / scales_lol [ - 2 ] - 0.5 if output_shape [ - 2 ] > 1 else Tensor ( [ 0 ] )
elif coordinate_transformation_mode == " tf_crop_and_resize " :
x_out = roi [ - 1 ] [ 0 ] * ( X . shape [ - 1 ] - 1 ) + x_out * ( ( roi [ - 1 ] [ 1 ] - roi [ - 1 ] [ 0 ] ) * ( X . shape [ - 1 ] - 1 ) / ( output_shape [ - 1 ] - 1 ) ) if output_shape [ - 1 ] > 1 else Tensor ( [ 0.5 * ( roi [ - 1 ] [ 0 ] + roi [ - 1 ] [ 1 ] ) * ( X . shape [ - 1 ] - 1 ) ] )
y_out = roi [ - 2 ] [ 0 ] * ( X . shape [ - 2 ] - 1 ) + y_out * ( ( roi [ - 2 ] [ 1 ] - roi [ - 2 ] [ 0 ] ) * ( X . shape [ - 2 ] - 1 ) / ( output_shape [ - 2 ] - 1 ) ) if output_shape [ - 2 ] > 1 else Tensor ( [ 0.5 * ( roi [ - 2 ] [ 0 ] + roi [ - 2 ] [ 1 ] ) * ( X . shape [ - 2 ] - 1 ) ] )
return x_out . clip ( 0 , X . shape [ - 1 ] - 1 ) , y_out . clip ( 0 , X . shape [ - 2 ] - 1 )
if roi is not None :
roi = safe_numpy ( roi )
roi = [ ( st , ed ) for st , ed in zip ( roi [ : len ( roi ) / / 2 ] , roi [ len ( roi ) / / 2 : ] ) ]
roi_ = [ ( 1 , 1 ) ] * 4
if axes is not None :
for a , r in zip ( axes , roi ) :
roi_ [ a ] = r
roi = roi_
if scales is not None :
scales = safe_numpy ( scales ) . tolist ( )
if axes is not None :
scales_ = [ 1 ] * X . ndim
for a , s in zip ( axes , scales ) :
scales_ [ a ] = s
scales = scales_
elif sizes is not None :
sizes = [ int ( i ) for i in safe_numpy ( sizes ) ]
scales = [ ]
if axes is not None :
sizes_ = [ 1 ] * X . ndim
for a , s in zip ( axes , sizes ) :
sizes_ [ a ] = s
scales . append ( s / X . shape [ a ] )
sizes = sizes_
else : scales = [ si / xs for xs , si in zip ( X . shape , sizes ) ]
if keep_aspect_ratio_policy == " not_larger " :
scale = min ( scales )
sizes = _round ( Tensor ( list ( X . shape [ - 2 : ] ) ) * scale , 0.5 , " round_up " )
sizes = list ( X . shape [ : - 2 ] ) + [ int ( i ) for i in safe_numpy ( sizes ) ]
elif keep_aspect_ratio_policy == " not_smaller " :
scale = max ( scales )
sizes = _round ( Tensor ( list ( X . shape [ - 2 : ] ) ) * scale , 0.5 , " round_up " )
sizes = list ( X . shape [ : - 2 ] ) + [ int ( i ) for i in safe_numpy ( sizes ) ]
output_shape = sizes if sizes else [ math . floor ( x * s ) for x , s in zip ( X . shape , scales ) ]
output_shape_ = sizes if sizes else [ x * s for x , s in zip ( X . shape , scales ) ]
scales_lol = [ os / xs for xs , os in zip ( X . shape , output_shape ) ]
x_out = Tensor . arange ( output_shape [ - 1 ] )
y_out = Tensor . arange ( output_shape [ - 2 ] )
if mode == " nearest " :
x_out , y_out = _coordinate_transformation ( x_out , y_out , output_shape , scales_lol , roi )
x_out = _nearest_mode ( x_out , nearest_mode , X . shape [ - 1 ] )
y_out = _nearest_mode ( y_out , nearest_mode , X . shape [ - 1 ] )
return _nearest_gather ( X , x_out , y_out )
elif mode == " linear " :
x_out , y_out = _coordinate_transformation ( x_out , y_out , output_shape_ , scales , roi )
ret = [ ]
for y in safe_numpy ( y_out ) :
for x in safe_numpy ( x_out ) :
x_floor , y_floor = int ( x ) , int ( y )
y_shrink = ( 0 , X . shape [ 2 ] ) if X . shape [ 2 ] == 1 else ( y_floor , y_floor + 2 ) if y != y_floor else ( y_floor , y_floor + 1 )
x_shrink = ( x_floor , x_floor + 2 ) if x != x_floor else ( x_floor , x_floor + 1 )
shrink_args = ( ( 0 , X . shape [ 0 ] ) , ( 0 , X . shape [ 1 ] ) , y_shrink , x_shrink )
corners = safe_numpy ( X . shrink ( shrink_args ) )
x1 , x2 , y1 , y2 = x_floor , x_floor + 1 , y_floor , y_floor + 1
if x == x_floor and y == y_floor : # TODO https://en.wikipedia.org/wiki/Bilinear_interpolation#Weighted_mean maybe do weighted mean?
ret . append ( corners [ 0 , 0 , 0 , 0 ] )
elif x == x_floor :
ret . append ( ( corners [ 0 , 0 , 0 , 0 ] * ( y2 - y ) + corners [ 0 , 0 , 1 , 0 ] * ( y - y1 ) ) / ( y2 - y1 ) )
elif y == y_floor :
ret . append ( ( corners [ 0 , 0 , 0 , 0 ] * ( x2 - x ) + corners [ 0 , 0 , 0 , 1 ] * ( x - x1 ) ) / ( x2 - x1 ) )
else :
ret . append ( ( corners [ 0 , 0 , 0 , 0 ] * ( x2 - x ) * ( y2 - y ) + corners [ 0 , 0 , 0 , 1 ] * ( x - x1 ) * ( y2 - y ) + corners [ 0 , 0 , 1 , 0 ] * ( x2 - x ) * ( y - y1 ) + corners [ 0 , 0 , 1 , 1 ] * ( x - x1 ) * ( y - y1 ) ) / ( ( x2 - x1 ) * ( y2 - y1 ) ) )
return Tensor ( ret ) . reshape ( output_shape )
elif mode == " cubic " :
raise Exception ( " cubic interpolation is not implemented " )
def CenterCropPad ( input : Tensor , shape : Tensor , axes = None ) :
if not axes : axes = list ( range ( input . ndim ) )
shrink_arg = [ ( 0 , i ) for i in input . shape ]
pad_arg = [ ( 0 , 0 ) for _ in range ( input . ndim ) ]
shape = safe_numpy ( shape ) . tolist ( )
for s , x in zip ( shape , axes ) :
if s < input . shape [ x ] : shrink_arg [ x ] = ( input . shape [ x ] / / 2 - s / / 2 , input . shape [ x ] / / 2 + s / / 2 ) if s % 2 == 0 else ( input . shape [ x ] / / 2 - s / / 2 - 1 , input . shape [ x ] / / 2 + s / / 2 )
elif s > input . shape [ x ] : pad_arg [ x ] = ( ( s - input . shape [ x ] ) / / 2 , ( s - input . shape [ x ] ) / / 2 ) if ( s - input . shape [ x ] ) % 2 == 0 else ( ( s - input . shape [ x ] ) / / 2 , ( s - input . shape [ x ] ) / / 2 + 1 )
return input . shrink ( tuple ( shrink_arg ) ) . pad ( tuple ( pad_arg ) )
def OneHot ( indices : Tensor , depth : Tensor , values : Tensor , axis = - 1 ) :
depth = int ( safe_numpy ( depth ) . item ( ) )
indices , rank = ( indices < 0 ) . where ( indices + depth , indices ) , len ( indices . shape )
if axis < 0 : axis + = rank + 1
ls , rs = indices . shape [ 0 : axis ] , indices . shape [ axis : rank ]
cond = indices [ : , None ] == Tensor . arange ( depth ) . reshape ( ( 1 , ) * len ( ls ) + ( depth , ) + ( 1 , ) * len ( rs ) )
return cond . where ( values [ 1 ] , values [ 0 ] ) . cast ( values . dtype )
def Erf ( x : Tensor ) :
sign = x . sign ( )
x = x . abs ( )
t = 1.0 / ( 1.0 + 0.3275911 * x )
term1 = 0.254829592 * t
term2 = - 0.284496736 * t * * 2
term3 = 1.421413741 * t * * 3
term4 = - 1.453152027 * t * * 4
term5 = 1.061405429 * t * * 5
y = ( term1 + term2 + term3 + term4 + term5 )
return sign * ( 1.0 - y * Tensor . exp ( - x * x ) )
def Compress ( inp : Tensor , condition : Tensor , axis = None ) :
if axis is None :
inp = inp . flatten ( )
axis = 0
axis = axis + inp . ndim if axis < 0 else axis
con_np = safe_numpy ( condition )
con = Tensor ( np . arange ( condition . shape [ 0 ] ) [ con_np ] ) # no boolean indexing in Tensor
return inp . __getitem__ ( tuple ( [ slice ( None ) if i != axis else con for i in range ( inp . ndim ) ] ) )
type_map = { TensorProto . DOUBLE : dtypes . double , TensorProto . FLOAT : dtypes . float32 }
def EyeLike ( x : Tensor , dtype = None , k = 0 ) :
if dtype is None : dtype = x . dtype
else : dtype = type_map [ dtype ]
shape = x . shape
dim = min ( x . shape )
if shape [ 0 ] == shape [ 1 ] : return Tensor . eye ( dim = dim , dtype = dtype )
else :
diff = ( shape [ 0 ] - dim , shape [ 1 ] - dim )
padarg = tuple ( [ ( d , d ) if d == 0 else ( k , d - k ) for d in diff ] )
return Tensor . eye ( dim = dim , dtype = dtype ) . pad ( padarg )
def Upsample ( X , scales , mode ) : return Resize ( X = X , scales = scales , mode = mode )
# Needs work
def IsInf ( x , detect_negative = 1 , detect_positive = 1 ) :
ret = ( x == float ( " inf " ) ) * detect_positive + ( x == float ( " -inf " ) ) * detect_negative + Tensor . zeros ( * x . shape )
return ret . cast ( dtypes . bool )
# Needs work
def DequantizeLinear ( x : Tensor , x_scale : Tensor , x_zero_point = 0 , axis = 1 ) :
axis = axis + x . ndim if axis < 0 else axis
x_sc = x_scale . reshape ( * [ 1 ] * axis , * x_scale . shape , * [ 1 ] * ( x . ndim - axis - x_scale . ndim ) )
x_zer = x_zero_point . reshape ( * [ 1 ] * axis , * x_scale . shape , * [ 1 ] * ( x . ndim - axis - x_scale . ndim ) ) if isinstance ( x_zero_point , Tensor ) else x_zero_point
return ( x - x_zer ) * x_sc
# Needs work
def IsNaN ( x ) :
return ( x < float ( " -inf " ) ) . cast ( dtypes . bool )
# **************** com.microsoft Ops ****************
def SkipLayerNormalization ( input : Tensor , skip : Tensor , gamma , beta : Optional [ Tensor ] = None , bias : Optional [ Tensor ] = None , epsilon = None ) :
if epsilon is None : epsilon = 1e-12
x = input + skip + bias
return x . layernorm ( eps = epsilon ) * gamma + beta , None , None , x
def FastGelu ( x : Tensor , bias : Optional [ Tensor ] = None ) :
x = x + bias
return 0.5 * x * ( 1 + ( x * 0.797885 + 0.035677 * x * * 3 ) . tanh ( ) )
def EmbedLayerNormalization ( input_ids : Tensor , segment_ids : Optional [ Tensor ] = None , word_embedding : Tensor = None , position_embedding : Tensor = None , segment_embedding : Optional [ Tensor ] = None , gamma = None , beta = None , mask : Optional [ Tensor ] = None , position_ids : Optional [ Tensor ] = None , epsilon = None , mask_index_type = None ) :
# https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.EmbedLayerNormalization
assert ( segment_ids is None ) is ( segment_embedding is None )
assert ( mask is None ) is ( mask_index_type is None )
assert mask is None , " functionality not supported yet " # TODO
input_shape = input_ids . shape
bsz , seq_length = input_shape [ 0 ] , input_shape [ 1 ]
compute_seg_emb = ( segment_embedding is not None and segment_ids is not None )
vocab_size , max_position_embeddings , type_vocab_size = word_embedding . shape [ 0 ] , position_embedding . shape [ 0 ] , ( segment_embedding . shape [ 0 ] if compute_seg_emb else None )
def embedding ( x : Tensor , vocab_size , weight : Tensor ) - > Tensor : # TODO from nn.Embedding. Could probably upstream this to Tensor
vocab_counter = Tensor . arange ( vocab_size , dtype = x . dtype , requires_grad = False ) . reshape ( 1 , 1 , vocab_size ) . expand ( * x . shape , vocab_size )
return ( vocab_counter == x . unsqueeze ( 2 ) . expand ( * x . shape , vocab_size ) ) @ weight
# bert embedding layer
if epsilon is None : epsilon = 1e-12
if position_ids is None : position_ids = Tensor . arange ( seq_length , requires_grad = False ) . unsqueeze ( 0 ) . expand ( * input_shape )
wrd_embedding_res = embedding ( input_ids , vocab_size , word_embedding )
pos_embedding_res = embedding ( position_ids , max_position_embeddings , position_embedding )
seg_embedding_res = embedding ( segment_ids , type_vocab_size , segment_embedding ) if compute_seg_emb else None
embedding_sum = wrd_embedding_res + pos_embedding_res + seg_embedding_res
out = embedding_sum . layernorm ( eps = epsilon ) * gamma + beta
return out , None , embedding_sum
def Attention ( input : Tensor , weights , bias : Optional [ Tensor ] = None , mask_index : Optional [ Tensor ] = None , past : Optional [ Tensor ] = None , relative_position_bias : Optional [ Tensor ] = None , past_sequence_length : Optional [ Tensor ] = None , do_rotary = None , mask_filter_value = None , num_heads = None , past_present_share_buffer = None , qkv_hidden_sizes = None , scale = None , unidirectional = None ) :
# https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.Attention
assert num_heads is not None # required
assert ( qkv_hidden_sizes is None and past is not None ) or ( qkv_hidden_sizes is not None )
assert relative_position_bias == do_rotary == past_sequence_length == mask_filter_value == past_present_share_buffer == scale == None , " functionality not supported yet " # TODO strange params
hidden_size , v_hidden_size = qkv_hidden_sizes [ 1 : ] if qkv_hidden_sizes is not None else 2 * ( weights . shape [ 1 ] / / 3 , )
if unidirectional : # gpt-style
assert hidden_size == v_hidden_size
xqkv = input . linear ( weights , bias )
xq , xk , xv = [ xqkv . slice ( [ None , None , ( i * hidden_size , ( i + 1 ) * hidden_size ) ] ) for i in range ( 3 ) ]
else : # bert-style
wq , wk , wv = weights [ : , : hidden_size ] , weights [ : , hidden_size : hidden_size + v_hidden_size ] , weights [ : , hidden_size + v_hidden_size : ]
bq , bk , bv = ( bias [ : hidden_size ] , bias [ hidden_size : hidden_size + v_hidden_size ] , bias [ hidden_size + v_hidden_size ] ) if bias is not None else None
xq , xk , xv = [ input . linear ( w , b ) for w , b in zip ( ( wq , wk , wv ) , ( bq , bk , bv ) ) ]
xq , xk , xv = [ x . reshape ( x . shape [ 0 ] , x . shape [ 1 ] , num_heads , - 1 ) . transpose ( 1 , 2 ) for x in ( xq , xk , xv ) ]
if past is not None :
xk , xv = Tensor . cat ( past [ 0 ] , xk , dim = - 2 ) , Tensor . cat ( past [ 1 ] , xv , dim = - 2 )
present = Tensor . cat ( xk . unsqueeze ( 0 ) , xv . unsqueeze ( 0 ) )
def attn ( query , key , value , attn_mask ) :
query_length , key_length = query . shape [ - 2 ] , key . shape [ - 2 ]
cdim = max ( query_length , key_length ) + 1
attn_weights = query @ key . transpose ( - 1 , - 2 ) / math . sqrt ( value . shape [ - 1 ] )
# This is where Tensor.scaled_dot_product_attention differs:
causal_mask = Tensor . ones ( ( cdim , cdim ) , requires_grad = False ) . cast ( dtypes . bool ) . tril ( 0 ) [ key_length - query_length : key_length , : key_length ] . cast ( dtypes . bool )
return ( Tensor . where ( causal_mask , attn_weights , - float ( " inf " ) ) + attn_mask ) . softmax ( - 1 ) @ value
bsz , _ , seq_len , _ = xq . shape
out = attn ( xq , xk , xv , mask_index ) . transpose ( 1 , 2 ) . reshape ( bsz , seq_len , - 1 )
return out , present
# **************** ai.onnx.preview.training Ops ****************
# TODO not entirely sure these optimizers are correct
def Adagrad ( R , T , * inputs , decay_factor = 0.0 , epsilon = 0.0 , norm_coefficient = 0.0 ) :
groups = len ( inputs ) / / 3
grouped_inputs = [ inputs [ i : : groups ] for i in range ( groups ) ]
T , R = safe_numpy ( T ) [ 0 ] , safe_numpy ( R ) [ 0 ]
r = R / ( 1 + T * decay_factor )
ret = [ ]
for input in grouped_inputs :
X , G , H = input
X . grad = norm_coefficient * X + G
X . grad . requires_grad , H . requires_grad = False , False # TODO manually turning off requires_grad, see TODO under (domain == "ai.onnx.preview.training") in onnx.py
H . assign ( H . detach ( ) + X . grad * X . grad ) . realize ( )
H_adaptive = H . sqrt ( ) + epsilon
X . assign ( X . detach ( ) - r * X . grad / H_adaptive )
ret . extend ( [ X , H ] )
ret = ret [ : : 2 ] + ret [ 1 : : 2 ]
return tuple ( ret )
def Momentum ( R , T , * inputs , alpha , beta , mode , norm_coefficient ) :
groups = len ( inputs ) / / 3
grouped_inputs = [ inputs [ i : : groups ] for i in range ( groups ) ]
T , R = safe_numpy ( T ) [ 0 ] , safe_numpy ( R ) [ 0 ]
beta_adjusted = beta if T > 0 else 1
ret = [ ]
for input in grouped_inputs :
X , G , V = input
X . grad = ( norm_coefficient * X + G ) . realize ( )
X . grad . requires_grad , V . requires_grad = False , False
V . assign ( alpha * V + beta_adjusted * X . grad ) . realize ( )
if mode == " standard " : X . assign ( X . detach ( ) - R * V ) . realize ( )
elif mode == " nesterov " : X . assign ( X . detach ( ) - R * ( X . grad + alpha + V ) ) . realize ( )
ret . extend ( [ X , V ] )
ret = ret [ : : 2 ] + ret [ 1 : : 2 ]
return tuple ( ret )
# copied from tinygrad/nn/optim.py: LAMB with some edits
def Adam ( R , T , * inputs , alpha = 0.9 , beta = 0.999 , epsilon = 0.0 , norm_coefficient = 0.0 , norm_coefficient_post = 0.0 ) :
groups = len ( inputs ) / / 4
grouped_inputs = [ inputs [ i : : groups ] for i in range ( groups ) ]
T , R = safe_numpy ( T ) [ 0 ] , safe_numpy ( R ) [ 0 ]
ret = [ ]
for input in grouped_inputs :
X , G , V , H = input
X . grad = ( norm_coefficient * X + G ) . realize ( )
V . requires_grad , H . requires_grad , X . grad . requires_grad = False , False , False
V . assign ( alpha * V + ( 1.0 - alpha ) * X . grad ) . realize ( )
H . assign ( beta * H + ( 1.0 - beta ) * ( X . grad * X . grad ) ) . realize ( )
up = ( V / ( 1.0 - alpha * * T ) ) / ( ( H / ( 1.0 - beta * * T ) ) . sqrt ( ) + epsilon ) if T > 0 else V / ( H . sqrt ( ) + epsilon )
X . assign ( X . detach ( ) - R * up ) . realize ( )
X = ( 1 - norm_coefficient_post ) * X
ret . extend ( [ X , V , H ] )
ret = ret [ : : 3 ] + ret [ 1 : : 3 ] + ret [ 2 : : 3 ]
return tuple ( ret )