From be5f185ac0e62388dc0a627b55fc95445d7756cc Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Tue, 31 Oct 2023 01:38:42 -0400 Subject: [PATCH] Higher test coverage for dtypes (#2156) * refactor unit tests for dtypes * add missing dtypes in llvmir.py and lib.py * skip torch tests * webgpu * cleaner skips * fix llvm bool casting issue using compare * llvm 100% passing * llvm segfault * TEMP decrease timeout mins to 11 debug * add bf16 to setup * skip half tests in cuda cpu * check for CUDACPU insetad * add int16 to triton dtypes * u16 for triton * remove debug - diff is still hard to read * derive from base class TestDType * enhance test_upcast and downcast by running on every possible version * dummy commit to rerun the flakey test * skip the correct tests for CUDA * bf16 should be skipped in the common TestDType cases * re-enable bf16 * more consistent structure * tiny changes to is_dtype_supported 1 * tiny changes 2 add reason * fuzz * fuzzer p2 * run fp32 twice * remove duplicate fp32 run * clang: use stdbool * skip triton on bool casts * merge and resolve conflicts --- test/test_dtype.py | 133 +++++++++++++++++----------------- tinygrad/renderer/llvmir.py | 2 +- tinygrad/renderer/triton.py | 4 +- tinygrad/runtime/lib.py | 2 +- tinygrad/runtime/ops_clang.py | 2 +- 5 files changed, 73 insertions(+), 70 deletions(-) diff --git a/test/test_dtype.py b/test/test_dtype.py index 61fedfc6..8b551537 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -1,11 +1,29 @@ import unittest import numpy as np -from tinygrad.helpers import getenv, DType, DEBUG, ImageDType, PtrDType +from tinygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType, PtrDType from tinygrad.ops import Device from tinygrad.tensor import Tensor, dtypes -from typing import List, Optional +from typing import Any, List from extra.utils import OSX, temp -import copy + +def is_dtype_supported(dtype: DType): + # for GPU, cl_khr_fp16 isn't supported (except now we don't need it!) + # for LLVM, it segfaults because it can't link to the casting function + if dtype == dtypes.half: return not (CI and Device.DEFAULT in ["GPU", "LLVM"]) and Device.DEFAULT != "WEBGPU" and getenv("CUDACPU") != 1 + if dtype == dtypes.bfloat16: return False # numpy doesn't support bf16, tested separately in TestBFloat16DType + if dtype == dtypes.float64: return Device.DEFAULT not in ["WEBGPU", "METAL"] and not OSX + if dtype in [dtypes.int8, dtypes.uint8]: return Device.DEFAULT not in ["WEBGPU"] + if dtype in [dtypes.int16, dtypes.uint16]: return Device.DEFAULT not in ["WEBGPU", "TORCH"] + if dtype == dtypes.uint32: return Device.DEFAULT not in ["TORCH"] + if dtype in [dtypes.int64, dtypes.uint64]: return Device.DEFAULT not in ["WEBGPU", "TORCH"] + if dtype == dtypes.bool: + # host-shareablity is a requirement for storage buffers, but 'bool' type is not host-shareable + if Device.DEFAULT == "WEBGPU": return False + # TODO remove triton from here once internal casting is fixed. CAST of fp32s between 0-1 is broken in triton + if getenv("TRITON") == 1: return False + return True + +def get_available_cast_dtypes(dtype: DType) -> List[DType]: return [v for k, v in DTYPES_DICT.items() if v != dtype and is_dtype_supported(v) and not k.startswith("_")] # dont cast internal dtypes def _test_to_np(a:Tensor, np_dtype, target): if DEBUG >= 2: print(a) @@ -26,31 +44,49 @@ def _assert_eq(tensor:Tensor, target_dtype:DType, target): raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e def _test_op(fxn, target_dtype:DType, target): _assert_eq(fxn(), target_dtype, target) -def _test_cast(a:Tensor, target_dtype:DType, target): _test_op(lambda: a.cast(target_dtype), target_dtype, target) +def _test_cast(a:Tensor, target_dtype:DType): _test_op(lambda: a.cast(target_dtype), target_dtype, a.numpy().astype(target_dtype.np).tolist()) def _test_bitcast(a:Tensor, target_dtype:DType, target): _test_op(lambda: a.bitcast(target_dtype), target_dtype, target) -# tests no-op casts from source_dtype to target_dtypes -def _test_casts_from(tensor_contents:List, source_dtype:DType, target_dtypes:List[DType], target_contents:Optional[List]=None): - if target_contents is None: target_contents = copy.deepcopy(tensor_contents) - list(map( - lambda t_dtype: _test_cast(Tensor(tensor_contents, dtype=source_dtype), t_dtype, target_contents), - target_dtypes +class TestDType(unittest.TestCase): + DTYPE: Any = None + DATA: Any = None + @classmethod + def setUpClass(cls): + if not is_dtype_supported(cls.DTYPE): raise unittest.SkipTest("dtype not supported") + cls.DATA = np.random.randint(0, 100, size=10, dtype=cls.DTYPE.np).tolist() if dtypes.is_int(cls.DTYPE) else np.random.choice([True, False], size=10).tolist() if cls.DTYPE == dtypes.bool else np.random.uniform(0, 1, size=10).tolist() + def setUp(self): + if self.DTYPE is None: raise unittest.SkipTest("base class") + + def test_to_np(self): _test_to_np(Tensor(self.DATA, dtype=self.DTYPE), self.DTYPE.np, np.array(self.DATA, dtype=self.DTYPE.np)) + + def test_casts_to(self): list(map( + lambda dtype: _test_cast(Tensor(self.DATA, dtype=dtype), self.DTYPE), + get_available_cast_dtypes(self.DTYPE) )) -# tests no-op casts from source_dtypes to target_dtype -def _test_casts_to(tensor_contents:List, source_dtypes:List[DType], target_dtype:DType, target_contents:Optional[List]=None): - if target_contents is None: target_contents = copy.deepcopy(tensor_contents) - list(map( - lambda s_dtype: _test_cast(Tensor(tensor_contents, dtype=s_dtype), target_dtype, target_contents), - source_dtypes + def test_casts_from(self): list(map( + lambda dtype: _test_cast(Tensor(self.DATA, dtype=self.DTYPE), dtype), + get_available_cast_dtypes(self.DTYPE) + )) + + def test_upcast_ops(self): list(map( + lambda dtype: _test_ops(a_dtype=self.DTYPE, b_dtype=dtype, target_dtype=dtype) if dtype.sz > self.DTYPE.sz else None, + get_available_cast_dtypes(self.DTYPE) + )) + def test_upcast_to_ops(self): list(map( + lambda dtype: _test_ops(a_dtype=dtype, b_dtype=self.DTYPE, target_dtype=self.DTYPE) if dtype.sz < self.DTYPE.sz else None, + get_available_cast_dtypes(self.DTYPE) )) def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype:DType): + if not is_dtype_supported(a_dtype) or not is_dtype_supported(b_dtype): raise unittest.SkipTest("dtype not supported") _assert_eq(Tensor([1,2,3,4], dtype=a_dtype)+Tensor([1,2,3,4], dtype=b_dtype), target_dtype, [2,4,6,8]) _assert_eq(Tensor([1,2,3,4], dtype=a_dtype)*Tensor([1,2,3,4], dtype=b_dtype), target_dtype, [1,4,9,16]) _assert_eq(Tensor([[1,2],[3,4]], dtype=a_dtype)@Tensor.eye(2, dtype=b_dtype), target_dtype, [[1,2],[3,4]]) _assert_eq(Tensor([1,1,1,1], dtype=a_dtype)+Tensor.ones((4,4), dtype=b_dtype), target_dtype, 2*Tensor.ones(4,4).numpy()) class TestBFloat16DType(unittest.TestCase): + def setUp(self): + if not is_dtype_supported(dtypes.bfloat16): raise unittest.SkipTest("bfloat16 not supported") def test_bf16_to_float(self): with self.assertRaises(AssertionError): _test_cast(Tensor([100000], dtype=dtypes.bfloat16), dtypes.float32, [100000]) @@ -61,14 +97,12 @@ class TestBFloat16DType(unittest.TestCase): # torch.tensor([10000, -1, -1000, -10000, 20]).type(torch.bfloat16) - @unittest.skipIf(Device.DEFAULT not in ["LLVM"], "bf16 only on LLVM") def test_bf16(self): t = Tensor([10000, -1, -1000, -10000, 20]).cast(dtypes.bfloat16) t.realize() back = t.cast(dtypes.float32) assert tuple(back.numpy().tolist()) == (9984., -1, -1000, -9984, 20) - @unittest.skipIf(Device.DEFAULT not in ["LLVM"], "bf16 only on LLVM") def test_bf16_disk_write_read(self): t = Tensor([10000, -1, -1000, -10000, 20]).cast(dtypes.float32) t.to(f"disk:{temp('f32')}").realize() @@ -82,44 +116,20 @@ class TestBFloat16DType(unittest.TestCase): back = t.cast(dtypes.float32) assert tuple(back.numpy().tolist()) == (9984., -1, -1000, -9984, 20) -# for GPU, cl_khr_fp16 isn't supported (except now we don't need it!) -# for LLVM, it segfaults because it can't link to the casting function -@unittest.skipIf((getenv("CI", "") != "" and Device.DEFAULT in ["LLVM"]) or Device.DEFAULT == "WEBGPU", "float16 broken in some CI backends") -class TestHalfDtype(unittest.TestCase): - def test_float16_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.float16), np.float16, [1,2,3,4]) - def test_casts_to_half(self): _test_casts_to([1,2,3,4], source_dtypes=[dtypes.float32, dtypes.int8, dtypes.uint8], target_dtype=dtypes.float16) - def test_casts_from_half(self): _test_casts_from([1,2,3,4], source_dtype=dtypes.float16, target_dtypes=[dtypes.int8, dtypes.uint8, dtypes.float32, dtypes.int32, dtypes.int64]) - def test_half_upcast_ops(self): _test_ops(a_dtype=dtypes.float16, b_dtype=dtypes.float32, target_dtype=dtypes.float32) - def test_upcast_to_half_ops(self): _test_ops(a_dtype=dtypes.int8, b_dtype=dtypes.float16, target_dtype=dtypes.float16) +class TestHalfDtype(TestDType): DTYPE = dtypes.half -@unittest.skipIf(Device.DEFAULT in ["WEBGPU", "METAL"] or OSX, "float64 is not supported by some backends") -class TestDoubleDtype(unittest.TestCase): - def test_float64_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.double), np.double, [1,2,3,4]) - def test_casts_to_float64(self): _test_casts_to([1,2,3,4], source_dtypes=[dtypes.float32, dtypes.int32, dtypes.uint8], target_dtype=dtypes.float64) - def test_upcast_to_float64_ops(self): _test_ops(a_dtype=dtypes.int8, b_dtype=dtypes.float64, target_dtype=dtypes.float64) +class TestFloatDType(TestDType): DTYPE = dtypes.float -@unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu does not support int8") -class TestInt8Dtype(unittest.TestCase): - def test_int8_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.int8), np.int8, [1,2,3,4]) - def test_uint8_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.uint8), np.uint8, [1,2,3,4]) - def test_int64_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.int64), np.int64, [1,2,3,4]) +class TestDoubleDtype(TestDType): DTYPE = dtypes.double - def test_casts_to_int8(self): _test_casts_from([1,2,3,4], source_dtype=dtypes.float32, target_dtypes=[dtypes.int8, dtypes.uint8, dtypes.int32, dtypes.int64]) - def test_casts_to_bool_1(self): _test_casts_from([1,2,3,4], source_dtype=dtypes.int8, target_dtypes=[dtypes.bool], target_contents=[True, True, True, True]) - def test_casts_to_bool_2(self): _test_casts_from([1,0,3,4], source_dtype=dtypes.int8, target_dtypes=[dtypes.bool], target_contents=[True, False, True, True]) - def test_casts_from_int8(self): _test_casts_from([1,2,3,4], source_dtype=dtypes.int8, target_dtypes=[dtypes.float32, dtypes.uint8, dtypes.int32, dtypes.int64]) - def test_casts_from_uint8(self): _test_casts_from([1,2,3,4], source_dtype=dtypes.uint8, target_dtypes=[dtypes.float32, dtypes.int8, dtypes.int32, dtypes.int64]) - - def test_int8_ops(self): _test_ops(a_dtype=dtypes.int8, b_dtype=dtypes.int8, target_dtype=dtypes.int8) - def test_int64_ops(self): _test_ops(a_dtype=dtypes.int64, b_dtype=dtypes.int64, target_dtype=dtypes.int64) - def test_int8_upcast_float(self): _test_ops(a_dtype=dtypes.int8, b_dtype=dtypes.float32, target_dtype=dtypes.float32) - def test_int8_upcast_int64(self): _test_ops(a_dtype=dtypes.int8, b_dtype=dtypes.int64, target_dtype=dtypes.int64) - - @unittest.skipIf(getenv("CUDA",0)==1, "cuda saturation works differently") - @unittest.skipIf(getenv("PTX",0)==1, "cuda saturation doesn't wrap") +class TestInt8Dtype(TestDType): + DTYPE = dtypes.int8 + @unittest.skipIf(getenv("CUDA",0)==1 or getenv("PTX", 0)==1, "cuda saturation works differently") def test_int8_to_uint8_negative(self): _test_op(lambda: Tensor([-1, -2, -3, -4], dtype=dtypes.int8).cast(dtypes.uint8), dtypes.uint8, [255, 254, 253, 252]) - @unittest.skipIf(getenv("PTX",0)==1, "cuda saturation doesn't wrap") +class TestUint8Dtype(TestDType): + DTYPE = dtypes.uint8 + @unittest.skipIf(getenv("CUDA",0)==1 or getenv("PTX", 0)==1, "cuda saturation works differently") def test_uint8_to_int8_overflow(self): _test_op(lambda: Tensor([255, 254, 253, 252], dtype=dtypes.uint8).cast(dtypes.int8), dtypes.int8, [-1, -2, -3, -4]) @unittest.skipIf(Device.DEFAULT not in {"CPU", "TORCH"}, "only bitcast in CPU and TORCH") @@ -141,23 +151,16 @@ class TestBitCast(unittest.TestCase): with self.assertRaises(AssertionError): _test_bitcast(Tensor([100000], dtype=dtypes.float32), dtypes.uint8, [100000]) -class TestInt32Dtype(unittest.TestCase): - def test_int32_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.int32), np.int32, [1,2,3,4]) +class TestInt16Dtype(TestDType): DTYPE = dtypes.int16 +class TestUint16Dtype(TestDType): DTYPE = dtypes.uint16 - @unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu does not support int64") - def test_casts_to_int32(self): _test_casts_to([1,2,3,4], source_dtypes=[dtypes.float32, dtypes.int64], target_dtype=dtypes.int32) - @unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu does not support int64") - def test_casts_from_int32(self): _test_casts_from([1,2,3,4], source_dtype=dtypes.int32, target_dtypes=[dtypes.float32, dtypes.int64]) +class TestInt32Dtype(TestDType): DTYPE = dtypes.int32 +class TestUint32Dtype(TestDType): DTYPE = dtypes.uint32 - def test_int32_ops(self): _test_ops(a_dtype=dtypes.int32, b_dtype=dtypes.int32, target_dtype=dtypes.int32) - def test_int32_upcast_float32(self): _test_ops(a_dtype=dtypes.int32, b_dtype=dtypes.float32, target_dtype=dtypes.float32) - @unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu does not support int64") - def test_int32_upcast_int64(self): _test_ops(a_dtype=dtypes.int32, b_dtype=dtypes.int64, target_dtype=dtypes.int64) +class TestInt64Dtype(TestDType): DTYPE = dtypes.int64 +class TestUint64Dtype(TestDType): DTYPE = dtypes.uint64 -@unittest.skipIf(Device.DEFAULT == "WEBGPU", "host-shareablity is a requirement for storage buffers, but 'bool' type is not host-shareable") -class TestBoolDtype(unittest.TestCase): - def test_casts_from_bool(self): _test_casts_from([0,1,1,0], source_dtype=dtypes.bool, target_dtypes=[dtypes.float32, dtypes.int32]) - def test_casts_to_bool(self): _test_casts_to([0,1,1,0], source_dtypes=[dtypes.float32, dtypes.int32], target_dtype=dtypes.bool) +class TestBoolDtype(TestDType): DTYPE = dtypes.bool class TestEqStrDType(unittest.TestCase): def test_image_ne(self): diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index f01fcaeb..1f0f8243 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -24,7 +24,7 @@ code_for_op: Final[Dict[Op, Callable]] = { TernaryOps.WHERE: lambda builder,x,y,z: builder.select(builder.fcmp_unordered("!=", x, ir.Constant(ir.FloatType(), 0), flags=LLVM_FAST_MATH_FLAGS) if isinstance(x.type, ir.FloatType) else builder.trunc(x, ir.IntType(1)), y, z, flags=LLVM_FAST_MATH_FLAGS), } -dtype_to_llvm_dtype = {dtypes.float64:ir.DoubleType(), dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32), dtypes._arg_int32: ir.IntType(32)} +dtype_to_llvm_dtype = {dtypes.float64:ir.DoubleType(), dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32), dtypes._arg_int32: ir.IntType(32), dtypes.int16:ir.IntType(16), dtypes.uint16:ir.IntType(16), dtypes.uint32:ir.IntType(32), dtypes.uint64:ir.IntType(64)} def cast(bb, val, input_type, output_type): if input_type == output_type: return val diff --git a/tinygrad/renderer/triton.py b/tinygrad/renderer/triton.py index 4088552e..2153131b 100644 --- a/tinygrad/renderer/triton.py +++ b/tinygrad/renderer/triton.py @@ -8,8 +8,8 @@ import linecache import math import re -triton_dtypes = {dtypes.double: "tl.float64", dtypes.float32: "tl.float32", dtypes.float16: "tl.float16", dtypes.bool: "tl.int1", dtypes.int8: "tl.int8", dtypes.uint8: "tl.uint8", dtypes.int32: "tl.int32", dtypes.int64: "tl.int64", dtypes.uint32: "tl.uint32", dtypes.uint64: "tl.uint64"} -signature_dtypes = {dtypes.double: "*fp64",dtypes.float32: "*fp32", dtypes.float16: "*fp16", dtypes.bool: "*i8", dtypes.int8: "*i1", dtypes.uint8: "*u8", dtypes._arg_int32: "i32", dtypes.int32: "*i32", dtypes.int64: "*i64", dtypes.uint32: "*u32", dtypes.uint64: "*u64"} +triton_dtypes = {dtypes.double: "tl.float64", dtypes.float32: "tl.float32", dtypes.float16: "tl.float16", dtypes.bool: "tl.int1", dtypes.int8: "tl.int8", dtypes.uint8: "tl.uint8", dtypes.int32: "tl.int32", dtypes.int64: "tl.int64", dtypes.uint32: "tl.uint32", dtypes.uint64: "tl.uint64", dtypes.int16: "tl.int16", dtypes.uint16: "tl.uint16"} +signature_dtypes = {dtypes.double: "*fp64",dtypes.float32: "*fp32", dtypes.float16: "*fp16", dtypes.bool: "*i8", dtypes.int8: "*i1", dtypes.uint8: "*u8", dtypes._arg_int32: "i32", dtypes.int32: "*i32", dtypes.int64: "*i64", dtypes.uint32: "*u32", dtypes.uint64: "*u64", dtypes.int16: "*i16", dtypes.uint16: "*u16"} def next_power_of_2(x): return 1 << (x - 1).bit_length() diff --git a/tinygrad/runtime/lib.py b/tinygrad/runtime/lib.py index b3d33a78..15a3f2ec 100644 --- a/tinygrad/runtime/lib.py +++ b/tinygrad/runtime/lib.py @@ -43,7 +43,7 @@ class RawBufferMapped(RawBufferCopyIn): # this one is simple enough that i moved it out of the runtimes class RawMallocBuffer(RawBufferMapped): - def __init__(self, size, dtype: DType): super().__init__(size, dtype, ({dtypes.float64:ctypes.c_double, dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16, dtypes.bfloat16: ctypes.c_int16, dtypes.int8: ctypes.c_int8, dtypes.uint8: ctypes.c_uint8, dtypes.bool: ctypes.c_uint8, dtypes.int32: ctypes.c_int32, dtypes.uint32: ctypes.c_uint32, dtypes.int64: ctypes.c_int64, dtypes.uint64: ctypes.c_uint64}[dtype] * size)()) + def __init__(self, size, dtype: DType): super().__init__(size, dtype, ({dtypes.float64:ctypes.c_double, dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16, dtypes.bfloat16: ctypes.c_int16, dtypes.int8: ctypes.c_int8, dtypes.uint8: ctypes.c_uint8, dtypes.bool: ctypes.c_uint8, dtypes.int32: ctypes.c_int32, dtypes.uint32: ctypes.c_uint32, dtypes.int64: ctypes.c_int64, dtypes.uint64: ctypes.c_uint64, dtypes.int16: ctypes.c_int16, dtypes.uint16: ctypes.c_uint16}[dtype] * size)()) def _buffer(self): return memoryview(self._buf) class RawBufferCopyInOut(RawBufferCopyIn): diff --git a/tinygrad/runtime/ops_clang.py b/tinygrad/runtime/ops_clang.py index 5c4fb1bd..3e82b220 100644 --- a/tinygrad/runtime/ops_clang.py +++ b/tinygrad/runtime/ops_clang.py @@ -18,7 +18,7 @@ args = { 'Darwin': {'cflags':'-lm -fPIC --rtlib=compiler-rt ', 'ext':'dylib', 'exp':''} }[platform.system()] -CLANG_PROGRAM_HEADER = '#include \n#define max(x,y) ((x>y)?x:y)\n#define int64 long\n#define half __fp16\n#define uchar unsigned char\n#define bool uchar\n' +CLANG_PROGRAM_HEADER = '#include \n#define max(x,y) ((x>y)?x:y)\n#define int64 long\n#define half __fp16\n#define uchar unsigned char\n#include \n' ADDRESS = 0x10000 # Unicorn doesn't support external calls