mirror of https://github.com/commaai/tinygrad.git
Llvm int support (#866)
* added int val support to llvm * lint fix * added types * fix merge issues
This commit is contained in:
parent
5670123d88
commit
1272d8526a
|
@ -1,128 +1,94 @@
|
|||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.helpers import getenv, DType, DEBUG
|
||||
from tinygrad.lazy import Device
|
||||
from tinygrad.tensor import Tensor, dtypes
|
||||
|
||||
|
||||
def _test_to_np(a:Tensor, np_dtype, target):
|
||||
print(a)
|
||||
na = a.numpy()
|
||||
print(na, na.dtype, a.lazydata.realized)
|
||||
assert na.dtype == np_dtype
|
||||
np.testing.assert_allclose(na, target)
|
||||
|
||||
def _test_op(fxn, target_dtype:DType, target):
|
||||
c = fxn()
|
||||
if DEBUG >= 2: print(c.numpy())
|
||||
assert c.dtype == target_dtype
|
||||
np.testing.assert_allclose(c.numpy(), target)
|
||||
|
||||
def _test_cast(a:Tensor, target_dtype:DType, target): _test_op(lambda: a.cast(target_dtype), target_dtype, target)
|
||||
def _test_add(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(lambda: a+b, target_dtype, target)
|
||||
def _test_mul(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(lambda: a*b, target_dtype, target)
|
||||
def _test_matmul(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(lambda: a@b, target_dtype, target)
|
||||
def _test_add_upcast(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(lambda: a+b, target_dtype, target)
|
||||
def _test_mul_upcast(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(lambda: a*b, target_dtype, target)
|
||||
def _test_matmul_upcast(a:Tensor, b:Tensor, target_dtype:DType, target): _test_op(lambda: a@b, target_dtype, target)
|
||||
|
||||
# for GPU, cl_khr_fp16 isn't supported (except now we don't need it!)
|
||||
# for LLVM, it segfaults because it can't link to the casting function
|
||||
@unittest.skipIf(getenv("CI", "") != "" and Device.DEFAULT in ["LLVM"], "float16 broken in some CI backends")
|
||||
class TestDtype(unittest.TestCase):
|
||||
def _test_to_np(self, a, np_dtype, target):
|
||||
print(a)
|
||||
na = a.numpy()
|
||||
print(na, na.dtype, a.lazydata.realized)
|
||||
assert na.dtype == np_dtype
|
||||
np.testing.assert_allclose(na, target)
|
||||
class TestHalfDtype(unittest.TestCase):
|
||||
def test_half_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.float16), np.float16, [1,2,3,4])
|
||||
|
||||
def test_half_to_np(self): self._test_to_np(Tensor([1,2,3,4], dtype=dtypes.float16), np.float16, [1,2,3,4])
|
||||
def test_int8_to_np(self): self._test_to_np(Tensor([1,2,3,4], dtype=dtypes.int8), np.int8, [1,2,3,4])
|
||||
def test_uint8_to_np(self): self._test_to_np(Tensor([1,2,3,4], dtype=dtypes.uint8), np.uint8, [1,2,3,4])
|
||||
def test_int64_to_np(self): self._test_to_np(Tensor([1,2,3,4], dtype=dtypes.int64), np.int64, [1,2,3,4])
|
||||
def test_half_to_float(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float32, [1,2,3,4])
|
||||
def test_half_to_int8(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.int8, [1,2,3,4])
|
||||
def test_half_to_uint8(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.uint8, [1,2,3,4])
|
||||
def test_half_to_int64(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.int64, [1,2,3,4])
|
||||
|
||||
def _test_cast(self, a, target_dtype, target):
|
||||
print(a)
|
||||
b = a.cast(target_dtype)
|
||||
print(b.numpy())
|
||||
assert b.dtype == target_dtype
|
||||
np.testing.assert_allclose(b.numpy(), target)
|
||||
def test_float_to_half(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float16, [1,2,3,4])
|
||||
def test_int8_to_half(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.float16, [1,2,3,4])
|
||||
def test_uint8_to_half(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.uint8), dtypes.float16, [1,2,3,4])
|
||||
|
||||
def test_float_to_half(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float16, [1,2,3,4])
|
||||
def test_float_to_int8(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.int8, [1,2,3,4])
|
||||
def test_float_to_uint8(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.uint8, [1,2,3,4])
|
||||
def test_float_to_int64(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.int64, [1,2,3,4])
|
||||
def test_half_add(self): _test_add(Tensor([1,2,3,4], dtype=dtypes.float16), Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float16, [2,4,6,8])
|
||||
def test_half_mul(self): _test_mul(Tensor([1,2,3,4], dtype=dtypes.float16), Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float16, [1,4,9,16])
|
||||
def test_half_matmul(self): _test_matmul(Tensor([[1,2],[3,4]], dtype=dtypes.float16), Tensor.eye(2, dtype=dtypes.float16), dtypes.float16, [[1,2],[3,4]])
|
||||
|
||||
def test_half_to_float(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float32, [1,2,3,4])
|
||||
def test_half_to_int8(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.int8, [1,2,3,4])
|
||||
def test_half_to_uint8(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.uint8, [1,2,3,4])
|
||||
def test_half_to_int64(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.int64, [1,2,3,4])
|
||||
def test_half_add_upcast_float(self): _test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.float16), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [2,4,6,8])
|
||||
def test_int8_add_upcast_half(self): _test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float16, [2,4,6,8])
|
||||
def test_int8_mul_upcast_half(self): _test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float16, [1,4,9,16])
|
||||
def test_half_mul_upcast_float(self): _test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.float16), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [1,4,9,16])
|
||||
def test_half_matmul_upcast_float(self): _test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.float16), Tensor.eye(2, dtype=dtypes.float32), dtypes.float32, [[1,2],[3,4]])
|
||||
def test_int8_matmul_upcast_half(self): _test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.float16), dtypes.float16, [[1,2],[3,4]])
|
||||
|
||||
def test_int8_to_float(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.float32, [1,2,3,4])
|
||||
def test_int8_to_half(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.float16, [1,2,3,4])
|
||||
def test_int8_to_uint8(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.uint8, [1,2,3,4])
|
||||
def test_int8_to_int64(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.int64, [1,2,3,4])
|
||||
class TestInt8Dtype(unittest.TestCase):
|
||||
def test_int8_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.int8), np.int8, [1,2,3,4])
|
||||
def test_uint8_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.uint8), np.uint8, [1,2,3,4])
|
||||
def test_int64_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.int64), np.int64, [1,2,3,4])
|
||||
|
||||
def test_uint8_to_float(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.uint8), dtypes.float32, [1,2,3,4])
|
||||
def test_uint8_to_half(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.uint8), dtypes.float16, [1,2,3,4])
|
||||
def test_uint8_to_int8(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.uint8), dtypes.int8, [1,2,3,4])
|
||||
def test_uint8_to_int64(self): self._test_cast(Tensor([1,2,3,4], dtype=dtypes.uint8), dtypes.int64, [1,2,3,4])
|
||||
def test_float_to_int8(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.int8, [1,2,3,4])
|
||||
def test_float_to_uint8(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.uint8, [1,2,3,4])
|
||||
def test_float_to_int64(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.int64, [1,2,3,4])
|
||||
|
||||
def _test_add(self, a, b, target_dtype, target):
|
||||
c = a+b
|
||||
print(c.numpy())
|
||||
assert c.dtype == target_dtype
|
||||
np.testing.assert_allclose(c.numpy(), target)
|
||||
def test_int8_to_float(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.float32, [1,2,3,4])
|
||||
def test_int8_to_uint8(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.uint8, [1,2,3,4])
|
||||
def test_int8_to_int64(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.int64, [1,2,3,4])
|
||||
|
||||
def test_half_add(self): self._test_add(Tensor([1,2,3,4], dtype=dtypes.float16), Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float16, [2,4,6,8])
|
||||
def test_int8_add(self): self._test_add(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.int8, [2,4,6,8])
|
||||
def test_int64_add(self): self._test_add(Tensor([1,2,3,4], dtype=dtypes.int64),Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [2,4,6,8])
|
||||
def test_uint8_to_float(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.uint8), dtypes.float32, [1,2,3,4])
|
||||
def test_uint8_to_int8(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.uint8), dtypes.int8, [1,2,3,4])
|
||||
def test_uint8_to_int64(self): _test_cast(Tensor([1,2,3,4], dtype=dtypes.uint8), dtypes.int64, [1,2,3,4])
|
||||
|
||||
def _test_mul(self, a, b, target_dtype, target):
|
||||
c = a*b
|
||||
print(c.numpy())
|
||||
assert c.dtype == target_dtype
|
||||
np.testing.assert_allclose(c.numpy(), target)
|
||||
def test_int8_add(self): _test_add(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.int8, [2,4,6,8])
|
||||
def test_int64_add(self): _test_add(Tensor([1,2,3,4], dtype=dtypes.int64),Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [2,4,6,8])
|
||||
|
||||
def test_half_mul(self): self._test_mul(Tensor([1,2,3,4], dtype=dtypes.float16), Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float16, [1,4,9,16])
|
||||
def test_int8_mul(self): self._test_mul(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.int8, [1,4,9,16])
|
||||
def test_int64_mul(self): self._test_mul(Tensor([1,2,3,4], dtype=dtypes.int64), Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [1,4,9,16])
|
||||
def test_int8_mul(self): _test_mul(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.int8), dtypes.int8, [1,4,9,16])
|
||||
def test_int64_mul(self): _test_mul(Tensor([1,2,3,4], dtype=dtypes.int64), Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [1,4,9,16])
|
||||
|
||||
def _test_matmul(self, a, b, target_dtype, target):
|
||||
c = a@b
|
||||
print(c.numpy())
|
||||
assert c.dtype == target_dtype
|
||||
np.testing.assert_allclose(c.numpy(), target)
|
||||
def test_int8_matmul(self): _test_matmul(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.int8), dtypes.int8, [[1,2],[3,4]])
|
||||
def test_int64_matmul(self): _test_matmul(Tensor([[1,2],[3,4]], dtype=dtypes.int64), Tensor.eye(2, dtype=dtypes.int64), dtypes.int64, [[1,2],[3,4]])
|
||||
|
||||
def test_half_matmul(self): self._test_matmul(Tensor([[1,2],[3,4]], dtype=dtypes.float16), Tensor.eye(2, dtype=dtypes.float16), dtypes.float16, [[1,2],[3,4]])
|
||||
def test_int8_matmul(self): self._test_matmul(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.int8), dtypes.int8, [[1,2],[3,4]])
|
||||
def test_int64_matmul(self): self._test_matmul(Tensor([[1,2],[3,4]], dtype=dtypes.int64), Tensor.eye(2, dtype=dtypes.int64), dtypes.int64, [[1,2],[3,4]])
|
||||
def test_int8_add_upcast_float(self): _test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [2,4,6,8])
|
||||
def test_int8_mul_upcast_float(self): _test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [1,4,9,16])
|
||||
def test_int8_matmul_upcast_float(self): _test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.float32), dtypes.float32, [[1,2],[3,4]])
|
||||
|
||||
def _test_add_upcast(self, a, b, target_dtype, target):
|
||||
c = a+b
|
||||
print(c.numpy())
|
||||
assert c.dtype == target_dtype
|
||||
np.testing.assert_allclose(c.numpy(), target)
|
||||
def test_int8_add_upcast_int64(self): _test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [2,4,6,8])
|
||||
def test_int8_mul_upcast_int64(self): _test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [1,4,9,16])
|
||||
def test_int8_matmul_upcast_int64(self): _test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.int64), dtypes.int64, [[1,2],[3,4]])
|
||||
|
||||
def test_half_add_upcast_float(self): self._test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.float16), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [2,4,6,8])
|
||||
def test_int8_add_upcast_float(self): self._test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [2,4,6,8])
|
||||
def test_int8_add_upcast_half(self): self._test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float16, [2,4,6,8])
|
||||
def test_int8_add_upcast_int64(self): self._test_add_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [2,4,6,8])
|
||||
def test_int8_to_uint8_negative(self): _test_op(lambda: Tensor([-1, -2, -3, -4], dtype=dtypes.int8).cast(dtypes.uint8), dtypes.uint8, [255, 254, 253, 252])
|
||||
|
||||
def _test_mul_upcast(self, a, b, target_dtype, target):
|
||||
c = a*b
|
||||
print(c.numpy())
|
||||
assert c.dtype == target_dtype
|
||||
np.testing.assert_allclose(c.numpy(), target)
|
||||
|
||||
def test_half_mul_upcast_float(self): self._test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.float16), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [1,4,9,16])
|
||||
def test_int8_mul_upcast_float(self): self._test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.float32, [1,4,9,16])
|
||||
def test_int8_mul_upcast_half(self): self._test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.float16), dtypes.float16, [1,4,9,16])
|
||||
def test_int8_mul_upcast_int64(self): self._test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [1,4,9,16])
|
||||
|
||||
def _test_matmul_upcast(self, a, b, target_dtype, target):
|
||||
c = a@b
|
||||
print(c.numpy())
|
||||
assert c.dtype == target_dtype
|
||||
np.testing.assert_allclose(c.numpy(), target)
|
||||
|
||||
def test_half_matmul_upcast_float(self): self._test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.float16), Tensor.eye(2, dtype=dtypes.float32), dtypes.float32, [[1,2],[3,4]])
|
||||
def test_int8_matmul_upcast_float(self): self._test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.float32), dtypes.float32, [[1,2],[3,4]])
|
||||
def test_int8_matmul_upcast_half(self): self._test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.float16), dtypes.float16, [[1,2],[3,4]])
|
||||
def test_int8_matmul_upcast_int64(self): self._test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.int64), dtypes.int64, [[1,2],[3,4]])
|
||||
|
||||
def test_int8_to_uint8_negative(self):
|
||||
a = Tensor([-1, -2, -3, -4], dtype=dtypes.int8)
|
||||
print(a)
|
||||
b = a.cast(dtypes.uint8)
|
||||
print(b.numpy())
|
||||
np.testing.assert_allclose(b.numpy(), [255, 254, 253, 252])
|
||||
|
||||
def test_uint8_to_int8_overflow(self):
|
||||
a = Tensor([255, 254, 253, 252], dtype=dtypes.uint8)
|
||||
print(a)
|
||||
b = a.cast(dtypes.int8)
|
||||
print(b.numpy())
|
||||
np.testing.assert_allclose(b.numpy(), [-1, -2, -3, -4])
|
||||
def test_uint8_to_int8_overflow(self): _test_op(lambda: Tensor([255, 254, 253, 252], dtype=dtypes.uint8).cast(dtypes.int8), dtypes.int8, [-1, -2, -3, -4])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -38,7 +38,7 @@ def uops_to_llvm_ir(uops:List[UOp], bufs:List[LazyBuffer]) -> str:
|
|||
module = ir.Module(name=__file__)
|
||||
|
||||
# create llvm function
|
||||
func_dtypes = [{dtypes.float16:ir.HalfType(), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1)}[buf.dtype] for buf in bufs]
|
||||
func_dtypes = [{dtypes.float16:ir.HalfType(), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64)}[buf.dtype] for buf in bufs]
|
||||
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() for x in func_dtypes]), name='exec')
|
||||
|
||||
# force llvmlite to allow us to add function attribute then add the attribute
|
||||
|
@ -88,13 +88,21 @@ def uops_to_llvm_ir(uops:List[UOp], bufs:List[LazyBuffer]) -> str:
|
|||
val = bb[-1].select(valid, bb[-1].load(bb[-1].gep(func.args[args.i], [aug_idx], inbounds=True)), ir.Constant(func_dtypes[args[0]], 0))
|
||||
else:
|
||||
val = bb[-1].load(bb[-1].gep(func.args[args.i], [idx], inbounds=True))
|
||||
if func_dtypes[args.i] != ir.FloatType(): val = bb[-1].fpext(val, ir.FloatType())
|
||||
if func_dtypes[args.i] != ir.FloatType():
|
||||
if dtypes.is_int(bufs[args.i].dtype):
|
||||
val = bb[-1].uitofp(val, ir.FloatType()) if dtypes.is_unsigned(bufs[args.i].dtype) else bb[-1].sitofp(val, ir.FloatType())
|
||||
else:
|
||||
val = bb[-1].fpext(val, ir.FloatType())
|
||||
lvars[newvar] = val
|
||||
if uop == UOps.STORE:
|
||||
assert args.valid.min == 1, "store must be valid"
|
||||
idx = args.idx.render(render_llvm, bb[-1])
|
||||
element = lvars[vin[0]]
|
||||
if func_dtypes[0] != ir.FloatType(): element = bb[-1].fptrunc(element, func_dtypes[0])
|
||||
if func_dtypes[0] != ir.FloatType():
|
||||
if dtypes.is_int(bufs[args.i].dtype):
|
||||
element = bb[-1].fptoui(element, func_dtypes[0]) if dtypes.is_unsigned(bufs[args.i].dtype) else bb[-1].fptosi(element, func_dtypes[0])
|
||||
else:
|
||||
element = bb[-1].fptrunc(element, func_dtypes[0])
|
||||
bb[-1].store(element, bb[-1].gep(func.args[args.i], [idx], inbounds=True))
|
||||
if uop == UOps.ALU:
|
||||
lvars[newvar] = code_for_op[args](bb[-1], *[lvars[x] for x in vin])
|
||||
|
|
|
@ -77,6 +77,8 @@ class dtypes:
|
|||
@staticmethod
|
||||
def is_int(x: DType): return x in (dtypes.int8, dtypes.uint8, dtypes.int32, dtypes.int64)
|
||||
@staticmethod
|
||||
def is_unsigned(x: DType): return x in (dtypes.uint8)
|
||||
@staticmethod
|
||||
def from_np(x) -> DType: return asdict(dtypes())[np.dtype(x).name]
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue