[GH-1305] Refactor test_dtypes.py to be cleaner (#1306)

Co-authored-by: waifairer <waifairer@gmail.com>
This commit is contained in:
waifairer 2023-07-21 16:18:02 -06:00 committed by GitHub
parent 48c4df1263
commit 7cac5ea16c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 53 additions and 81 deletions

View File

@ -3,29 +3,50 @@ import numpy as np
from tinygrad.helpers import getenv, DType, DEBUG from tinygrad.helpers import getenv, DType, DEBUG
from tinygrad.lazy import Device from tinygrad.lazy import Device
from tinygrad.tensor import Tensor, dtypes from tinygrad.tensor import Tensor, dtypes
from typing import List, Optional
from extra.utils import OSX, temp from extra.utils import OSX, temp
import copy
def _test_to_np(a:Tensor, np_dtype, target): def _test_to_np(a:Tensor, np_dtype, target):
print(a) if DEBUG >= 2: print(a)
na = a.numpy() na = a.numpy()
print(na, na.dtype, a.lazydata.realized) if DEBUG >= 2: print(na, na.dtype, a.lazydata.realized)
try:
assert na.dtype == np_dtype assert na.dtype == np_dtype
np.testing.assert_allclose(na, target) np.testing.assert_allclose(na, target)
except AssertionError as e:
raise AssertionError(f"\ntensor {a.numpy()} does not match target {target} with np_dtype {np_dtype}") from e
def _test_op(fxn, target_dtype:DType, target): def _assert_eq(tensor:Tensor, target_dtype:DType, target):
c = fxn() if DEBUG >= 2: print(tensor.numpy())
if DEBUG >= 2: print(c.numpy()) try:
assert c.dtype == target_dtype assert tensor.dtype == target_dtype
np.testing.assert_allclose(c.numpy(), target) np.testing.assert_allclose(tensor.numpy(), target)
except AssertionError as e:
raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e
def _test_op(fxn, target_dtype:DType, target): _assert_eq(fxn(), target_dtype, target)
def _test_cast(a:Tensor, target_dtype:DType, target): _test_op(lambda: a.cast(target_dtype), target_dtype, target) def _test_cast(a:Tensor, target_dtype:DType, target): _test_op(lambda: a.cast(target_dtype), target_dtype, target)
def _test_add(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(lambda: a+b, target_dtype, target)
def _test_mul(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(lambda: a*b, target_dtype, target)
def _test_matmul(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(lambda: a@b, target_dtype, target)
def _test_add_upcast(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(lambda: a+b, target_dtype, target)
def _test_mul_upcast(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(lambda: a*b, target_dtype, target)
def _test_matmul_upcast(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(lambda: a@b, target_dtype, target)
# tests no-op casts from source_dtype to target_dtypes
def _test_casts_from(tensor_contents:List, source_dtype:DType, target_dtypes:List[DType], target_contents:Optional[List]=None):
if target_contents is None: target_contents = copy.deepcopy(tensor_contents)
list(map(
lambda t_dtype: _test_cast(Tensor(tensor_contents, dtype=source_dtype), t_dtype, target_contents),
target_dtypes
))
# tests no-op casts from source_dtypes to target_dtype
def _test_casts_to(tensor_contents:List, source_dtypes:List[DType], target_dtype:DType, target_contents:Optional[List]=None):
if target_contents is None: target_contents = copy.deepcopy(tensor_contents)
list(map(
lambda s_dtype: _test_cast(Tensor(tensor_contents, dtype=s_dtype), target_dtype, target_contents),
source_dtypes
))
def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype:DType):
_assert_eq(Tensor([1,2,3,4], dtype=a_dtype)+Tensor([1,2,3,4], dtype=b_dtype), target_dtype, [2,4,6,8])
_assert_eq(Tensor([1,2,3,4], dtype=a_dtype)*Tensor([1,2,3,4], dtype=b_dtype), target_dtype, [1,4,9,16])
_assert_eq(Tensor([[1,2],[3,4]], dtype=a_dtype)@Tensor.eye(2, dtype=b_dtype), target_dtype, [[1,2],[3,4]])
class TestBFloat16DType(unittest.TestCase): class TestBFloat16DType(unittest.TestCase):
def test_bf16_to_float(self): def test_bf16_to_float(self):
@ -63,28 +84,11 @@ class TestBFloat16DType(unittest.TestCase):
# for LLVM, it segfaults because it can't link to the casting function # for LLVM, it segfaults because it can't link to the casting function
@unittest.skipIf((getenv("CI", "") != "" and Device.DEFAULT in ["LLVM"]) or Device.DEFAULT == "WEBGPU", "float16 broken in some CI backends") @unittest.skipIf((getenv("CI", "") != "" and Device.DEFAULT in ["LLVM"]) or Device.DEFAULT == "WEBGPU", "float16 broken in some CI backends")
class TestHalfDtype(unittest.TestCase): class TestHalfDtype(unittest.TestCase):
def test_half_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.float16), np.float16, [1,2,3,4]) def test_float16_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.float16), np.float16, [1,2,3,4])
def test_casts_to_half(self): _test_casts_to([1,2,3,4], source_dtypes=[dtypes.float32, dtypes.int8, dtypes.uint8], target_dtype=dtypes.float16)
def test_half_to_float(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float32, [1,2,3,4]) def test_casts_from_half(self): _test_casts_from([1,2,3,4], source_dtype=dtypes.float16, target_dtypes=[dtypes.int8, dtypes.uint8, dtypes.float32, dtypes.int32, dtypes.int64])
def test_half_to_int8(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.int8, [1,2,3,4]) def test_half_upcast_ops(self): _test_ops(a_dtype=dtypes.float16, b_dtype=dtypes.float32, target_dtype=dtypes.float32)
def test_half_to_uint8(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.uint8, [1,2,3,4]) def test_upcast_to_half_ops(self): _test_ops(a_dtype=dtypes.int8, b_dtype=dtypes.float16, target_dtype=dtypes.float16)
def test_half_to_int32(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.int32, [1,2,3,4])
def test_half_to_int64(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.int64, [1,2,3,4])
def test_float_to_half(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float16, [1,2,3,4])
def test_int8_to_half(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.float16, [1,2,3,4])
def test_uint8_to_half(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.uint8), dtypes.float16, [1,2,3,4])
def test_half_add(self): _test_add(Tensor([1,2,3,4], dtype=dtypes.float16), Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float16, [2,4,6,8])
def test_half_mul(self): _test_mul(Tensor([1,2,3,4], dtype=dtypes.float16), Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float16, [1,4,9,16])
def test_half_matmul(self): _test_matmul(Tensor([[1,2],[3,4]], dtype=dtypes.float16), Tensor.eye(2, dtype=dtypes.float16), dtypes.float16, [[1,2],[3,4]])
def test_half_add_upcast_float(self): _test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.float16), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [2,4,6,8])
def test_int8_add_upcast_half(self): _test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float16, [2,4,6,8])
def test_int8_mul_upcast_half(self): _test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float16, [1,4,9,16])
def test_half_mul_upcast_float(self): _test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.float16), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [1,4,9,16])
def test_half_matmul_upcast_float(self): _test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.float16), Tensor.eye(2, dtype=dtypes.float32), dtypes.float32, [[1,2],[3,4]])
def test_int8_matmul_upcast_half(self): _test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.float16), dtypes.float16, [[1,2],[3,4]])
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu does not support int8") @unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu does not support int8")
class TestInt8Dtype(unittest.TestCase): class TestInt8Dtype(unittest.TestCase):
@ -92,35 +96,14 @@ class TestInt8Dtype(unittest.TestCase):
def test_uint8_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.uint8), np.uint8, [1,2,3,4]) def test_uint8_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.uint8), np.uint8, [1,2,3,4])
def test_int64_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.int64), np.int64, [1,2,3,4]) def test_int64_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.int64), np.int64, [1,2,3,4])
def test_float_to_int8(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.int8, [1,2,3,4]) def test_casts_to_int8(self): _test_casts_from([1,2,3,4], source_dtype=dtypes.float32, target_dtypes=[dtypes.int8, dtypes.uint8, dtypes.int32, dtypes.int64])
def test_float_to_uint8(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.uint8, [1,2,3,4]) def test_casts_from_int8(self): _test_casts_from([1,2,3,4], source_dtype=dtypes.int8, target_dtypes=[dtypes.float32, dtypes.uint8, dtypes.int32, dtypes.int64])
def test_float_to_int64(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.int64, [1,2,3,4]) def test_casts_from_uint8(self): _test_casts_from([1,2,3,4], source_dtype=dtypes.uint8, target_dtypes=[dtypes.float32, dtypes.int8, dtypes.int32, dtypes.int64])
def test_int8_to_float(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.float32, [1,2,3,4]) def test_int8_ops(self): _test_ops(a_dtype=dtypes.int8, b_dtype=dtypes.int8, target_dtype=dtypes.int8)
def test_int8_to_uint8(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.uint8, [1,2,3,4]) def test_int64_ops(self): _test_ops(a_dtype=dtypes.int64, b_dtype=dtypes.int64, target_dtype=dtypes.int64)
def test_int8_to_int32(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.int32, [1,2,3,4]) def test_int8_upcast_float(self): _test_ops(a_dtype=dtypes.int8, b_dtype=dtypes.float32, target_dtype=dtypes.float32)
def test_int8_to_int64(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.int64, [1,2,3,4]) def test_int8_upcast_int64(self): _test_ops(a_dtype=dtypes.int8, b_dtype=dtypes.int64, target_dtype=dtypes.int64)
def test_uint8_to_float(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.uint8), dtypes.float32, [1,2,3,4])
def test_uint8_to_int8(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.uint8), dtypes.int8, [1,2,3,4])
def test_uint8_to_int64(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.uint8), dtypes.int64, [1,2,3,4])
def test_int8_add(self): _test_add(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.int8, [2,4,6,8])
def test_int64_add(self): _test_add(Tensor([1,2,3,4], dtype=dtypes.int64),Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [2,4,6,8])
def test_int8_mul(self): _test_mul(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.int8, [1,4,9,16])
def test_int64_mul(self): _test_mul(Tensor([1,2,3,4], dtype=dtypes.int64), Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [1,4,9,16])
def test_int8_matmul(self): _test_matmul(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.int8), dtypes.int8, [[1,2],[3,4]])
def test_int64_matmul(self): _test_matmul(Tensor([[1,2],[3,4]], dtype=dtypes.int64), Tensor.eye(2, dtype=dtypes.int64), dtypes.int64, [[1,2],[3,4]])
def test_int8_add_upcast_float(self): _test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [2,4,6,8])
def test_int8_mul_upcast_float(self): _test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [1,4,9,16])
def test_int8_matmul_upcast_float(self): _test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.float32), dtypes.float32, [[1,2],[3,4]])
def test_int8_add_upcast_int64(self): _test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [2,4,6,8])
def test_int8_mul_upcast_int64(self): _test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [1,4,9,16])
def test_int8_matmul_upcast_int64(self): _test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.int64), dtypes.int64, [[1,2],[3,4]])
@unittest.skipIf(getenv("CUDA",0)==1, "cuda saturation works differently") @unittest.skipIf(getenv("CUDA",0)==1, "cuda saturation works differently")
def test_int8_to_uint8_negative(self): _test_op(lambda: Tensor([-1, -2, -3, -4], dtype=dtypes.int8).cast(dtypes.uint8), dtypes.uint8, [255, 254, 253, 252]) def test_int8_to_uint8_negative(self): _test_op(lambda: Tensor([-1, -2, -3, -4], dtype=dtypes.int8).cast(dtypes.uint8), dtypes.uint8, [255, 254, 253, 252])
@ -130,23 +113,12 @@ class TestInt8Dtype(unittest.TestCase):
class TestInt32Dtype(unittest.TestCase): class TestInt32Dtype(unittest.TestCase):
def test_int32_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.int32), np.int32, [1,2,3,4]) def test_int32_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.int32), np.int32, [1,2,3,4])
def test_float_to_int32(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.int32, [1,2,3,4]) def test_casts_to_int32(self): _test_casts_to([1,2,3,4], source_dtypes=[dtypes.float32, dtypes.int64], target_dtype=dtypes.int32)
def test_int64_to_int32(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int32, [1,2,3,4]) def test_casts_from_int32(self): _test_casts_from([1,2,3,4], source_dtype=dtypes.int32, target_dtypes=[dtypes.float32, dtypes.int64])
def test_int32_to_float(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.int32), dtypes.float32, [1,2,3,4]) def test_int32_ops(self): _test_ops(a_dtype=dtypes.int32, b_dtype=dtypes.int32, target_dtype=dtypes.int32)
def test_int32_to_int64(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.int32), dtypes.int64, [1,2,3,4]) def test_int32_upcast_float32(self): _test_ops(a_dtype=dtypes.int32, b_dtype=dtypes.float32, target_dtype=dtypes.float32)
def test_int32_upcast_int64(self): _test_ops(a_dtype=dtypes.int32, b_dtype=dtypes.int64, target_dtype=dtypes.int64)
def test_int32_add(self): _test_add(Tensor([1,2,3,4], dtype=dtypes.int32), Tensor([1,2,3,4], dtype=dtypes.int32), dtypes.int32, [2,4,6,8])
def test_int32_mul(self): _test_mul(Tensor([1,2,3,4], dtype=dtypes.int32), Tensor([1,2,3,4], dtype=dtypes.int32), dtypes.int32, [1,4,9,16])
def test_int32_matmul(self): _test_matmul(Tensor([[1,2],[3,4]], dtype=dtypes.int32), Tensor.eye(2, dtype=dtypes.int32), dtypes.int32, [[1,2],[3,4]])
def test_int32_add_upcast_float(self): _test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.int32), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [2,4,6,8])
def test_int32_mul_upcast_float(self): _test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.int32), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [1,4,9,16])
def test_int32_matmul_upcast_float(self): _test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.int32), Tensor.eye(2, dtype=dtypes.float32), dtypes.float32, [[1,2],[3,4]])
def test_int32_add_upcast_int64(self): _test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.int32), Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [2,4,6,8])
def test_int32_mul_upcast_int64(self): _test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.int32), Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [1,4,9,16])
def test_int32_matmul_upcast_int64(self): _test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.int32), Tensor.eye(2, dtype=dtypes.int64), dtypes.int64, [[1,2],[3,4]])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()