From c704a77ca0cede1ba1c589ebf84198e4ae35c1c5 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Wed, 6 Dec 2023 11:15:46 -0500 Subject: [PATCH] green dtypes ALU tests (#2617) * dtypes alu test * those types don't exist in torch * floats * more tests * disable those * a couple unary tests * skip float16 tests in CI for GPU * fix LLVM bool add True+True=1+1=2 which truncates to False in native LLVM * remove hardcoded float for LLVM ALU fns * less sensitive atol for fp32, 1e-10 is flaky and sometimes failed even if you revert the merge commit for non-fp32 math, nothing has changed in our kernels for fp32. * return on overflows * fix CUDA exp2 * compute results of op regardless of bounds in a python backend * skip fp16 in GPU and CUDACPU * fuzz a smaller range in the float_midcast_int32 test I sampled this and we overflow ~70% of the time. because numpy behaves differently on different devices for overflows and Metal seems to do the same, I'm opting to eliminate the non-determinism here * remove CUDA exp2 overload it's already there now --------- Co-authored-by: George Hotz --- .gitignore | 1 + setup.py | 1 + test/test_dtype_alu.py | 103 ++++++++++++++++++++++++++++++++++++ tinygrad/renderer/llvmir.py | 10 ++-- 4 files changed, 110 insertions(+), 5 deletions(-) create mode 100644 test/test_dtype_alu.py diff --git a/.gitignore b/.gitignore index 3e7c1583..08a5b113 100644 --- a/.gitignore +++ b/.gitignore @@ -47,3 +47,4 @@ outputs_yolov8 wandb model.safetensors quickstart.py +.hypothesis diff --git a/setup.py b/setup.py index 5be6ae8b..a37fbdd1 100644 --- a/setup.py +++ b/setup.py @@ -51,6 +51,7 @@ setup(name='tinygrad', "tiktoken", "librosa", "networkx", + "hypothesis", ] }, include_package_data=True) diff --git a/test/test_dtype_alu.py b/test/test_dtype_alu.py new file mode 100644 index 00000000..a3cdcb28 --- /dev/null +++ b/test/test_dtype_alu.py @@ -0,0 +1,103 @@ +import unittest +from tinygrad import Tensor, dtypes, Device +import operator +import numpy as np +from hypothesis import given, strategies as st, settings + +from tinygrad.helpers import CI, getenv + +settings.register_profile("my_profile", max_examples=200, deadline=None) +settings.load_profile("my_profile") +print(settings.default) + +def skipUnlessFP16Supported(): return unittest.skip("GPU requires cl_khr_fp16") if Device.DEFAULT == "GPU" and CI else unittest.skip("CUDACPU architecture is sm_35 but we need at least sm_70 to run fp16 ALUs") if getenv("CUDACPU") else lambda _x: None + + +dtypes_float = (dtypes.float32, dtypes.float16) +dtypes_int = (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64) +dtypes_bool = (dtypes.bool,) +# TODO: truediv is broken +# TODO: lt and eq should cast in tensor +binary_operations = ((operator.add, operator.add), (operator.sub, operator.sub), (operator.mul, operator.mul)) #, operator.lt, operator.eq) #, operator.truediv) +unary_operations = ((Tensor.exp, np.exp), (Tensor.log, np.log), (operator.neg, operator.neg)) + +def universal_test(a, b, dtype, op): + tensor_value = (op[0](Tensor([a], dtype=dtype), Tensor([b], dtype=dtype))).numpy() + numpy_value = op[1](np.array([a]).astype(dtype.np), np.array([b]).astype(dtype.np)) + if dtype in dtypes_float: np.testing.assert_allclose(tensor_value, numpy_value, atol=1e-10) + else: np.testing.assert_equal(tensor_value, numpy_value) + +def universal_test_unary(a, dtype, op): + tensor_value = op[0](Tensor([a], dtype=dtype)).numpy() + numpy_value = op[1](np.array([a]).astype(dtype.np)) + if dtype in dtypes_float: np.testing.assert_allclose(tensor_value, numpy_value, atol=1e-7, rtol=1e-5 if dtype == dtypes.float32 else 1e-2) # exp and log are approximations + else: np.testing.assert_equal(tensor_value, numpy_value) + +class TestDTypeALU(unittest.TestCase): + @given(st.floats(width=32, allow_subnormal=False), st.floats(width=32, allow_subnormal=False), st.sampled_from(binary_operations)) + def test_float32(self, a, b, op): universal_test(a, b, dtypes.float32, op) + + @skipUnlessFP16Supported() + @given(st.floats(width=16, allow_subnormal=False), st.floats(width=16, allow_subnormal=False), st.sampled_from(binary_operations)) + def test_float16(self, a, b, op): universal_test(a, b, dtypes.float16, op) + + @given(st.floats(width=32, allow_subnormal=False), st.sampled_from(unary_operations)) + def test_float32_unary(self, a, op): universal_test_unary(a, dtypes.float32, op) + + @skipUnlessFP16Supported() + @given(st.floats(width=32, allow_subnormal=False), st.sampled_from(unary_operations)) + def test_float16_unary(self, a, op): universal_test_unary(a, dtypes.float16, op) + + @given(st.integers(0, 255), st.integers(0, 255), st.sampled_from(binary_operations)) + def test_uint8(self, a, b, op): universal_test(a, b, dtypes.uint8, op) + + @unittest.skipIf(Device.DEFAULT == "TORCH", "no uint16 in torch") + @given(st.integers(0, 65535), st.integers(0, 65535), st.sampled_from(binary_operations)) + def test_uint16(self, a, b, op): universal_test(a, b, dtypes.uint16, op) + + @unittest.skipIf(Device.DEFAULT == "TORCH", "no uint32 in torch") + @given(st.integers(0, 4294967295), st.integers(0, 4294967295), st.sampled_from(binary_operations)) + def test_uint32(self, a, b, op): universal_test(a, b, dtypes.uint32, op) + + @given(st.integers(-128, 127), st.integers(-128, 127), st.sampled_from(binary_operations)) + def test_int8(self, a, b, op): universal_test(a, b, dtypes.int8, op) + + @given(st.integers(-32768, 32767), st.integers(-32768, 32767), st.sampled_from(binary_operations)) + def test_int16(self, a, b, op): universal_test(a, b, dtypes.int16, op) + + @given(st.integers(-2147483648, 2147483647), st.integers(-2147483648, 2147483647), st.sampled_from(binary_operations)) + def test_int32(self, a, b, op): universal_test(a, b, dtypes.int32, op) + + @given(st.booleans(), st.booleans(), st.sampled_from(((operator.add, operator.add), (operator.mul, operator.mul)))) + def test_bool(self, a, b, op): universal_test(a, b, dtypes.bool, op) + + @given(st.integers(-2147483648, 2147483647), st.integers(-2147483648, 2147483647), st.floats(width=32, allow_subnormal=False), st.sampled_from(binary_operations), st.sampled_from(binary_operations)) + def test_int32_midcast_float(self, a, b, c, op1, op2): + at, bt, ct = Tensor([a], dtype=dtypes.int32), Tensor([b], dtype=dtypes.int32), Tensor([c], dtype=dtypes.float32) + an, bn, cn = np.array([a]).astype(np.int32), np.array([b]).astype(np.int32), np.array([c]).astype(np.float32) + tensor_value = op2[0](op1[0](at, bt).cast(dtypes.float32), ct).numpy() + numpy_value = op2[1](op1[1](an, bn).astype(np.float32), cn) + np.testing.assert_almost_equal(tensor_value, numpy_value) + + @given(st.floats(width=32, allow_subnormal=False, min_value=0, max_value=10.0), st.floats(width=32, allow_subnormal=False, min_value=0, max_value=10.0), st.integers(-10, 10), st.sampled_from(binary_operations), st.sampled_from(binary_operations)) + def test_float_midcast_int32(self, a, b, c, op1, op2): + at, bt, ct = Tensor([a], dtype=dtypes.float32), Tensor([b], dtype=dtypes.float32), Tensor([c], dtype=dtypes.int32) + an, bn, cn = np.array([a]).astype(np.float32), np.array([b]).astype(np.float32), np.array([c]).astype(np.int32) + tensor_value = op2[0](op1[0](at, bt).cast(dtypes.int32), ct).numpy() + numpy_value = op2[1](op1[1](an, bn).astype(np.int32), cn) + np.testing.assert_equal(tensor_value, numpy_value) + + @given(st.floats(width=32, allow_subnormal=False), st.sampled_from(dtypes_float+dtypes_int+dtypes_bool)) + def test_float_cast(self, a, dtype): + tensor_value = Tensor([a], dtype=dtypes.float32).cast(dtype) + numpy_value = np.array([a]).astype(dtype.np) + np.testing.assert_equal(tensor_value, numpy_value) + + @given(st.integers(-2147483648, 2147483647), st.sampled_from(dtypes_float+dtypes_int+dtypes_bool)) + def test_int32_cast(self, a, dtype): + tensor_value = Tensor([a], dtype=dtypes.int32).cast(dtype) + numpy_value = np.array([a]).astype(dtype.np) + np.testing.assert_equal(tensor_value, numpy_value) + +if __name__ == '__main__': + unittest.main() diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 0d543d47..80e2f4ae 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -9,11 +9,11 @@ LLVM_FAST_MATH_FLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from def is_bool(t:ir.Type): return isinstance(t, ir.IntType) and t.width == 1 code_for_op: Final[Dict[Op, Callable]] = { UnaryOps.NEG: lambda builder,x: builder.xor(x, ir.Constant(ir.IntType(1), 1)) if is_bool(x.type) else builder.neg(x) if isinstance(x.type, ir.IntType) else builder.fneg(x, flags=LLVM_FAST_MATH_FLAGS), - UnaryOps.EXP2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp2', [ir.FloatType()]), [x], fastmath=LLVM_FAST_MATH_FLAGS), - UnaryOps.LOG2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log2', [ir.FloatType()]), [x], fastmath=LLVM_FAST_MATH_FLAGS), - UnaryOps.SIN: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.sin', [ir.FloatType()]), [x], fastmath=LLVM_FAST_MATH_FLAGS), - UnaryOps.SQRT: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.sqrt', [ir.FloatType()]), [x], fastmath=LLVM_FAST_MATH_FLAGS), - BinaryOps.ADD: lambda builder,x,y: builder.add(x,y) if isinstance(x.type, ir.IntType) else builder.fadd(x,y, flags=LLVM_FAST_MATH_FLAGS), + UnaryOps.EXP2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp2', [x.type]), [x], fastmath=LLVM_FAST_MATH_FLAGS), + UnaryOps.LOG2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log2', [x.type]), [x], fastmath=LLVM_FAST_MATH_FLAGS), + UnaryOps.SIN: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.sin', [x.type]), [x], fastmath=LLVM_FAST_MATH_FLAGS), + UnaryOps.SQRT: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.sqrt', [x.type]), [x], fastmath=LLVM_FAST_MATH_FLAGS), + BinaryOps.ADD: lambda builder,x,y: builder.or_(x, y) if is_bool(x.type) else builder.add(x,y) if isinstance(x.type, ir.IntType) else builder.fadd(x,y, flags=LLVM_FAST_MATH_FLAGS), BinaryOps.SUB: lambda builder,x,y: builder.sub(x,y) if isinstance(x.type, ir.IntType) else builder.fsub(x,y, flags=LLVM_FAST_MATH_FLAGS), BinaryOps.MUL: lambda builder,x,y: builder.mul(x,y) if isinstance(x.type, ir.IntType) else builder.fmul(x,y, flags=LLVM_FAST_MATH_FLAGS), BinaryOps.DIV: lambda builder,x,y: builder.sdiv(x,y) if isinstance(x.type, ir.IntType) else builder.fdiv(x,y, flags=LLVM_FAST_MATH_FLAGS),