Llvm int support (#866)

* added int val support to llvm

* lint fix

* added types

* fix merge issues
This commit is contained in:
Diogo 2023-05-30 20:49:26 -04:00 committed by GitHub
parent 5670123d88
commit 1272d8526a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 85 additions and 109 deletions

View File

@ -1,128 +1,94 @@
import unittest import unittest
import numpy as np import numpy as np
from tinygrad.helpers import getenv from tinygrad.helpers import getenv, DType, DEBUG
from tinygrad.lazy import Device from tinygrad.lazy import Device
from tinygrad.tensor import Tensor, dtypes from tinygrad.tensor import Tensor, dtypes
def _test_to_np(a:Tensor, np_dtype, target):
print(a)
na = a.numpy()
print(na, na.dtype, a.lazydata.realized)
assert na.dtype == np_dtype
np.testing.assert_allclose(na, target)
def _test_op(fxn, target_dtype:DType, target):
c = fxn()
if DEBUG >= 2: print(c.numpy())
assert c.dtype == target_dtype
np.testing.assert_allclose(c.numpy(), target)
def _test_cast(a:Tensor, target_dtype:DType, target): _test_op(lambda: a.cast(target_dtype), target_dtype, target)
def _test_add(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(lambda: a+b, target_dtype, target)
def _test_mul(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(lambda: a*b, target_dtype, target)
def _test_matmul(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(lambda: a@b, target_dtype, target)
def _test_add_upcast(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(lambda: a+b, target_dtype, target)
def _test_mul_upcast(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(lambda: a*b, target_dtype, target)
def _test_matmul_upcast(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(lambda: a@b, target_dtype, target)
# for GPU, cl_khr_fp16 isn't supported (except now we don't need it!) # for GPU, cl_khr_fp16 isn't supported (except now we don't need it!)
# for LLVM, it segfaults because it can't link to the casting function # for LLVM, it segfaults because it can't link to the casting function
@unittest.skipIf(getenv("CI", "") != "" and Device.DEFAULT in ["LLVM"], "float16 broken in some CI backends") @unittest.skipIf(getenv("CI", "") != "" and Device.DEFAULT in ["LLVM"], "float16 broken in some CI backends")
class TestDtype(unittest.TestCase): class TestHalfDtype(unittest.TestCase):
def _test_to_np(self, a, np_dtype, target): def test_half_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.float16), np.float16, [1,2,3,4])
print(a)
na = a.numpy()
print(na, na.dtype, a.lazydata.realized)
assert na.dtype == np_dtype
np.testing.assert_allclose(na, target)
def test_half_to_np(self): self._test_to_np(Tensor([1,2,3,4], dtype=dtypes.float16), np.float16, [1,2,3,4])
def test_int8_to_np(self): self._test_to_np(Tensor([1,2,3,4], dtype=dtypes.int8), np.int8, [1,2,3,4])
def test_uint8_to_np(self): self._test_to_np(Tensor([1,2,3,4], dtype=dtypes.uint8), np.uint8, [1,2,3,4])
def test_int64_to_np(self): self._test_to_np(Tensor([1,2,3,4], dtype=dtypes.int64), np.int64, [1,2,3,4])
def _test_cast(self, a, target_dtype, target):
print(a)
b = a.cast(target_dtype)
print(b.numpy())
assert b.dtype == target_dtype
np.testing.assert_allclose(b.numpy(), target)
def test_float_to_half(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float16, [1,2,3,4])
def test_float_to_int8(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.int8, [1,2,3,4])
def test_float_to_uint8(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.uint8, [1,2,3,4])
def test_float_to_int64(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.int64, [1,2,3,4])
def test_half_to_float(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float32, [1,2,3,4])
def test_half_to_int8(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.int8, [1,2,3,4])
def test_half_to_uint8(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.uint8, [1,2,3,4])
def test_half_to_int64(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.int64, [1,2,3,4])
def test_int8_to_float(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.float32, [1,2,3,4])
def test_int8_to_half(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.float16, [1,2,3,4])
def test_int8_to_uint8(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.uint8, [1,2,3,4])
def test_int8_to_int64(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.int64, [1,2,3,4])
def test_uint8_to_float(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.uint8), dtypes.float32, [1,2,3,4])
def test_uint8_to_half(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.uint8), dtypes.float16, [1,2,3,4])
def test_uint8_to_int8(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.uint8), dtypes.int8, [1,2,3,4])
def test_uint8_to_int64(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.uint8), dtypes.int64, [1,2,3,4])
def _test_add(self, a, b, target_dtype, target):
c = a+b
print(c.numpy())
assert c.dtype == target_dtype
np.testing.assert_allclose(c.numpy(), target)
def test_half_add(self): self._test_add(Tensor([1,2,3,4], dtype=dtypes.float16), Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float16, [2,4,6,8])
def test_int8_add(self): self._test_add(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.int8, [2,4,6,8])
def test_int64_add(self): self._test_add(Tensor([1,2,3,4], dtype=dtypes.int64),Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [2,4,6,8])
def _test_mul(self, a, b, target_dtype, target): def test_half_to_float(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float32, [1,2,3,4])
c = a*b def test_half_to_int8(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.int8, [1,2,3,4])
print(c.numpy()) def test_half_to_uint8(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.uint8, [1,2,3,4])
assert c.dtype == target_dtype def test_half_to_int64(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.int64, [1,2,3,4])
np.testing.assert_allclose(c.numpy(), target)
def test_half_mul(self): self._test_mul(Tensor([1,2,3,4], dtype=dtypes.float16), Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float16, [1,4,9,16]) def test_float_to_half(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float16, [1,2,3,4])
def test_int8_mul(self): self._test_mul(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.int8, [1,4,9,16]) def test_int8_to_half(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.float16, [1,2,3,4])
def test_int64_mul(self): self._test_mul(Tensor([1,2,3,4], dtype=dtypes.int64), Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [1,4,9,16]) def test_uint8_to_half(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.uint8), dtypes.float16, [1,2,3,4])
def test_half_add(self): _test_add(Tensor([1,2,3,4], dtype=dtypes.float16), Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float16, [2,4,6,8])
def test_half_mul(self): _test_mul(Tensor([1,2,3,4], dtype=dtypes.float16), Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float16, [1,4,9,16])
def test_half_matmul(self): _test_matmul(Tensor([[1,2],[3,4]], dtype=dtypes.float16), Tensor.eye(2, dtype=dtypes.float16), dtypes.float16, [[1,2],[3,4]])
def test_half_add_upcast_float(self): _test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.float16), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [2,4,6,8])
def test_int8_add_upcast_half(self): _test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float16, [2,4,6,8])
def test_int8_mul_upcast_half(self): _test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float16, [1,4,9,16])
def test_half_mul_upcast_float(self): _test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.float16), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [1,4,9,16])
def test_half_matmul_upcast_float(self): _test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.float16), Tensor.eye(2, dtype=dtypes.float32), dtypes.float32, [[1,2],[3,4]])
def test_int8_matmul_upcast_half(self): _test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.float16), dtypes.float16, [[1,2],[3,4]])
def _test_matmul(self, a, b, target_dtype, target): class TestInt8Dtype(unittest.TestCase):
c = a@b def test_int8_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.int8), np.int8, [1,2,3,4])
print(c.numpy()) def test_uint8_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.uint8), np.uint8, [1,2,3,4])
assert c.dtype == target_dtype def test_int64_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.int64), np.int64, [1,2,3,4])
np.testing.assert_allclose(c.numpy(), target)
def test_half_matmul(self): self._test_matmul(Tensor([[1,2],[3,4]], dtype=dtypes.float16), Tensor.eye(2, dtype=dtypes.float16), dtypes.float16, [[1,2],[3,4]]) def test_float_to_int8(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.int8, [1,2,3,4])
def test_int8_matmul(self): self._test_matmul(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.int8), dtypes.int8, [[1,2],[3,4]]) def test_float_to_uint8(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.uint8, [1,2,3,4])
def test_int64_matmul(self): self._test_matmul(Tensor([[1,2],[3,4]], dtype=dtypes.int64), Tensor.eye(2, dtype=dtypes.int64), dtypes.int64, [[1,2],[3,4]]) def test_float_to_int64(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.int64, [1,2,3,4])
def _test_add_upcast(self, a, b, target_dtype, target): def test_int8_to_float(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.float32, [1,2,3,4])
c = a+b def test_int8_to_uint8(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.uint8, [1,2,3,4])
print(c.numpy()) def test_int8_to_int64(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.int64, [1,2,3,4])
assert c.dtype == target_dtype
np.testing.assert_allclose(c.numpy(), target)
def test_half_add_upcast_float(self): self._test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.float16), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [2,4,6,8]) def test_uint8_to_float(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.uint8), dtypes.float32, [1,2,3,4])
def test_int8_add_upcast_float(self): self._test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [2,4,6,8]) def test_uint8_to_int8(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.uint8), dtypes.int8, [1,2,3,4])
def test_int8_add_upcast_half(self): self._test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float16, [2,4,6,8]) def test_uint8_to_int64(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.uint8), dtypes.int64, [1,2,3,4])
def test_int8_add_upcast_int64(self): self._test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [2,4,6,8])
def _test_mul_upcast(self, a, b, target_dtype, target): def test_int8_add(self): _test_add(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.int8, [2,4,6,8])
c = a*b def test_int64_add(self): _test_add(Tensor([1,2,3,4], dtype=dtypes.int64),Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [2,4,6,8])
print(c.numpy())
assert c.dtype == target_dtype
np.testing.assert_allclose(c.numpy(), target)
def test_half_mul_upcast_float(self): self._test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.float16), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [1,4,9,16]) def test_int8_mul(self): _test_mul(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.int8, [1,4,9,16])
def test_int8_mul_upcast_float(self): self._test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [1,4,9,16]) def test_int64_mul(self): _test_mul(Tensor([1,2,3,4], dtype=dtypes.int64), Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [1,4,9,16])
def test_int8_mul_upcast_half(self): self._test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float16, [1,4,9,16])
def test_int8_mul_upcast_int64(self): self._test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [1,4,9,16])
def _test_matmul_upcast(self, a, b, target_dtype, target): def test_int8_matmul(self): _test_matmul(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.int8), dtypes.int8, [[1,2],[3,4]])
c = a@b def test_int64_matmul(self): _test_matmul(Tensor([[1,2],[3,4]], dtype=dtypes.int64), Tensor.eye(2, dtype=dtypes.int64), dtypes.int64, [[1,2],[3,4]])
print(c.numpy())
assert c.dtype == target_dtype
np.testing.assert_allclose(c.numpy(), target)
def test_half_matmul_upcast_float(self): self._test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.float16), Tensor.eye(2, dtype=dtypes.float32), dtypes.float32, [[1,2],[3,4]]) def test_int8_add_upcast_float(self): _test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [2,4,6,8])
def test_int8_matmul_upcast_float(self): self._test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.float32), dtypes.float32, [[1,2],[3,4]]) def test_int8_mul_upcast_float(self): _test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [1,4,9,16])
def test_int8_matmul_upcast_half(self): self._test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.float16), dtypes.float16, [[1,2],[3,4]]) def test_int8_matmul_upcast_float(self): _test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.float32), dtypes.float32, [[1,2],[3,4]])
def test_int8_matmul_upcast_int64(self): self._test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.int64), dtypes.int64, [[1,2],[3,4]])
def test_int8_to_uint8_negative(self): def test_int8_add_upcast_int64(self): _test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [2,4,6,8])
a = Tensor([-1, -2, -3, -4], dtype=dtypes.int8) def test_int8_mul_upcast_int64(self): _test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [1,4,9,16])
print(a) def test_int8_matmul_upcast_int64(self): _test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.int64), dtypes.int64, [[1,2],[3,4]])
b = a.cast(dtypes.uint8)
print(b.numpy())
np.testing.assert_allclose(b.numpy(), [255, 254, 253, 252])
def test_uint8_to_int8_overflow(self): def test_int8_to_uint8_negative(self): _test_op(lambda: Tensor([-1, -2, -3, -4], dtype=dtypes.int8).cast(dtypes.uint8), dtypes.uint8, [255, 254, 253, 252])
a = Tensor([255, 254, 253, 252], dtype=dtypes.uint8)
print(a) def test_uint8_to_int8_overflow(self): _test_op(lambda: Tensor([255, 254, 253, 252], dtype=dtypes.uint8).cast(dtypes.int8), dtypes.int8, [-1, -2, -3, -4])
b = a.cast(dtypes.int8)
print(b.numpy())
np.testing.assert_allclose(b.numpy(), [-1, -2, -3, -4])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -38,7 +38,7 @@ def uops_to_llvm_ir(uops:List[UOp], bufs:List[LazyBuffer]) -> str:
module = ir.Module(name=__file__) module = ir.Module(name=__file__)
# create llvm function # create llvm function
func_dtypes = [{dtypes.float16:ir.HalfType(), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1)}[buf.dtype] for buf in bufs] func_dtypes = [{dtypes.float16:ir.HalfType(), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64)}[buf.dtype] for buf in bufs]
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() for x in func_dtypes]), name='exec') func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() for x in func_dtypes]), name='exec')
# force llvmlite to allow us to add function attribute then add the attribute # force llvmlite to allow us to add function attribute then add the attribute
@ -88,13 +88,21 @@ def uops_to_llvm_ir(uops:List[UOp], bufs:List[LazyBuffer]) -> str:
val = bb[-1].select(valid, bb[-1].load(bb[-1].gep(func.args[args.i], [aug_idx], inbounds=True)), ir.Constant(func_dtypes[args[0]], 0)) val = bb[-1].select(valid, bb[-1].load(bb[-1].gep(func.args[args.i], [aug_idx], inbounds=True)), ir.Constant(func_dtypes[args[0]], 0))
else: else:
val = bb[-1].load(bb[-1].gep(func.args[args.i], [idx], inbounds=True)) val = bb[-1].load(bb[-1].gep(func.args[args.i], [idx], inbounds=True))
if func_dtypes[args.i] != ir.FloatType(): val = bb[-1].fpext(val, ir.FloatType()) if func_dtypes[args.i] != ir.FloatType():
if dtypes.is_int(bufs[args.i].dtype):
val = bb[-1].uitofp(val, ir.FloatType()) if dtypes.is_unsigned(bufs[args.i].dtype) else bb[-1].sitofp(val, ir.FloatType())
else:
val = bb[-1].fpext(val, ir.FloatType())
lvars[newvar] = val lvars[newvar] = val
if uop == UOps.STORE: if uop == UOps.STORE:
assert args.valid.min == 1, "store must be valid" assert args.valid.min == 1, "store must be valid"
idx = args.idx.render(render_llvm, bb[-1]) idx = args.idx.render(render_llvm, bb[-1])
element = lvars[vin[0]] element = lvars[vin[0]]
if func_dtypes[0] != ir.FloatType(): element = bb[-1].fptrunc(element, func_dtypes[0]) if func_dtypes[0] != ir.FloatType():
if dtypes.is_int(bufs[args.i].dtype):
element = bb[-1].fptoui(element, func_dtypes[0]) if dtypes.is_unsigned(bufs[args.i].dtype) else bb[-1].fptosi(element, func_dtypes[0])
else:
element = bb[-1].fptrunc(element, func_dtypes[0])
bb[-1].store(element, bb[-1].gep(func.args[args.i], [idx], inbounds=True)) bb[-1].store(element, bb[-1].gep(func.args[args.i], [idx], inbounds=True))
if uop == UOps.ALU: if uop == UOps.ALU:
lvars[newvar] = code_for_op[args](bb[-1], *[lvars[x] for x in vin]) lvars[newvar] = code_for_op[args](bb[-1], *[lvars[x] for x in vin])

View File

@ -75,7 +75,9 @@ class dtypes:
int64: Final[DType] = DType(2, 8, "int64", np.int64) int64: Final[DType] = DType(2, 8, "int64", np.int64)
uint8: Final[DType] = DType(0, 1, "uchar", np.uint8) uint8: Final[DType] = DType(0, 1, "uchar", np.uint8)
@staticmethod @staticmethod
def is_int(x: DType): return x in (dtypes.int8, dtypes.uint8, dtypes.int32, dtypes.int64) def is_int(x: DType): return x in (dtypes.int8, dtypes.uint8, dtypes.int32, dtypes.int64)
@staticmethod
def is_unsigned(x: DType): return x in (dtypes.uint8)
@staticmethod @staticmethod
def from_np(x) -> DType: return asdict(dtypes())[np.dtype(x).name] def from_np(x) -> DType: return asdict(dtypes())[np.dtype(x).name]