green dtypes ALU tests (#2617)

* dtypes alu test

* those types don't exist in torch

* floats

* more tests

* disable those

* a couple unary tests

* skip float16 tests in CI for GPU

* fix LLVM bool add True+True=1+1=2 which truncates to False in native LLVM

* remove hardcoded float for LLVM ALU fns

* less sensitive atol for fp32, 1e-10 is flaky and sometimes failed even if you revert the merge commit for non-fp32 math, nothing has changed in our kernels for fp32.

* return on overflows

* fix CUDA exp2

* compute results of op regardless of bounds in a python backend

* skip fp16 in GPU and CUDACPU

* fuzz a smaller range in the float_midcast_int32 test

I sampled this and we overflow ~70% of the time.
because numpy behaves differently on different devices for overflows and Metal seems to do the same, I'm opting to eliminate the non-determinism here

* remove CUDA exp2 overload it's already there now

---------

Co-authored-by: George Hotz <geohot@gmail.com>
This commit is contained in:
qazal 2023-12-06 11:15:46 -05:00 committed by GitHub
parent 71d989b476
commit c704a77ca0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 110 additions and 5 deletions

1
.gitignore vendored
View File

@ -47,3 +47,4 @@ outputs_yolov8
wandb
model.safetensors
quickstart.py
.hypothesis

View File

@ -51,6 +51,7 @@ setup(name='tinygrad',
"tiktoken",
"librosa",
"networkx",
"hypothesis",
]
},
include_package_data=True)

103
test/test_dtype_alu.py Normal file
View File

