fix onnx tests

This commit is contained in:
George Hotz 2023-02-24 09:27:18 -08:00
parent e263c0c628
commit 3becefa218
1 changed files with 2 additions and 1 deletions

View File

@ -19,7 +19,8 @@ def Gemm(A, B, C=None, alpha=1.0, beta=1.0, transA=0, transB=0):
return ret
# TODO: this is copied from tinygrad/nn/__init__.py
def BatchNormalization(X, scale, B, input_mean, input_var, epsilon=1e-05, momentum=0.9, training_mode=0):
# spatial is from opset 7 and has since been removed
def BatchNormalization(X, scale, B, input_mean, input_var, epsilon=1e-05, momentum=0.9, training_mode=0, spatial=1):
if training_mode:
x_detached = X.detach()
current_mean = x_detached.mean(axis=(0,2,3))