tinygrad/test/external/external_test_onnx_backend.py

215 lines
8.6 KiB
Python
Raw Normal View History

2023-02-24 13:55:07 +08:00
import unittest
from typing import Any, Tuple
2023-02-24 13:55:07 +08:00
from onnx.backend.base import Backend, BackendRep
import onnx.backend.test
2023-02-25 01:36:32 +08:00
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv, CI
from tinygrad.device import Device, Compiled
2023-02-24 13:55:07 +08:00
# pip3 install tabulate
pytest_plugins = 'onnx.backend.test.report',
from extra.onnx import get_run_onnx
class TinygradModel(BackendRep):
def __init__(self, run_onnx, input_names):
super().__init__()
self.fxn = run_onnx
self.input_names = input_names
def run(self, inputs: Any, **kwargs: Any) -> Tuple[Any, ...]:
real_inputs = {k:v for k,v in zip(self.input_names, inputs)}
ret = self.fxn(real_inputs, debug=True)
onnx full passing (#1076) * 1 * 83 failed * learning how git works * lol idk * zero shape aaaa * space lol * aaa * test check * haha * fixed gather * 73 failing * 71 failing * 68 failing * added some debug * fking resize * lol * 62 failing * 58 failling fucking did nearest resize hell yeah * clean up * 56 failing * janitor duty * lol * 53 failing * hi mom * 50 failing * added linear interp, but coord_trans is wrong * did lin interpolation woohoo * 43 failing * 40 failing * temporary Gather fix * 39 failing * fixed slice onnxver<10 * 37 failing * 35 failing * excluded tests that use float64 * 32 failing with hacks * added _batchnorm() for 3D 5D batchnorm, 29 failing * changed ALLOWED_KERNEL_COUNT from 199 to 207 * added improved Gather op, reverted ALLOWED_KERNEL_COUNT commit * support Round op * added storage_order/indices maxpool, 27 failing * support maxunpool, 25 failures * support Gradient, 23 failures * merged new where * added Adam * cleanups * added Momentum and Nesterov Momentum * added Adagrad * support sequence_type, 20 failing * ugh git * I give up on cubic interp :D, 9 failing * sexy 1 liner gather, much improved, wow * polished gather to make it shine bright like a diamond * clean 1 liner for gather * improved readability of gather * uhh * clean up * more clean up * WHITEspace * implemented SoftmaxCrossEntropyLoss op * added comments and cleaned up if statements * update * thank based wozeparrot for pow and new GatherElements * CPU and TORCH all pass | cast float64 -> float32 for all fromCPU() * _nearest_gather() failing on yolo * reverted ops_cpu change and added assert in Resize * added comments for resize for multiple channels * oops * merge * test * switched np.pad to Tensor.pad for constant padding * gah * gah2 * sexy reflect pad with movementops -> add * delete commented out lines * edge mode pad sexy as well * trying out model_benchmark * revert gitignore change lol * init * Revert "init" This reverts commit 682bf2073a8b4eca111596c67cf6ebd79f59e585. * wrote cast workaround for CPU, CPU and TORCH all pass * wrote cast workaround for CPU, CPU and TORCH all pass * skipped tests w/ 0 shape for METAL and GPU * excluded tests for CLANG, CPU, TORCH, CLANG pass * fixed hacky ConvTranspose * gotta figure out autopad * UOps.STORE support cast bool -> float * small fix for fast gather * reverted 0 shape skipped tests * oops missed a file * added comment * fixed slice op hack * First commit to pr * More trig ops * More trig ops * format * isinf support * More ops * changed onnx_ops to use our new gather :D * Det op bug fix * rebase * fixed some tests * det broken and slow * fixed compress to use new gather * implemented argmax argmin * support variable types in type_proto * support Upsample and Identity sequence * we support float64 now and tinygrad support automatic broadcasting * added EyeLike op * resize does support multiple channels now actually * yolov8 onnx runs successfully * added batch size 1 * oops * finally fixed type_proto I think * fixed some llvm bugs * del whitespaces * added ZenginU Format PR * test * oops * added float64 exclude tests back * more skipped tests * try * ok openpilot pass * flake8 pass * woooooohooo * revert external_model_benchmark changes * perf tested gather * removed promote types from ops_cpu * numerical errors from 1681 is fixed --------- Co-authored-by: ZenginU <umutzengin00@gmail.com>
2023-09-06 04:23:32 +08:00
return tuple(x.numpy() if isinstance(x, Tensor) else [i.numpy() for i in x] if isinstance(x, list) else np.array(x) for x in ret.values())
2023-02-24 13:55:07 +08:00
class TinygradBackend(Backend):
@classmethod
def prepare(cls, model, device):
input_all = [x.name for x in model.graph.input]
input_initializer = [x.name for x in model.graph.initializer]
net_feed_input = [x for x in input_all if x not in input_initializer]
print("prepare", cls, device, net_feed_input)
run_onnx = get_run_onnx(model)
return TinygradModel(run_onnx, net_feed_input)
2023-02-24 13:55:07 +08:00
@classmethod
def supports_device(cls, device: str) -> bool:
return device == "CPU"
backend_test = onnx.backend.test.BackendTest(TinygradBackend, __name__)
2023-02-24 13:55:07 +08:00
2023-02-25 04:52:23 +08:00
# no support for reduce with multiply (needs llop)
backend_test.exclude('test_reduce_prod_*')
onnx full passing (#1076) * 1 * 83 failed * learning how git works * lol idk * zero shape aaaa * space lol * aaa * test check * haha * fixed gather * 73 failing * 71 failing * 68 failing * added some debug * fking resize * lol * 62 failing * 58 failling fucking did nearest resize hell yeah * clean up * 56 failing * janitor duty * lol * 53 failing * hi mom * 50 failing * added linear interp, but coord_trans is wrong * did lin interpolation woohoo * 43 failing * 40 failing * temporary Gather fix * 39 failing * fixed slice onnxver<10 * 37 failing * 35 failing * excluded tests that use float64 * 32 failing with hacks * added _batchnorm() for 3D 5D batchnorm, 29 failing * changed ALLOWED_KERNEL_COUNT from 199 to 207 * added improved Gather op, reverted ALLOWED_KERNEL_COUNT commit * support Round op * added storage_order/indices maxpool, 27 failing * support maxunpool, 25 failures * support Gradient, 23 failures * merged new where * added Adam * cleanups * added Momentum and Nesterov Momentum * added Adagrad * support sequence_type, 20 failing * ugh git * I give up on cubic interp :D, 9 failing * sexy 1 liner gather, much improved, wow * polished gather to make it shine bright like a diamond * clean 1 liner for gather * improved readability of gather * uhh * clean up * more clean up * WHITEspace * implemented SoftmaxCrossEntropyLoss op * added comments and cleaned up if statements * update * thank based wozeparrot for pow and new GatherElements * CPU and TORCH all pass | cast float64 -> float32 for all fromCPU() * _nearest_gather() failing on yolo * reverted ops_cpu change and added assert in Resize * added comments for resize for multiple channels * oops * merge * test * switched np.pad to Tensor.pad for constant padding * gah * gah2 * sexy reflect pad with movementops -> add * delete commented out lines * edge mode pad sexy as well * trying out model_benchmark * revert gitignore change lol * init * Revert "init" This reverts commit 682bf2073a8b4eca111596c67cf6ebd79f59e585. * wrote cast workaround for CPU, CPU and TORCH all pass * wrote cast workaround for CPU, CPU and TORCH all pass * skipped tests w/ 0 shape for METAL and GPU * excluded tests for CLANG, CPU, TORCH, CLANG pass * fixed hacky ConvTranspose * gotta figure out autopad * UOps.STORE support cast bool -> float * small fix for fast gather * reverted 0 shape skipped tests * oops missed a file * added comment * fixed slice op hack * First commit to pr * More trig ops * More trig ops * format * isinf support * More ops * changed onnx_ops to use our new gather :D * Det op bug fix * rebase * fixed some tests * det broken and slow * fixed compress to use new gather * implemented argmax argmin * support variable types in type_proto * support Upsample and Identity sequence * we support float64 now and tinygrad support automatic broadcasting * added EyeLike op * resize does support multiple channels now actually * yolov8 onnx runs successfully * added batch size 1 * oops * finally fixed type_proto I think * fixed some llvm bugs * del whitespaces * added ZenginU Format PR * test * oops * added float64 exclude tests back * more skipped tests * try * ok openpilot pass * flake8 pass * woooooohooo * revert external_model_benchmark changes * perf tested gather * removed promote types from ops_cpu * numerical errors from 1681 is fixed --------- Co-authored-by: ZenginU <umutzengin00@gmail.com>
2023-09-06 04:23:32 +08:00
# TODO figure out why it's returning wrong values, geohotstan's uneducated guess is it's due to imprecision from float64 (double) -> float32
# see Type Constraints: https://onnx.ai/onnx/operators/onnx_aionnxpreviewtraining_Adam.html#type-constraints
backend_test.exclude('test_adam_multiple_cpu')
backend_test.exclude('test_nesterov_momentum_cpu')
2023-02-25 04:19:05 +08:00
# about different dtypes
backend_test.exclude('int8') # OverflowError: cannot convert float infinity to integer
if Device.DEFAULT in ["TORCH"]:
backend_test.exclude('uint16')
backend_test.exclude('uint32')
backend_test.exclude('uint64')
if Device.DEFAULT in ["METAL"]:
backend_test.exclude('float64')
backend_test.exclude('string')
2023-02-25 04:52:23 +08:00
backend_test.exclude('test_pow_types_int*')
2023-02-25 04:00:03 +08:00
backend_test.exclude('test_cast_*')
backend_test.exclude('test_castlike_*')
2023-02-25 04:52:23 +08:00
backend_test.exclude('test_convinteger_*')
backend_test.exclude('test_matmulinteger_*')
onnx full passing (#1076) * 1 * 83 failed * learning how git works * lol idk * zero shape aaaa * space lol * aaa * test check * haha * fixed gather * 73 failing * 71 failing * 68 failing * added some debug * fking resize * lol * 62 failing * 58 failling fucking did nearest resize hell yeah * clean up * 56 failing * janitor duty * lol * 53 failing * hi mom * 50 failing * added linear interp, but coord_trans is wrong * did lin interpolation woohoo * 43 failing * 40 failing * temporary Gather fix * 39 failing * fixed slice onnxver<10 * 37 failing * 35 failing * excluded tests that use float64 * 32 failing with hacks * added _batchnorm() for 3D 5D batchnorm, 29 failing * changed ALLOWED_KERNEL_COUNT from 199 to 207 * added improved Gather op, reverted ALLOWED_KERNEL_COUNT commit * support Round op * added storage_order/indices maxpool, 27 failing * support maxunpool, 25 failures * support Gradient, 23 failures * merged new where * added Adam * cleanups * added Momentum and Nesterov Momentum * added Adagrad * support sequence_type, 20 failing * ugh git * I give up on cubic interp :D, 9 failing * sexy 1 liner gather, much improved, wow * polished gather to make it shine bright like a diamond * clean 1 liner for gather * improved readability of gather * uhh * clean up * more clean up * WHITEspace * implemented SoftmaxCrossEntropyLoss op * added comments and cleaned up if statements * update * thank based wozeparrot for pow and new GatherElements * CPU and TORCH all pass | cast float64 -> float32 for all fromCPU() * _nearest_gather() failing on yolo * reverted ops_cpu change and added assert in Resize * added comments for resize for multiple channels * oops * merge * test * switched np.pad to Tensor.pad for constant padding * gah * gah2 * sexy reflect pad with movementops -> add * delete commented out lines * edge mode pad sexy as well * trying out model_benchmark * revert gitignore change lol * init * Revert "init" This reverts commit 682bf2073a8b4eca111596c67cf6ebd79f59e585. * wrote cast workaround for CPU, CPU and TORCH all pass * wrote cast workaround for CPU, CPU and TORCH all pass * skipped tests w/ 0 shape for METAL and GPU * excluded tests for CLANG, CPU, TORCH, CLANG pass * fixed hacky ConvTranspose * gotta figure out autopad * UOps.STORE support cast bool -> float * small fix for fast gather * reverted 0 shape skipped tests * oops missed a file * added comment * fixed slice op hack * First commit to pr * More trig ops * More trig ops * format * isinf support * More ops * changed onnx_ops to use our new gather :D * Det op bug fix * rebase * fixed some tests * det broken and slow * fixed compress to use new gather * implemented argmax argmin * support variable types in type_proto * support Upsample and Identity sequence * we support float64 now and tinygrad support automatic broadcasting * added EyeLike op * resize does support multiple channels now actually * yolov8 onnx runs successfully * added batch size 1 * oops * finally fixed type_proto I think * fixed some llvm bugs * del whitespaces * added ZenginU Format PR * test * oops * added float64 exclude tests back * more skipped tests * try * ok openpilot pass * flake8 pass * woooooohooo * revert external_model_benchmark changes * perf tested gather * removed promote types from ops_cpu * numerical errors from 1681 is fixed --------- Co-authored-by: ZenginU <umutzengin00@gmail.com>
2023-09-06 04:23:32 +08:00
backend_test.exclude('test_reduce_log_sum_exp*') # dependent on actual float64 implementation for backends
backend_test.exclude('test_operator_add*') # dependent on float64 math. Without it values default to 0 or inf
2023-02-25 04:52:23 +08:00
# we don't support indexes
# backend_test.exclude('test_argmax_*') # Needs more work: select_last_index
# backend_test.exclude('test_argmin_*') # Needs more work: select_last_index
2023-02-25 04:52:23 +08:00
backend_test.exclude('test_nonzero_*')
2023-02-25 04:00:03 +08:00
# no support for mod
backend_test.exclude('test_mod_*')
# no boolean ops (2d, 3d, 4d)
backend_test.exclude('test_bitshift_*')
onnx full passing (#1076) * 1 * 83 failed * learning how git works * lol idk * zero shape aaaa * space lol * aaa * test check * haha * fixed gather * 73 failing * 71 failing * 68 failing * added some debug * fking resize * lol * 62 failing * 58 failling fucking did nearest resize hell yeah * clean up * 56 failing * janitor duty * lol * 53 failing * hi mom * 50 failing * added linear interp, but coord_trans is wrong * did lin interpolation woohoo * 43 failing * 40 failing * temporary Gather fix * 39 failing * fixed slice onnxver<10 * 37 failing * 35 failing * excluded tests that use float64 * 32 failing with hacks * added _batchnorm() for 3D 5D batchnorm, 29 failing * changed ALLOWED_KERNEL_COUNT from 199 to 207 * added improved Gather op, reverted ALLOWED_KERNEL_COUNT commit * support Round op * added storage_order/indices maxpool, 27 failing * support maxunpool, 25 failures * support Gradient, 23 failures * merged new where * added Adam * cleanups * added Momentum and Nesterov Momentum * added Adagrad * support sequence_type, 20 failing * ugh git * I give up on cubic interp :D, 9 failing * sexy 1 liner gather, much improved, wow * polished gather to make it shine bright like a diamond * clean 1 liner for gather * improved readability of gather * uhh * clean up * more clean up * WHITEspace * implemented SoftmaxCrossEntropyLoss op * added comments and cleaned up if statements * update * thank based wozeparrot for pow and new GatherElements * CPU and TORCH all pass | cast float64 -> float32 for all fromCPU() * _nearest_gather() failing on yolo * reverted ops_cpu change and added assert in Resize * added comments for resize for multiple channels * oops * merge * test * switched np.pad to Tensor.pad for constant padding * gah * gah2 * sexy reflect pad with movementops -> add * delete commented out lines * edge mode pad sexy as well * trying out model_benchmark * revert gitignore change lol * init * Revert "init" This reverts commit 682bf2073a8b4eca111596c67cf6ebd79f59e585. * wrote cast workaround for CPU, CPU and TORCH all pass * wrote cast workaround for CPU, CPU and TORCH all pass * skipped tests w/ 0 shape for METAL and GPU * excluded tests for CLANG, CPU, TORCH, CLANG pass * fixed hacky ConvTranspose * gotta figure out autopad * UOps.STORE support cast bool -> float * small fix for fast gather * reverted 0 shape skipped tests * oops missed a file * added comment * fixed slice op hack * First commit to pr * More trig ops * More trig ops * format * isinf support * More ops * changed onnx_ops to use our new gather :D * Det op bug fix * rebase * fixed some tests * det broken and slow * fixed compress to use new gather * implemented argmax argmin * support variable types in type_proto * support Upsample and Identity sequence * we support float64 now and tinygrad support automatic broadcasting * added EyeLike op * resize does support multiple channels now actually * yolov8 onnx runs successfully * added batch size 1 * oops * finally fixed type_proto I think * fixed some llvm bugs * del whitespaces * added ZenginU Format PR * test * oops * added float64 exclude tests back * more skipped tests * try * ok openpilot pass * flake8 pass * woooooohooo * revert external_model_benchmark changes * perf tested gather * removed promote types from ops_cpu * numerical errors from 1681 is fixed --------- Co-authored-by: ZenginU <umutzengin00@gmail.com>
2023-09-06 04:23:32 +08:00
# no scatternd gathernd
2023-02-25 04:00:03 +08:00
backend_test.exclude('test_gathernd_*')
backend_test.exclude('test_scatternd_*')
2023-02-25 04:19:05 +08:00
# no quantize
backend_test.exclude('test_dynamicquantizelinear_*')
backend_test.exclude('test_qlinearmatmul_*')
2023-02-25 04:52:23 +08:00
backend_test.exclude('test_qlinearconv_*')
2023-02-25 04:19:05 +08:00
backend_test.exclude('test_quantizelinear_*')
2023-02-25 04:52:23 +08:00
# no rnn
backend_test.exclude('test_gru_*')
backend_test.exclude('test_rnn_*')
backend_test.exclude('test_lstm_*')
2023-02-28 00:25:48 +08:00
backend_test.exclude('test_simple_rnn_*')
2023-02-25 04:52:23 +08:00
# no control flow
# control flow uses AttributeProto.GRAPH
2023-02-25 04:52:23 +08:00
backend_test.exclude('test_if_*')
backend_test.exclude('test_loop*')
onnx full passing (#1076) * 1 * 83 failed * learning how git works * lol idk * zero shape aaaa * space lol * aaa * test check * haha * fixed gather * 73 failing * 71 failing * 68 failing * added some debug * fking resize * lol * 62 failing * 58 failling fucking did nearest resize hell yeah * clean up * 56 failing * janitor duty * lol * 53 failing * hi mom * 50 failing * added linear interp, but coord_trans is wrong * did lin interpolation woohoo * 43 failing * 40 failing * temporary Gather fix * 39 failing * fixed slice onnxver<10 * 37 failing * 35 failing * excluded tests that use float64 * 32 failing with hacks * added _batchnorm() for 3D 5D batchnorm, 29 failing * changed ALLOWED_KERNEL_COUNT from 199 to 207 * added improved Gather op, reverted ALLOWED_KERNEL_COUNT commit * support Round op * added storage_order/indices maxpool, 27 failing * support maxunpool, 25 failures * support Gradient, 23 failures * merged new where * added Adam * cleanups * added Momentum and Nesterov Momentum * added Adagrad * support sequence_type, 20 failing * ugh git * I give up on cubic interp :D, 9 failing * sexy 1 liner gather, much improved, wow * polished gather to make it shine bright like a diamond * clean 1 liner for gather * improved readability of gather * uhh * clean up * more clean up * WHITEspace * implemented SoftmaxCrossEntropyLoss op * added comments and cleaned up if statements * update * thank based wozeparrot for pow and new GatherElements * CPU and TORCH all pass | cast float64 -> float32 for all fromCPU() * _nearest_gather() failing on yolo * reverted ops_cpu change and added assert in Resize * added comments for resize for multiple channels * oops * merge * test * switched np.pad to Tensor.pad for constant padding * gah * gah2 * sexy reflect pad with movementops -> add * delete commented out lines * edge mode pad sexy as well * trying out model_benchmark * revert gitignore change lol * init * Revert "init" This reverts commit 682bf2073a8b4eca111596c67cf6ebd79f59e585. * wrote cast workaround for CPU, CPU and TORCH all pass * wrote cast workaround for CPU, CPU and TORCH all pass * skipped tests w/ 0 shape for METAL and GPU * excluded tests for CLANG, CPU, TORCH, CLANG pass * fixed hacky ConvTranspose * gotta figure out autopad * UOps.STORE support cast bool -> float * small fix for fast gather * reverted 0 shape skipped tests * oops missed a file * added comment * fixed slice op hack * First commit to pr * More trig ops * More trig ops * format * isinf support * More ops * changed onnx_ops to use our new gather :D * Det op bug fix * rebase * fixed some tests * det broken and slow * fixed compress to use new gather * implemented argmax argmin * support variable types in type_proto * support Upsample and Identity sequence * we support float64 now and tinygrad support automatic broadcasting * added EyeLike op * resize does support multiple channels now actually * yolov8 onnx runs successfully * added batch size 1 * oops * finally fixed type_proto I think * fixed some llvm bugs * del whitespaces * added ZenginU Format PR * test * oops * added float64 exclude tests back * more skipped tests * try * ok openpilot pass * flake8 pass * woooooohooo * revert external_model_benchmark changes * perf tested gather * removed promote types from ops_cpu * numerical errors from 1681 is fixed --------- Co-authored-by: ZenginU <umutzengin00@gmail.com>
2023-09-06 04:23:32 +08:00
backend_test.exclude('test_range_float_type_positive_delta_expanded_cpu') # requires loop
backend_test.exclude('test_affine_grid_2d_align_corners_expanded_cpu')
backend_test.exclude('test_affine_grid_2d_expanded_cpu')
backend_test.exclude('test_affine_grid_3d_align_corners_expanded_cpu')
backend_test.exclude('test_affine_grid_3d_expanded_cpu')
backend_test.exclude('test_range_int32_type_negative_delta_expanded_cpu')
2023-02-25 04:52:23 +08:00
2023-02-25 04:00:03 +08:00
# unsupported (strange) ops
backend_test.exclude('test_bitwise_*')
backend_test.exclude('test_blackmanwindow_*')
backend_test.exclude('test_bernoulli_*')
backend_test.exclude('test_cumsum_*')
onnx full passing (#1076) * 1 * 83 failed * learning how git works * lol idk * zero shape aaaa * space lol * aaa * test check * haha * fixed gather * 73 failing * 71 failing * 68 failing * added some debug * fking resize * lol * 62 failing * 58 failling fucking did nearest resize hell yeah * clean up * 56 failing * janitor duty * lol * 53 failing * hi mom * 50 failing * added linear interp, but coord_trans is wrong * did lin interpolation woohoo * 43 failing * 40 failing * temporary Gather fix * 39 failing * fixed slice onnxver<10 * 37 failing * 35 failing * excluded tests that use float64 * 32 failing with hacks * added _batchnorm() for 3D 5D batchnorm, 29 failing * changed ALLOWED_KERNEL_COUNT from 199 to 207 * added improved Gather op, reverted ALLOWED_KERNEL_COUNT commit * support Round op * added storage_order/indices maxpool, 27 failing * support maxunpool, 25 failures * support Gradient, 23 failures * merged new where * added Adam * cleanups * added Momentum and Nesterov Momentum * added Adagrad * support sequence_type, 20 failing * ugh git * I give up on cubic interp :D, 9 failing * sexy 1 liner gather, much improved, wow * polished gather to make it shine bright like a diamond * clean 1 liner for gather * improved readability of gather * uhh * clean up * more clean up * WHITEspace * implemented SoftmaxCrossEntropyLoss op * added comments and cleaned up if statements * update * thank based wozeparrot for pow and new GatherElements * CPU and TORCH all pass | cast float64 -> float32 for all fromCPU() * _nearest_gather() failing on yolo * reverted ops_cpu change and added assert in Resize * added comments for resize for multiple channels * oops * merge * test * switched np.pad to Tensor.pad for constant padding * gah * gah2 * sexy reflect pad with movementops -> add * delete commented out lines * edge mode pad sexy as well * trying out model_benchmark * revert gitignore change lol * init * Revert "init" This reverts commit 682bf2073a8b4eca111596c67cf6ebd79f59e585. * wrote cast workaround for CPU, CPU and TORCH all pass * wrote cast workaround for CPU, CPU and TORCH all pass * skipped tests w/ 0 shape for METAL and GPU * excluded tests for CLANG, CPU, TORCH, CLANG pass * fixed hacky ConvTranspose * gotta figure out autopad * UOps.STORE support cast bool -> float * small fix for fast gather * reverted 0 shape skipped tests * oops missed a file * added comment * fixed slice op hack * First commit to pr * More trig ops * More trig ops * format * isinf support * More ops * changed onnx_ops to use our new gather :D * Det op bug fix * rebase * fixed some tests * det broken and slow * fixed compress to use new gather * implemented argmax argmin * support variable types in type_proto * support Upsample and Identity sequence * we support float64 now and tinygrad support automatic broadcasting * added EyeLike op * resize does support multiple channels now actually * yolov8 onnx runs successfully * added batch size 1 * oops * finally fixed type_proto I think * fixed some llvm bugs * del whitespaces * added ZenginU Format PR * test * oops * added float64 exclude tests back * more skipped tests * try * ok openpilot pass * flake8 pass * woooooohooo * revert external_model_benchmark changes * perf tested gather * removed promote types from ops_cpu * numerical errors from 1681 is fixed --------- Co-authored-by: ZenginU <umutzengin00@gmail.com>
2023-09-06 04:23:32 +08:00
backend_test.exclude('test_det_*')
backend_test.exclude('test_tril_zero_cpu') # TODO: zero array tril support
backend_test.exclude('test_triu_zero_cpu') # TODO: zero array triu support
2023-02-25 04:00:03 +08:00
backend_test.exclude('test_col2im_*')
backend_test.exclude('test_hammingwindow_*')
backend_test.exclude('test_hannwindow_*')
backend_test.exclude('test_hardmax_*')
backend_test.exclude('test_gridsample_*')
backend_test.exclude('test_dft_*')
backend_test.exclude('test_einsum_*')
backend_test.exclude('test_strnorm_*')
backend_test.exclude('test_unique_*')
backend_test.exclude('test_sequence_*')
2023-02-25 04:19:05 +08:00
backend_test.exclude('test_nonmaxsuppression_*')
backend_test.exclude('test_reversesequence_*')
backend_test.exclude('test_roialign_*')
backend_test.exclude('test_top_k_*')
2023-02-25 04:52:23 +08:00
backend_test.exclude('test_tfidfvectorizer_*')
backend_test.exclude('test_stft_*')
backend_test.exclude('test_melweightmatrix_*')
2023-02-25 04:19:05 +08:00
2023-06-19 13:29:46 +08:00
# more strange ops
backend_test.exclude('test_basic_deform_conv_*')
backend_test.exclude('test_deform_conv_*')
backend_test.exclude('test_lppool_*')
backend_test.exclude('test_depthtospace_*')
2023-06-19 13:34:30 +08:00
backend_test.exclude('test_spacetodepth_*')
backend_test.exclude('test_scan*')
backend_test.exclude('test_split_to_sequence_*')
onnx full passing (#1076) * 1 * 83 failed * learning how git works * lol idk * zero shape aaaa * space lol * aaa * test check * haha * fixed gather * 73 failing * 71 failing * 68 failing * added some debug * fking resize * lol * 62 failing * 58 failling fucking did nearest resize hell yeah * clean up * 56 failing * janitor duty * lol * 53 failing * hi mom * 50 failing * added linear interp, but coord_trans is wrong * did lin interpolation woohoo * 43 failing * 40 failing * temporary Gather fix * 39 failing * fixed slice onnxver<10 * 37 failing * 35 failing * excluded tests that use float64 * 32 failing with hacks * added _batchnorm() for 3D 5D batchnorm, 29 failing * changed ALLOWED_KERNEL_COUNT from 199 to 207 * added improved Gather op, reverted ALLOWED_KERNEL_COUNT commit * support Round op * added storage_order/indices maxpool, 27 failing * support maxunpool, 25 failures * support Gradient, 23 failures * merged new where * added Adam * cleanups * added Momentum and Nesterov Momentum * added Adagrad * support sequence_type, 20 failing * ugh git * I give up on cubic interp :D, 9 failing * sexy 1 liner gather, much improved, wow * polished gather to make it shine bright like a diamond * clean 1 liner for gather * improved readability of gather * uhh * clean up * more clean up * WHITEspace * implemented SoftmaxCrossEntropyLoss op * added comments and cleaned up if statements * update * thank based wozeparrot for pow and new GatherElements * CPU and TORCH all pass | cast float64 -> float32 for all fromCPU() * _nearest_gather() failing on yolo * reverted ops_cpu change and added assert in Resize * added comments for resize for multiple channels * oops * merge * test * switched np.pad to Tensor.pad for constant padding * gah * gah2 * sexy reflect pad with movementops -> add * delete commented out lines * edge mode pad sexy as well * trying out model_benchmark * revert gitignore change lol * init * Revert "init" This reverts commit 682bf2073a8b4eca111596c67cf6ebd79f59e585. * wrote cast workaround for CPU, CPU and TORCH all pass * wrote cast workaround for CPU, CPU and TORCH all pass * skipped tests w/ 0 shape for METAL and GPU * excluded tests for CLANG, CPU, TORCH, CLANG pass * fixed hacky ConvTranspose * gotta figure out autopad * UOps.STORE support cast bool -> float * small fix for fast gather * reverted 0 shape skipped tests * oops missed a file * added comment * fixed slice op hack * First commit to pr * More trig ops * More trig ops * format * isinf support * More ops * changed onnx_ops to use our new gather :D * Det op bug fix * rebase * fixed some tests * det broken and slow * fixed compress to use new gather * implemented argmax argmin * support variable types in type_proto * support Upsample and Identity sequence * we support float64 now and tinygrad support automatic broadcasting * added EyeLike op * resize does support multiple channels now actually * yolov8 onnx runs successfully * added batch size 1 * oops * finally fixed type_proto I think * fixed some llvm bugs * del whitespaces * added ZenginU Format PR * test * oops * added float64 exclude tests back * more skipped tests * try * ok openpilot pass * flake8 pass * woooooohooo * revert external_model_benchmark changes * perf tested gather * removed promote types from ops_cpu * numerical errors from 1681 is fixed --------- Co-authored-by: ZenginU <umutzengin00@gmail.com>
2023-09-06 04:23:32 +08:00
backend_test.exclude('test_resize_downsample_scales_cubic_*') # unsure how to implement cubic
backend_test.exclude('test_resize_downsample_sizes_cubic_*') # unsure how to implement cubic
backend_test.exclude('test_resize_upsample_scales_cubic_*') # unsure how to implement cubic
backend_test.exclude('test_resize_upsample_sizes_cubic_*') # unsure how to implement cubic
# rest of the failing tests
backend_test.exclude('test_regex_*') # does not support string Tensors
backend_test.exclude('test_reshape_allowzero_reordered_cpu') # reshaping to shape with 0, also allowzero
onnx full passing (#1076) * 1 * 83 failed * learning how git works * lol idk * zero shape aaaa * space lol * aaa * test check * haha * fixed gather * 73 failing * 71 failing * 68 failing * added some debug * fking resize * lol * 62 failing * 58 failling fucking did nearest resize hell yeah * clean up * 56 failing * janitor duty * lol * 53 failing * hi mom * 50 failing * added linear interp, but coord_trans is wrong * did lin interpolation woohoo * 43 failing * 40 failing * temporary Gather fix * 39 failing * fixed slice onnxver<10 * 37 failing * 35 failing * excluded tests that use float64 * 32 failing with hacks * added _batchnorm() for 3D 5D batchnorm, 29 failing * changed ALLOWED_KERNEL_COUNT from 199 to 207 * added improved Gather op, reverted ALLOWED_KERNEL_COUNT commit * support Round op * added storage_order/indices maxpool, 27 failing * support maxunpool, 25 failures * support Gradient, 23 failures * merged new where * added Adam * cleanups * added Momentum and Nesterov Momentum * added Adagrad * support sequence_type, 20 failing * ugh git * I give up on cubic interp :D, 9 failing * sexy 1 liner gather, much improved, wow * polished gather to make it shine bright like a diamond * clean 1 liner for gather * improved readability of gather * uhh * clean up * more clean up * WHITEspace * implemented SoftmaxCrossEntropyLoss op * added comments and cleaned up if statements * update * thank based wozeparrot for pow and new GatherElements * CPU and TORCH all pass | cast float64 -> float32 for all fromCPU() * _nearest_gather() failing on yolo * reverted ops_cpu change and added assert in Resize * added comments for resize for multiple channels * oops * merge * test * switched np.pad to Tensor.pad for constant padding * gah * gah2 * sexy reflect pad with movementops -> add * delete commented out lines * edge mode pad sexy as well * trying out model_benchmark * revert gitignore change lol * init * Revert "init" This reverts commit 682bf2073a8b4eca111596c67cf6ebd79f59e585. * wrote cast workaround for CPU, CPU and TORCH all pass * wrote cast workaround for CPU, CPU and TORCH all pass * skipped tests w/ 0 shape for METAL and GPU * excluded tests for CLANG, CPU, TORCH, CLANG pass * fixed hacky ConvTranspose * gotta figure out autopad * UOps.STORE support cast bool -> float * small fix for fast gather * reverted 0 shape skipped tests * oops missed a file * added comment * fixed slice op hack * First commit to pr * More trig ops * More trig ops * format * isinf support * More ops * changed onnx_ops to use our new gather :D * Det op bug fix * rebase * fixed some tests * det broken and slow * fixed compress to use new gather * implemented argmax argmin * support variable types in type_proto * support Upsample and Identity sequence * we support float64 now and tinygrad support automatic broadcasting * added EyeLike op * resize does support multiple channels now actually * yolov8 onnx runs successfully * added batch size 1 * oops * finally fixed type_proto I think * fixed some llvm bugs * del whitespaces * added ZenginU Format PR * test * oops * added float64 exclude tests back * more skipped tests * try * ok openpilot pass * flake8 pass * woooooohooo * revert external_model_benchmark changes * perf tested gather * removed promote types from ops_cpu * numerical errors from 1681 is fixed --------- Co-authored-by: ZenginU <umutzengin00@gmail.com>
2023-09-06 04:23:32 +08:00
backend_test.exclude('test_resize_downsample_scales_linear_antialias_cpu') # antialias not implemented
backend_test.exclude('test_resize_downsample_sizes_linear_antialias_cpu') # antialias not implemented
backend_test.exclude('test_resize_tf_crop_and_resize_cpu') # unsure about fill value after clip
backend_test.exclude('test_ai_onnx_ml_label_encoder_tensor_value_only_mapping_cpu') # bad data type string
backend_test.exclude('test_ai_onnx_ml_label_encoder_tensor_mapping_cpu') # bad data type string
onnx full passing (#1076) * 1 * 83 failed * learning how git works * lol idk * zero shape aaaa * space lol * aaa * test check * haha * fixed gather * 73 failing * 71 failing * 68 failing * added some debug * fking resize * lol * 62 failing * 58 failling fucking did nearest resize hell yeah * clean up * 56 failing * janitor duty * lol * 53 failing * hi mom * 50 failing * added linear interp, but coord_trans is wrong * did lin interpolation woohoo * 43 failing * 40 failing * temporary Gather fix * 39 failing * fixed slice onnxver<10 * 37 failing * 35 failing * excluded tests that use float64 * 32 failing with hacks * added _batchnorm() for 3D 5D batchnorm, 29 failing * changed ALLOWED_KERNEL_COUNT from 199 to 207 * added improved Gather op, reverted ALLOWED_KERNEL_COUNT commit * support Round op * added storage_order/indices maxpool, 27 failing * support maxunpool, 25 failures * support Gradient, 23 failures * merged new where * added Adam * cleanups * added Momentum and Nesterov Momentum * added Adagrad * support sequence_type, 20 failing * ugh git * I give up on cubic interp :D, 9 failing * sexy 1 liner gather, much improved, wow * polished gather to make it shine bright like a diamond * clean 1 liner for gather * improved readability of gather * uhh * clean up * more clean up * WHITEspace * implemented SoftmaxCrossEntropyLoss op * added comments and cleaned up if statements * update * thank based wozeparrot for pow and new GatherElements * CPU and TORCH all pass | cast float64 -> float32 for all fromCPU() * _nearest_gather() failing on yolo * reverted ops_cpu change and added assert in Resize * added comments for resize for multiple channels * oops * merge * test * switched np.pad to Tensor.pad for constant padding * gah * gah2 * sexy reflect pad with movementops -> add * delete commented out lines * edge mode pad sexy as well * trying out model_benchmark * revert gitignore change lol * init * Revert "init" This reverts commit 682bf2073a8b4eca111596c67cf6ebd79f59e585. * wrote cast workaround for CPU, CPU and TORCH all pass * wrote cast workaround for CPU, CPU and TORCH all pass * skipped tests w/ 0 shape for METAL and GPU * excluded tests for CLANG, CPU, TORCH, CLANG pass * fixed hacky ConvTranspose * gotta figure out autopad * UOps.STORE support cast bool -> float * small fix for fast gather * reverted 0 shape skipped tests * oops missed a file * added comment * fixed slice op hack * First commit to pr * More trig ops * More trig ops * format * isinf support * More ops * changed onnx_ops to use our new gather :D * Det op bug fix * rebase * fixed some tests * det broken and slow * fixed compress to use new gather * implemented argmax argmin * support variable types in type_proto * support Upsample and Identity sequence * we support float64 now and tinygrad support automatic broadcasting * added EyeLike op * resize does support multiple channels now actually * yolov8 onnx runs successfully * added batch size 1 * oops * finally fixed type_proto I think * fixed some llvm bugs * del whitespaces * added ZenginU Format PR * test * oops * added float64 exclude tests back * more skipped tests * try * ok openpilot pass * flake8 pass * woooooohooo * revert external_model_benchmark changes * perf tested gather * removed promote types from ops_cpu * numerical errors from 1681 is fixed --------- Co-authored-by: ZenginU <umutzengin00@gmail.com>
2023-09-06 04:23:32 +08:00
# issue 1556 https://github.com/tinygrad/tinygrad/issues/1556
onnx full passing (#1076) * 1 * 83 failed * learning how git works * lol idk * zero shape aaaa * space lol * aaa * test check * haha * fixed gather * 73 failing * 71 failing * 68 failing * added some debug * fking resize * lol * 62 failing * 58 failling fucking did nearest resize hell yeah * clean up * 56 failing * janitor duty * lol * 53 failing * hi mom * 50 failing * added linear interp, but coord_trans is wrong * did lin interpolation woohoo * 43 failing * 40 failing * temporary Gather fix * 39 failing * fixed slice onnxver<10 * 37 failing * 35 failing * excluded tests that use float64 * 32 failing with hacks * added _batchnorm() for 3D 5D batchnorm, 29 failing * changed ALLOWED_KERNEL_COUNT from 199 to 207 * added improved Gather op, reverted ALLOWED_KERNEL_COUNT commit * support Round op * added storage_order/indices maxpool, 27 failing * support maxunpool, 25 failures * support Gradient, 23 failures * merged new where * added Adam * cleanups * added Momentum and Nesterov Momentum * added Adagrad * support sequence_type, 20 failing * ugh git * I give up on cubic interp :D, 9 failing * sexy 1 liner gather, much improved, wow * polished gather to make it shine bright like a diamond * clean 1 liner for gather * improved readability of gather * uhh * clean up * more clean up * WHITEspace * implemented SoftmaxCrossEntropyLoss op * added comments and cleaned up if statements * update * thank based wozeparrot for pow and new GatherElements * CPU and TORCH all pass | cast float64 -> float32 for all fromCPU() * _nearest_gather() failing on yolo * reverted ops_cpu change and added assert in Resize * added comments for resize for multiple channels * oops * merge * test * switched np.pad to Tensor.pad for constant padding * gah * gah2 * sexy reflect pad with movementops -> add * delete commented out lines * edge mode pad sexy as well * trying out model_benchmark * revert gitignore change lol * init * Revert "init" This reverts commit 682bf2073a8b4eca111596c67cf6ebd79f59e585. * wrote cast workaround for CPU, CPU and TORCH all pass * wrote cast workaround for CPU, CPU and TORCH all pass * skipped tests w/ 0 shape for METAL and GPU * excluded tests for CLANG, CPU, TORCH, CLANG pass * fixed hacky ConvTranspose * gotta figure out autopad * UOps.STORE support cast bool -> float * small fix for fast gather * reverted 0 shape skipped tests * oops missed a file * added comment * fixed slice op hack * First commit to pr * More trig ops * More trig ops * format * isinf support * More ops * changed onnx_ops to use our new gather :D * Det op bug fix * rebase * fixed some tests * det broken and slow * fixed compress to use new gather * implemented argmax argmin * support variable types in type_proto * support Upsample and Identity sequence * we support float64 now and tinygrad support automatic broadcasting * added EyeLike op * resize does support multiple channels now actually * yolov8 onnx runs successfully * added batch size 1 * oops * finally fixed type_proto I think * fixed some llvm bugs * del whitespaces * added ZenginU Format PR * test * oops * added float64 exclude tests back * more skipped tests * try * ok openpilot pass * flake8 pass * woooooohooo * revert external_model_benchmark changes * perf tested gather * removed promote types from ops_cpu * numerical errors from 1681 is fixed --------- Co-authored-by: ZenginU <umutzengin00@gmail.com>
2023-09-06 04:23:32 +08:00
backend_test.exclude('test_isinf_cpu')
backend_test.exclude('test_isinf_negative_cpu')
backend_test.exclude('test_isinf_positive_cpu')
backend_test.exclude('test_isinf_float16_cpu')
backend_test.exclude('test_isnan_float16_cpu')
onnx full passing (#1076) * 1 * 83 failed * learning how git works * lol idk * zero shape aaaa * space lol * aaa * test check * haha * fixed gather * 73 failing * 71 failing * 68 failing * added some debug * fking resize * lol * 62 failing * 58 failling fucking did nearest resize hell yeah * clean up * 56 failing * janitor duty * lol * 53 failing * hi mom * 50 failing * added linear interp, but coord_trans is wrong * did lin interpolation woohoo * 43 failing * 40 failing * temporary Gather fix * 39 failing * fixed slice onnxver<10 * 37 failing * 35 failing * excluded tests that use float64 * 32 failing with hacks * added _batchnorm() for 3D 5D batchnorm, 29 failing * changed ALLOWED_KERNEL_COUNT from 199 to 207 * added improved Gather op, reverted ALLOWED_KERNEL_COUNT commit * support Round op * added storage_order/indices maxpool, 27 failing * support maxunpool, 25 failures * support Gradient, 23 failures * merged new where * added Adam * cleanups * added Momentum and Nesterov Momentum * added Adagrad * support sequence_type, 20 failing * ugh git * I give up on cubic interp :D, 9 failing * sexy 1 liner gather, much improved, wow * polished gather to make it shine bright like a diamond * clean 1 liner for gather * improved readability of gather * uhh * clean up * more clean up * WHITEspace * implemented SoftmaxCrossEntropyLoss op * added comments and cleaned up if statements * update * thank based wozeparrot for pow and new GatherElements * CPU and TORCH all pass | cast float64 -> float32 for all fromCPU() * _nearest_gather() failing on yolo * reverted ops_cpu change and added assert in Resize * added comments for resize for multiple channels * oops * merge * test * switched np.pad to Tensor.pad for constant padding * gah * gah2 * sexy reflect pad with movementops -> add * delete commented out lines * edge mode pad sexy as well * trying out model_benchmark * revert gitignore change lol * init * Revert "init" This reverts commit 682bf2073a8b4eca111596c67cf6ebd79f59e585. * wrote cast workaround for CPU, CPU and TORCH all pass * wrote cast workaround for CPU, CPU and TORCH all pass * skipped tests w/ 0 shape for METAL and GPU * excluded tests for CLANG, CPU, TORCH, CLANG pass * fixed hacky ConvTranspose * gotta figure out autopad * UOps.STORE support cast bool -> float * small fix for fast gather * reverted 0 shape skipped tests * oops missed a file * added comment * fixed slice op hack * First commit to pr * More trig ops * More trig ops * format * isinf support * More ops * changed onnx_ops to use our new gather :D * Det op bug fix * rebase * fixed some tests * det broken and slow * fixed compress to use new gather * implemented argmax argmin * support variable types in type_proto * support Upsample and Identity sequence * we support float64 now and tinygrad support automatic broadcasting * added EyeLike op * resize does support multiple channels now actually * yolov8 onnx runs successfully * added batch size 1 * oops * finally fixed type_proto I think * fixed some llvm bugs * del whitespaces * added ZenginU Format PR * test * oops * added float64 exclude tests back * more skipped tests * try * ok openpilot pass * flake8 pass * woooooohooo * revert external_model_benchmark changes * perf tested gather * removed promote types from ops_cpu * numerical errors from 1681 is fixed --------- Co-authored-by: ZenginU <umutzengin00@gmail.com>
2023-09-06 04:23:32 +08:00
backend_test.exclude('test_isnan_cpu')
# issue 1791 fast math messes with these https://github.com/tinygrad/tinygrad/issues/1791
backend_test.exclude('test_resize_upsample_sizes_nearest_axes_2_3_cpu')
backend_test.exclude('test_resize_upsample_sizes_nearest_axes_3_2_cpu')
backend_test.exclude('test_resize_upsample_sizes_nearest_cpu')
# issue 2067 potentially also a fastmath issue https://github.com/tinygrad/tinygrad/issues/2067
if Device.DEFAULT in ['METAL']:
backend_test.exclude('test_maxpool_2d_pads_cpu')
backend_test.exclude('test_maxpool_2d_same_lower_cpu')
if Device.DEFAULT in ['GPU', 'METAL']:
onnx full passing (#1076) * 1 * 83 failed * learning how git works * lol idk * zero shape aaaa * space lol * aaa * test check * haha * fixed gather * 73 failing * 71 failing * 68 failing * added some debug * fking resize * lol * 62 failing * 58 failling fucking did nearest resize hell yeah * clean up * 56 failing * janitor duty * lol * 53 failing * hi mom * 50 failing * added linear interp, but coord_trans is wrong * did lin interpolation woohoo * 43 failing * 40 failing * temporary Gather fix * 39 failing * fixed slice onnxver<10 * 37 failing * 35 failing * excluded tests that use float64 * 32 failing with hacks * added _batchnorm() for 3D 5D batchnorm, 29 failing * changed ALLOWED_KERNEL_COUNT from 199 to 207 * added improved Gather op, reverted ALLOWED_KERNEL_COUNT commit * support Round op * added storage_order/indices maxpool, 27 failing * support maxunpool, 25 failures * support Gradient, 23 failures * merged new where * added Adam * cleanups * added Momentum and Nesterov Momentum * added Adagrad * support sequence_type, 20 failing * ugh git * I give up on cubic interp :D, 9 failing * sexy 1 liner gather, much improved, wow * polished gather to make it shine bright like a diamond * clean 1 liner for gather * improved readability of gather * uhh * clean up * more clean up * WHITEspace * implemented SoftmaxCrossEntropyLoss op * added comments and cleaned up if statements * update * thank based wozeparrot for pow and new GatherElements * CPU and TORCH all pass | cast float64 -> float32 for all fromCPU() * _nearest_gather() failing on yolo * reverted ops_cpu change and added assert in Resize * added comments for resize for multiple channels * oops * merge * test * switched np.pad to Tensor.pad for constant padding * gah * gah2 * sexy reflect pad with movementops -> add * delete commented out lines * edge mode pad sexy as well * trying out model_benchmark * revert gitignore change lol * init * Revert "init" This reverts commit 682bf2073a8b4eca111596c67cf6ebd79f59e585. * wrote cast workaround for CPU, CPU and TORCH all pass * wrote cast workaround for CPU, CPU and TORCH all pass * skipped tests w/ 0 shape for METAL and GPU * excluded tests for CLANG, CPU, TORCH, CLANG pass * fixed hacky ConvTranspose * gotta figure out autopad * UOps.STORE support cast bool -> float * small fix for fast gather * reverted 0 shape skipped tests * oops missed a file * added comment * fixed slice op hack * First commit to pr * More trig ops * More trig ops * format * isinf support * More ops * changed onnx_ops to use our new gather :D * Det op bug fix * rebase * fixed some tests * det broken and slow * fixed compress to use new gather * implemented argmax argmin * support variable types in type_proto * support Upsample and Identity sequence * we support float64 now and tinygrad support automatic broadcasting * added EyeLike op * resize does support multiple channels now actually * yolov8 onnx runs successfully * added batch size 1 * oops * finally fixed type_proto I think * fixed some llvm bugs * del whitespaces * added ZenginU Format PR * test * oops * added float64 exclude tests back * more skipped tests * try * ok openpilot pass * flake8 pass * woooooohooo * revert external_model_benchmark changes * perf tested gather * removed promote types from ops_cpu * numerical errors from 1681 is fixed --------- Co-authored-by: ZenginU <umutzengin00@gmail.com>
2023-09-06 04:23:32 +08:00
backend_test.exclude('test_mish_cpu') # weird inaccuracy
backend_test.exclude('test_mish_expanded_cpu') # weird inaccuracy
backend_test.exclude('test_eyelike_with_dtype_cpu') # backend does not support dtype: Double
onnx full passing (#1076) * 1 * 83 failed * learning how git works * lol idk * zero shape aaaa * space lol * aaa * test check * haha * fixed gather * 73 failing * 71 failing * 68 failing * added some debug * fking resize * lol * 62 failing * 58 failling fucking did nearest resize hell yeah * clean up * 56 failing * janitor duty * lol * 53 failing * hi mom * 50 failing * added linear interp, but coord_trans is wrong * did lin interpolation woohoo * 43 failing * 40 failing * temporary Gather fix * 39 failing * fixed slice onnxver<10 * 37 failing * 35 failing * excluded tests that use float64 * 32 failing with hacks * added _batchnorm() for 3D 5D batchnorm, 29 failing * changed ALLOWED_KERNEL_COUNT from 199 to 207 * added improved Gather op, reverted ALLOWED_KERNEL_COUNT commit * support Round op * added storage_order/indices maxpool, 27 failing * support maxunpool, 25 failures * support Gradient, 23 failures * merged new where * added Adam * cleanups * added Momentum and Nesterov Momentum * added Adagrad * support sequence_type, 20 failing * ugh git * I give up on cubic interp :D, 9 failing * sexy 1 liner gather, much improved, wow * polished gather to make it shine bright like a diamond * clean 1 liner for gather * improved readability of gather * uhh * clean up * more clean up * WHITEspace * implemented SoftmaxCrossEntropyLoss op * added comments and cleaned up if statements * update * thank based wozeparrot for pow and new GatherElements * CPU and TORCH all pass | cast float64 -> float32 for all fromCPU() * _nearest_gather() failing on yolo * reverted ops_cpu change and added assert in Resize * added comments for resize for multiple channels * oops * merge * test * switched np.pad to Tensor.pad for constant padding * gah * gah2 * sexy reflect pad with movementops -> add * delete commented out lines * edge mode pad sexy as well * trying out model_benchmark * revert gitignore change lol * init * Revert "init" This reverts commit 682bf2073a8b4eca111596c67cf6ebd79f59e585. * wrote cast workaround for CPU, CPU and TORCH all pass * wrote cast workaround for CPU, CPU and TORCH all pass * skipped tests w/ 0 shape for METAL and GPU * excluded tests for CLANG, CPU, TORCH, CLANG pass * fixed hacky ConvTranspose * gotta figure out autopad * UOps.STORE support cast bool -> float * small fix for fast gather * reverted 0 shape skipped tests * oops missed a file * added comment * fixed slice op hack * First commit to pr * More trig ops * More trig ops * format * isinf support * More ops * changed onnx_ops to use our new gather :D * Det op bug fix * rebase * fixed some tests * det broken and slow * fixed compress to use new gather * implemented argmax argmin * support variable types in type_proto * support Upsample and Identity sequence * we support float64 now and tinygrad support automatic broadcasting * added EyeLike op * resize does support multiple channels now actually * yolov8 onnx runs successfully * added batch size 1 * oops * finally fixed type_proto I think * fixed some llvm bugs * del whitespaces * added ZenginU Format PR * test * oops * added float64 exclude tests back * more skipped tests * try * ok openpilot pass * flake8 pass * woooooohooo * revert external_model_benchmark changes * perf tested gather * removed promote types from ops_cpu * numerical errors from 1681 is fixed --------- Co-authored-by: ZenginU <umutzengin00@gmail.com>
2023-09-06 04:23:32 +08:00
Non fp32 math (#2264) * `global_load` and `global_store` using buffer dtype * `UOps.PHI` in all dtypes * `UOps.ALU` in all dtypes * `UOps.CONST` & `UOps.DEFINE_ACC` in all dtypes * -- endof implementation -- +tiny lint changes * these tests require the fp16 extention you can run them locally to confirm they're green: (GPT2 test is broken in master for mac, see [this](https://discord.com/channels/1068976834382925865/1069001075828469790/1177993277958533261) `GPU=1 python3 -m pytest test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_dequantizelinear_e4m3fn_float16_cpu test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_max_float16_cpu test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_min_float16_cpu test/models/test_real_world.py::TestRealWorld::test_llama test/models/test_real_world.py::TestRealWorld::test_gpt2 test/models/test_whisper.py test/test_specific_conv.py::TestSpecific::test_big_vec_mul` skip the new test_linearizer_failures in CI GPU because of the fp16 extention This passes on a real GPU since the extention is available: `GPU=1 python3 -m pytest test/test_linearizer_failures.py::TestLinearizerFailures::test_failure_8` see CI logs [here](https://github.com/tinygrad/tinygrad/actions/runs/6996590597/job/19032641427#step:14:644) * these tests fail in CI due to segfaults and CPU crashes To confirm they're green locally, you can run the following commands: 1. For the tests skipped in test_ops.py (note: CLANG is very slow) `for var in GPU CUDA CLANG; do export $var=1; for test in test/test_ops.py::TestOps::test_slice_fancy_indexing_no_dim_collapse test/test_ops.py::TestOps::test_slice_fancy_indexing_dim_collapse_int test/test_ops.py::TestOps::test_slice_fancy_indexing_dim_inject_none test/test_ops.py::TestOps::test_slice_fancy_indexing_dim_inject_and_collapse; do python3 -m pytest $test; done; unset $var; done` 2. For the ONNX tests skipped in CLANG: ``` CLANG=1 python3 -m pytest test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_ai_onnx_ml_array_feature_extractor_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_gather_elements_0_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_mean_weight_ii_3d_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_gather_elements_1_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_NCd1_mean_weight_negative_ii_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1_weight_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2d3_none_no_weight_negative_ii_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1_ii_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_mean_weight_ii_4d_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_mean_weight_ii_3d_log_prob_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_gather_elements_negative_indices_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_NCd1d2d3d4d5_mean_weight_log_prob_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_NCd1_mean_weight_negative_ii_log_prob_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_no_weight_reduction_mean_ii_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_NCd1d2d3d4d5_mean_weight_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2d3d4d5_mean_weight_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1_mean_weight_negative_ii_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_mean_weight_ii_4d_log_prob_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_with_weight_reduction_mean_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1_weight_ii_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_with_weight_reduction_sum_ii_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_with_weight_reduction_sum_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_reduction_sum_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2d3d4d5_none_no_weight_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2d3_sum_weight_high_ii_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_reduction_mean_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_with_weight_expanded_cpu ``` 3. The LLVM test I skipped here is already [skipped in master for all backends](https://github.com/tinygrad/tinygrad/blob/master/test/external/external_test_onnx_backend.py#L186), I just made it more specific `LLVM=1 python3 -m pytest test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_dequantizelinear_e4m3fn_float16_cpu` * Revert "these tests fail in CI due to segfaults and CPU crashes" This reverts commit 15db57014381a4449d563526ac6c870e36257658. * merge with cleanup-vectorized-hip-renders * barely working HIP P1, ALU ops need a refactor? * manage the fact that in HIP [half2 is actually an unsigned int vec](https://github.com/argonne-lcf/THAPI/blob/f921880387730e75f04b97a8b07d191a8d28bc28/hip/include/hip/amd_detail/amd_hip_fp16.h#L59) and half is a totally different __half that [has an unsigned int element in it](https://github.com/argonne-lcf/THAPI/blob/f921880387730e75f04b97a8b07d191a8d28bc28/hip/include/hip/amd_detail/amd_hip_fp16.h#L50) but can't be accessed [because it's private](https://github.com/argonne-lcf/THAPI/blob/f921880387730e75f04b97a8b07d191a8d28bc28/hip/include/hip/amd_detail/amd_hip_fp16.h#L86). If you just do this: ``` half2 val0 = // ... half val1 = // ... ``` then you can't do: ``` val0.x + val1 // error: use of overloaded operator '+' is ambiguous (with operand types 'unsigned short' and 'half' (aka '__half')) ``` * update the sign definition to avoid division by zero in all dtypes * diff cleanup p1: why were these in the diff anyways * less hacky HIP, enable CIFAR fp16 benchmark, test ops for HIP in CI! add ALU ops overloads for HIP this will make HIP max work handle mod Revert "handle mod" This reverts commit 370fd4b3fbe99b6ae8cc293d005b106628205933. update max to use hmax add HIP GEP render logic enable CIFAR fp16 benchmark test ops for HIP back to store as float because this only works for float4 grouping right now test_ops for hip!! always sign * back to the sign we had before because we cant do a backward pass on a Less node * remove old hacks HIP compiling test_ops in CI takes ~9 mins, not doing it for now new HIP ALUs * reduce accs done right * refactor to function * no device hacks hacks p2 the other way * LLVM ALU ops half, float and double are all float update max * update test_uops, cmplt is always a bool in the real linearizer. assertAlmostEqual is wrong when ret is bool * cleanup LLVM wrong code * dummy change for the CUDA install glitch --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2023-12-04 05:45:49 +08:00
# Segfaults in CI, GPU requires cl_khr_fp16
if Device.DEFAULT in ['LLVM', 'CUDA', 'GPU'] and CI:
backend_test.exclude('test_max_float16_cpu')
backend_test.exclude('test_min_float16_cpu')
2023-11-11 07:40:30 +08:00
# error: casting to type 'half' is not allowed
backend_test.exclude('test_dequantizelinear_e4m3fn_float16_cpu')
2023-06-19 13:29:46 +08:00
# TODO: this somehow passes in CI but does not pass if run locally
if isinstance(Device[Device.DEFAULT], Compiled):
backend_test.exclude('test_MaxPool3d_stride_padding_cpu')
2023-02-25 04:19:05 +08:00
# disable model tests for now since they are slow
2023-06-19 13:34:30 +08:00
if not getenv("MODELTESTS"):
2023-02-28 03:22:45 +08:00
for x in backend_test.test_suite:
if 'OnnxBackendRealModelTest' in str(type(x)):
backend_test.exclude(str(x).split(" ")[0])
else:
# model tests all pass!
backend_test.include('test_resnet50')
backend_test.include('test_inception_v1')
backend_test.include('test_inception_v2')
backend_test.include('test_densenet121')
backend_test.include('test_shufflenet')
backend_test.include('test_squeezenet')
backend_test.include('test_bvlc_alexnet')
backend_test.include('test_zfnet512')
backend_test.include('test_vgg19')
2023-02-24 13:55:07 +08:00
globals().update(backend_test.enable_report().test_cases)
if __name__ == '__main__':
unittest.main()