tinygrad/test/test_dtype.py

518 lines
26 KiB
Python
Raw Normal View History

import unittest
import numpy as np
import torch
import operator
from tinygrad.helpers import CI, getenv, DEBUG, OSX, temp
from tinygrad.dtype import DType, DTYPES_DICT, ImageDType, PtrDType, least_upper_float, least_upper_dtype
from tinygrad import Device
from tinygrad.tensor import Tensor, dtypes
from typing import Any, List
from hypothesis import given, settings, strategies as strat
settings.register_profile("my_profile", max_examples=200, deadline=None)
settings.load_profile("my_profile")
core_dtypes = list(DTYPES_DICT.values())
floats = [dt for dt in core_dtypes if dtypes.is_float(dt)]
def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT):
if dtype == dtypes.bfloat16: return False # numpy doesn't support bf16, tested separately in TestBFloat16DType
if device in ["WEBGPU", "WEBGL"]: return dtype in [dtypes.float, dtypes.int32, dtypes.uint32]
# for CI GPU, cl_khr_fp16 isn't supported
# for CI LLVM, it segfaults because it can't link to the casting function
# CUDA in CI uses CUDACPU that does not support half
# PYTHON supports half memoryview in 3.12+ https://github.com/python/cpython/issues/90751
if dtype == dtypes.half: return not (CI and device in ["GPU", "LLVM", "CUDA"]) and device != "PYTHON"
if dtype == dtypes.float64: return device != "METAL" and not (OSX and device == "GPU")
return True
def get_available_cast_dtypes(dtype: DType) -> List[DType]:
if not is_dtype_supported(dtype): return []
return [v for k, v in DTYPES_DICT.items() if v != dtype and is_dtype_supported(v) and not k.startswith("_")] # dont cast internal dtypes
def _test_to_np(a:Tensor, np_dtype, target):
if DEBUG >= 2: print(a)
na = a.numpy()
if DEBUG >= 2: print(na, na.dtype, a.lazydata.base.realized)
try:
assert na.dtype == np_dtype
np.testing.assert_allclose(na, target)
except AssertionError as e:
raise AssertionError(f"\ntensor {a.numpy()} does not match target {target} with np_dtype {np_dtype}") from e
def _assert_eq(tensor:Tensor, target_dtype:DType, target):
if DEBUG >= 2: print(tensor.numpy())
try:
assert tensor.dtype == target_dtype
Christopherm99 ptx (#3139) * get basic ptx impl working * test ops passing * mypy * dont hardcode target * more walrus * ptx in ci * bool cast and f16 load/store * weird numpy bug and f16 cast tolerance * cast half to bool * fix 1 byte load/store * disable half for ptx * fix args and enable xid * fix non-ptr args * allow bitcast * mypy * cleanups * midcast use allclose * add xor * Revert "disable half for ptx" This reverts commit 73391c05fde5f7811293f60d994417d97ab20613. * enable float16 * mypy * no more crashing in ci * fix ci * minor cleanups * use new fn for ptx compiler * no diskcache in ptx compile * use rn instead of rz * save some lines * new DEFINE_GLOBAL syntax * line length * new llvm * cmpeq * minor fix * cast in mulacc * update test_recursive_add to check line count * mypy * remove llvmir.py * fix bool const * wip * cleanups * working * llvm in separate pr * cleanups * more cleanups * fix ci * use in_features directly in nn.Linear.__init__ bound check (#3050) * use in_features directly in nn.Linear.__init__ bound check get rid of the unnecessary check of isinstance int * that is always int * long lines * Device._buffers -> Device._devices (#3052) backend devices used to be called buffers * make Embedding device aware for multigpu (#3051) * make Embedding device aware for multigpu * split line instead of igore because that's cheating * add test incomplete * add test complete * remove comment * fix white space * remove nn.Embedding * remove unused reciprocal (#3053) * remove unused reciprocal * comment * unit tests for Device.canonicalize (#3055) * add multigpu test for RMSNorm (#3056) * need all gather * add two multigpu test scenarios for RMSNorm * No extra vars call (#3054) * remove unused reciprocal * comment * remove unneeded call to vars * free speedup * explicit lazybuffer caching (#3058) * hotfix: remove useless slow assert from ShapeTracker * Speed tweaks (#3059) * base doesn't have to be a function * no double fetch * pop, don't check * make the gc happy * avoid hasattr * cache canonicalize * remove assert, faster base * don't redefine that every time * fix gpt2 attention with start_pos = 0 (#3061) * fix gpt2 attention with start_pos size 1 test cases taken from ll_transformer branch * fix interpreted * Tensor.cat with 0 shape tensors (#3062) * Tensor.cat with 0 shape tensors supported both 0 in cat axis (for a subset of input), or 0 in non-cat axis (all needs to be 0) * no shp * test scaled dot product attention (#3063) * add test * add initial test for scaled dot product attention * test pass for scaled dot product attention * cached size (#3060) * cached size * simplify simplify * 0 doesn't have base * fix test * cleaner cache * hmm, metal is flaky on this...might be real(ish) but useless as test * short circuit reshape/expand properly * better reshape bypass * hotfix: use is for enum compare * hotfix: use is for enum compare, a few more * speedtweaks3: apply shouldn't use the tensor constructor (#3065) * speedtweaks3: apply shouldn't use the tensor constructor * replace 0 size with CONST, not 0 in shape * update gh actions (#3033) * update checkout actions * update upload artifact * update setup python --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com> * unbind view or shapetracker also returns var_val (#3067) * unbind view or shapetracker also returns var_val 4% faster for llama compile time * one line less * unbound_views * hotfix: examples/transformer.py * jit autorealizes output (#3069) * early gate the graph (#3070) * simpler idxs_to_idx (#3071) * filter_strides -> canonicalize_strides (#3072) * fix onehot and jit in examples/transformer (#3073) trained to 0.999 in < 6 seconds on M1 Max consistently * better test demonstration (#3077) * a better test demonstration * fix white space * Tensor.expand resolves the new_shape before shortcut return (#3078) similar to how reshape is done. also updated shrink shortcut criteria to read similar to pad * minor cleanups of lazy.py (#3080) * wmma: clean up device specific tensor core code (#3081) * mem_estimate is always int, not symbolic (#3083) * mem_estimate is always int, not symbolic op_estimate can be symbolic, but mem_estimate is always int, thus we don't need to sym_infer it. fixed some long lines too. update_stats is a very big function * operator does not need underscores * cat works (#3086) * hotfix disable flaky mac runner wino cifar (#3087) * remove the third merging state in view._merge_dims (#3085) no logic depends on state == 0 or state == 2 * minor cleanup of View.reshape (#3088) * minor cleanup of View.reshape removed some redundant logic * new_strides * revert that * use BEAM=2 instead of BEAM=4 in cuda ci gpt2 (#3089) BEAM=2 is faster and less search time. investigating why BEAM2+BEAM4 is slower than BEAM2 alone * use device from LinearizerOptions in kernel search (#3090) * use device from LinearizerOptions in kernel search removed all Device.DEFAULT in search.py * pass device string for parallel pickle * device for interpreted backends in LinearizerOptions * update jit type annotation post lazy rewrite (#3091) * add mutigpu support for llama attention (#3064) * add llama attention test for multigpu * test fails * kv cache trying to shrink on sharded axis * mask None works for scale dot product * kv cache seems to be working but scale dot product breaks * scaled dot product works, but the last linear layer failed * running into the reshape case where it could be wrong for multigpu * making sure it was the reshape * adding contiguous doesn't solve * need to shard more properly * remove reshape test * minor adjustment to scale dot product attention test * weights are sharded wrong * continue fix new weight sharding * clean up * fix attention when start_pos is 0 * remove print * add TODOs for the best mutigpu interface * bugfix do not reset shapetracker of 0 size lazybuffer (#3096) it might be coming from an expand, and resetting results incorrect stride. caught by interpreted backend * One hot in tensor.py (#3093) * onehot in Tensor.py * one_hot tests * works for all shapes, not just 1 * pylint * not a static method * moved around, num_classes mandatory * pylint * pylint * space & moving * formatting * moved tests * fix broadcasted logic if there's 0 in shapes (#3097) * fix broadcasted logic if there's 0 in shapes should always expand into 0, not the other way around. fixed matmul with 0 in input shapes. for forwards for now though, backward is more involved and would need to change 0 size shortcuts * fix tests * replace with tensor op (#3099) * fix gpt2 with empty prompt (#3100) logits would be empty so need to replace that with ones before sampling, also cannot reshape with -1 when there's 0 in other axes * Revert "fix gpt2 with empty prompt" (#3101) * fix gpt2 with empty prompt take 2 (#3102) logits would be empty so need to replace that with ones before sampling, also cannot reshape with -1 when there's 0 in other axes * wmma: enable METAL half tensor cores and clean up cstyle (#3095) * wmma: enable METAL half tensor cores and clean up cstyle * revert simple_matmul rand changes and break line in tensor * added metal fp16->fp32 tensor core * add half @ half to mac benchmark (#3103) * flag to profile mixtral - 1.7 tok/s now (#3104) * update NumNode.__hash__ to be hash(self.b) (#3105) with this, `a:=NumNode(x) == b` implies `hash(a) == hash(b)` * catch runtime error in search._time_program (#3106) return inf if search encountered runtime errors. * no exceptions in __del__ when module creation is failed in hip/cuda (#3107) * failed test case due to cast resets shapetracker (#3109) cast implicitly resets shapetracker and makes it contiguous (for disk tensor), which fails for Interpreted backend if inputs contain non-contiguous st. * cleanup ops_disk type annotation and redundant str cast (#3110) * minor cleanup of test_disk_tensor (#3112) * add Tensor.var (#3114) also updated MeanVarianceNormalization and made test_ops test tensors of var and std smaller * move sample inside jit for beautiful_mnist (#3115) also removed .realize() for jit functions since jit does it automatically now. a little more beautiful * minor cleanups of onnx_ops (#3116) * fix conversation: llama generates token not prob now (#3120) * add device options for tests in multigpu (#3121) * make DType a dataclass (#3111) * remove np from DType * convert to dataclass * remove dunder hash, eq, ne overrides from ImageDType * is dataclass required for PtrDType? * fix GPU tests * reduce lines * revert changes to np * minor cleanup * hotfix: ptrdtype compare was broken * move fromcpu out of lazy.py (#3122) * move fromcpu out of lazy.py * fix abstractions2 * remove numpy from device (#3123) * remove numpy from device * fix tests * np item * cleanups * simplify with as_buffer * no toCPU * tinygradic * cast to scalar * remove numpy from ops_torch (#3124) updated mnist test to cast label to int8 and avoid hacking cast issue of torch uint8 * Fix backward fn for `<` and `==` (#3037) * fix no grad fn for < and == * remove 2 line breaks * Remove deprecated autograd variable --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com> * separate try except blocks in onnx2torch in model benchmark (#3126) exceptions can be raised from either model conversion or individual backend failed. openpilot on torch mps works, but does not work with torch cpu. seperate the expcetion block so that the benchmark can inlcude torch mps for openpilot. * update env_vars.md (#3127) mostly removed deprecated ones. not clear how to maintain this especially for extra/examples * update test_ptr_ne (#3130) * remove np from metal graph (#3129) * dtype fmt (#3132) * dtype fmt * three ways to access * fix off-by-one error in st_equal (#3131) * fix off by one error * whitespace * no numpy (#3134) * fast resnet eval (#3135) * fast resnet eval * fix HIP multidevice graph * neater expression for devices * lines * add decorator test * remove LLVMOPT * move ptx * Update ops_cuda.py --------- Co-authored-by: Christopher Milan <chrismilan@ucla.edu> Co-authored-by: chenyu <chenyu@fastmail.com> Co-authored-by: Yixiang Gao <yixiangg310573@gmail.com> Co-authored-by: jxdv <virgoj@protonmail.com> Co-authored-by: Francis Lam <flam@alum.mit.edu> Co-authored-by: SnakeOnex <sheeproman@gmail.com> Co-authored-by: nimlgen <138685161+nimlgen@users.noreply.github.com> Co-authored-by: Jyotirmaya Mahanta <jyotirmaya.mahanta@gmail.com> Co-authored-by: Guy Leroy <g.m.leroy@outlook.com> Co-authored-by: Paul Gustafson <paul.gustafson@theambrusgroup.com>
2024-01-16 08:44:20 +08:00
np.testing.assert_allclose(tensor.numpy(), target, rtol=1e-3 if target_dtype == dtypes.float16 else 1e-7)
except AssertionError as e:
raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e
def _test_op(fxn, target_dtype:DType, target):
_assert_eq(fxn(), target_dtype, target)
def _test_cast(a:Tensor, target_dtype:DType):
_test_op(lambda: a.cast(target_dtype), target_dtype, list(a.numpy().astype(target_dtype.np)))
def _test_bitcast(a:Tensor, target_dtype:DType, target=None):
_test_op(lambda: a.bitcast(target_dtype), target_dtype, target or a.numpy().view(target_dtype.np).tolist())
class TestDType(unittest.TestCase):
DTYPE: Any = None
DATA: Any = None
@classmethod
def setUpClass(cls):
if not cls.DTYPE or not is_dtype_supported(cls.DTYPE): raise unittest.SkipTest("dtype not supported")
if dtypes.is_int(cls.DTYPE): cls.DATA = np.random.randint(0, 100, size=10, dtype=cls.DTYPE.np).tolist()
elif cls.DTYPE == dtypes.bool: cls.DATA = np.random.choice([True, False], size=10).tolist()
else: cls.DATA = np.random.uniform(0, 1, size=10).tolist()
def setUp(self):
if self.DTYPE is None: raise unittest.SkipTest("base class")
def test_to_np(self): _test_to_np(Tensor(self.DATA, dtype=self.DTYPE), self.DTYPE.np, np.array(self.DATA, dtype=self.DTYPE.np))
def test_casts_to(self): list(map(
lambda dtype: _test_cast(Tensor(self.DATA, dtype=dtype), self.DTYPE),
get_available_cast_dtypes(self.DTYPE)
))
def test_casts_from(self): list(map(
lambda dtype: _test_cast(Tensor(self.DATA, dtype=self.DTYPE), dtype),
get_available_cast_dtypes(self.DTYPE)
))
def test_same_size_ops(self):
list(map(
lambda dtype: _test_ops(a_dtype=self.DTYPE, b_dtype=dtype) if dtype.itemsize == self.DTYPE.itemsize else None,
get_available_cast_dtypes(self.DTYPE)
))
def test_upcast_ops(self):
list(map(
lambda dtype: _test_ops(a_dtype=self.DTYPE, b_dtype=dtype) if dtype.itemsize > self.DTYPE.itemsize else None,
get_available_cast_dtypes(self.DTYPE)
))
def test_upcast_to_ops(self):
list(map(
lambda dtype: _test_ops(a_dtype=dtype, b_dtype=self.DTYPE) if dtype.itemsize < self.DTYPE.itemsize else None,
get_available_cast_dtypes(self.DTYPE)
))
def test_bitcast(self):
if Device.DEFAULT == "WEBGL": raise unittest.SkipTest("no bitcast in WebGL GLSL")
if self.DTYPE == dtypes.bool: raise unittest.SkipTest("no bools in bitcast")
list(map(
lambda dtype:
_test_bitcast(Tensor(self.DATA, dtype=self.DTYPE), dtype) if dtype.itemsize == self.DTYPE.itemsize and dtype != dtypes.bool else None,
get_available_cast_dtypes(self.DTYPE)
))
def test_dtypes_fields(self):
fields = dtypes.fields()
self.assertTrue(all(isinstance(value, DType) for value in fields.values()))
self.assertTrue(all(issubclass(value.np, np.generic) for value in fields.values() if value.np is not None))
def test_resulting_and_init_dtypes_match(self):
dtypes = list(map(np.dtype, ["bool", "uint8", "int8", "int16", "int32", "int64", "float32", "float64"]))
data = [1., 2., 0., 0.5, -1.5, 5.25]
for dt in dtypes:
arr = np.asarray(data, dtype=dt)
tin = Tensor(arr).numpy()
tor = torch.as_tensor(arr).detach().numpy()
assert dt is tin.dtype is tor.dtype, f"dtype mismatch: expected={dt} | tinygrad={tin.dtype} | torch={tor.dtype}"
np.testing.assert_allclose(tin, tor, atol=1e-6, rtol=1e-3)
def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None):
target_dtype = target_dtype or least_upper_dtype(a_dtype, b_dtype)
if not is_dtype_supported(a_dtype) or not is_dtype_supported(b_dtype) or not is_dtype_supported(target_dtype): return
if a_dtype == dtypes.bool or b_dtype == dtypes.bool: return
_assert_eq(Tensor([1,2,3,4], dtype=a_dtype)+Tensor([1,2,3,4], dtype=b_dtype), target_dtype, [2,4,6,8])
_assert_eq((Tensor([1], dtype=a_dtype).cast(b_dtype)+Tensor([1], dtype=a_dtype).cast(b_dtype)).cast(a_dtype), a_dtype, [2])
_assert_eq(Tensor([1,2,3,4], dtype=a_dtype)*Tensor([1,2,3,4], dtype=b_dtype), target_dtype, [1,4,9,16])
_assert_eq(Tensor([[1,2],[3,4]], dtype=a_dtype)@Tensor.eye(2, dtype=b_dtype), target_dtype, [[1,2],[3,4]])
_assert_eq(Tensor([1,1,1,1], dtype=a_dtype)+Tensor.ones((4,4), dtype=b_dtype), target_dtype, 2*Tensor.ones(4,4).numpy())
@unittest.skipUnless(Device.DEFAULT == "LLVM", "bfloat16 not supported")
class TestBFloat16DType(unittest.TestCase):
def test_bf16_to_float(self):
with self.assertRaises(AssertionError):
_test_cast(Tensor([100000], dtype=dtypes.bfloat16), dtypes.float32)
def test_float_to_bf16(self):
with self.assertRaises(AssertionError):
_test_cast(Tensor([100000], dtype=dtypes.float32), dtypes.bfloat16)
# torch.tensor([10000, -1, -1000, -10000, 20]).type(torch.bfloat16)
def test_bf16(self):
t = Tensor([10000, -1, -1000, -10000, 20]).cast(dtypes.bfloat16)
t.realize()
back = t.cast(dtypes.float32)
assert tuple(back.numpy().tolist()) == (9984., -1, -1000, -9984, 20)
def test_bf16_disk_write_read(self):
t = Tensor([10000, -1, -1000, -10000, 20]).cast(dtypes.float32)
t.to(f"disk:{temp('f32')}").realize()
# hack to "cast" f32 -> bf16
with open(temp('f32'), "rb") as f: dat = f.read()
adat = b''.join([dat[i+2:i+4] for i in range(0, len(dat), 4)])
with open(temp('bf16'), "wb") as f: f.write(adat)
t = Tensor.empty(5, dtype=dtypes.bfloat16, device=f"disk:{temp('bf16')}").llvm().realize()
back = t.cast(dtypes.float32)
assert tuple(back.numpy().tolist()) == (9984., -1, -1000, -9984, 20)
class TestHalfDtype(TestDType): DTYPE = dtypes.half
class TestFloatDType(TestDType): DTYPE = dtypes.float
class TestDoubleDtype(TestDType):
DTYPE = dtypes.double
@unittest.skipIf(getenv("CUDACPU",0)==1, "conversion not supported on CUDACPU")
@unittest.skipIf(getenv("HIP",0)==1, "HIP renderer does not support f64 precision")
def test_float64_increased_precision(self):
for func in [
lambda t: t.exp(),
lambda t: t.exp2(),
lambda t: t.log(),
lambda t: t.log2(),
lambda t: t.sqrt(),
lambda t: t.rsqrt(),
lambda t: t.sin(),
lambda t: t.cos(),
lambda t: t.tan(),
lambda t: t.sigmoid(),
]:
a = [2, 3, 4]
np.testing.assert_allclose(func(Tensor(a, dtype=self.DTYPE)).numpy(), func(torch.tensor(a, dtype=torch.float64)), rtol=1e-12, atol=1e-12)
class TestInt8Dtype(TestDType):
DTYPE = dtypes.int8
@unittest.skipIf(getenv("CUDA",0)==1 or getenv("PTX", 0)==1, "cuda saturation works differently")
def test_int8_to_uint8_negative(self):
_test_op(lambda: Tensor([-1, -2, -3, -4], dtype=dtypes.int8).cast(dtypes.uint8), dtypes.uint8, [255, 254, 253, 252])
class TestUint8Dtype(TestDType):
DTYPE = dtypes.uint8
@unittest.skipIf(getenv("CUDA",0)==1 or getenv("PTX", 0)==1, "cuda saturation works differently")
def test_uint8_to_int8_overflow(self):
_test_op(lambda: Tensor([255, 254, 253, 252], dtype=dtypes.uint8).cast(dtypes.int8), dtypes.int8, [-1, -2, -3, -4])
@unittest.skipIf(Device.DEFAULT == "WEBGL", "No bitcast on WebGL")
class TestBitCast(unittest.TestCase):
def test_shape_change_bitcast(self):
with self.assertRaises(AssertionError):
_test_bitcast(Tensor([100000], dtype=dtypes.float32), dtypes.uint8, [100000])
def test_bitcast_float_to_int32(self):
a = Tensor([1.,2,3])
b = a.bitcast(dtypes.int32)
assert b.numpy()[0] == 0x3f800000
def test_bitcast_upcasted(self):
a = Tensor.zeros(100, 4, dtype=dtypes.int32).contiguous() + 0x3f800000
b = a.bitcast(dtypes.float32)
assert b.numpy()[0,0] == 1.
class TestInt16Dtype(TestDType): DTYPE = dtypes.int16
class TestUint16Dtype(TestDType): DTYPE = dtypes.uint16
Added Test Coverage to Int32 and Make Sure Tests Succeed (#1174) * Added test coverage for int32 in `test/test_dtype.py` Tests for int32 include: - testing that int32 can be converted into a numpy array - testing that float and int64 can be cast into int32 - testing that int32 can be cast into float and int64 - testing addition, multiplication, and matrix multiplication with int32 - testing that addition, multiplication, and matrix multiplication with int32 and either float or int64 gets successfully cast into float and int64, respectively Additional changes include testing that int8 casts into int32 and testing that float16 casts into int32 * Added type casting to the add, subtract, and divide binary operations * Added automatic type casting when types differ to FusedOps.MULACC I moved the match_types function back so that I could call it in einsum_mulacc where it would cast the types of the MULACC to be the same * Added unit test for match_types and added type hints to the parameters * Added tests for ops_cpu.match_types * Changed ops_cpu.einsum logic to play nicely with PyTorch Changed `tinygrad.runtime.ops_cpu.einsum_mulacc` logic to not perform type matching. Type matching was instead moved to the numpy_fxn_for_op dictionary in the ops_cpu file. Since ops_torch uses the same einsum_mulacc function, this should fix all the broken pytorch tests. * empty commit to rerun ci * reverting PR#1213 in attempt to fix broken test * Removed all tests I added to see if they are causing CI issues * Added back type matching tests * removed type matching tests and added back int tests * added back part of the type matching tests * removed braking type matching tests * empty commit for testing * added test back but inside comment * removed a test from the comment to see if it breaks CI * removed another function * more testing * emptied test comment * cleaned up comments * Added optimize=True flag to einsum_mullac in cpu_ops.py * Removed unnecessary imports from tests * optimized match_types by removing unnecessary array copying
2023-07-13 01:29:15 +08:00
class TestInt32Dtype(TestDType): DTYPE = dtypes.int32
class TestUint32Dtype(TestDType): DTYPE = dtypes.uint32
Added Test Coverage to Int32 and Make Sure Tests Succeed (#1174) * Added test coverage for int32 in `test/test_dtype.py` Tests for int32 include: - testing that int32 can be converted into a numpy array - testing that float and int64 can be cast into int32 - testing that int32 can be cast into float and int64 - testing addition, multiplication, and matrix multiplication with int32 - testing that addition, multiplication, and matrix multiplication with int32 and either float or int64 gets successfully cast into float and int64, respectively Additional changes include testing that int8 casts into int32 and testing that float16 casts into int32 * Added type casting to the add, subtract, and divide binary operations * Added automatic type casting when types differ to FusedOps.MULACC I moved the match_types function back so that I could call it in einsum_mulacc where it would cast the types of the MULACC to be the same * Added unit test for match_types and added type hints to the parameters * Added tests for ops_cpu.match_types * Changed ops_cpu.einsum logic to play nicely with PyTorch Changed `tinygrad.runtime.ops_cpu.einsum_mulacc` logic to not perform type matching. Type matching was instead moved to the numpy_fxn_for_op dictionary in the ops_cpu file. Since ops_torch uses the same einsum_mulacc function, this should fix all the broken pytorch tests. * empty commit to rerun ci * reverting PR#1213 in attempt to fix broken test * Removed all tests I added to see if they are causing CI issues * Added back type matching tests * removed type matching tests and added back int tests * added back part of the type matching tests * removed braking type matching tests * empty commit for testing * added test back but inside comment * removed a test from the comment to see if it breaks CI * removed another function * more testing * emptied test comment * cleaned up comments * Added optimize=True flag to einsum_mullac in cpu_ops.py * Removed unnecessary imports from tests * optimized match_types by removing unnecessary array copying
2023-07-13 01:29:15 +08:00
class TestInt64Dtype(TestDType): DTYPE = dtypes.int64
class TestUint64Dtype(TestDType): DTYPE = dtypes.uint64
Added Test Coverage to Int32 and Make Sure Tests Succeed (#1174) * Added test coverage for int32 in `test/test_dtype.py` Tests for int32 include: - testing that int32 can be converted into a numpy array - testing that float and int64 can be cast into int32 - testing that int32 can be cast into float and int64 - testing addition, multiplication, and matrix multiplication with int32 - testing that addition, multiplication, and matrix multiplication with int32 and either float or int64 gets successfully cast into float and int64, respectively Additional changes include testing that int8 casts into int32 and testing that float16 casts into int32 * Added type casting to the add, subtract, and divide binary operations * Added automatic type casting when types differ to FusedOps.MULACC I moved the match_types function back so that I could call it in einsum_mulacc where it would cast the types of the MULACC to be the same * Added unit test for match_types and added type hints to the parameters * Added tests for ops_cpu.match_types * Changed ops_cpu.einsum logic to play nicely with PyTorch Changed `tinygrad.runtime.ops_cpu.einsum_mulacc` logic to not perform type matching. Type matching was instead moved to the numpy_fxn_for_op dictionary in the ops_cpu file. Since ops_torch uses the same einsum_mulacc function, this should fix all the broken pytorch tests. * empty commit to rerun ci * reverting PR#1213 in attempt to fix broken test * Removed all tests I added to see if they are causing CI issues * Added back type matching tests * removed type matching tests and added back int tests * added back part of the type matching tests * removed braking type matching tests * empty commit for testing * added test back but inside comment * removed a test from the comment to see if it breaks CI * removed another function * more testing * emptied test comment * cleaned up comments * Added optimize=True flag to einsum_mullac in cpu_ops.py * Removed unnecessary imports from tests * optimized match_types by removing unnecessary array copying
2023-07-13 01:29:15 +08:00
class TestBoolDtype(TestDType): DTYPE = dtypes.bool
class TestImageDType(unittest.TestCase):
def test_image_scalar(self):
assert dtypes.imagef((10,10)).scalar() == dtypes.float32
assert dtypes.imageh((10,10)).scalar() == dtypes.float32
def test_image_vec(self):
assert dtypes.imagef((10,10)).vec(4) == dtypes.float32.vec(4)
assert dtypes.imageh((10,10)).vec(4) == dtypes.float32.vec(4)
class TestEqStrDType(unittest.TestCase):
def test_image_ne(self):
if ImageDType is None: raise unittest.SkipTest("no ImageDType support")
assert dtypes.float == dtypes.float32, "float doesn't match?"
assert dtypes.imagef((1,2,4)) != dtypes.imageh((1,2,4)), "different image dtype doesn't match"
assert dtypes.imageh((1,2,4)) != dtypes.imageh((1,4,2)), "different shape doesn't match"
assert dtypes.imageh((1,2,4)) == dtypes.imageh((1,2,4)), "same shape matches"
assert isinstance(dtypes.imageh((1,2,4)), ImageDType)
def test_ptr_ne(self):
if PtrDType is None: raise unittest.SkipTest("no PtrDType support")
# TODO: is this the wrong behavior?
assert PtrDType(dtypes.float32) == dtypes.float32
2024-01-16 00:36:29 +08:00
assert not (PtrDType(dtypes.float32) != dtypes.float32)
assert PtrDType(dtypes.float32) == PtrDType(dtypes.float32)
assert not (PtrDType(dtypes.float32) != PtrDType(dtypes.float32))
#assert PtrDType(dtypes.float32) != dtypes.float32
def test_strs(self):
if PtrDType is None: raise unittest.SkipTest("no PtrDType support")
self.assertEqual(str(dtypes.imagef((1,2,4))), "dtypes.imagef((1, 2, 4))")
self.assertEqual(str(PtrDType(dtypes.float32)), "ptr.dtypes.float")
class TestHelpers(unittest.TestCase):
signed_ints = (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64)
uints = (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
floats = (dtypes.float16, dtypes.float32, dtypes.float64)
@given(strat.sampled_from(signed_ints+uints), strat.integers(min_value=1, max_value=8))
def test_is_int(self, dtype, amt):
assert dtypes.is_int(dtype.vec(amt) if amt > 1 else dtype)
assert not dtypes.is_float(dtype.vec(amt) if amt > 1 else dtype)
@given(strat.sampled_from(uints), strat.integers(min_value=1, max_value=8))
def test_is_unsigned_uints(self, dtype, amt):
assert dtypes.is_unsigned(dtype.vec(amt) if amt > 1 else dtype)
@given(strat.sampled_from(signed_ints), strat.integers(min_value=1, max_value=8))
def test_is_unsigned_signed_ints(self, dtype, amt):
assert not dtypes.is_unsigned(dtype.vec(amt) if amt > 1 else dtype)
@given(strat.sampled_from(floats), strat.integers(min_value=1, max_value=8))
def test_is_float(self, dtype, amt):
assert dtypes.is_float(dtype.vec(amt) if amt > 1 else dtype)
assert not dtypes.is_int(dtype.vec(amt) if amt > 1 else dtype)
assert not dtypes.is_unsigned(dtype.vec(amt) if amt > 1 else dtype)
def test_bf16_is_float(self):
assert dtypes.is_float(dtypes.bfloat16)
@given(strat.sampled_from([d for d in DTYPES_DICT.values() if dtypes.is_float(d) or dtypes.is_int(d)]), strat.integers(min_value=2, max_value=8))
def test_scalar(self, dtype, amt):
assert dtype.vec(amt).scalar() == dtype
class TestTypeSpec(unittest.TestCase):
def setUp(self):
self.old_default_int, self.old_default_float = dtypes.default_int, dtypes.default_float
def tearDown(self):
dtypes.default_int, dtypes.default_float = self.old_default_int, self.old_default_float
def test_set_dtype_default(self):
dtypes.default_int = dtypes.int16
assert dtypes.default_int == dtypes.int16
dtypes.default_int = dtypes.int64
assert dtypes.default_int == dtypes.int64
dtypes.default_int = dtypes.int32
assert dtypes.default_int == dtypes.int32
dtypes.default_float = dtypes.float16
assert dtypes.default_float == dtypes.float16
dtypes.default_float = dtypes.float64
assert dtypes.default_float == dtypes.float64
@given(strat.sampled_from([dtypes.int8,dtypes.int16,dtypes.int32,dtypes.int64]), strat.sampled_from([dtypes.float16,dtypes.float32,dtypes.float64]))
def test_creation(self, default_int, default_float):
dtypes.default_int, dtypes.default_float = default_int, default_float
2023-12-17 08:04:08 +08:00
assert Tensor(True).dtype == dtypes.bool
assert Tensor(None).dtype == dtypes.default_float
assert Tensor(2).dtype == dtypes.default_int
assert Tensor(2.34).dtype == dtypes.default_float
assert Tensor([]).dtype == dtypes.default_float
assert Tensor([1]).dtype == dtypes.default_int
assert Tensor([1.1]).dtype == dtypes.default_float
assert Tensor([0,1], dtype=dtypes.bfloat16).dtype == dtypes.bfloat16
assert Tensor.eye(0).dtype == dtypes.default_float
assert Tensor.eye(3).dtype == dtypes.default_float
assert Tensor.eye(3, dtype=dtypes.float16).dtype == dtypes.float16
assert Tensor.eye(3, dtype=dtypes.int64).dtype == dtypes.int64
@given(strat.sampled_from([dtypes.int8,dtypes.int16,dtypes.int32,dtypes.int64]), strat.sampled_from([dtypes.float16,dtypes.float32,dtypes.float64]))
def test_full(self, default_int, default_float):
dtypes.default_int, dtypes.default_float = default_int, default_float
assert Tensor.ones([2,3]).dtype == dtypes.default_float
assert Tensor.zeros([2,3]).dtype == dtypes.default_float
assert Tensor.full([2,3], 3.3).dtype == dtypes.default_float
assert Tensor.full([2,3], 3).dtype == dtypes.default_int
assert Tensor.full([2,3], True).dtype == dtypes.bool
assert Tensor.zeros(3, 3).dtype == dtypes.default_float
assert Tensor.zeros(3, 3, dtype=dtypes.float16).dtype == dtypes.float16
assert Tensor.zeros(3, 3, dtype=dtypes.int64).dtype == dtypes.int64
assert Tensor.ones(3, 3).dtype == dtypes.default_float
assert Tensor.ones(3, 3, dtype=dtypes.float16).dtype == dtypes.float16
assert Tensor.ones(3, 3, dtype=dtypes.int64).dtype == dtypes.int64
assert Tensor.full((3, 3), 3).dtype == dtypes.default_int
assert Tensor.full((3, 3), 3.0).dtype == dtypes.default_float
assert Tensor.full((3, 3), 3, dtype=dtypes.float16).dtype == dtypes.float16
assert Tensor.full((3, 3), 3, dtype=dtypes.int64).dtype == dtypes.int64
def test_reduce_0d_default(self):
assert Tensor.ones([2,3,0]).sum(2).dtype == dtypes.default_float
assert Tensor.ones([2,3,0], dtype=dtypes.int).sum(2).dtype == dtypes.int
@given(strat.sampled_from([dtypes.int8,dtypes.int16,dtypes.int32,dtypes.int64]), strat.sampled_from([dtypes.float16,dtypes.float32,dtypes.float64]))
def test_arange(self, default_int, default_float):
dtypes.default_int, dtypes.default_float = default_int, default_float
assert Tensor.arange(5).dtype == dtypes.default_int
assert Tensor.arange(5.0).dtype == dtypes.default_float
assert Tensor.arange(5, dtype=dtypes.int16).dtype == dtypes.int16
assert Tensor.arange(5, dtype=dtypes.int64).dtype == dtypes.int64
assert Tensor.arange(5, dtype=dtypes.float16).dtype == dtypes.float16
assert Tensor.arange(3, 9, 0.7).dtype == dtypes.default_float
assert Tensor.arange(3, 8.5, 3).dtype == dtypes.default_float
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "WEBGPU doesn't follow the bool ops spec")
@given(strat.sampled_from(core_dtypes), strat.sampled_from([operator.gt, operator.ge, operator.le, operator.lt, operator.eq, operator.ne]))
def test_bool_ops(self, dtype, op):
assert op(Tensor.rand(4, 4, dtype=dtype), Tensor.rand(4, 4, dtype=dtype)).dtype == dtypes.bool
@given(strat.sampled_from(core_dtypes),
strat.sampled_from([dtypes.int8,dtypes.int16,dtypes.int32,dtypes.int64]), strat.sampled_from([dtypes.float16,dtypes.float32,dtypes.float64]))
def test_functions_return_index(self, dtype, default_int, default_float):
dtypes.default_int, dtypes.default_float = default_int, default_float
assert Tensor([0, 1], dtype=dtype).argmax().dtype == dtypes.default_int
assert Tensor([0, 1], dtype=dtype).argmin().dtype == dtypes.default_int
assert Tensor([0, 1], dtype=dtype).multinomial().dtype == dtypes.default_int
class TestTypePromotion(unittest.TestCase):
@given(strat.sampled_from(core_dtypes))
def test_self_promo_to_self(self, dtype):
assert least_upper_dtype(dtype) == dtype
assert least_upper_dtype(dtype, dtype) == dtype
assert least_upper_dtype(dtype, dtype, dtype) == dtype
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
def test_promo_resulted_higher_than_inputs(self, dtype1, dtype2):
result = least_upper_dtype(dtype1, dtype2)
assert result >= dtype1 and result >= dtype2
def test_dtype_promo(self):
assert least_upper_dtype(dtypes.bool, dtypes.int8) == dtypes.int8
assert least_upper_dtype(dtypes.int8, dtypes.uint8) == dtypes.int16
assert least_upper_dtype(dtypes.uint8, dtypes.int16) == dtypes.int16
assert least_upper_dtype(dtypes.int16, dtypes.uint16) == dtypes.int32
assert least_upper_dtype(dtypes.uint16, dtypes.int32) == dtypes.int32
assert least_upper_dtype(dtypes.int32, dtypes.uint32) == dtypes.int64
assert least_upper_dtype(dtypes.uint32, dtypes.int64) == dtypes.int64
# similar to jax but we don't use weak type
assert least_upper_dtype(dtypes.int64, dtypes.uint64) == dtypes.float16
assert least_upper_dtype(dtypes.float16, dtypes.float32) == dtypes.float32
assert least_upper_dtype(dtypes.float32, dtypes.float64) == dtypes.float64
assert least_upper_dtype(dtypes.bool, dtypes.float32) == dtypes.float32
assert least_upper_dtype(dtypes.bool, dtypes.float64) == dtypes.float64
assert least_upper_dtype(dtypes.float16, dtypes.int64) == dtypes.float16
assert least_upper_dtype(dtypes.float16, dtypes.uint64) == dtypes.float16
@given(strat.sampled_from(floats))
def test_float_to_float(self, dt):
assert least_upper_float(dt) == dt
class TestAutoCastType(unittest.TestCase):
def setUp(self):
self.old_default_int, self.old_default_float = dtypes.default_int, dtypes.default_float
def tearDown(self):
dtypes.default_int, dtypes.default_float = self.old_default_int, self.old_default_float
@given(strat.sampled_from([d for d in DTYPES_DICT.values() if dtypes.is_int(d) and is_dtype_supported(d)]))
def test_int_to_float_unary_func(self, dtype):
for func in [
lambda t: t.exp(),
lambda t: t.exp2(),
lambda t: t.log(),
lambda t: t.log2(),
lambda t: t.sqrt(),
lambda t: t.rsqrt(),
lambda t: t.sin(),
lambda t: t.cos(),
lambda t: t.tan(),
lambda t: t.sigmoid(),
]:
a = [2, 3, 4]
# float16 can have larger precision errors
np.testing.assert_allclose(func(Tensor(a, dtype=dtype)).numpy(), func(torch.tensor(a)), rtol=1e-3, atol=1e-3)
@given(strat.sampled_from(core_dtypes))
def test_broadcast_scalar(self, dt):
assert (Tensor.rand(4, 4, dtype=dt) + 2.3).dtype == (dt if dtypes.is_float(dt) else dtypes.default_float)
assert (Tensor.rand(4, 4, dtype=dt) + 2).dtype == (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int)
if Device.DEFAULT != "WEBGPU" and dt != dtypes.bool:
assert (Tensor.rand(4, 4, dtype=dt) + True).dtype == dt
def test_sum(self):
assert (Tensor([0, 1], dtype=dtypes.bool)).sum().dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int8)).sum().dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int16)).sum().dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int32)).sum().dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int64)).sum().dtype == dtypes.int64
assert (Tensor([0, 1], dtype=dtypes.uint8)).sum().dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint16)).sum().dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint32)).sum().dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint64)).sum().dtype == dtypes.uint64
assert (Tensor([0, 1], dtype=dtypes.float16)).sum().dtype == dtypes.float16
assert (Tensor([0, 1], dtype=dtypes.bfloat16)).sum().dtype == dtypes.bfloat16
assert (Tensor([0, 1], dtype=dtypes.float32)).sum().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.float64)).sum().dtype == dtypes.float64
def test_cumsum(self):
assert (Tensor([0, 1], dtype=dtypes.bool)).cumsum(0).dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int8)).cumsum(0).dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int16)).cumsum(0).dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int32)).cumsum(0).dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int64)).cumsum(0).dtype == dtypes.int64
assert (Tensor([0, 1], dtype=dtypes.uint8)).cumsum(0).dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint16)).cumsum(0).dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint32)).cumsum(0).dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint64)).cumsum(0).dtype == dtypes.uint64
assert (Tensor([0, 1], dtype=dtypes.float16)).cumsum(0).dtype == dtypes.float16
assert (Tensor([0, 1], dtype=dtypes.bfloat16)).cumsum(0).dtype == dtypes.bfloat16
assert (Tensor([0, 1], dtype=dtypes.float32)).cumsum(0).dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.float64)).cumsum(0).dtype == dtypes.float64
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
def test_matmul(self, dt1, dt2):
assert (Tensor([0, 1], dtype=dt1) @ Tensor([0, 1], dtype=dt2)).dtype == least_upper_dtype(dt1, dt2)
@staticmethod
def check_where_alternate_input_other(input_, other, data_type):
assert (Tensor([True, False]).where(input_, other)).dtype == data_type
assert (Tensor([True, False]).where(other, input_)).dtype == data_type
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
def test_where_no_scalar(self, dt1, dt2):
self.check_where_alternate_input_other(Tensor(2, dtype=dt1), Tensor(3, dtype=dt2), least_upper_dtype(dt1, dt2))
@given(strat.sampled_from(core_dtypes))
def test_where_one_scalar(self, dt):
t = Tensor(2, dtype=dt)
self.check_where_alternate_input_other(t, 3.2, (dt if dtypes.is_float(dt) else dtypes.default_float))
self.check_where_alternate_input_other(t, 3, (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int))
self.check_where_alternate_input_other(t, True, dt)
def test_where_two_scalars(self):
self.check_where_alternate_input_other(3.1, 3.2, dtypes.default_float)
self.check_where_alternate_input_other(3.1, 3, dtypes.default_float)
self.check_where_alternate_input_other(3.1, True, dtypes.default_float)
self.check_where_alternate_input_other(3, 2, dtypes.default_int)
self.check_where_alternate_input_other(3, True, dtypes.default_int)
self.check_where_alternate_input_other(False, True, dtypes.bool)
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
def test_maximum(self, dt1, dt2):
assert Tensor([0, 1, 2], dtype=dt1).maximum(Tensor([2, 0, 5], dtype=dt2)).dtype == least_upper_dtype(dt1, dt2)
@given(strat.sampled_from(core_dtypes))
def test_maximum_const(self, dt):
assert Tensor([1, 2], dtype=dt).maximum(3.1).dtype == (dt if dtypes.is_float(dt) else dtypes.default_float)
assert Tensor([1, 2], dtype=dt).maximum(3).dtype == (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int)
assert Tensor([1, 2], dtype=dt).maximum(True).dtype == dt
def test_div(self):
assert (Tensor([1, 2], dtype=dtypes.int32) / Tensor([2, 2], dtype=dtypes.int32)).dtype == dtypes.default_float
assert (Tensor([1, 2], dtype=dtypes.int16) / Tensor([2, 2], dtype=dtypes.int32)).dtype == dtypes.default_float
assert (Tensor([1, 2], dtype=dtypes.float32) / Tensor([2, 2], dtype=dtypes.float16)).dtype == dtypes.float32
assert (Tensor([1, 2], dtype=dtypes.int32) / Tensor([2, 2], dtype=dtypes.float16)).dtype == dtypes.float16
def test_div_const(self):
assert (Tensor([1, 2], dtype=dtypes.int32) / 2).dtype == dtypes.default_float
assert (Tensor([1, 2], dtype=dtypes.int32) / 2.0).dtype == dtypes.default_float
assert (Tensor([1, 2], dtype=dtypes.float16) / 2).dtype == dtypes.float16
assert (Tensor([1, 2], dtype=dtypes.float16) / 2.0).dtype == dtypes.float16
if __name__ == '__main__':
unittest.main()