Higher test coverage for dtypes (#2156)

* refactor unit tests for dtypes

* add missing dtypes in llvmir.py and lib.py

* skip torch tests

* webgpu

* cleaner skips

* fix llvm bool casting issue using compare

* llvm 100% passing

* llvm segfault

* TEMP decrease timeout mins to 11

debug

* add bf16 to setup

* skip half tests in cuda cpu

* check for CUDACPU insetad

* add int16 to triton dtypes

* u16 for triton

* remove debug - diff is still hard to read

* derive from base class TestDType

* enhance test_upcast and downcast by running on every possible version

* dummy commit to rerun the flakey test

* skip the correct tests for CUDA

* bf16 should be skipped in the common TestDType cases

* re-enable bf16

* more consistent structure

* tiny changes to is_dtype_supported 1

* tiny changes 2

add reason

* fuzz

* fuzzer p2

* run fp32 twice

* remove duplicate fp32 run

* clang: use stdbool

* skip triton on bool casts

* merge and resolve conflicts
This commit is contained in:
qazal 2023-10-31 01:38:42 -04:00 committed by GitHub
parent f294bdd681
commit be5f185ac0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 73 additions and 70 deletions

View File

@ -1,11 +1,29 @@
import unittest
import numpy as np
from tinygrad.helpers import getenv, DType, DEBUG, ImageDType, PtrDType
from tinygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType, PtrDType
from tinygrad.ops import Device
from tinygrad.tensor import Tensor, dtypes
from typing import List, Optional
from typing import Any, List
from extra.utils import OSX, temp
import copy
def is_dtype_supported(dtype: DType):
# for GPU, cl_khr_fp16 isn't supported (except now we don't need it!)
# for LLVM, it segfaults because it can't link to the casting function
if dtype == dtypes.half: return not (CI and Device.DEFAULT in ["GPU", "LLVM"]) and Device.DEFAULT != "WEBGPU" and getenv("CUDACPU") != 1
if dtype == dtypes.bfloat16: return False # numpy doesn't support bf16, tested separately in TestBFloat16DType
if dtype == dtypes.float64: return Device.DEFAULT not in ["WEBGPU", "METAL"] and not OSX
if dtype in [dtypes.int8, dtypes.uint8]: return Device.DEFAULT not in ["WEBGPU"]
if dtype in [dtypes.int16, dtypes.uint16]: return Device.DEFAULT not in ["WEBGPU", "TORCH"]
if dtype == dtypes.uint32: return Device.DEFAULT not in ["TORCH"]
if dtype in [dtypes.int64, dtypes.uint64]: return Device.DEFAULT not in ["WEBGPU", "TORCH"]
if dtype == dtypes.bool:
# host-shareablity is a requirement for storage buffers, but 'bool' type is not host-shareable
if Device.DEFAULT == "WEBGPU": return False
# TODO remove triton from here once internal casting is fixed. CAST of fp32s between 0-1 is broken in triton
if getenv("TRITON") == 1: return False
return True
def get_available_cast_dtypes(dtype: DType) -> List[DType]: return [v for k, v in DTYPES_DICT.items() if v != dtype and is_dtype_supported(v) and not k.startswith("_")] # dont cast internal dtypes
def _test_to_np(a:Tensor, np_dtype, target):
if DEBUG >= 2: print(a)
@ -26,31 +44,49 @@ def _assert_eq(tensor:Tensor, target_dtype:DType, target):
raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e
def _test_op(fxn, target_dtype:DType, target): _assert_eq(fxn(), target_dtype, target)
def _test_cast(a:Tensor, target_dtype:DType, target): _test_op(lambda: a.cast(target_dtype), target_dtype, target)
def _test_cast(a:Tensor, target_dtype:DType): _test_op(lambda: a.cast(target_dtype), target_dtype, a.numpy().astype(target_dtype.np).tolist())
def _test_bitcast(a:Tensor, target_dtype:DType, target): _test_op(lambda: a.bitcast(target_dtype), target_dtype, target)
# tests no-op casts from source_dtype to target_dtypes
def _test_casts_from(tensor_contents:List, source_dtype:DType, target_dtypes:List[DType], target_contents:Optional[List]=None):
if target_contents is None: target_contents = copy.deepcopy(tensor_contents)
list(map(
lambda t_dtype: _test_cast(Tensor(tensor_contents, dtype=source_dtype), t_dtype, target_contents),
target_dtypes
class TestDType(unittest.TestCase):
DTYPE: Any = None
DATA: Any = None
@classmethod
def setUpClass(cls):
if not is_dtype_supported(cls.DTYPE): raise unittest.SkipTest("dtype not supported")
cls.DATA = np.random.randint(0, 100, size=10, dtype=cls.DTYPE.np).tolist() if dtypes.is_int(cls.DTYPE) else np.random.choice([True, False], size=10).tolist() if cls.DTYPE == dtypes.bool else np.random.uniform(0, 1, size=10).tolist()
def setUp(self):
if self.DTYPE is None: raise unittest.SkipTest("base class")
def test_to_np(self): _test_to_np(Tensor(self.DATA, dtype=self.DTYPE), self.DTYPE.np, np.array(self.DATA, dtype=self.DTYPE.np))
def test_casts_to(self): list(map(
lambda dtype: _test_cast(Tensor(self.DATA, dtype=dtype), self.DTYPE),
get_available_cast_dtypes(self.DTYPE)
))
# tests no-op casts from source_dtypes to target_dtype
def _test_casts_to(tensor_contents:List, source_dtypes:List[DType], target_dtype:DType, target_contents:Optional[List]=None):
if target_contents is None: target_contents = copy.deepcopy(tensor_contents)
list(map(
lambda s_dtype: _test_cast(Tensor(tensor_contents, dtype=s_dtype), target_dtype, target_contents),
source_dtypes
def test_casts_from(self): list(map(
lambda dtype: _test_cast(Tensor(self.DATA, dtype=self.DTYPE), dtype),
get_available_cast_dtypes(self.DTYPE)
))
def test_upcast_ops(self): list(map(
lambda dtype: _test_ops(a_dtype=self.DTYPE, b_dtype=dtype, target_dtype=dtype) if dtype.sz > self.DTYPE.sz else None,
get_available_cast_dtypes(self.DTYPE)
))
def test_upcast_to_ops(self): list(map(
lambda dtype: _test_ops(a_dtype=dtype, b_dtype=self.DTYPE, target_dtype=self.DTYPE) if dtype.sz < self.DTYPE.sz else None,
get_available_cast_dtypes(self.DTYPE)
))
def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype:DType):
if not is_dtype_supported(a_dtype) or not is_dtype_supported(b_dtype): raise unittest.SkipTest("dtype not supported")
_assert_eq(Tensor([1,2,3,4], dtype=a_dtype)+Tensor([1,2,3,4], dtype=b_dtype), target_dtype, [2,4,6,8])
_assert_eq(Tensor([1,2,3,4], dtype=a_dtype)*Tensor([1,2,3,4], dtype=b_dtype), target_dtype, [1,4,9,16])
_assert_eq(Tensor([[1,2],[3,4]], dtype=a_dtype)@Tensor.eye(2, dtype=b_dtype), target_dtype, [[1,2],[3,4]])
_assert_eq(Tensor([1,1,1,1], dtype=a_dtype)+Tensor.ones((4,4), dtype=b_dtype), target_dtype, 2*Tensor.ones(4,4).numpy())
class TestBFloat16DType(unittest.TestCase):
def setUp(self):
if not is_dtype_supported(dtypes.bfloat16): raise unittest.SkipTest("bfloat16 not supported")
def test_bf16_to_float(self):
with self.assertRaises(AssertionError):
_test_cast(Tensor([100000], dtype=dtypes.bfloat16), dtypes.float32, [100000])
@ -61,14 +97,12 @@ class TestBFloat16DType(unittest.TestCase):
# torch.tensor([10000, -1, -1000, -10000, 20]).type(torch.bfloat16)
@unittest.skipIf(Device.DEFAULT not in ["LLVM"], "bf16 only on LLVM")
def test_bf16(self):
t = Tensor([10000, -1, -1000, -10000, 20]).cast(dtypes.bfloat16)
t.realize()
back = t.cast(dtypes.float32)
assert tuple(back.numpy().tolist()) == (9984., -1, -1000, -9984, 20)
@unittest.skipIf(Device.DEFAULT not in ["LLVM"], "bf16 only on LLVM")
def test_bf16_disk_write_read(self):
t = Tensor([10000, -1, -1000, -10000, 20]).cast(dtypes.float32)
t.to(f"disk:{temp('f32')}").realize()
@ -82,44 +116,20 @@ class TestBFloat16DType(unittest.TestCase):
back = t.cast(dtypes.float32)
assert tuple(back.numpy().tolist()) == (9984., -1, -1000, -9984, 20)
# for GPU, cl_khr_fp16 isn't supported (except now we don't need it!)
# for LLVM, it segfaults because it can't link to the casting function
@unittest.skipIf((getenv("CI", "") != "" and Device.DEFAULT in ["LLVM"]) or Device.DEFAULT == "WEBGPU", "float16 broken in some CI backends")
class TestHalfDtype(unittest.TestCase):
def test_float16_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.float16), np.float16, [1,2,3,4])
def test_casts_to_half(self): _test_casts_to([1,2,3,4], source_dtypes=[dtypes.float32, dtypes.int8, dtypes.uint8], target_dtype=dtypes.float16)
def test_casts_from_half(self): _test_casts_from([1,2,3,4], source_dtype=dtypes.float16, target_dtypes=[dtypes.int8, dtypes.uint8, dtypes.float32, dtypes.int32, dtypes.int64])
def test_half_upcast_ops(self): _test_ops(a_dtype=dtypes.float16, b_dtype=dtypes.float32, target_dtype=dtypes.float32)
def test_upcast_to_half_ops(self): _test_ops(a_dtype=dtypes.int8, b_dtype=dtypes.float16, target_dtype=dtypes.float16)
class TestHalfDtype(TestDType): DTYPE = dtypes.half
@unittest.skipIf(Device.DEFAULT in ["WEBGPU", "METAL"] or OSX, "float64 is not supported by some backends")
class TestDoubleDtype(unittest.TestCase):
def test_float64_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.double), np.double, [1,2,3,4])
def test_casts_to_float64(self): _test_casts_to([1,2,3,4], source_dtypes=[dtypes.float32, dtypes.int32, dtypes.uint8], target_dtype=dtypes.float64)
def test_upcast_to_float64_ops(self): _test_ops(a_dtype=dtypes.int8, b_dtype=dtypes.float64, target_dtype=dtypes.float64)
class TestFloatDType(TestDType): DTYPE = dtypes.float
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu does not support int8")
class TestInt8Dtype(unittest.TestCase):
def test_int8_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.int8), np.int8, [1,2,3,4])
def test_uint8_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.uint8), np.uint8, [1,2,3,4])
def test_int64_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.int64), np.int64, [1,2,3,4])
class TestDoubleDtype(TestDType): DTYPE = dtypes.double
def test_casts_to_int8(self): _test_casts_from([1,2,3,4], source_dtype=dtypes.float32, target_dtypes=[dtypes.int8, dtypes.uint8, dtypes.int32, dtypes.int64])
def test_casts_to_bool_1(self): _test_casts_from([1,2,3,4], source_dtype=dtypes.int8, target_dtypes=[dtypes.bool], target_contents=[True, True, True, True])
def test_casts_to_bool_2(self): _test_casts_from([1,0,3,4], source_dtype=dtypes.int8, target_dtypes=[dtypes.bool], target_contents=[True, False, True, True])
def test_casts_from_int8(self): _test_casts_from([1,2,3,4], source_dtype=dtypes.int8, target_dtypes=[dtypes.float32, dtypes.uint8, dtypes.int32, dtypes.int64])
def test_casts_from_uint8(self): _test_casts_from([1,2,3,4], source_dtype=dtypes.uint8, target_dtypes=[dtypes.float32, dtypes.int8, dtypes.int32, dtypes.int64])
def test_int8_ops(self): _test_ops(a_dtype=dtypes.int8, b_dtype=dtypes.int8, target_dtype=dtypes.int8)
def test_int64_ops(self): _test_ops(a_dtype=dtypes.int64, b_dtype=dtypes.int64, target_dtype=dtypes.int64)
def test_int8_upcast_float(self): _test_ops(a_dtype=dtypes.int8, b_dtype=dtypes.float32, target_dtype=dtypes.float32)
def test_int8_upcast_int64(self): _test_ops(a_dtype=dtypes.int8, b_dtype=dtypes.int64, target_dtype=dtypes.int64)
@unittest.skipIf(getenv("CUDA",0)==1, "cuda saturation works differently")
@unittest.skipIf(getenv("PTX",0)==1, "cuda saturation doesn't wrap")
class TestInt8Dtype(TestDType):
DTYPE = dtypes.int8
@unittest.skipIf(getenv("CUDA",0)==1 or getenv("PTX", 0)==1, "cuda saturation works differently")
def test_int8_to_uint8_negative(self): _test_op(lambda: Tensor([-1, -2, -3, -4], dtype=dtypes.int8).cast(dtypes.uint8), dtypes.uint8, [255, 254, 253, 252])
@unittest.skipIf(getenv("PTX",0)==1, "cuda saturation doesn't wrap")
class TestUint8Dtype(TestDType):
DTYPE = dtypes.uint8
@unittest.skipIf(getenv("CUDA",0)==1 or getenv("PTX", 0)==1, "cuda saturation works differently")
def test_uint8_to_int8_overflow(self): _test_op(lambda: Tensor([255, 254, 253, 252], dtype=dtypes.uint8).cast(dtypes.int8), dtypes.int8, [-1, -2, -3, -4])
@unittest.skipIf(Device.DEFAULT not in {"CPU", "TORCH"}, "only bitcast in CPU and TORCH")
@ -141,23 +151,16 @@ class TestBitCast(unittest.TestCase):
with self.assertRaises(AssertionError):
_test_bitcast(Tensor([100000], dtype=dtypes.float32), dtypes.uint8, [100000])
class TestInt32Dtype(unittest.TestCase):
def test_int32_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.int32), np.int32, [1,2,3,4])
class TestInt16Dtype(TestDType): DTYPE = dtypes.int16
class TestUint16Dtype(TestDType): DTYPE = dtypes.uint16
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu does not support int64")
def test_casts_to_int32(self): _test_casts_to([1,2,3,4], source_dtypes=[dtypes.float32, dtypes.int64], target_dtype=dtypes.int32)
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu does not support int64")
def test_casts_from_int32(self): _test_casts_from([1,2,3,4], source_dtype=dtypes.int32, target_dtypes=[dtypes.float32, dtypes.int64])
class TestInt32Dtype(TestDType): DTYPE = dtypes.int32
class TestUint32Dtype(TestDType): DTYPE = dtypes.uint32
def test_int32_ops(self): _test_ops(a_dtype=dtypes.int32, b_dtype=dtypes.int32, target_dtype=dtypes.int32)
def test_int32_upcast_float32(self): _test_ops(a_dtype=dtypes.int32, b_dtype=dtypes.float32, target_dtype=dtypes.float32)
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu does not support int64")
def test_int32_upcast_int64(self): _test_ops(a_dtype=dtypes.int32, b_dtype=dtypes.int64, target_dtype=dtypes.int64)
class TestInt64Dtype(TestDType): DTYPE = dtypes.int64
class TestUint64Dtype(TestDType): DTYPE = dtypes.uint64
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "host-shareablity is a requirement for storage buffers, but 'bool' type is not host-shareable")
class TestBoolDtype(unittest.TestCase):
def test_casts_from_bool(self): _test_casts_from([0,1,1,0], source_dtype=dtypes.bool, target_dtypes=[dtypes.float32, dtypes.int32])
def test_casts_to_bool(self): _test_casts_to([0,1,1,0], source_dtypes=[dtypes.float32, dtypes.int32], target_dtype=dtypes.bool)
class TestBoolDtype(TestDType): DTYPE = dtypes.bool
class TestEqStrDType(unittest.TestCase):
def test_image_ne(self):