@ -0,0 +1,103 @@
import unittest
from tinygrad import Tensor, dtypes, Device
import operator
import numpy as np
from hypothesis import given, strategies as st, settings
from tinygrad.helpers import CI, getenv
settings.register_profile("my_profile", max_examples=200, deadline=None)
settings.load_profile("my_profile")
print(settings.default)
def skipUnlessFP16Supported(): return unittest.skip("GPU requires cl_khr_fp16") if Device.DEFAULT == "GPU" and CI else unittest.skip("CUDACPU architecture is sm_35 but we need at least sm_70 to run fp16 ALUs") if getenv("CUDACPU") else lambda _x: None
dtypes_float = (dtypes.float32, dtypes.float16)
dtypes_int = (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
dtypes_bool = (dtypes.bool,)
# TODO: truediv is broken
# TODO: lt and eq should cast in tensor
binary_operations = ((operator.add, operator.add), (operator.sub, operator.sub), (operator.mul, operator.mul)) #, operator.lt, operator.eq) #, operator.truediv)
unary_operations = ((Tensor.exp, np.exp), (Tensor.log, np.log), (operator.neg, operator.neg))
def universal_test(a, b, dtype, op):
tensor_value = (op[0](Tensor([a], dtype=dtype), Tensor([b], dtype=dtype))).numpy()
numpy_value = op[1](np.array([a]).astype(dtype.np), np.array([b]).astype(dtype.np))
if dtype in dtypes_float: np.testing.assert_allclose(tensor_value, numpy_value, atol=1e-10)
else: np.testing.assert_equal(tensor_value, numpy_value)
def universal_test_unary(a, dtype, op):
tensor_value = op[0](Tensor([a], dtype=dtype)).numpy()
numpy_value = op[1](np.array([a]).astype(dtype.np))
if dtype in dtypes_float: np.testing.assert_allclose(tensor_value, numpy_value, atol=1e-7, rtol=1e-5 if dtype == dtypes.float32 else 1e-2) # exp and log are approximations
else: np.testing.assert_equal(tensor_value, numpy_value)
class TestDTypeALU(unittest.TestCase):
@given(st.floats(width=32, allow_subnormal=False), st.floats(width=32, allow_subnormal=False), st.sampled_from(binary_operations))
def test_float32(self, a, b, op): universal_test(a, b, dtypes.float32, op)
@skipUnlessFP16Supported()
@given(st.floats(width=16, allow_subnormal=False), st.floats(width=16, allow_subnormal=False), st.sampled_from(binary_operations))
def test_float16(self, a, b, op): universal_test(a, b, dtypes.float16, op)
@given(st.floats(width=32, allow_subnormal=False), st.sampled_from(unary_operations))
def test_float32_unary(self, a, op): universal_test_unary(a, dtypes.float32, op)
@skipUnlessFP16Supported()
@given(st.floats(width=32, allow_subnormal=False), st.sampled_from(unary_operations))
def test_float16_unary(self, a, op): universal_test_unary(a, dtypes.float16, op)
@given(st.integers(0, 255), st.integers(0, 255), st.sampled_from(binary_operations))
def test_uint8(self, a, b, op): universal_test(a, b, dtypes.uint8, op)
@unittest.skipIf(Device.DEFAULT == "TORCH", "no uint16 in torch")
@given(st.integers(0, 65535), st.integers(0, 65535), st.sampled_from(binary_operations))
def test_uint16(self, a, b, op): universal_test(a, b, dtypes.uint16, op)
@unittest.skipIf(Device.DEFAULT == "TORCH", "no uint32 in torch")
@given(st.integers(0, 4294967295), st.integers(0, 4294967295), st.sampled_from(binary_operations))
def test_uint32(self, a, b, op): universal_test(a, b, dtypes.uint32, op)
@given(st.integers(-128, 127), st.integers(-128, 127), st.sampled_from(binary_operations))
def test_int8(self, a, b, op): universal_test(a, b, dtypes.int8, op)
@given(st.integers(-32768, 32767), st.integers(-32768, 32767), st.sampled_from(binary_operations))
def test_int16(self, a, b, op): universal_test(a, b, dtypes.int16, op)
@given(st.integers(-2147483648, 2147483647), st.integers(-2147483648, 2147483647), st.sampled_from(binary_operations))
def test_int32(self, a, b, op): universal_test(a, b, dtypes.int32, op)
@given(st.booleans(), st.booleans(), st.sampled_from(((operator.add, operator.add), (operator.mul, operator.mul))))
def test_bool(self, a, b, op): universal_test(a, b, dtypes.bool, op)
@given(st.integers(-2147483648, 2147483647), st.integers(-2147483648, 2147483647), st.floats(width=32, allow_subnormal=False), st.sampled_from(binary_operations), st.sampled_from(binary_operations))
def test_int32_midcast_float(self, a, b, c, op1, op2):
at, bt, ct = Tensor([a], dtype=dtypes.int32), Tensor([b], dtype=dtypes.int32), Tensor([c], dtype=dtypes.float32)
an, bn, cn = np.array([a]).astype(np.int32), np.array([b]).astype(np.int32), np.array([c]).astype(np.float32)
tensor_value = op2[0](op1[0](at, bt).cast(dtypes.float32), ct).numpy()
numpy_value = op2[1](op1[1](an, bn).astype(np.float32), cn)
np.testing.assert_almost_equal(tensor_value, numpy_value)
@given(st.floats(width=32, allow_subnormal=False, min_value=0, max_value=10.0), st.floats(width=32, allow_subnormal=False, min_value=0, max_value=10.0), st.integers(-10, 10), st.sampled_from(binary_operations), st.sampled_from(binary_operations))
def test_float_midcast_int32(self, a, b, c, op1, op2):
at, bt, ct = Tensor([a], dtype=dtypes.float32), Tensor([b], dtype=dtypes.float32), Tensor([c], dtype=dtypes.int32)
an, bn, cn = np.array([a]).astype(np.float32), np.array([b]).astype(np.float32), np.array([c]).astype(np.int32)
tensor_value = op2[0](op1[0](at, bt).cast(dtypes.int32), ct).numpy()
numpy_value = op2[1](op1[1](an, bn).astype(np.int32), cn)
np.testing.assert_equal(tensor_value, numpy_value)
@given(st.floats(width=32, allow_subnormal=False), st.sampled_from(dtypes_float+dtypes_int+dtypes_bool))
def test_float_cast(self, a, dtype):
tensor_value = Tensor([a], dtype=dtypes.float32).cast(dtype)
numpy_value = np.array([a]).astype(dtype.np)
np.testing.assert_equal(tensor_value, numpy_value)
@given(st.integers(-2147483648, 2147483647), st.sampled_from(dtypes_float+dtypes_int+dtypes_bool))
def test_int32_cast(self, a, dtype):
tensor_value = Tensor([a], dtype=dtypes.int32).cast(dtype)
numpy_value = np.array([a]).astype(dtype.np)
np.testing.assert_equal(tensor_value, numpy_value)
if __name__ == '__main__':
unittest.main()

View File

@ -9,11 +9,11 @@ LLVM_FAST_MATH_FLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from
def is_bool(t:ir.Type): return isinstance(t, ir.IntType) and t.width == 1
code_for_op: Final[Dict[Op, Callable]] = {
UnaryOps.NEG: lambda builder,x: builder.xor(x, ir.Constant(ir.IntType(1), 1)) if is_bool(x.type) else builder.neg(x) if isinstance(x.type, ir.IntType) else builder.fneg(x, flags=LLVM_FAST_MATH_FLAGS),
UnaryOps.EXP2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp2', [ir.FloatType()]), [x], fastmath=LLVM_FAST_MATH_FLAGS),
UnaryOps.LOG2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log2', [ir.FloatType()]), [x], fastmath=LLVM_FAST_MATH_FLAGS),
UnaryOps.SIN: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.sin', [ir.FloatType()]), [x], fastmath=LLVM_FAST_MATH_FLAGS),
UnaryOps.SQRT: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.sqrt', [ir.FloatType()]), [x], fastmath=LLVM_FAST_MATH_FLAGS),
BinaryOps.ADD: lambda builder,x,y: builder.add(x,y) if isinstance(x.type, ir.IntType) else builder.fadd(x,y, flags=LLVM_FAST_MATH_FLAGS),
UnaryOps.EXP2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp2', [x.type]), [x], fastmath=LLVM_FAST_MATH_FLAGS),
UnaryOps.LOG2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log2', [x.type]), [x], fastmath=LLVM_FAST_MATH_FLAGS),
UnaryOps.SIN: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.sin', [x.type]), [x], fastmath=LLVM_FAST_MATH_FLAGS),
UnaryOps.SQRT: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.sqrt', [x.type]), [x], fastmath=LLVM_FAST_MATH_FLAGS),
BinaryOps.ADD: lambda builder,x,y: builder.or_(x, y) if is_bool(x.type) else builder.add(x,y) if isinstance(x.type, ir.IntType) else builder.fadd(x,y, flags=LLVM_FAST_MATH_FLAGS),
BinaryOps.SUB: lambda builder,x,y: builder.sub(x,y) if isinstance(x.type, ir.IntType) else builder.fsub(x,y, flags=LLVM_FAST_MATH_FLAGS),
BinaryOps.MUL: lambda builder,x,y: builder.mul(x,y) if isinstance(x.type, ir.IntType) else builder.fmul(x,y, flags=LLVM_FAST_MATH_FLAGS),
BinaryOps.DIV: lambda builder,x,y: builder.sdiv(x,y) if isinstance(x.type, ir.IntType) else builder.fdiv(x,y, flags=LLVM_FAST_MATH_FLAGS),