2023-12-13 09:34:47 +08:00
|
|
|
# ruff: noqa: E501
|
2023-09-02 03:53:07 +08:00
|
|
|
from typing import Optional, Tuple, Any, List
|
2023-08-06 03:35:56 +08:00
|
|
|
import unittest, math
|
|
|
|
import numpy as np
|
2023-09-05 00:58:33 +08:00
|
|
|
from tinygrad.helpers import dtypes, getenv, DType, PtrDType
|
2023-12-01 09:07:16 +08:00
|
|
|
from tinygrad.device import Buffer, Device
|
2023-11-28 03:34:37 +08:00
|
|
|
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
|
|
|
|
from tinygrad.device import CompiledASTRunner, Compiled
|
2023-09-05 00:58:33 +08:00
|
|
|
from tinygrad.codegen.linearizer import UOps, UOp
|
2023-12-15 11:41:51 +08:00
|
|
|
from test.test_dtype import is_dtype_supported
|
2023-08-06 03:35:56 +08:00
|
|
|
|
2023-12-02 16:32:25 +08:00
|
|
|
def _uops_to_prg(uops):
|
2023-10-18 01:33:32 +08:00
|
|
|
src, runtime_args = Device[Device.DEFAULT].renderer("test", uops)
|
2023-11-16 03:13:38 +08:00
|
|
|
return CompiledASTRunner(None, "test", src,
|
|
|
|
[1] if Device[Device.DEFAULT].linearizer_opts.has_local else None, [1] if Device[Device.DEFAULT].linearizer_opts.has_local else None,
|
2023-12-02 16:32:25 +08:00
|
|
|
runtime_args=runtime_args).build(Device[Device.DEFAULT].compiler, Device[Device.DEFAULT].runtime)
|
2023-08-06 03:35:56 +08:00
|
|
|
|
2023-09-02 03:53:07 +08:00
|
|
|
def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None) -> UOp:
|
2023-12-04 05:45:49 +08:00
|
|
|
uops.append(UOp(uop, dtype if arg != BinaryOps.CMPLT else dtypes.bool, tuple(vin), arg))
|
2023-09-02 03:53:07 +08:00
|
|
|
return uops[-1]
|
|
|
|
|
|
|
|
def _test_single_value(vals, op, dtype):
|
|
|
|
uops = []
|
2023-12-15 04:42:36 +08:00
|
|
|
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), 'data0')
|
|
|
|
buf_loads = [uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), f'data{i+1}') for i in range(len(vals))]
|
2023-09-05 00:58:33 +08:00
|
|
|
loads = (uop(uops, UOps.LOAD, dtype, [buf_loads[i], uop(uops, UOps.CONST, dtypes.int32, (), 0)]) for i in range(len(vals)))
|
2023-09-02 03:53:07 +08:00
|
|
|
alu = uop(uops, UOps.ALU, dtype, loads, op)
|
2023-09-05 00:58:33 +08:00
|
|
|
uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
|
2023-12-01 09:07:16 +08:00
|
|
|
buf = Buffer(Device.DEFAULT, 1, dtype)
|
|
|
|
buf2 = [Buffer.fromCPU(Device.DEFAULT, np.array([a], dtype=dtype.np)) for a in vals]
|
2023-12-02 16:32:25 +08:00
|
|
|
prg = _uops_to_prg(uops)
|
2023-11-24 04:46:07 +08:00
|
|
|
prg.exec([buf]+buf2)
|
2023-08-06 03:35:56 +08:00
|
|
|
return buf.toCPU()[0]
|
|
|
|
|
2023-09-02 03:53:07 +08:00
|
|
|
def _test_single_value_const(vals, op, dtype):
|
|
|
|
uops = []
|
2023-12-15 04:42:36 +08:00
|
|
|
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), 'data0')
|
2023-09-02 10:01:43 +08:00
|
|
|
loads = (uop(uops, UOps.CONST, dtype, [], a) for a in vals)
|
2023-09-02 03:53:07 +08:00
|
|
|
alu = uop(uops, UOps.ALU, dtype, loads, op)
|
2023-09-05 00:58:33 +08:00
|
|
|
uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
|
2023-12-01 09:07:16 +08:00
|
|
|
buf = Buffer(Device.DEFAULT, 1, dtype)
|
2023-12-02 16:32:25 +08:00
|
|
|
prg = _uops_to_prg(uops)
|
2023-11-24 04:46:07 +08:00
|
|
|
prg.exec([buf])
|
2023-08-06 03:35:56 +08:00
|
|
|
return buf.toCPU()[0]
|
|
|
|
|
|
|
|
class TestUOps(unittest.TestCase):
|
|
|
|
def _equal(self, v1, v2):
|
2023-12-04 05:45:49 +08:00
|
|
|
if not (math.isnan(v1) and math.isnan(v2)): self.assertAlmostEqual(v1, v2, places=5) if v1.dtype != np.bool_ else self.assertEqual(v1, v2)
|
2023-08-06 03:35:56 +08:00
|
|
|
|
2023-12-15 03:17:14 +08:00
|
|
|
def _test_uop_fxn(self, bop, fxn, dt=PtrDType(dtypes.float32)):
|
2023-08-06 03:35:56 +08:00
|
|
|
for f in [_test_single_value, _test_single_value_const]:
|
2023-10-19 04:46:42 +08:00
|
|
|
for a in [-2.0, 0.0, 1.0]:
|
2023-09-02 03:53:07 +08:00
|
|
|
self._equal(f([a], bop, dt), fxn(a))
|
2023-08-06 03:35:56 +08:00
|
|
|
|
2023-12-15 03:17:14 +08:00
|
|
|
def _test_bop_fxn(self, bop, fxn, dt=PtrDType(dtypes.float32), no_b_zero=False):
|
2023-08-06 03:35:56 +08:00
|
|
|
for f in [_test_single_value, _test_single_value_const]:
|
2023-10-19 04:46:42 +08:00
|
|
|
for a in [-2.0, 0.0, 1.0]:
|
|
|
|
for b in [-3.0, 1.0] + ([] if no_b_zero else [0.0]):
|
2023-09-02 03:53:07 +08:00
|
|
|
self._equal(f([a,b], bop, dt), fxn(a,b))
|
2023-08-06 15:30:50 +08:00
|
|
|
|
2023-12-15 03:17:14 +08:00
|
|
|
def _test_top_fxn(self, bop, fxn, dt=PtrDType(dtypes.float32)):
|
2023-08-06 03:35:56 +08:00
|
|
|
for f in [_test_single_value, _test_single_value_const]:
|
2023-10-19 04:46:42 +08:00
|
|
|
for a in [-2.0, 0, 1]:
|
2023-08-06 03:35:56 +08:00
|
|
|
for b in [-3.0, 3.0]:
|
|
|
|
for c in [-4.0, 4.0]:
|
2023-09-02 03:53:07 +08:00
|
|
|
self._equal(f([a,b,c], bop, dt), fxn(a,b,c))
|
2023-08-16 00:07:26 +08:00
|
|
|
|
|
|
|
@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled), "only test for compiled backends")
|
|
|
|
class TestFloatUOps(TestUOps):
|
2023-09-04 03:44:26 +08:00
|
|
|
def test_neg(self): self._test_uop_fxn(UnaryOps.NEG, lambda a: -a)
|
2023-08-16 00:07:26 +08:00
|
|
|
def test_exp2(self): self._test_uop_fxn(UnaryOps.EXP2, lambda a: np.exp2(a))
|
|
|
|
def test_log2(self): self._test_uop_fxn(UnaryOps.LOG2, lambda a: math.log2(a) if a > 0 else float('-inf' if a==0 else 'nan'))
|
|
|
|
def test_sin(self): self._test_uop_fxn(UnaryOps.SIN, lambda a: math.sin(a))
|
|
|
|
def test_sqrt(self): self._test_uop_fxn(UnaryOps.SQRT, lambda a: math.sqrt(a) if a >= 0 else float('nan'))
|
|
|
|
# this is not on most backends
|
|
|
|
#def test_recip(self): self._test_uop_fxn(UnaryOps.RECIP, lambda a: 1.0/a if a != 0 else float('inf'))
|
|
|
|
|
|
|
|
def test_add(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: a+b)
|
|
|
|
def test_sub(self): self._test_bop_fxn(BinaryOps.SUB, lambda a,b: a-b)
|
|
|
|
def test_mul(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: a*b)
|
|
|
|
def test_div(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: a/b if b != 0 else a*float('inf'))
|
|
|
|
def test_max(self): self._test_bop_fxn(BinaryOps.MAX, lambda a,b: max(a,b))
|
2023-12-04 05:45:49 +08:00
|
|
|
def test_cmplt(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: a<b)
|
2023-08-16 00:07:26 +08:00
|
|
|
# MOD isn't tested on floats
|
|
|
|
|
2023-08-06 03:35:56 +08:00
|
|
|
def test_mulacc(self): self._test_top_fxn(TernaryOps.MULACC, lambda a,b,c: (a*b)+c)
|
|
|
|
def test_where(self): self._test_top_fxn(TernaryOps.WHERE, lambda a,b,c: b if a!=0 else c)
|
|
|
|
|
2023-08-16 00:07:26 +08:00
|
|
|
# TODO: fix this on all the backends
|
2023-08-16 14:22:32 +08:00
|
|
|
@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled) or getenv('ARM64', False), "only test for compiled backends, broken on some")
|
2023-08-16 00:07:26 +08:00
|
|
|
class TestNonFloatUOps(TestUOps):
|
2023-12-15 03:17:14 +08:00
|
|
|
def test_neg_int32(self): self._test_uop_fxn(UnaryOps.NEG, lambda a: -a, PtrDType(dtypes.int32))
|
|
|
|
def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: int(a)+int(b), PtrDType(dtypes.int32))
|
|
|
|
def test_sub_int32(self): self._test_bop_fxn(BinaryOps.SUB, lambda a,b: int(a)-int(b), PtrDType(dtypes.int32))
|
|
|
|
def test_mul_int32(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: int(a)*int(b), PtrDType(dtypes.int32))
|
|
|
|
def test_div_int32(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: int(a/b), PtrDType(dtypes.int32), no_b_zero=True)
|
|
|
|
def test_mod_int32(self): self._test_bop_fxn(BinaryOps.MOD, lambda a,b: abs(int(a))%abs(int(b))*(1,-1)[a<0], PtrDType(dtypes.int32), no_b_zero=True)
|
|
|
|
def test_cmplt_int32(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a<b), PtrDType(dtypes.int32))
|
2023-12-15 11:41:51 +08:00
|
|
|
@unittest.skipUnless(is_dtype_supported(dtypes.bool), "dtype not supported")
|
2023-12-15 03:17:14 +08:00
|
|
|
def test_mul_bool(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: bool(a) and bool(b), PtrDType(dtypes.bool))
|
2023-12-15 11:41:51 +08:00
|
|
|
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "dtype not supported")
|
|
|
|
def test_where_float16(self): self._test_top_fxn(TernaryOps.WHERE, lambda a,b,c: b if a!=0 else c, PtrDType(dtypes.float16))
|
2023-08-16 00:07:26 +08:00
|
|
|
|
2023-08-06 03:35:56 +08:00
|
|
|
if __name__ == '__main__':
|
|
|
|
unittest.main(verbosity=2)
|