2024-03-16 02:33:26 +08:00
|
|
|
import unittest, operator
|
2023-03-11 08:56:07 +08:00
|
|
|
import numpy as np
|
2023-12-13 13:10:29 +08:00
|
|
|
import torch
|
2024-02-24 22:22:06 +08:00
|
|
|
from typing import Any, List
|
2024-03-16 02:33:26 +08:00
|
|
|
from tinygrad.helpers import getenv, DEBUG
|
2024-01-02 06:58:48 +08:00
|
|
|
from tinygrad.dtype import DType, DTYPES_DICT, ImageDType, PtrDType, least_upper_float, least_upper_dtype
|
2024-02-24 22:22:06 +08:00
|
|
|
from tinygrad import Device, Tensor, dtypes
|
2024-01-20 00:55:26 +08:00
|
|
|
from hypothesis import given, settings, strategies as strat
|
2024-03-16 02:33:26 +08:00
|
|
|
from test.helpers import is_dtype_supported
|
2023-10-31 13:38:42 +08:00
|
|
|
|
2024-01-20 03:03:01 +08:00
|
|
|
settings.register_profile("my_profile", max_examples=200, deadline=None)
|
|
|
|
settings.load_profile("my_profile")
|
|
|
|
|
2023-12-26 00:33:17 +08:00
|
|
|
core_dtypes = list(DTYPES_DICT.values())
|
2024-03-10 07:30:34 +08:00
|
|
|
if Device.DEFAULT == "CPU": core_dtypes.remove(dtypes.bfloat16) # NOTE: this is for teenygrad, don't remove
|
2024-03-23 02:22:06 +08:00
|
|
|
dtype_ints = [dt for dt in core_dtypes if dtypes.is_int(dt) and is_dtype_supported(dt)]
|
2024-03-20 04:31:27 +08:00
|
|
|
dtype_floats = [dt for dt in core_dtypes if dtypes.is_float(dt) and is_dtype_supported(dt)]
|
2023-10-31 13:38:42 +08:00
|
|
|
|
2023-12-16 12:46:57 +08:00
|
|
|
def get_available_cast_dtypes(dtype: DType) -> List[DType]:
|
|
|
|
if not is_dtype_supported(dtype): return []
|
|
|
|
return [v for k, v in DTYPES_DICT.items() if v != dtype and is_dtype_supported(v) and not k.startswith("_")] # dont cast internal dtypes
|
2023-05-31 08:49:26 +08:00
|
|
|
|
|
|
|
def _test_to_np(a:Tensor, np_dtype, target):
|
2023-07-22 06:18:02 +08:00
|
|
|
if DEBUG >= 2: print(a)
|
2023-05-31 08:49:26 +08:00
|
|
|
na = a.numpy()
|
2023-12-23 03:45:13 +08:00
|
|
|
if DEBUG >= 2: print(na, na.dtype, a.lazydata.base.realized)
|
2023-07-22 06:18:02 +08:00
|
|
|
try:
|
|
|
|
assert na.dtype == np_dtype
|
|
|
|
np.testing.assert_allclose(na, target)
|
|
|
|
except AssertionError as e:
|
|
|
|
raise AssertionError(f"\ntensor {a.numpy()} does not match target {target} with np_dtype {np_dtype}") from e
|
|
|
|
|
|
|
|
def _assert_eq(tensor:Tensor, target_dtype:DType, target):
|
|
|
|
if DEBUG >= 2: print(tensor.numpy())
|
|
|
|
try:
|
|
|
|
assert tensor.dtype == target_dtype
|
2024-03-15 12:17:11 +08:00
|
|
|
np.testing.assert_allclose(tensor.numpy(), target, rtol={dtypes.float16:1e-3, dtypes.bfloat16:1e-2}.get(target_dtype, 1e-7))
|
2023-07-22 06:18:02 +08:00
|
|
|
except AssertionError as e:
|
|
|
|
raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e
|
|
|
|
|
2023-12-19 05:06:09 +08:00
|
|
|
def _test_op(fxn, target_dtype:DType, target):
|
|
|
|
_assert_eq(fxn(), target_dtype, target)
|
|
|
|
def _test_cast(a:Tensor, target_dtype:DType):
|
|
|
|
_test_op(lambda: a.cast(target_dtype), target_dtype, list(a.numpy().astype(target_dtype.np)))
|
|
|
|
def _test_bitcast(a:Tensor, target_dtype:DType, target=None):
|
2024-03-15 12:17:11 +08:00
|
|
|
if target_dtype == dtypes.bfloat16: raise unittest.SkipTest("no test for bf16 bitcast yet")
|
2023-12-19 05:06:09 +08:00
|
|
|
_test_op(lambda: a.bitcast(target_dtype), target_dtype, target or a.numpy().view(target_dtype.np).tolist())
|
2023-05-31 08:49:26 +08:00
|
|
|
|
2023-10-31 13:38:42 +08:00
|
|
|
class TestDType(unittest.TestCase):
|
|
|
|
DTYPE: Any = None
|
|
|
|
DATA: Any = None
|
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls):
|
2023-12-12 01:28:19 +08:00
|
|
|
if not cls.DTYPE or not is_dtype_supported(cls.DTYPE): raise unittest.SkipTest("dtype not supported")
|
2024-03-23 02:22:06 +08:00
|
|
|
DATA_SIZE = 10
|
|
|
|
if dtypes.is_unsigned(cls.DTYPE):
|
|
|
|
cls.DATA = np.random.randint(0, 100, size=DATA_SIZE, dtype=cls.DTYPE.np)
|
|
|
|
elif dtypes.is_int(cls.DTYPE):
|
|
|
|
cls.DATA = np.random.randint(-100, 100, size=DATA_SIZE, dtype=cls.DTYPE.np)
|
|
|
|
elif cls.DTYPE == dtypes.bool:
|
|
|
|
cls.DATA = np.random.choice([True, False], size=DATA_SIZE)
|
|
|
|
else:
|
|
|
|
# TODO: include negative numbers here and fix negative number cast to uint
|
|
|
|
cls.DATA = np.random.uniform(0, 10, size=DATA_SIZE).astype(cls.DTYPE.np)
|
2023-10-31 13:38:42 +08:00
|
|
|
def setUp(self):
|
|
|
|
if self.DTYPE is None: raise unittest.SkipTest("base class")
|
|
|
|
|
|
|
|
def test_to_np(self): _test_to_np(Tensor(self.DATA, dtype=self.DTYPE), self.DTYPE.np, np.array(self.DATA, dtype=self.DTYPE.np))
|
|
|
|
|
|
|
|
def test_casts_to(self): list(map(
|
|
|
|
lambda dtype: _test_cast(Tensor(self.DATA, dtype=dtype), self.DTYPE),
|
|
|
|
get_available_cast_dtypes(self.DTYPE)
|
|
|
|
))
|
|
|
|
def test_casts_from(self): list(map(
|
|
|
|
lambda dtype: _test_cast(Tensor(self.DATA, dtype=self.DTYPE), dtype),
|
|
|
|
get_available_cast_dtypes(self.DTYPE)
|
|
|
|
))
|
|
|
|
|
2023-11-10 07:17:43 +08:00
|
|
|
def test_same_size_ops(self):
|
|
|
|
list(map(
|
2023-12-24 14:14:54 +08:00
|
|
|
lambda dtype: _test_ops(a_dtype=self.DTYPE, b_dtype=dtype) if dtype.itemsize == self.DTYPE.itemsize else None,
|
2023-11-10 07:17:43 +08:00
|
|
|
get_available_cast_dtypes(self.DTYPE)
|
|
|
|
))
|
2023-12-24 14:14:54 +08:00
|
|
|
def test_upcast_ops(self):
|
|
|
|
list(map(
|
|
|
|
lambda dtype: _test_ops(a_dtype=self.DTYPE, b_dtype=dtype) if dtype.itemsize > self.DTYPE.itemsize else None,
|
|
|
|
get_available_cast_dtypes(self.DTYPE)
|
2023-07-22 06:18:02 +08:00
|
|
|
))
|
2023-11-10 07:17:43 +08:00
|
|
|
def test_upcast_to_ops(self):
|
|
|
|
list(map(
|
2023-12-24 14:14:54 +08:00
|
|
|
lambda dtype: _test_ops(a_dtype=dtype, b_dtype=self.DTYPE) if dtype.itemsize < self.DTYPE.itemsize else None,
|
|
|
|
get_available_cast_dtypes(self.DTYPE)
|
2023-07-22 06:18:02 +08:00
|
|
|
))
|
2023-12-06 08:19:28 +08:00
|
|
|
def test_bitcast(self):
|
2024-01-09 01:29:13 +08:00
|
|
|
if Device.DEFAULT == "WEBGL": raise unittest.SkipTest("no bitcast in WebGL GLSL")
|
2023-12-06 08:19:28 +08:00
|
|
|
if self.DTYPE == dtypes.bool: raise unittest.SkipTest("no bools in bitcast")
|
|
|
|
list(map(
|
2023-12-19 05:06:09 +08:00
|
|
|
lambda dtype:
|
|
|
|
_test_bitcast(Tensor(self.DATA, dtype=self.DTYPE), dtype) if dtype.itemsize == self.DTYPE.itemsize and dtype != dtypes.bool else None,
|
2023-12-06 08:19:28 +08:00
|
|
|
get_available_cast_dtypes(self.DTYPE)
|
|
|
|
))
|
2023-07-22 06:18:02 +08:00
|
|
|
|
2024-01-03 01:37:56 +08:00
|
|
|
def test_dtypes_fields(self):
|
|
|
|
fields = dtypes.fields()
|
|
|
|
self.assertTrue(all(isinstance(value, DType) for value in fields.values()))
|
|
|
|
self.assertTrue(all(issubclass(value.np, np.generic) for value in fields.values() if value.np is not None))
|
|
|
|
|
2024-01-27 16:13:42 +08:00
|
|
|
def test_resulting_and_init_dtypes_match(self):
|
|
|
|
dtypes = list(map(np.dtype, ["bool", "uint8", "int8", "int16", "int32", "int64", "float32", "float64"]))
|
|
|
|
data = [1., 2., 0., 0.5, -1.5, 5.25]
|
|
|
|
for dt in dtypes:
|
|
|
|
arr = np.asarray(data, dtype=dt)
|
|
|
|
tin = Tensor(arr).numpy()
|
|
|
|
tor = torch.as_tensor(arr).detach().numpy()
|
|
|
|
assert dt is tin.dtype is tor.dtype, f"dtype mismatch: expected={dt} | tinygrad={tin.dtype} | torch={tor.dtype}"
|
|
|
|
np.testing.assert_allclose(tin, tor, atol=1e-6, rtol=1e-3)
|
|
|
|
|
2023-11-10 07:17:43 +08:00
|
|
|
def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None):
|
2023-12-16 12:46:57 +08:00
|
|
|
target_dtype = target_dtype or least_upper_dtype(a_dtype, b_dtype)
|
|
|
|
if not is_dtype_supported(a_dtype) or not is_dtype_supported(b_dtype) or not is_dtype_supported(target_dtype): return
|
2023-11-10 07:17:43 +08:00
|
|
|
if a_dtype == dtypes.bool or b_dtype == dtypes.bool: return
|
2023-07-22 06:18:02 +08:00
|
|
|
_assert_eq(Tensor([1,2,3,4], dtype=a_dtype)+Tensor([1,2,3,4], dtype=b_dtype), target_dtype, [2,4,6,8])
|
2024-02-14 22:40:48 +08:00
|
|
|
_assert_eq((Tensor([1], dtype=a_dtype).cast(b_dtype)+Tensor([1], dtype=a_dtype).cast(b_dtype)).cast(a_dtype), a_dtype, [2])
|
2023-07-22 06:18:02 +08:00
|
|
|
_assert_eq(Tensor([1,2,3,4], dtype=a_dtype)*Tensor([1,2,3,4], dtype=b_dtype), target_dtype, [1,4,9,16])
|
|
|
|
_assert_eq(Tensor([[1,2],[3,4]], dtype=a_dtype)@Tensor.eye(2, dtype=b_dtype), target_dtype, [[1,2],[3,4]])
|
2023-08-19 12:40:13 +08:00
|
|
|
_assert_eq(Tensor([1,1,1,1], dtype=a_dtype)+Tensor.ones((4,4), dtype=b_dtype), target_dtype, 2*Tensor.ones(4,4).numpy())
|
2023-07-20 11:18:32 +08:00
|
|
|
|
2024-03-18 06:51:22 +08:00
|
|
|
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "bfloat16 not supported")
|
2024-03-14 06:44:05 +08:00
|
|
|
class TestBFloat16(unittest.TestCase):
|
|
|
|
def test_bf16_creation_numpy(self):
|
|
|
|
data = [-1, 1, 2]
|
|
|
|
t = Tensor(data, dtype=dtypes.bfloat16)
|
|
|
|
assert t.dtype == dtypes.bfloat16
|
|
|
|
tnp = t.numpy()
|
|
|
|
assert tnp.dtype == np.float32
|
|
|
|
np.testing.assert_allclose(tnp, np.array(data))
|
|
|
|
|
2024-03-14 08:46:45 +08:00
|
|
|
def test_bf16_ones(self):
|
2024-03-15 07:24:05 +08:00
|
|
|
# TODO: fix this with correct bfloat16 cast
|
2024-03-14 08:46:45 +08:00
|
|
|
t = Tensor.ones(3, 5, dtype=dtypes.bfloat16)
|
|
|
|
assert t.dtype == dtypes.bfloat16
|
|
|
|
np.testing.assert_allclose(t.numpy(), np.ones((3, 5)))
|
|
|
|
|
|
|
|
def test_bf16_eye(self):
|
2024-03-15 07:24:05 +08:00
|
|
|
# TODO: fix this with correct bfloat16 cast
|
2024-03-14 08:46:45 +08:00
|
|
|
t = Tensor.eye(3, dtype=dtypes.bfloat16)
|
|
|
|
assert t.dtype == dtypes.bfloat16
|
|
|
|
np.testing.assert_allclose(t.numpy(), np.eye(3))
|
|
|
|
|
2024-03-18 06:51:22 +08:00
|
|
|
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "bfloat16 not supported")
|
2023-07-20 11:18:32 +08:00
|
|
|
class TestBFloat16DType(unittest.TestCase):
|
|
|
|
def test_bf16_to_float(self):
|
2024-03-14 06:44:05 +08:00
|
|
|
_test_cast(Tensor([100000], dtype=dtypes.bfloat16), dtypes.float32)
|
2023-07-20 11:18:32 +08:00
|
|
|
|
|
|
|
def test_float_to_bf16(self):
|
2024-03-15 12:17:11 +08:00
|
|
|
_test_cast(Tensor([100000], dtype=dtypes.float32), dtypes.bfloat16)
|
2023-07-20 11:18:32 +08:00
|
|
|
|
|
|
|
# torch.tensor([10000, -1, -1000, -10000, 20]).type(torch.bfloat16)
|
|
|
|
|
|
|
|
def test_bf16(self):
|
|
|
|
t = Tensor([10000, -1, -1000, -10000, 20]).cast(dtypes.bfloat16)
|
|
|
|
t.realize()
|
|
|
|
back = t.cast(dtypes.float32)
|
|
|
|
assert tuple(back.numpy().tolist()) == (9984., -1, -1000, -9984, 20)
|
|
|
|
|
2024-03-18 06:51:22 +08:00
|
|
|
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "bfloat16 not supported")
|
2024-03-08 02:45:36 +08:00
|
|
|
class TestBFloat16DTypeCast(unittest.TestCase):
|
|
|
|
def test_f16_to_bf16_conversion(self):
|
|
|
|
original_tensor = Tensor([1.0, 2.0, 3.0], dtype=dtypes.float16)
|
|
|
|
converted_tensor = original_tensor.cast(dtypes.bfloat16)
|
|
|
|
self.assertEqual(converted_tensor.dtype, dtypes.bfloat16)
|
|
|
|
back_to_float32 = converted_tensor.cast(dtypes.float32)
|
|
|
|
original_to_float32 = original_tensor.cast(dtypes.float32)
|
|
|
|
np.testing.assert_allclose(back_to_float32.numpy(), original_to_float32.numpy(), rtol=1e-2, atol=1e-3)
|
|
|
|
|
|
|
|
def test_f16_to_bf16_edge_cases(self):
|
|
|
|
edge_cases = Tensor([0.0, -0.0, float('inf'), float('-inf'), float('nan')], dtype=dtypes.float16)
|
|
|
|
converted = edge_cases.cast(dtypes.bfloat16).cast(dtypes.float32)
|
|
|
|
np.testing.assert_equal(converted.numpy(), edge_cases.cast(dtypes.float32).numpy())
|
|
|
|
|
|
|
|
def test_f16_to_bf16_range_precision(self):
|
|
|
|
large_value = Tensor([65504.0], dtype=dtypes.float16) # Max representable in float16
|
|
|
|
small_value = Tensor([6.1035e-5], dtype=dtypes.float16) # Smallest positive normal float16
|
|
|
|
large_converted = large_value.cast(dtypes.bfloat16).cast(dtypes.float32)
|
|
|
|
small_converted = small_value.cast(dtypes.bfloat16).cast(dtypes.float32)
|
|
|
|
np.testing.assert_allclose(large_converted.numpy(), large_value.cast(dtypes.float32).numpy(), rtol=1e-2, atol=1e-3)
|
|
|
|
np.testing.assert_equal(small_converted.numpy(), small_value.cast(dtypes.float32).numpy())
|
|
|
|
|
|
|
|
def test_f16_to_bf16_randomized(self):
|
|
|
|
np.random.seed(42) # For reproducibility
|
|
|
|
random_values = Tensor(np.random.uniform(-65504, 65504, 1000), dtype=dtypes.float16)
|
|
|
|
converted = random_values.cast(dtypes.bfloat16).cast(dtypes.float32)
|
|
|
|
np.testing.assert_allclose(converted.numpy(), random_values.cast(dtypes.float32).numpy(), rtol=1e-2, atol=1e-3)
|
|
|
|
|
2023-10-31 13:38:42 +08:00
|
|
|
class TestHalfDtype(TestDType): DTYPE = dtypes.half
|
|
|
|
|
2024-02-20 16:20:43 +08:00
|
|
|
class TestFloatDType(TestDType):
|
|
|
|
DTYPE = dtypes.float
|
|
|
|
|
|
|
|
def test_float_to_uint(self):
|
|
|
|
_test_op(lambda: Tensor([-0.9, -0.3, 1.2], dtype=dtypes.float32).cast(dtypes.uint32), dtypes.uint32,
|
|
|
|
[0, 0, 1])
|
2023-10-31 13:38:42 +08:00
|
|
|
|
2024-02-16 17:08:59 +08:00
|
|
|
class TestDoubleDtype(TestDType):
|
|
|
|
DTYPE = dtypes.double
|
2024-03-23 02:22:06 +08:00
|
|
|
@unittest.skipIf(getenv("CUDACPU") or getenv("PTX"), "conversion not supported on CUDACPU and PTX") # TODO: why not?
|
2024-02-16 17:08:59 +08:00
|
|
|
def test_float64_increased_precision(self):
|
|
|
|
for func in [
|
|
|
|
lambda t: t.exp(),
|
|
|
|
lambda t: t.exp2(),
|
|
|
|
lambda t: t.log(),
|
|
|
|
lambda t: t.log2(),
|
|
|
|
lambda t: t.sqrt(),
|
|
|
|
lambda t: t.rsqrt(),
|
|
|
|
lambda t: t.sin(),
|
|
|
|
lambda t: t.cos(),
|
|
|
|
lambda t: t.tan(),
|
|
|
|
lambda t: t.sigmoid(),
|
|
|
|
]:
|
|
|
|
a = [2, 3, 4]
|
|
|
|
np.testing.assert_allclose(func(Tensor(a, dtype=self.DTYPE)).numpy(), func(torch.tensor(a, dtype=torch.float64)), rtol=1e-12, atol=1e-12)
|
2023-10-31 13:38:42 +08:00
|
|
|
|
2024-02-20 16:20:43 +08:00
|
|
|
def test_float64_to_float32_cast_inf(self):
|
|
|
|
_test_op(lambda: Tensor([3.4e40, 3.4e38, 1, 0], dtype=dtypes.float64).cast(dtypes.float32),
|
|
|
|
dtypes.float32, [float('inf'), 3.4e38, 1, 0])
|
|
|
|
|
|
|
|
|
2023-10-31 13:38:42 +08:00
|
|
|
class TestInt8Dtype(TestDType):
|
|
|
|
DTYPE = dtypes.int8
|
|
|
|
@unittest.skipIf(getenv("CUDA",0)==1 or getenv("PTX", 0)==1, "cuda saturation works differently")
|
2023-12-19 05:06:09 +08:00
|
|
|
def test_int8_to_uint8_negative(self):
|
|
|
|
_test_op(lambda: Tensor([-1, -2, -3, -4], dtype=dtypes.int8).cast(dtypes.uint8), dtypes.uint8, [255, 254, 253, 252])
|
2023-05-31 08:49:26 +08:00
|
|
|
|
2024-02-20 16:20:43 +08:00
|
|
|
def test_int8_to_uint16_negative(self):
|
|
|
|
_test_op(lambda: Tensor([-1, -2, -3, -4], dtype=dtypes.int8).cast(dtypes.uint16), dtypes.uint16, [2**16-1, 2**16-2, 2**16-3, 2**16-4])
|
|
|
|
|
2023-10-31 13:38:42 +08:00
|
|
|
class TestUint8Dtype(TestDType):
|
|
|
|
DTYPE = dtypes.uint8
|
|
|
|
@unittest.skipIf(getenv("CUDA",0)==1 or getenv("PTX", 0)==1, "cuda saturation works differently")
|
2023-12-19 05:06:09 +08:00
|
|
|
def test_uint8_to_int8_overflow(self):
|
|
|
|
_test_op(lambda: Tensor([255, 254, 253, 252], dtype=dtypes.uint8).cast(dtypes.int8), dtypes.int8, [-1, -2, -3, -4])
|
2023-03-23 23:02:52 +08:00
|
|
|
|
2024-01-09 01:29:13 +08:00
|
|
|
@unittest.skipIf(Device.DEFAULT == "WEBGL", "No bitcast on WebGL")
|
2023-08-06 15:30:50 +08:00
|
|
|
class TestBitCast(unittest.TestCase):
|
|
|
|
def test_shape_change_bitcast(self):
|
|
|
|
with self.assertRaises(AssertionError):
|
|
|
|
_test_bitcast(Tensor([100000], dtype=dtypes.float32), dtypes.uint8, [100000])
|
|
|
|
|
2024-01-06 06:51:25 +08:00
|
|
|
def test_bitcast_float_to_int32(self):
|
|
|
|
a = Tensor([1.,2,3])
|
|
|
|
b = a.bitcast(dtypes.int32)
|
|
|
|
assert b.numpy()[0] == 0x3f800000
|
|
|
|
|
|
|
|
def test_bitcast_upcasted(self):
|
|
|
|
a = Tensor.zeros(100, 4, dtype=dtypes.int32).contiguous() + 0x3f800000
|
|
|
|
b = a.bitcast(dtypes.float32)
|
|
|
|
assert b.numpy()[0,0] == 1.
|
|
|
|
|
2023-10-31 13:38:42 +08:00
|
|
|
class TestInt16Dtype(TestDType): DTYPE = dtypes.int16
|
2024-02-20 16:20:43 +08:00
|
|
|
|
|
|
|
class TestUint16Dtype(TestDType):
|
|
|
|
DTYPE = dtypes.uint16
|
|
|
|
|
|
|
|
def test_uint16_to_int8_overflow(self):
|
|
|
|
_test_op(lambda: Tensor([2**16-1, 2**16-2, 1, 0], dtype=dtypes.uint16).cast(dtypes.int8), dtypes.int8, [-1, -2, 1, 0])
|
2023-07-13 01:29:15 +08:00
|
|
|
|
2023-10-31 13:38:42 +08:00
|
|
|
class TestInt32Dtype(TestDType): DTYPE = dtypes.int32
|
|
|
|
class TestUint32Dtype(TestDType): DTYPE = dtypes.uint32
|
2023-07-13 01:29:15 +08:00
|
|
|
|
2023-10-31 13:38:42 +08:00
|
|
|
class TestInt64Dtype(TestDType): DTYPE = dtypes.int64
|
|
|
|
class TestUint64Dtype(TestDType): DTYPE = dtypes.uint64
|
2023-07-13 01:29:15 +08:00
|
|
|
|
2023-10-31 13:38:42 +08:00
|
|
|
class TestBoolDtype(TestDType): DTYPE = dtypes.bool
|
2023-08-10 01:12:52 +08:00
|
|
|
|
2023-12-06 03:42:28 +08:00
|
|
|
class TestImageDType(unittest.TestCase):
|
|
|
|
def test_image_scalar(self):
|
|
|
|
assert dtypes.imagef((10,10)).scalar() == dtypes.float32
|
|
|
|
assert dtypes.imageh((10,10)).scalar() == dtypes.float32
|
|
|
|
def test_image_vec(self):
|
|
|
|
assert dtypes.imagef((10,10)).vec(4) == dtypes.float32.vec(4)
|
|
|
|
assert dtypes.imageh((10,10)).vec(4) == dtypes.float32.vec(4)
|
|
|
|
|
2023-10-17 08:52:38 +08:00
|
|
|
class TestEqStrDType(unittest.TestCase):
|
|
|
|
def test_image_ne(self):
|
2023-11-25 04:50:56 +08:00
|
|
|
if ImageDType is None: raise unittest.SkipTest("no ImageDType support")
|
2023-10-17 08:52:38 +08:00
|
|
|
assert dtypes.float == dtypes.float32, "float doesn't match?"
|
|
|
|
assert dtypes.imagef((1,2,4)) != dtypes.imageh((1,2,4)), "different image dtype doesn't match"
|
|
|
|
assert dtypes.imageh((1,2,4)) != dtypes.imageh((1,4,2)), "different shape doesn't match"
|
|
|
|
assert dtypes.imageh((1,2,4)) == dtypes.imageh((1,2,4)), "same shape matches"
|
|
|
|
assert isinstance(dtypes.imageh((1,2,4)), ImageDType)
|
|
|
|
def test_ptr_ne(self):
|
2023-11-25 04:50:56 +08:00
|
|
|
if PtrDType is None: raise unittest.SkipTest("no PtrDType support")
|
2023-10-17 08:52:38 +08:00
|
|
|
# TODO: is this the wrong behavior?
|
|
|
|
assert PtrDType(dtypes.float32) == dtypes.float32
|
2024-01-16 00:36:29 +08:00
|
|
|
assert not (PtrDType(dtypes.float32) != dtypes.float32)
|
|
|
|
assert PtrDType(dtypes.float32) == PtrDType(dtypes.float32)
|
|
|
|
assert not (PtrDType(dtypes.float32) != PtrDType(dtypes.float32))
|
2023-10-17 08:52:38 +08:00
|
|
|
#assert PtrDType(dtypes.float32) != dtypes.float32
|
|
|
|
def test_strs(self):
|
2023-11-25 04:50:56 +08:00
|
|
|
if PtrDType is None: raise unittest.SkipTest("no PtrDType support")
|
2023-10-17 08:52:38 +08:00
|
|
|
self.assertEqual(str(dtypes.imagef((1,2,4))), "dtypes.imagef((1, 2, 4))")
|
|
|
|
self.assertEqual(str(PtrDType(dtypes.float32)), "ptr.dtypes.float")
|
|
|
|
|
2023-12-12 01:28:19 +08:00
|
|
|
class TestHelpers(unittest.TestCase):
|
|
|
|
signed_ints = (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64)
|
|
|
|
uints = (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
|
|
|
|
floats = (dtypes.float16, dtypes.float32, dtypes.float64)
|
|
|
|
|
2024-01-20 00:55:26 +08:00
|
|
|
@given(strat.sampled_from(signed_ints+uints), strat.integers(min_value=1, max_value=8))
|
2023-12-12 01:28:19 +08:00
|
|
|
def test_is_int(self, dtype, amt):
|
|
|
|
assert dtypes.is_int(dtype.vec(amt) if amt > 1 else dtype)
|
|
|
|
assert not dtypes.is_float(dtype.vec(amt) if amt > 1 else dtype)
|
|
|
|
|
2024-01-20 00:55:26 +08:00
|
|
|
@given(strat.sampled_from(uints), strat.integers(min_value=1, max_value=8))
|
2023-12-12 01:28:19 +08:00
|
|
|
def test_is_unsigned_uints(self, dtype, amt):
|
|
|
|
assert dtypes.is_unsigned(dtype.vec(amt) if amt > 1 else dtype)
|
|
|
|
|
2024-01-20 00:55:26 +08:00
|
|
|
@given(strat.sampled_from(signed_ints), strat.integers(min_value=1, max_value=8))
|
2023-12-12 01:28:19 +08:00
|
|
|
def test_is_unsigned_signed_ints(self, dtype, amt):
|
|
|
|
assert not dtypes.is_unsigned(dtype.vec(amt) if amt > 1 else dtype)
|
|
|
|
|
2024-01-20 00:55:26 +08:00
|
|
|
@given(strat.sampled_from(floats), strat.integers(min_value=1, max_value=8))
|
2023-12-12 01:28:19 +08:00
|
|
|
def test_is_float(self, dtype, amt):
|
|
|
|
assert dtypes.is_float(dtype.vec(amt) if amt > 1 else dtype)
|
|
|
|
assert not dtypes.is_int(dtype.vec(amt) if amt > 1 else dtype)
|
|
|
|
assert not dtypes.is_unsigned(dtype.vec(amt) if amt > 1 else dtype)
|
|
|
|
|
2023-12-16 10:41:30 +08:00
|
|
|
def test_bf16_is_float(self):
|
|
|
|
assert dtypes.is_float(dtypes.bfloat16)
|
|
|
|
|
2024-01-20 00:55:26 +08:00
|
|
|
@given(strat.sampled_from([d for d in DTYPES_DICT.values() if dtypes.is_float(d) or dtypes.is_int(d)]), strat.integers(min_value=2, max_value=8))
|
2023-12-12 01:28:19 +08:00
|
|
|
def test_scalar(self, dtype, amt):
|
|
|
|
assert dtype.vec(amt).scalar() == dtype
|
|
|
|
|
2023-12-12 08:33:49 +08:00
|
|
|
class TestTypeSpec(unittest.TestCase):
|
2023-12-19 01:21:44 +08:00
|
|
|
def setUp(self):
|
|
|
|
self.old_default_int, self.old_default_float = dtypes.default_int, dtypes.default_float
|
|
|
|
def tearDown(self):
|
|
|
|
dtypes.default_int, dtypes.default_float = self.old_default_int, self.old_default_float
|
|
|
|
|
|
|
|
def test_set_dtype_default(self):
|
2024-03-20 04:31:27 +08:00
|
|
|
for default_int in [dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64]:
|
|
|
|
dtypes.default_int = default_int
|
|
|
|
assert dtypes.default_int == default_int
|
|
|
|
|
|
|
|
for default_float in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
|
|
|
|
dtypes.default_float = default_float
|
|
|
|
assert dtypes.default_float == default_float
|
|
|
|
|
|
|
|
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
|
2023-12-19 01:21:44 +08:00
|
|
|
def test_creation(self, default_int, default_float):
|
|
|
|
dtypes.default_int, dtypes.default_float = default_int, default_float
|
2024-03-20 04:31:27 +08:00
|
|
|
_assert_eq(Tensor(True), dtypes.bool, True)
|
|
|
|
_assert_eq(Tensor(None), dtypes.default_float, [])
|
|
|
|
_assert_eq(Tensor(2), dtypes.default_int, 2)
|
|
|
|
_assert_eq(Tensor(2.34), dtypes.default_float, 2.34)
|
|
|
|
_assert_eq(Tensor([]), dtypes.default_float, [])
|
|
|
|
_assert_eq(Tensor([1]), dtypes.default_int, [1])
|
|
|
|
_assert_eq(Tensor([1.1]), dtypes.default_float, [1.1])
|
|
|
|
|
|
|
|
_assert_eq(Tensor.eye(0), dtypes.default_float, np.eye(0))
|
|
|
|
_assert_eq(Tensor.eye(3), dtypes.default_float, np.eye(3))
|
|
|
|
_assert_eq(Tensor.eye(3, dtype=dtypes.int64), dtypes.int64, np.eye(3))
|
|
|
|
if is_dtype_supported(dtypes.float16):
|
|
|
|
_assert_eq(Tensor.eye(3, dtype=dtypes.float16), dtypes.float16, np.eye(3))
|
|
|
|
|
|
|
|
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
|
2023-12-19 01:21:44 +08:00
|
|
|
def test_full(self, default_int, default_float):
|
|
|
|
dtypes.default_int, dtypes.default_float = default_int, default_float
|
|
|
|
|
2024-03-20 04:31:27 +08:00
|
|
|
_assert_eq(Tensor.zeros((2, 3)), dtypes.default_float, np.zeros((2, 3)))
|
|
|
|
_assert_eq(Tensor.zeros((2, 3), dtype=dtypes.int64), dtypes.int64, np.zeros((2, 3)))
|
|
|
|
if is_dtype_supported(dtypes.float16):
|
|
|
|
_assert_eq(Tensor.zeros((2, 3), dtype=dtypes.float16), dtypes.float16, np.zeros((2, 3)))
|
|
|
|
|
|
|
|
_assert_eq(Tensor.ones((2, 3)), dtypes.default_float, np.ones((2, 3)))
|
|
|
|
_assert_eq(Tensor.ones((2, 3), dtype=dtypes.int64), dtypes.int64, np.ones((2, 3)))
|
|
|
|
if is_dtype_supported(dtypes.float16):
|
|
|
|
_assert_eq(Tensor.ones((2, 3), dtype=dtypes.float16), dtypes.float16, np.ones((2, 3)))
|
|
|
|
|
|
|
|
_assert_eq(Tensor.full((2, 3), 3.0), dtypes.default_float, np.full((2, 3), 3.0))
|
|
|
|
_assert_eq(Tensor.full((2, 3), 3), dtypes.default_int, np.full((2, 3), 3))
|
|
|
|
_assert_eq(Tensor.full((2, 3), True), dtypes.bool, np.full((2, 3), True))
|
|
|
|
_assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
|
|
|
|
_assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
|
|
|
|
if is_dtype_supported(dtypes.float16):
|
|
|
|
_assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.float16), dtypes.float16, np.full((2, 3), 3))
|
|
|
|
_assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.float16), dtypes.float16, np.full((2, 3), 3))
|
|
|
|
|
|
|
|
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
|
|
|
|
def test_reduce_0d_default(self, default_int, default_float):
|
|
|
|
dtypes.default_int, dtypes.default_float = default_int, default_float
|
|
|
|
_assert_eq(Tensor.ones((2,3,0)).sum(2), dtypes.default_float, np.zeros((2, 3)))
|
|
|
|
# TODO: what should this one be?
|
|
|
|
# _assert_eq(Tensor.ones((2,3,0), dtype=dtypes.default_int).sum(2), dtypes.default_int, np.zeros((2, 3)))
|
|
|
|
_assert_eq(Tensor.ones((2,3,0), dtype=dtypes.int32).sum(2), dtypes.int32, np.zeros((2, 3)))
|
2023-12-12 08:33:49 +08:00
|
|
|
|
2024-03-20 05:49:58 +08:00
|
|
|
@unittest.skipIf(Device.DEFAULT=="RHIP", "failed in HIP CI")
|
2024-03-20 04:31:27 +08:00
|
|
|
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
|
2023-12-19 01:21:44 +08:00
|
|
|
def test_arange(self, default_int, default_float):
|
|
|
|
dtypes.default_int, dtypes.default_float = default_int, default_float
|
|
|
|
|
2024-03-20 05:49:58 +08:00
|
|
|
_assert_eq(Tensor.arange(5), dtypes.default_int, np.arange(5))
|
|
|
|
_assert_eq(Tensor.arange(120), dtypes.default_int, np.arange(120))
|
|
|
|
_assert_eq(Tensor.arange(5.0), dtypes.default_float, np.arange(5))
|
|
|
|
_assert_eq(Tensor.arange(5, dtype=dtypes.int16), dtypes.int16, np.arange(5))
|
|
|
|
_assert_eq(Tensor.arange(5, dtype=dtypes.int64), dtypes.int64, np.arange(5))
|
|
|
|
if is_dtype_supported(dtypes.float16):
|
|
|
|
_assert_eq(Tensor.arange(5, dtype=dtypes.float16), dtypes.float16, np.arange(5))
|
|
|
|
_assert_eq(Tensor.arange(3, 9, 0.7), dtypes.default_float, np.arange(3, 9, 0.7))
|
|
|
|
_assert_eq(Tensor.arange(3, 8.5, 3), dtypes.default_float, np.arange(3, 8.5, 3))
|
2023-12-15 06:53:00 +08:00
|
|
|
|
2024-01-20 00:55:26 +08:00
|
|
|
@given(strat.sampled_from(core_dtypes), strat.sampled_from([operator.gt, operator.ge, operator.le, operator.lt, operator.eq, operator.ne]))
|
2023-12-26 01:38:47 +08:00
|
|
|
def test_bool_ops(self, dtype, op):
|
|
|
|
assert op(Tensor.rand(4, 4, dtype=dtype), Tensor.rand(4, 4, dtype=dtype)).dtype == dtypes.bool
|
|
|
|
|
2024-03-20 05:06:57 +08:00
|
|
|
@given(strat.sampled_from(core_dtypes), strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
|
2023-12-26 00:33:17 +08:00
|
|
|
def test_functions_return_index(self, dtype, default_int, default_float):
|
|
|
|
dtypes.default_int, dtypes.default_float = default_int, default_float
|
2024-03-20 05:06:57 +08:00
|
|
|
assert Tensor([0, 1], dtype=dtype).argmax().dtype == dtypes.int32
|
|
|
|
assert Tensor([0, 1], dtype=dtype).argmin().dtype == dtypes.int32
|
|
|
|
assert Tensor([0, 1], dtype=dtype).multinomial().dtype == dtypes.int32
|
2023-12-26 00:33:17 +08:00
|
|
|
|
2023-12-12 12:14:23 +08:00
|
|
|
class TestTypePromotion(unittest.TestCase):
|
2024-01-20 00:55:26 +08:00
|
|
|
@given(strat.sampled_from(core_dtypes))
|
2023-12-12 12:14:23 +08:00
|
|
|
def test_self_promo_to_self(self, dtype):
|
|
|
|
assert least_upper_dtype(dtype) == dtype
|
|
|
|
assert least_upper_dtype(dtype, dtype) == dtype
|
|
|
|
assert least_upper_dtype(dtype, dtype, dtype) == dtype
|
|
|
|
|
2024-01-20 00:55:26 +08:00
|
|
|
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
|
2023-12-12 12:14:23 +08:00
|
|
|
def test_promo_resulted_higher_than_inputs(self, dtype1, dtype2):
|
|
|
|
result = least_upper_dtype(dtype1, dtype2)
|
|
|
|
assert result >= dtype1 and result >= dtype2
|
|
|
|
|
|
|
|
def test_dtype_promo(self):
|
|
|
|
assert least_upper_dtype(dtypes.bool, dtypes.int8) == dtypes.int8
|
|
|
|
assert least_upper_dtype(dtypes.int8, dtypes.uint8) == dtypes.int16
|
|
|
|
assert least_upper_dtype(dtypes.uint8, dtypes.int16) == dtypes.int16
|
|
|
|
assert least_upper_dtype(dtypes.int16, dtypes.uint16) == dtypes.int32
|
|
|
|
assert least_upper_dtype(dtypes.uint16, dtypes.int32) == dtypes.int32
|
|
|
|
assert least_upper_dtype(dtypes.int32, dtypes.uint32) == dtypes.int64
|
|
|
|
assert least_upper_dtype(dtypes.uint32, dtypes.int64) == dtypes.int64
|
2023-12-13 05:12:57 +08:00
|
|
|
# similar to jax but we don't use weak type
|
|
|
|
assert least_upper_dtype(dtypes.int64, dtypes.uint64) == dtypes.float16
|
2023-12-12 12:14:23 +08:00
|
|
|
assert least_upper_dtype(dtypes.float16, dtypes.float32) == dtypes.float32
|
|
|
|
assert least_upper_dtype(dtypes.float32, dtypes.float64) == dtypes.float64
|
|
|
|
|
|
|
|
assert least_upper_dtype(dtypes.bool, dtypes.float32) == dtypes.float32
|
|
|
|
assert least_upper_dtype(dtypes.bool, dtypes.float64) == dtypes.float64
|
2023-12-13 05:12:57 +08:00
|
|
|
assert least_upper_dtype(dtypes.float16, dtypes.int64) == dtypes.float16
|
|
|
|
assert least_upper_dtype(dtypes.float16, dtypes.uint64) == dtypes.float16
|
2023-12-12 12:14:23 +08:00
|
|
|
|
2024-03-20 04:31:27 +08:00
|
|
|
@given(strat.sampled_from(dtype_floats))
|
2023-12-15 23:11:47 +08:00
|
|
|
def test_float_to_float(self, dt):
|
|
|
|
assert least_upper_float(dt) == dt
|
|
|
|
|
2023-12-13 13:10:29 +08:00
|
|
|
class TestAutoCastType(unittest.TestCase):
|
2023-12-22 11:00:21 +08:00
|
|
|
def setUp(self):
|
|
|
|
self.old_default_int, self.old_default_float = dtypes.default_int, dtypes.default_float
|
|
|
|
def tearDown(self):
|
|
|
|
dtypes.default_int, dtypes.default_float = self.old_default_int, self.old_default_float
|
|
|
|
|
2024-01-20 00:55:26 +08:00
|
|
|
@given(strat.sampled_from([d for d in DTYPES_DICT.values() if dtypes.is_int(d) and is_dtype_supported(d)]))
|
2023-12-13 13:10:29 +08:00
|
|
|
def test_int_to_float_unary_func(self, dtype):
|
|
|
|
for func in [
|
|
|
|
lambda t: t.exp(),
|
2023-12-16 14:58:05 +08:00
|
|
|
lambda t: t.exp2(),
|
2023-12-13 13:10:29 +08:00
|
|
|
lambda t: t.log(),
|
|
|
|
lambda t: t.log2(),
|
|
|
|
lambda t: t.sqrt(),
|
2023-12-16 14:58:05 +08:00
|
|
|
lambda t: t.rsqrt(),
|
2023-12-13 13:10:29 +08:00
|
|
|
lambda t: t.sin(),
|
2023-12-16 14:58:05 +08:00
|
|
|
lambda t: t.cos(),
|
|
|
|
lambda t: t.tan(),
|
2023-12-13 13:10:29 +08:00
|
|
|
lambda t: t.sigmoid(),
|
|
|
|
]:
|
|
|
|
a = [2, 3, 4]
|
2023-12-22 11:00:21 +08:00
|
|
|
# float16 can have larger precision errors
|
|
|
|
np.testing.assert_allclose(func(Tensor(a, dtype=dtype)).numpy(), func(torch.tensor(a)), rtol=1e-3, atol=1e-3)
|
2023-12-13 13:10:29 +08:00
|
|
|
|
2024-01-26 01:26:04 +08:00
|
|
|
@given(strat.sampled_from(core_dtypes))
|
|
|
|
def test_broadcast_scalar(self, dt):
|
|
|
|
assert (Tensor.rand(4, 4, dtype=dt) + 2.3).dtype == (dt if dtypes.is_float(dt) else dtypes.default_float)
|
|
|
|
assert (Tensor.rand(4, 4, dtype=dt) + 2).dtype == (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int)
|
|
|
|
if Device.DEFAULT != "WEBGPU" and dt != dtypes.bool:
|
|
|
|
assert (Tensor.rand(4, 4, dtype=dt) + True).dtype == dt
|
2023-12-12 12:14:23 +08:00
|
|
|
|
2023-12-24 14:14:54 +08:00
|
|
|
def test_sum(self):
|
|
|
|
assert (Tensor([0, 1], dtype=dtypes.bool)).sum().dtype == dtypes.int32
|
|
|
|
assert (Tensor([0, 1], dtype=dtypes.int8)).sum().dtype == dtypes.int32
|
|
|
|
assert (Tensor([0, 1], dtype=dtypes.int16)).sum().dtype == dtypes.int32
|
|
|
|
assert (Tensor([0, 1], dtype=dtypes.int32)).sum().dtype == dtypes.int32
|
|
|
|
assert (Tensor([0, 1], dtype=dtypes.int64)).sum().dtype == dtypes.int64
|
|
|
|
assert (Tensor([0, 1], dtype=dtypes.uint8)).sum().dtype == dtypes.uint32
|
|
|
|
assert (Tensor([0, 1], dtype=dtypes.uint16)).sum().dtype == dtypes.uint32
|
|
|
|
assert (Tensor([0, 1], dtype=dtypes.uint32)).sum().dtype == dtypes.uint32
|
|
|
|
assert (Tensor([0, 1], dtype=dtypes.uint64)).sum().dtype == dtypes.uint64
|
|
|
|
assert (Tensor([0, 1], dtype=dtypes.float16)).sum().dtype == dtypes.float16
|
2024-03-10 07:30:34 +08:00
|
|
|
#assert (Tensor([0, 1], dtype=dtypes.bfloat16)).sum().dtype == dtypes.bfloat16
|
2023-12-24 14:14:54 +08:00
|
|
|
assert (Tensor([0, 1], dtype=dtypes.float32)).sum().dtype == dtypes.float32
|
|
|
|
assert (Tensor([0, 1], dtype=dtypes.float64)).sum().dtype == dtypes.float64
|
|
|
|
|
2023-12-26 00:33:17 +08:00
|
|
|
def test_cumsum(self):
|
|
|
|
assert (Tensor([0, 1], dtype=dtypes.bool)).cumsum(0).dtype == dtypes.int32
|
|
|
|
assert (Tensor([0, 1], dtype=dtypes.int8)).cumsum(0).dtype == dtypes.int32
|
|
|
|
assert (Tensor([0, 1], dtype=dtypes.int16)).cumsum(0).dtype == dtypes.int32
|
|
|
|
assert (Tensor([0, 1], dtype=dtypes.int32)).cumsum(0).dtype == dtypes.int32
|
|
|
|
assert (Tensor([0, 1], dtype=dtypes.int64)).cumsum(0).dtype == dtypes.int64
|
|
|
|
assert (Tensor([0, 1], dtype=dtypes.uint8)).cumsum(0).dtype == dtypes.uint32
|
|
|
|
assert (Tensor([0, 1], dtype=dtypes.uint16)).cumsum(0).dtype == dtypes.uint32
|
|
|
|
assert (Tensor([0, 1], dtype=dtypes.uint32)).cumsum(0).dtype == dtypes.uint32
|
|
|
|
assert (Tensor([0, 1], dtype=dtypes.uint64)).cumsum(0).dtype == dtypes.uint64
|
|
|
|
assert (Tensor([0, 1], dtype=dtypes.float16)).cumsum(0).dtype == dtypes.float16
|
2024-03-10 07:30:34 +08:00
|
|
|
#assert (Tensor([0, 1], dtype=dtypes.bfloat16)).cumsum(0).dtype == dtypes.bfloat16
|
2023-12-26 00:33:17 +08:00
|
|
|
assert (Tensor([0, 1], dtype=dtypes.float32)).cumsum(0).dtype == dtypes.float32
|
|
|
|
assert (Tensor([0, 1], dtype=dtypes.float64)).cumsum(0).dtype == dtypes.float64
|
|
|
|
|
2024-01-20 00:55:26 +08:00
|
|
|
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
|
2023-12-24 14:14:54 +08:00
|
|
|
def test_matmul(self, dt1, dt2):
|
|
|
|
assert (Tensor([0, 1], dtype=dt1) @ Tensor([0, 1], dtype=dt2)).dtype == least_upper_dtype(dt1, dt2)
|
|
|
|
|
2024-01-26 01:26:04 +08:00
|
|
|
@staticmethod
|
|
|
|
def check_where_alternate_input_other(input_, other, data_type):
|
|
|
|
assert (Tensor([True, False]).where(input_, other)).dtype == data_type
|
|
|
|
assert (Tensor([True, False]).where(other, input_)).dtype == data_type
|
|
|
|
|
|
|
|
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
|
|
|
|
def test_where_no_scalar(self, dt1, dt2):
|
|
|
|
self.check_where_alternate_input_other(Tensor(2, dtype=dt1), Tensor(3, dtype=dt2), least_upper_dtype(dt1, dt2))
|
|
|
|
|
|
|
|
@given(strat.sampled_from(core_dtypes))
|
|
|
|
def test_where_one_scalar(self, dt):
|
|
|
|
t = Tensor(2, dtype=dt)
|
|
|
|
self.check_where_alternate_input_other(t, 3.2, (dt if dtypes.is_float(dt) else dtypes.default_float))
|
|
|
|
self.check_where_alternate_input_other(t, 3, (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int))
|
|
|
|
self.check_where_alternate_input_other(t, True, dt)
|
|
|
|
|
|
|
|
def test_where_two_scalars(self):
|
|
|
|
self.check_where_alternate_input_other(3.1, 3.2, dtypes.default_float)
|
|
|
|
self.check_where_alternate_input_other(3.1, 3, dtypes.default_float)
|
|
|
|
self.check_where_alternate_input_other(3.1, True, dtypes.default_float)
|
|
|
|
self.check_where_alternate_input_other(3, 2, dtypes.default_int)
|
|
|
|
self.check_where_alternate_input_other(3, True, dtypes.default_int)
|
|
|
|
self.check_where_alternate_input_other(False, True, dtypes.bool)
|
|
|
|
|
|
|
|
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
|
|
|
|
def test_maximum(self, dt1, dt2):
|
|
|
|
assert Tensor([0, 1, 2], dtype=dt1).maximum(Tensor([2, 0, 5], dtype=dt2)).dtype == least_upper_dtype(dt1, dt2)
|
|
|
|
|
|
|
|
@given(strat.sampled_from(core_dtypes))
|
|
|
|
def test_maximum_const(self, dt):
|
|
|
|
assert Tensor([1, 2], dtype=dt).maximum(3.1).dtype == (dt if dtypes.is_float(dt) else dtypes.default_float)
|
|
|
|
assert Tensor([1, 2], dtype=dt).maximum(3).dtype == (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int)
|
|
|
|
assert Tensor([1, 2], dtype=dt).maximum(True).dtype == dt
|
|
|
|
|
2024-02-16 08:34:40 +08:00
|
|
|
def test_div(self):
|
|
|
|
assert (Tensor([1, 2], dtype=dtypes.int32) / Tensor([2, 2], dtype=dtypes.int32)).dtype == dtypes.default_float
|
|
|
|
assert (Tensor([1, 2], dtype=dtypes.int16) / Tensor([2, 2], dtype=dtypes.int32)).dtype == dtypes.default_float
|
|
|
|
assert (Tensor([1, 2], dtype=dtypes.float32) / Tensor([2, 2], dtype=dtypes.float16)).dtype == dtypes.float32
|
|
|
|
assert (Tensor([1, 2], dtype=dtypes.int32) / Tensor([2, 2], dtype=dtypes.float16)).dtype == dtypes.float16
|
|
|
|
|
|
|
|
def test_div_const(self):
|
|
|
|
assert (Tensor([1, 2], dtype=dtypes.int32) / 2).dtype == dtypes.default_float
|
|
|
|
assert (Tensor([1, 2], dtype=dtypes.int32) / 2.0).dtype == dtypes.default_float
|
|
|
|
assert (Tensor([1, 2], dtype=dtypes.float16) / 2).dtype == dtypes.float16
|
|
|
|
assert (Tensor([1, 2], dtype=dtypes.float16) / 2.0).dtype == dtypes.float16
|
|
|
|
|
2024-03-07 02:49:03 +08:00
|
|
|
class TestImplicitFunctionTypeChange(unittest.TestCase):
|
|
|
|
def test_functions(self):
|
|
|
|
result = []
|
|
|
|
for func in [
|
|
|
|
lambda t: t.exp(),
|
|
|
|
lambda t: t.exp2(),
|
|
|
|
lambda t: t.log(),
|
|
|
|
lambda t: t.log2(),
|
|
|
|
lambda t: t.sqrt(),
|
|
|
|
lambda t: t.sin(),
|
|
|
|
]:
|
|
|
|
t = func(Tensor([4.0, 3.0])).max() == func(Tensor([4.0, 3.0]))
|
|
|
|
result.append(t.numpy().sum())
|
|
|
|
|
|
|
|
if Device.DEFAULT not in ["PYTHON", "CLANG"]:
|
|
|
|
assert all(result)
|
|
|
|
else:
|
|
|
|
# CLANG and PYTHON function default returns in double, and comparison to float can fail
|
|
|
|
# TODO: fix this
|
|
|
|
assert not all(result)
|
|
|
|
|
2023-03-11 08:56:07 +08:00
|
|
|
if __name__ == '__main__':
|
2023-03-23 23:02:52 +08:00
|
|
|
unittest.main()
|