tinygrad/test/test_dtype.py

120 lines
9.7 KiB
Python
Raw Normal View History

import unittest
import numpy as np
from tinygrad.helpers import getenv, DType, DEBUG
from tinygrad.lazy import Device
from tinygrad.tensor import Tensor, dtypes
from extra.utils import OSX
def _test_to_np(a:Tensor, np_dtype, target):
print(a)
na = a.numpy()
print(na, na.dtype, a.lazydata.realized)
assert na.dtype == np_dtype
np.testing.assert_allclose(na, target)
def _test_op(fxn, target_dtype:DType, target):
c = fxn()
if DEBUG >= 2: print(c.numpy())
assert c.dtype == target_dtype
np.testing.assert_allclose(c.numpy(), target)
def _test_cast(a:Tensor, target_dtype:DType, target): _test_op(lambda: a.cast(target_dtype), target_dtype, target)
def _test_add(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(lambda: a+b, target_dtype, target)
def _test_mul(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(lambda: a*b, target_dtype, target)
def _test_matmul(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(lambda: a@b, target_dtype, target)
def _test_add_upcast(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(lambda: a+b, target_dtype, target)
def _test_mul_upcast(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(lambda: a*b, target_dtype, target)
def _test_matmul_upcast(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(lambda: a@b, target_dtype, target)
# for GPU, cl_khr_fp16 isn't supported (except now we don't need it!)
# for LLVM, it segfaults because it can't link to the casting function
Webgpu support (#1077) * initial commit * 81 passing * 105 passing tests * 148 passing * CI tests * install dep on ci * try opencl pkgs * try using vulkan * down to only 6 failing * refactor * cleaning up * another test skipped due to buffer limit * linter * segfault * indent fix * another segfault found * small touchups * Fix max and maxpool tests * Add constant folding * Add javascript export script * better asserts in codegen * manual upcasting * reverted token type change * skip safetensor test due to unsupported type * FIx efficientnet and all other model tests * Remove np copy * fixed indent and missing import * manually destroy the buffer * revert back to length * linter errors * removed extra val * skip broken tests * skipping more tests * Make the page pretty * Save model weights as safetensor * Fix imagenet to c test * Fix second imagenet to c bug * Async and paralel kernel compilation * workgroup support * reversed local size * fixed non local bug * correct local groups * ci experiment * removed typo * Fix define local by using shared memory * Refactor * try running on mac * match metal tests * add more workers * scope down tests * trying windows runner * fixed windows env * see how many it can do * merged master * refactor * missed refactor * increase test suite coverage * missing import * whitespace in test_efficientnet.py * getting there * fixed reset * fixed bufs * switched to cstyle * cleanup * min/max rename * one more linter issue * fixed demo * linter * testing ci chrome * add unsafe webgpu arg * add build step * remove WEBGPU from cmd line * use module * try forcing directx * trying forced metal backend * temp disable conv2d for CI * disable conv_trasnpose2d --------- Co-authored-by: 0x4d - Martin Loretz <20306567+martinloretzzz@users.noreply.github.com> Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2023-07-13 03:52:06 +08:00
@unittest.skipIf((getenv("CI", "") != "" and Device.DEFAULT in ["LLVM"]) or Device.DEFAULT == "WEBGPU", "float16 broken in some CI backends")
class TestHalfDtype(unittest.TestCase):
def test_half_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.float16), np.float16, [1,2,3,4])
def test_half_to_float(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float32, [1,2,3,4])
def test_half_to_int8(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.int8, [1,2,3,4])
def test_half_to_uint8(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.uint8, [1,2,3,4])
Added Test Coverage to Int32 and Make Sure Tests Succeed (#1174) * Added test coverage for int32 in `test/test_dtype.py` Tests for int32 include: - testing that int32 can be converted into a numpy array - testing that float and int64 can be cast into int32 - testing that int32 can be cast into float and int64 - testing addition, multiplication, and matrix multiplication with int32 - testing that addition, multiplication, and matrix multiplication with int32 and either float or int64 gets successfully cast into float and int64, respectively Additional changes include testing that int8 casts into int32 and testing that float16 casts into int32 * Added type casting to the add, subtract, and divide binary operations * Added automatic type casting when types differ to FusedOps.MULACC I moved the match_types function back so that I could call it in einsum_mulacc where it would cast the types of the MULACC to be the same * Added unit test for match_types and added type hints to the parameters * Added tests for ops_cpu.match_types * Changed ops_cpu.einsum logic to play nicely with PyTorch Changed `tinygrad.runtime.ops_cpu.einsum_mulacc` logic to not perform type matching. Type matching was instead moved to the numpy_fxn_for_op dictionary in the ops_cpu file. Since ops_torch uses the same einsum_mulacc function, this should fix all the broken pytorch tests. * empty commit to rerun ci * reverting PR#1213 in attempt to fix broken test * Removed all tests I added to see if they are causing CI issues * Added back type matching tests * removed type matching tests and added back int tests * added back part of the type matching tests * removed braking type matching tests * empty commit for testing * added test back but inside comment * removed a test from the comment to see if it breaks CI * removed another function * more testing * emptied test comment * cleaned up comments * Added optimize=True flag to einsum_mullac in cpu_ops.py * Removed unnecessary imports from tests * optimized match_types by removing unnecessary array copying
2023-07-13 01:29:15 +08:00
def test_half_to_int32(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.int32, [1,2,3,4])
def test_half_to_int64(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.int64, [1,2,3,4])
def test_float_to_half(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float16, [1,2,3,4])
def test_int8_to_half(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.float16, [1,2,3,4])
def test_uint8_to_half(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.uint8), dtypes.float16, [1,2,3,4])
def test_half_add(self): _test_add(Tensor([1,2,3,4], dtype=dtypes.float16), Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float16, [2,4,6,8])
def test_half_mul(self): _test_mul(Tensor([1,2,3,4], dtype=dtypes.float16), Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float16, [1,4,9,16])
def test_half_matmul(self): _test_matmul(Tensor([[1,2],[3,4]], dtype=dtypes.float16), Tensor.eye(2, dtype=dtypes.float16), dtypes.float16, [[1,2],[3,4]])
def test_half_add_upcast_float(self): _test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.float16), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [2,4,6,8])
def test_int8_add_upcast_half(self): _test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float16, [2,4,6,8])
def test_int8_mul_upcast_half(self): _test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float16, [1,4,9,16])
def test_half_mul_upcast_float(self): _test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.float16), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [1,4,9,16])
def test_half_matmul_upcast_float(self): _test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.float16), Tensor.eye(2, dtype=dtypes.float32), dtypes.float32, [[1,2],[3,4]])
def test_int8_matmul_upcast_half(self): _test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.float16), dtypes.float16, [[1,2],[3,4]])
Webgpu support (#1077) * initial commit * 81 passing * 105 passing tests * 148 passing * CI tests * install dep on ci * try opencl pkgs * try using vulkan * down to only 6 failing * refactor * cleaning up * another test skipped due to buffer limit * linter * segfault * indent fix * another segfault found * small touchups * Fix max and maxpool tests * Add constant folding * Add javascript export script * better asserts in codegen * manual upcasting * reverted token type change * skip safetensor test due to unsupported type * FIx efficientnet and all other model tests * Remove np copy * fixed indent and missing import * manually destroy the buffer * revert back to length * linter errors * removed extra val * skip broken tests * skipping more tests * Make the page pretty * Save model weights as safetensor * Fix imagenet to c test * Fix second imagenet to c bug * Async and paralel kernel compilation * workgroup support * reversed local size * fixed non local bug * correct local groups * ci experiment * removed typo * Fix define local by using shared memory * Refactor * try running on mac * match metal tests * add more workers * scope down tests * trying windows runner * fixed windows env * see how many it can do * merged master * refactor * missed refactor * increase test suite coverage * missing import * whitespace in test_efficientnet.py * getting there * fixed reset * fixed bufs * switched to cstyle * cleanup * min/max rename * one more linter issue * fixed demo * linter * testing ci chrome * add unsafe webgpu arg * add build step * remove WEBGPU from cmd line * use module * try forcing directx * trying forced metal backend * temp disable conv2d for CI * disable conv_trasnpose2d --------- Co-authored-by: 0x4d - Martin Loretz <20306567+martinloretzzz@users.noreply.github.com> Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2023-07-13 03:52:06 +08:00
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu does not support int8")
class TestInt8Dtype(unittest.TestCase):
def test_int8_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.int8), np.int8, [1,2,3,4])
def test_uint8_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.uint8), np.uint8, [1,2,3,4])
def test_int64_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.int64), np.int64, [1,2,3,4])
def test_float_to_int8(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.int8, [1,2,3,4])
def test_float_to_uint8(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.uint8, [1,2,3,4])
def test_float_to_int64(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.int64, [1,2,3,4])
def test_int8_to_float(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.float32, [1,2,3,4])
def test_int8_to_uint8(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.uint8, [1,2,3,4])
Added Test Coverage to Int32 and Make Sure Tests Succeed (#1174) * Added test coverage for int32 in `test/test_dtype.py` Tests for int32 include: - testing that int32 can be converted into a numpy array - testing that float and int64 can be cast into int32 - testing that int32 can be cast into float and int64 - testing addition, multiplication, and matrix multiplication with int32 - testing that addition, multiplication, and matrix multiplication with int32 and either float or int64 gets successfully cast into float and int64, respectively Additional changes include testing that int8 casts into int32 and testing that float16 casts into int32 * Added type casting to the add, subtract, and divide binary operations * Added automatic type casting when types differ to FusedOps.MULACC I moved the match_types function back so that I could call it in einsum_mulacc where it would cast the types of the MULACC to be the same * Added unit test for match_types and added type hints to the parameters * Added tests for ops_cpu.match_types * Changed ops_cpu.einsum logic to play nicely with PyTorch Changed `tinygrad.runtime.ops_cpu.einsum_mulacc` logic to not perform type matching. Type matching was instead moved to the numpy_fxn_for_op dictionary in the ops_cpu file. Since ops_torch uses the same einsum_mulacc function, this should fix all the broken pytorch tests. * empty commit to rerun ci * reverting PR#1213 in attempt to fix broken test * Removed all tests I added to see if they are causing CI issues * Added back type matching tests * removed type matching tests and added back int tests * added back part of the type matching tests * removed braking type matching tests * empty commit for testing * added test back but inside comment * removed a test from the comment to see if it breaks CI * removed another function * more testing * emptied test comment * cleaned up comments * Added optimize=True flag to einsum_mullac in cpu_ops.py * Removed unnecessary imports from tests * optimized match_types by removing unnecessary array copying
2023-07-13 01:29:15 +08:00
def test_int8_to_int32(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.int32, [1,2,3,4])
def test_int8_to_int64(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.int64, [1,2,3,4])
def test_uint8_to_float(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.uint8), dtypes.float32, [1,2,3,4])
def test_uint8_to_int8(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.uint8), dtypes.int8, [1,2,3,4])
def test_uint8_to_int64(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.uint8), dtypes.int64, [1,2,3,4])
def test_int8_add(self): _test_add(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.int8, [2,4,6,8])
def test_int64_add(self): _test_add(Tensor([1,2,3,4], dtype=dtypes.int64),Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [2,4,6,8])
def test_int8_mul(self): _test_mul(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.int8, [1,4,9,16])
def test_int64_mul(self): _test_mul(Tensor([1,2,3,4], dtype=dtypes.int64), Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [1,4,9,16])
def test_int8_matmul(self): _test_matmul(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.int8), dtypes.int8, [[1,2],[3,4]])
def test_int64_matmul(self): _test_matmul(Tensor([[1,2],[3,4]], dtype=dtypes.int64), Tensor.eye(2, dtype=dtypes.int64), dtypes.int64, [[1,2],[3,4]])
def test_int8_add_upcast_float(self): _test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [2,4,6,8])
def test_int8_mul_upcast_float(self): _test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [1,4,9,16])
def test_int8_matmul_upcast_float(self): _test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.float32), dtypes.float32, [[1,2],[3,4]])
def test_int8_add_upcast_int64(self): _test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [2,4,6,8])
def test_int8_mul_upcast_int64(self): _test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [1,4,9,16])
def test_int8_matmul_upcast_int64(self): _test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.int64), dtypes.int64, [[1,2],[3,4]])
2023-06-23 05:15:50 +08:00
@unittest.skipIf(getenv("CUDA",0)==1, "cuda saturation works differently")
def test_int8_to_uint8_negative(self): _test_op(lambda: Tensor([-1, -2, -3, -4], dtype=dtypes.int8).cast(dtypes.uint8), dtypes.uint8, [255, 254, 253, 252])
def test_uint8_to_int8_overflow(self): _test_op(lambda: Tensor([255, 254, 253, 252], dtype=dtypes.uint8).cast(dtypes.int8), dtypes.int8, [-1, -2, -3, -4])
Added Test Coverage to Int32 and Make Sure Tests Succeed (#1174) * Added test coverage for int32 in `test/test_dtype.py` Tests for int32 include: - testing that int32 can be converted into a numpy array - testing that float and int64 can be cast into int32 - testing that int32 can be cast into float and int64 - testing addition, multiplication, and matrix multiplication with int32 - testing that addition, multiplication, and matrix multiplication with int32 and either float or int64 gets successfully cast into float and int64, respectively Additional changes include testing that int8 casts into int32 and testing that float16 casts into int32 * Added type casting to the add, subtract, and divide binary operations * Added automatic type casting when types differ to FusedOps.MULACC I moved the match_types function back so that I could call it in einsum_mulacc where it would cast the types of the MULACC to be the same * Added unit test for match_types and added type hints to the parameters * Added tests for ops_cpu.match_types * Changed ops_cpu.einsum logic to play nicely with PyTorch Changed `tinygrad.runtime.ops_cpu.einsum_mulacc` logic to not perform type matching. Type matching was instead moved to the numpy_fxn_for_op dictionary in the ops_cpu file. Since ops_torch uses the same einsum_mulacc function, this should fix all the broken pytorch tests. * empty commit to rerun ci * reverting PR#1213 in attempt to fix broken test * Removed all tests I added to see if they are causing CI issues * Added back type matching tests * removed type matching tests and added back int tests * added back part of the type matching tests * removed braking type matching tests * empty commit for testing * added test back but inside comment * removed a test from the comment to see if it breaks CI * removed another function * more testing * emptied test comment * cleaned up comments * Added optimize=True flag to einsum_mullac in cpu_ops.py * Removed unnecessary imports from tests * optimized match_types by removing unnecessary array copying
2023-07-13 01:29:15 +08:00
class TestInt32Dtype(unittest.TestCase):
def test_int32_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.int32), np.int32, [1,2,3,4])
def test_float_to_int32(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.int32, [1,2,3,4])
def test_int64_to_int32(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int32, [1,2,3,4])
def test_int32_to_float(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.int32), dtypes.float32, [1,2,3,4])
def test_int32_to_int64(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.int32), dtypes.int64, [1,2,3,4])
def test_int32_add(self): _test_add(Tensor([1,2,3,4], dtype=dtypes.int32), Tensor([1,2,3,4], dtype=dtypes.int32), dtypes.int32, [2,4,6,8])
def test_int32_mul(self): _test_mul(Tensor([1,2,3,4], dtype=dtypes.int32), Tensor([1,2,3,4], dtype=dtypes.int32), dtypes.int32, [1,4,9,16])
def test_int32_matmul(self): _test_matmul(Tensor([[1,2],[3,4]], dtype=dtypes.int32), Tensor.eye(2, dtype=dtypes.int32), dtypes.int32, [[1,2],[3,4]])
def test_int32_add_upcast_float(self): _test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.int32), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [2,4,6,8])
def test_int32_mul_upcast_float(self): _test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.int32), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [1,4,9,16])
def test_int32_matmul_upcast_float(self): _test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.int32), Tensor.eye(2, dtype=dtypes.float32), dtypes.float32, [[1,2],[3,4]])
def test_int32_add_upcast_int64(self): _test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.int32), Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [2,4,6,8])
def test_int32_mul_upcast_int64(self): _test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.int32), Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [1,4,9,16])
def test_int32_matmul_upcast_int64(self): _test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.int32), Tensor.eye(2, dtype=dtypes.int64), dtypes.int64, [[1,2],[3,4]])
if __name__ == '__main__':
unittest.main()