mirror of https://github.com/commaai/tinygrad.git
support Int64 type in cstyle gen (#860)
* added metal int64 and some simple tests * removed bool return type def * typo in test * also missing in clang and gpu runtimes * switched order for opencl * increased atol and removed new line in kernel prefix
This commit is contained in:
parent
0fc4cf72a2
commit
0dab8edc97
|
@ -18,6 +18,7 @@ class TestDtype(unittest.TestCase):
|
|||
def test_half_to_np(self): self._test_to_np(Tensor([1,2,3,4], dtype=dtypes.float16), np.float16, [1,2,3,4])
|
||||
def test_int8_to_np(self): self._test_to_np(Tensor([1,2,3,4], dtype=dtypes.int8), np.int8, [1,2,3,4])
|
||||
def test_uint8_to_np(self): self._test_to_np(Tensor([1,2,3,4], dtype=dtypes.uint8), np.uint8, [1,2,3,4])
|
||||
def test_int64_to_np(self): self._test_to_np(Tensor([1,2,3,4], dtype=dtypes.int64), np.int64, [1,2,3,4])
|
||||
|
||||
def _test_cast(self, a, target_dtype, target):
|
||||
print(a)
|
||||
|
@ -29,18 +30,22 @@ class TestDtype(unittest.TestCase):
|
|||
def test_float_to_half(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float16, [1,2,3,4])
|
||||
def test_float_to_int8(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.int8, [1,2,3,4])
|
||||
def test_float_to_uint8(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.uint8, [1,2,3,4])
|
||||
def test_float_to_int64(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.int64, [1,2,3,4])
|
||||
|
||||
def test_half_to_float(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float32, [1,2,3,4])
|
||||
def test_half_to_int8(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.int8, [1,2,3,4])
|
||||
def test_half_to_uint8(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.uint8, [1,2,3,4])
|
||||
def test_half_to_int64(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.int64, [1,2,3,4])
|
||||
|
||||
def test_int8_to_float(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.float32, [1,2,3,4])
|
||||
def test_int8_to_half(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.float16, [1,2,3,4])
|
||||
def test_int8_to_uint8(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.uint8, [1,2,3,4])
|
||||
def test_int8_to_int64(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.int64, [1,2,3,4])
|
||||
|
||||
def test_uint8_to_float(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.uint8), dtypes.float32, [1,2,3,4])
|
||||
def test_uint8_to_half(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.uint8), dtypes.float16, [1,2,3,4])
|
||||
def test_uint8_to_int8(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.uint8), dtypes.int8, [1,2,3,4])
|
||||
def test_uint8_to_int64(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.uint8), dtypes.int64, [1,2,3,4])
|
||||
|
||||
def _test_add(self, a, b, target_dtype, target):
|
||||
c = a+b
|
||||
|
@ -50,6 +55,7 @@ class TestDtype(unittest.TestCase):
|
|||
|
||||
def test_half_add(self): self._test_add(Tensor([1,2,3,4], dtype=dtypes.float16), Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float16, [2,4,6,8])
|
||||
def test_int8_add(self): self._test_add(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.int8, [2,4,6,8])
|
||||
def test_int64_add(self): self._test_add(Tensor([1,2,3,4], dtype=dtypes.int64),Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [2,4,6,8])
|
||||
|
||||
def _test_mul(self, a, b, target_dtype, target):
|
||||
c = a*b
|
||||
|
@ -59,6 +65,7 @@ class TestDtype(unittest.TestCase):
|
|||
|
||||
def test_half_mul(self): self._test_mul(Tensor([1,2,3,4], dtype=dtypes.float16), Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float16, [1,4,9,16])
|
||||
def test_int8_mul(self): self._test_mul(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.int8, [1,4,9,16])
|
||||
def test_int64_mul(self): self._test_mul(Tensor([1,2,3,4], dtype=dtypes.int64), Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [1,4,9,16])
|
||||
|
||||
def _test_matmul(self, a, b, target_dtype, target):
|
||||
c = a@b
|
||||
|
@ -68,6 +75,7 @@ class TestDtype(unittest.TestCase):
|
|||
|
||||
def test_half_matmul(self): self._test_matmul(Tensor([[1,2],[3,4]], dtype=dtypes.float16), Tensor.eye(2, dtype=dtypes.float16), dtypes.float16, [[1,2],[3,4]])
|
||||
def test_int8_matmul(self): self._test_matmul(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.int8), dtypes.int8, [[1,2],[3,4]])
|
||||
def test_int64_matmul(self): self._test_matmul(Tensor([[1,2],[3,4]], dtype=dtypes.int64), Tensor.eye(2, dtype=dtypes.int64), dtypes.int64, [[1,2],[3,4]])
|
||||
|
||||
def _test_add_upcast(self, a, b, target_dtype, target):
|
||||
c = a+b
|
||||
|
@ -78,6 +86,7 @@ class TestDtype(unittest.TestCase):
|
|||
def test_half_add_upcast_float(self): self._test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.float16), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [2,4,6,8])
|
||||
def test_int8_add_upcast_float(self): self._test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [2,4,6,8])
|
||||
def test_int8_add_upcast_half(self): self._test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float16, [2,4,6,8])
|
||||
def test_int8_add_upcast_int64(self): self._test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [2,4,6,8])
|
||||
|
||||
def _test_mul_upcast(self, a, b, target_dtype, target):
|
||||
c = a*b
|
||||
|
@ -88,6 +97,7 @@ class TestDtype(unittest.TestCase):
|
|||
def test_half_mul_upcast_float(self): self._test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.float16), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [1,4,9,16])
|
||||
def test_int8_mul_upcast_float(self): self._test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [1,4,9,16])
|
||||
def test_int8_mul_upcast_half(self): self._test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float16, [1,4,9,16])
|
||||
def test_int8_mul_upcast_int64(self): self._test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [1,4,9,16])
|
||||
|
||||
def _test_matmul_upcast(self, a, b, target_dtype, target):
|
||||
c = a@b
|
||||
|
@ -98,6 +108,7 @@ class TestDtype(unittest.TestCase):
|
|||
def test_half_matmul_upcast_float(self): self._test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.float16), Tensor.eye(2, dtype=dtypes.float32), dtypes.float32, [[1,2],[3,4]])
|
||||
def test_int8_matmul_upcast_float(self): self._test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.float32), dtypes.float32, [[1,2],[3,4]])
|
||||
def test_int8_matmul_upcast_half(self): self._test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.float16), dtypes.float16, [[1,2],[3,4]])
|
||||
def test_int8_matmul_upcast_int64(self): self._test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.int64), dtypes.int64, [[1,2],[3,4]])
|
||||
|
||||
def test_int8_to_uint8_negative(self):
|
||||
a = Tensor([-1, -2, -3, -4], dtype=dtypes.int8)
|
||||
|
|
|
@ -46,7 +46,7 @@ class TestNN(unittest.TestCase):
|
|||
|
||||
np.testing.assert_allclose(bn.running_mean.numpy(), tbn.running_mean.detach().numpy(), rtol=1e-5, atol=1e-6)
|
||||
|
||||
np.testing.assert_allclose(bn.running_var.numpy(), tbn.running_var.detach().numpy(), rtol=1e-5)
|
||||
np.testing.assert_allclose(bn.running_var.numpy(), tbn.running_var.detach().numpy(), rtol=1e-5, atol=1e-6)
|
||||
|
||||
def test_batchnorm2d_training(self):
|
||||
self.test_batchnorm2d(True)
|
||||
|
|
|
@ -132,7 +132,7 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
|
|||
if bufs[args.i] is not None and isinstance(bufs[args.i].realized, RawConst):
|
||||
assert newvar.ltype == LocalTypes.float, "const can't be float4"
|
||||
# nan? inf?
|
||||
val = f"{bufs[args.i].realized._buf}" + ("f" if bufs[args.i].dtype not in (dtypes.int8, dtypes.uint8) else "")
|
||||
val = f"{bufs[args.i].realized._buf}" + ("f" if not dtypes.is_int(bufs[args.i].dtype) else "")
|
||||
elif isinstance(bufs[args.i].dtype, ImageDType):
|
||||
assert newvar.ltype == LocalTypes.float4, "image must be float4"
|
||||
prekernel.add("const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n")
|
||||
|
|
|
@ -75,6 +75,8 @@ class dtypes:
|
|||
int64: Final[DType] = DType(2, 8, "int64", np.int64)
|
||||
uint8: Final[DType] = DType(0, 1, "uchar", np.uint8)
|
||||
@staticmethod
|
||||
def is_int(x: DType): return x in (dtypes.int8, dtypes.uint8, dtypes.int32, dtypes.int64)
|
||||
@staticmethod
|
||||
def from_np(x) -> DType: return asdict(dtypes())[np.dtype(x).name]
|
||||
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ class RawBufferMapped(RawBufferCopyIn):
|
|||
|
||||
# this one is simple enough that i moved it out of the runtimes
|
||||
class RawMallocBuffer(RawBufferMapped):
|
||||
def __init__(self, size, dtype: DType): super().__init__(size, dtype, ({dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16, dtypes.int8: ctypes.c_int8, dtypes.uint8: ctypes.c_uint8, dtypes.bool: ctypes.c_uint8}[dtype] * size)())
|
||||
def __init__(self, size, dtype: DType): super().__init__(size, dtype, ({dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16, dtypes.int8: ctypes.c_int8, dtypes.uint8: ctypes.c_uint8, dtypes.bool: ctypes.c_uint8, dtypes.int64: ctypes.c_int64}[dtype] * size)())
|
||||
def _buffer(self): return memoryview(self._buf)
|
||||
|
||||
class RawBufferCopyInOut(RawBufferCopyIn):
|
||||
|
|
|
@ -5,7 +5,7 @@ from tinygrad.codegen.cstyle import CStyleCodegen, CStyleLanguage
|
|||
|
||||
class ClangProgram:
|
||||
def __init__(self, name:str, prg:str):
|
||||
prg = "#include <math.h>\n#define max(x,y) ((x>y)?x:y)\n#define half __fp16\n#define uchar unsigned char\n#define bool uchar\n" + prg
|
||||
prg = "#include <math.h>\n#define max(x,y) ((x>y)?x:y)\n#define int64 long\n#define half __fp16\n#define uchar unsigned char\n#define bool uchar\n" + prg
|
||||
# TODO: is there a way to not write this to disk?
|
||||
fn = f"/tmp/clang_{hashlib.md5(prg.encode('utf-8')).hexdigest()}.{'dylib' if platform.system() == 'Darwin' else 'so'}"
|
||||
# NOTE: --rtlib=compiler-rt fixes float16 on Linux, it defines __gnu_h2f_ieee and __gnu_f2h_ieee
|
||||
|
|
|
@ -87,7 +87,7 @@ class CLProgram:
|
|||
|
||||
class CLCodegen(CStyleCodegen):
|
||||
lang = CStyleLanguage(
|
||||
kernel_prefix = "__kernel", buffer_prefix = "__global ", smem_prefix = "__local ",
|
||||
kernel_prefix = "#define int64 long\n__kernel", buffer_prefix = "__global ", smem_prefix = "__local ",
|
||||
half_prekernel = "#pragma OPENCL EXTENSION cl_khr_fp16 : enable",
|
||||
barrier = "barrier(CLK_LOCAL_MEM_FENCE);", float4 = "(float4)",
|
||||
gid = [f'get_global_id({i})' for i in range(3)], lid = [f'get_local_id({i})' for i in range(3)], uses_vload=True)
|
||||
|
|
|
@ -78,7 +78,7 @@ class MetalProgram:
|
|||
|
||||
class MetalCodegen(CStyleCodegen):
|
||||
lang = CStyleLanguage(
|
||||
kernel_prefix = "#include <metal_stdlib>\nusing namespace metal;\nkernel", buffer_prefix = "device ", smem_prefix = "threadgroup ",
|
||||
kernel_prefix = "#include <metal_stdlib>;\n#define int64 long\nusing namespace metal;\nkernel", buffer_prefix = "device ", smem_prefix = "threadgroup ",
|
||||
barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);", float4 = "float4",
|
||||
gid = [f"gid.{chr(120+i)}" for i in range(3)], lid = [f"lid.{chr(120+i)}" for i in range(3)],
|
||||
extra_args = ['uint3 gid [[thread_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]'])
|
||||
|
|
Loading…
Reference in New Issue