View File

@ -24,7 +24,7 @@ code_for_op: Final[Dict[Op, Callable]] = {
TernaryOps.WHERE: lambda builder,x,y,z: builder.select(builder.fcmp_unordered("!=", x, ir.Constant(ir.FloatType(), 0), flags=LLVM_FAST_MATH_FLAGS) if isinstance(x.type, ir.FloatType) else builder.trunc(x, ir.IntType(1)), y, z, flags=LLVM_FAST_MATH_FLAGS),
}
dtype_to_llvm_dtype = {dtypes.float64:ir.DoubleType(), dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32), dtypes._arg_int32: ir.IntType(32)}
dtype_to_llvm_dtype = {dtypes.float64:ir.DoubleType(), dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32), dtypes._arg_int32: ir.IntType(32), dtypes.int16:ir.IntType(16), dtypes.uint16:ir.IntType(16), dtypes.uint32:ir.IntType(32), dtypes.uint64:ir.IntType(64)}
def cast(bb, val, input_type, output_type):
if input_type == output_type: return val

View File

@ -8,8 +8,8 @@ import linecache
import math
import re
triton_dtypes = {dtypes.double: "tl.float64", dtypes.float32: "tl.float32", dtypes.float16: "tl.float16", dtypes.bool: "tl.int1", dtypes.int8: "tl.int8", dtypes.uint8: "tl.uint8", dtypes.int32: "tl.int32", dtypes.int64: "tl.int64", dtypes.uint32: "tl.uint32", dtypes.uint64: "tl.uint64"}
signature_dtypes = {dtypes.double: "*fp64",dtypes.float32: "*fp32", dtypes.float16: "*fp16", dtypes.bool: "*i8", dtypes.int8: "*i1", dtypes.uint8: "*u8", dtypes._arg_int32: "i32", dtypes.int32: "*i32", dtypes.int64: "*i64", dtypes.uint32: "*u32", dtypes.uint64: "*u64"}
triton_dtypes = {dtypes.double: "tl.float64", dtypes.float32: "tl.float32", dtypes.float16: "tl.float16", dtypes.bool: "tl.int1", dtypes.int8: "tl.int8", dtypes.uint8: "tl.uint8", dtypes.int32: "tl.int32", dtypes.int64: "tl.int64", dtypes.uint32: "tl.uint32", dtypes.uint64: "tl.uint64", dtypes.int16: "tl.int16", dtypes.uint16: "tl.uint16"}
signature_dtypes = {dtypes.double: "*fp64",dtypes.float32: "*fp32", dtypes.float16: "*fp16", dtypes.bool: "*i8", dtypes.int8: "*i1", dtypes.uint8: "*u8", dtypes._arg_int32: "i32", dtypes.int32: "*i32", dtypes.int64: "*i64", dtypes.uint32: "*u32", dtypes.uint64: "*u64", dtypes.int16: "*i16", dtypes.uint16: "*u16"}
def next_power_of_2(x):
return 1 << (x - 1).bit_length()

