mirror of https://github.com/commaai/tinygrad.git
fix onnx tests
This commit is contained in:
parent
e263c0c628
commit
3becefa218
|
@ -19,7 +19,8 @@ def Gemm(A, B, C=None, alpha=1.0, beta=1.0, transA=0, transB=0):
|
|||
return ret
|
||||
|
||||
# TODO: this is copied from tinygrad/nn/__init__.py
|
||||
def BatchNormalization(X, scale, B, input_mean, input_var, epsilon=1e-05, momentum=0.9, training_mode=0):
|
||||
# spatial is from opset 7 and has since been removed
|
||||
def BatchNormalization(X, scale, B, input_mean, input_var, epsilon=1e-05, momentum=0.9, training_mode=0, spatial=1):
|
||||
if training_mode:
|
||||
x_detached = X.detach()
|
||||
current_mean = x_detached.mean(axis=(0,2,3))
|
||||
|
|
Loading…
Reference in New Issue