View File

@ -43,7 +43,7 @@ class RawBufferMapped(RawBufferCopyIn):
# this one is simple enough that i moved it out of the runtimes
class RawMallocBuffer(RawBufferMapped):
def __init__(self, size, dtype: DType): super().__init__(size, dtype, ({dtypes.float64:ctypes.c_double, dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16, dtypes.bfloat16: ctypes.c_int16, dtypes.int8: ctypes.c_int8, dtypes.uint8: ctypes.c_uint8, dtypes.bool: ctypes.c_uint8, dtypes.int32: ctypes.c_int32, dtypes.uint32: ctypes.c_uint32, dtypes.int64: ctypes.c_int64, dtypes.uint64: ctypes.c_uint64}[dtype] * size)())
def __init__(self, size, dtype: DType): super().__init__(size, dtype, ({dtypes.float64:ctypes.c_double, dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16, dtypes.bfloat16: ctypes.c_int16, dtypes.int8: ctypes.c_int8, dtypes.uint8: ctypes.c_uint8, dtypes.bool: ctypes.c_uint8, dtypes.int32: ctypes.c_int32, dtypes.uint32: ctypes.c_uint32, dtypes.int64: ctypes.c_int64, dtypes.uint64: ctypes.c_uint64, dtypes.int16: ctypes.c_int16, dtypes.uint16: ctypes.c_uint16}[dtype] * size)())
def _buffer(self): return memoryview(self._buf)
class RawBufferCopyInOut(RawBufferCopyIn):

View File

@ -18,7 +18,7 @@ args = {
'Darwin': {'cflags':'-lm -fPIC --rtlib=compiler-rt ', 'ext':'dylib', 'exp':''}
}[platform.system()]
CLANG_PROGRAM_HEADER = '#include <math.h>\n#define max(x,y) ((x>y)?x:y)\n#define int64 long\n#define half __fp16\n#define uchar unsigned char\n#define bool uchar\n'
CLANG_PROGRAM_HEADER = '#include <math.h>\n#define max(x,y) ((x>y)?x:y)\n#define int64 long\n#define half __fp16\n#define uchar unsigned char\n#include <stdbool.h>\n'
ADDRESS = 0x10000
# Unicorn doesn't support external calls