2024-04-03 04:45:58 +08:00
|
|
|
import unittest, math
|
2024-04-03 08:52:05 +08:00
|
|
|
from tinygrad import Tensor, Device, dtypes
|
2024-08-17 06:17:57 +08:00
|
|
|
from tinygrad.ops import UOps
|
2024-04-01 01:09:23 +08:00
|
|
|
from tinygrad.engine.schedule import create_schedule
|
2024-04-02 04:53:43 +08:00
|
|
|
from tinygrad.helpers import CI
|
|
|
|
import numpy as np
|
2024-04-29 23:40:45 +08:00
|
|
|
from test.helpers import is_dtype_supported
|
2024-04-01 01:09:23 +08:00
|
|
|
|
|
|
|
def _check_ast_count(desired_count:int, t:Tensor):
|
|
|
|
# NOTE: this has side effect because everything can be scheduled only once
|
2024-04-13 13:32:16 +08:00
|
|
|
schedule = create_schedule(t.lazydata.lbs)
|
2024-08-17 03:09:00 +08:00
|
|
|
asts = [s for s in schedule if s.ast.op is UOps.SINK]
|
2024-04-01 01:09:23 +08:00
|
|
|
assert len(asts) == desired_count
|
|
|
|
|
2024-04-04 02:39:28 +08:00
|
|
|
class TestUnaryOpsConstFolding(unittest.TestCase):
|
Approximations for SIN/LOG2/EXP2 passing all tests. (#5187)
* [WIP] Added an approximated implementation of Sin(FP32, FP64) passing all tests on Clang runtime
* Map nan/-inf/inf as 1.0 in order to avoid doing as_const(math.inf)
* [WIP] Added a support for LLVM IR
* cleaned up the code for the mypy and linter
* [WIP] Updated fp64 supports (bitwise shift causes the compilation error), fixed linter issue.
* [Add] added fast=true mode which disables the payne-hanek reduction which is slow
* [Fix] fails to compute elements when shape includes zero
* [WIP] Added BinaryOps.ADD/BinaryOps.OR to assembly
* [wip] update the assembly for ptx
* Enables fast=True when device is one of PTX, NV, CUDA, to avoid slow bitwise ops (as lv3 reduction is not required).
* [WIP] Added an approximation of LOG2/EXP2 (FP32, FP64)
* [Fix] Cyclic dependencies existing in xlog2
* [Fix] Cycle dependency in the graph of exp2, and log2. (passing test_symbolic_ops.py)
* [Fix] keep using higher precision for exp2, but cycle graph issue remained to be fixed...
* [Refactor] removed is_metal option. xsin does not rely on fp64 when fp32 mode.
* [WIP] fp16 xsin implementation passing all tests. (still needs to be refactored)
* [WIP] Added fp16 exp2 implementation
* [WIP] Increased the precision of Log2 from 3.5 ULP to 1.0 ULP, and added FP16 Log2 approximation.
* stashed the changes for FP16 sin
* [Fix] Patch for FP16 Sin/Exp2. (updated the dtype_via, fp32_p, and lower)
* [Refactor] migration to fastmath.py, some code simplification, renamed apis in fastmath, et al.
* [Refactor] Added the function polyN to clean-up N-terms polynomial approximation.
* [Patch] Increase fp64 precision when ldexp3k if possible, and patch for fp16 exp2
* [Patch] added bitcast_forward option
* [Patch] resolved cycle graph
* patch fix cycle graph
* set bitcast_forward=True in ilogb2k
* bitcast_forward for multi.py
* E501
* Break into multiple small PRs
* [Patch] FP16 -> FP64 upcast is not anymore required since xlog2 use quad precision polyN
* [Patch] NV still required FP64 for xlog2
* updated schedule test
* updated the count of kernels
* [Update] Removed all bitwise ops (SHL/SHR), tweaked the nan manipulation of log2, passing all tests except for AMD.
* Bitcast: make them api-compatible
* [update] force to use bitcast
* updated the count of constant folding
* [Patch] Creating a mask for exp2 using x <= Inf satisfies True as long as x is a real value
* [Update] isNaN(x) Free log2 algorithm, passing PTX tests, METAL with fastmath enabled is able to handle nan well, amd backend will not crash.
* xsin is reluctant to call payne_hanek_reduction which is slow to compile, passing stable diffusion compilation in a realistic time
* some minor simplification to payne hanek reduction
* [refactor] refactored some rebundant parts existing in payne hanek
* [refactor] more readable payne hanek impl
* [refactor] improved the code consistency of payne hanek
* [experiment] topological sort when doing _recursive_group (i dunno if this is good but at least it works.)
* Revert "[experiment] topological sort when doing _recursive_group (i dunno if this is good but at least it works.)"
This reverts commit 0eee08b87c9e46da8aec0a8edec5316634031a49.
* use allow_buffer_view
* lets support multilazytensor
* updated the count of kernels
* [test] added the jit tests for approx ops
* keep failed constant folding tests tested, added expectedFailure
* explict the timeout deadline when testing approx jit timeout
* [WIP] Simplified the implementation of xsin, never timeouts
* [Refactor] Improved the consistency of approx sin implementation, passing time out tests
* integrated xexp2_base into xexp2
* Set switch_over=39800.0
* delete: is_buffer_fastmath_supported
* sin: compute against abs(x)
* some cleanups
* fix typo
* removed the space between param and dtype
* allow 514 kernels on CI for sd
* [refactor] no need to upcast ad ldexp3k
* [refactor] added some comments, references to help understanding the code.
* [Fix] 1.0 ULP Sine Approximation for FP16
* [update] assume e != 0
* use pow2if instead of ldexp3k to fuse payne_hanek reduction into one
* check if approximated sin/log2/exp are fused into one
* clean up changes
* test amd exp
* some code cleanup and test sigmoid
* fix: enabled payne_hanek for fp16 to achieve higher acc
* fix: payne_hanek always accumlates the value with uint64, and fp16 sin is fused to a single kernel
* [Refactor] Rename: fastmath -> transcendental
* [Refactor] Added TRANSCENDENTAL, Moved the gate function to function.py
* updated const folding tests
* TRANSCENDENTAL as a ContextVar, removed old test of cody waite reduction, added assertions, et al.
* Add: unittest.main()
* Import TRANSCENDENTAL instead of getenv
* Refactor: Added dtype check when TRANSCENDENTAL=2, more context var
* Patch: xlog2, break expt(2, 32) x 2 -> expt(2, 16) x 4 for fp16 math
---------
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
Co-authored-by: chenyu <chenyu@fastmail.com>
2024-07-11 07:44:58 +08:00
|
|
|
def test_all_consts_ops(self):
|
2024-07-11 11:05:03 +08:00
|
|
|
_check_ast_count(0, Tensor.ones(4).exp())
|
2024-04-03 08:52:05 +08:00
|
|
|
_check_ast_count(0, Tensor.ones(4).sqrt())
|
|
|
|
_check_ast_count(0, Tensor.ones(4) + Tensor.ones(4))
|
|
|
|
_check_ast_count(0, Tensor.ones(4) / Tensor.ones(4))
|
|
|
|
|
|
|
|
def test_cast(self):
|
|
|
|
_check_ast_count(0, Tensor.ones(4).cast(dtypes.int16))
|
|
|
|
_check_ast_count(0, Tensor.full(4, fill_value=-1).cast(dtypes.uint16))
|
|
|
|
|
2024-08-20 06:41:28 +08:00
|
|
|
@unittest.expectedFailure # no two level fold at lazybuffer
|
2024-05-08 13:52:22 +08:00
|
|
|
def test_neg_folding(self):
|
|
|
|
_check_ast_count(0, Tensor([1, 2, 3]).mul(-1).neg())
|
|
|
|
_check_ast_count(0, Tensor([1, 2, 3]).neg().mul(-1))
|
|
|
|
_check_ast_count(0, Tensor([1, 2, 3]).neg().neg())
|
|
|
|
|
2024-06-25 07:40:37 +08:00
|
|
|
def test_neg_realized_no_fold(self):
|
|
|
|
x = Tensor.randn(32, 32)
|
|
|
|
x = x.clip(0, 1).realize()
|
|
|
|
_check_ast_count(1, x.neg())
|
|
|
|
|
2024-04-04 02:39:28 +08:00
|
|
|
class TestBinaryOpsConstFolding(unittest.TestCase):
|
2024-04-01 01:09:23 +08:00
|
|
|
def test_add_literal_zero(self):
|
2024-04-01 04:35:36 +08:00
|
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) + 0)
|
2024-04-01 01:09:23 +08:00
|
|
|
def test_add_tensor_zero(self):
|
2024-04-01 04:35:36 +08:00
|
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) + Tensor.zeros(4))
|
2024-04-02 03:09:04 +08:00
|
|
|
def test_literal_zero_add(self):
|
|
|
|
_check_ast_count(0, 0 + Tensor([1.0, 2, 3, 4]))
|
|
|
|
def test_tensor_zero_add(self):
|
|
|
|
_check_ast_count(0, Tensor.zeros(4) + Tensor([1.0, 2, 3, 4]))
|
2024-04-01 01:09:23 +08:00
|
|
|
|
|
|
|
def test_sub_literal_zero(self):
|
2024-04-01 04:35:36 +08:00
|
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) - 0)
|
2024-04-01 01:09:23 +08:00
|
|
|
def test_sub_tensor_zero(self):
|
2024-04-01 04:35:36 +08:00
|
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) - Tensor.zeros(4))
|
2024-04-01 01:09:23 +08:00
|
|
|
|
|
|
|
def test_mul_literal_zero(self):
|
2024-04-01 04:35:36 +08:00
|
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) * 0)
|
2024-04-01 01:09:23 +08:00
|
|
|
def test_mul_tensor_zero(self):
|
2024-04-01 04:35:36 +08:00
|
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) * Tensor.zeros(4))
|
2024-04-02 03:09:04 +08:00
|
|
|
def test_literal_zero_mul(self):
|
|
|
|
_check_ast_count(0, 0 * Tensor([1.0, 2, 3, 4]) * 0)
|
|
|
|
def test_tensor_zero_mul(self):
|
|
|
|
_check_ast_count(0, Tensor.zeros(4) * Tensor([1.0, 2, 3, 4]))
|
2024-04-01 01:09:23 +08:00
|
|
|
|
|
|
|
def test_mul_literal_one(self):
|
2024-04-01 04:35:36 +08:00
|
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) * 1)
|
2024-04-01 01:09:23 +08:00
|
|
|
def test_mul_tensor_one(self):
|
2024-04-01 04:35:36 +08:00
|
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) * Tensor.ones(4))
|
2024-04-02 03:09:04 +08:00
|
|
|
def test_literal_one_mul(self):
|
|
|
|
_check_ast_count(0, 1 * Tensor([1.0, 2, 3, 4]))
|
|
|
|
def test_tensor_one_mul(self):
|
|
|
|
_check_ast_count(0, Tensor.ones(4) * Tensor([1.0, 2, 3, 4]))
|
2024-04-01 01:09:23 +08:00
|
|
|
|
2024-05-08 13:52:22 +08:00
|
|
|
def test_bool_tensor_mul_bool(self):
|
|
|
|
_check_ast_count(0, Tensor([True, False]) * True)
|
|
|
|
_check_ast_count(0, Tensor([True, False]) * False)
|
|
|
|
def test_bool_mul_bool_tensor(self):
|
|
|
|
_check_ast_count(0, True * Tensor([True, False]))
|
|
|
|
_check_ast_count(0, False * Tensor([True, False]))
|
|
|
|
|
2024-04-01 01:09:23 +08:00
|
|
|
def test_div_literal_one(self):
|
2024-04-01 04:35:36 +08:00
|
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) / 1)
|
2024-04-01 01:09:23 +08:00
|
|
|
def test_div_tensor_one(self):
|
2024-04-01 04:35:36 +08:00
|
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) / Tensor.ones(4))
|
2024-04-01 01:09:23 +08:00
|
|
|
|
2024-08-26 22:29:59 +08:00
|
|
|
def test_idiv_literal_one(self):
|
|
|
|
_check_ast_count(0, Tensor([1, 2, 3, 4]) // 1)
|
|
|
|
def test_idiv_tensor_one(self):
|
|
|
|
_check_ast_count(0, Tensor([1, 2, 3, 4]) // Tensor.ones(4, dtype=dtypes.int32))
|
|
|
|
|
2024-04-01 01:09:23 +08:00
|
|
|
def test_pow_literal_zero(self):
|
2024-04-01 04:35:36 +08:00
|
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** 0)
|
2024-04-01 01:09:23 +08:00
|
|
|
def test_pow_tensor_zero(self):
|
2024-04-01 04:35:36 +08:00
|
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** Tensor.zeros(4))
|
2024-04-01 01:09:23 +08:00
|
|
|
|
|
|
|
def test_pow_literal_one(self):
|
2024-04-01 04:35:36 +08:00
|
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** 1)
|
2024-04-01 01:09:23 +08:00
|
|
|
def test_pow_tensor_one(self):
|
2024-04-01 04:35:36 +08:00
|
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** Tensor.ones(4))
|
2024-04-02 03:09:04 +08:00
|
|
|
def test_literal_one_pow(self):
|
|
|
|
_check_ast_count(0, 1 ** Tensor([1.0, 2, 3, 4]))
|
|
|
|
def test_tensor_one_pow(self):
|
|
|
|
_check_ast_count(0, Tensor.ones(4) ** Tensor([1.0, 2, 3, 4]))
|
2024-04-01 01:09:23 +08:00
|
|
|
|
2024-04-04 22:46:28 +08:00
|
|
|
# folds advance indexing into basic indexing
|
|
|
|
class TestIndexingConstFolding(unittest.TestCase):
|
|
|
|
def test_scalar_index(self):
|
|
|
|
t = Tensor.arange(16).float().reshape(1,1,4,4).realize()
|
|
|
|
_check_ast_count(0, t[:,:,Tensor(1),:])
|
|
|
|
_check_ast_count(0, t[:,:,Tensor(1)+2,:])
|
|
|
|
_check_ast_count(0, t[:,:,Tensor(1),Tensor(0)])
|
|
|
|
|
|
|
|
@unittest.expectedFailure
|
|
|
|
def test_const_tensor_index(self):
|
|
|
|
# TODO: implement const tensor folded indexing
|
|
|
|
t = Tensor.arange(16).float().reshape(1,1,4,4).realize()
|
|
|
|
_check_ast_count(0, t[:,:,Tensor.ones(2,1),:])
|
|
|
|
_check_ast_count(0, t[:,:,Tensor.ones(1,2)+2,:])
|
|
|
|
_check_ast_count(0, t[:,:,Tensor.ones(1,1),Tensor.zeros(2,1,2)])
|
|
|
|
|
2024-04-01 01:09:23 +08:00
|
|
|
class TestMovedConstFolding(unittest.TestCase):
|
|
|
|
def test_add_shrunk_zero(self):
|
2024-04-01 04:35:36 +08:00
|
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) + Tensor.zeros(6).shrink(((1, 5),)))
|
2024-04-01 01:09:23 +08:00
|
|
|
|
|
|
|
def test_add_padded_zero(self):
|
|
|
|
# TODO: it's 1 now, this might be possible to fold
|
2024-04-01 04:35:36 +08:00
|
|
|
_check_ast_count(1, Tensor([1.0, 2, 3, 4]) + Tensor.zeros(2).pad(((1, 1),)))
|
2024-04-01 01:09:23 +08:00
|
|
|
|
|
|
|
def test_mul_shrunk_one(self):
|
2024-04-01 04:35:36 +08:00
|
|
|
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) * Tensor.ones(6).shrink(((1, 5),)))
|
2024-04-01 01:09:23 +08:00
|
|
|
|
|
|
|
def test_add_padded_one(self):
|
2024-04-01 04:35:36 +08:00
|
|
|
_check_ast_count(1, Tensor([1.0, 2, 3, 4]) * Tensor.ones(2).pad(((1, 1),)))
|
2024-04-02 04:53:43 +08:00
|
|
|
|
2024-04-04 00:31:24 +08:00
|
|
|
def test_cast_padded(self):
|
2024-05-08 12:14:29 +08:00
|
|
|
# NOTE: this is folded due to CAST_BEFORE_VIEW
|
|
|
|
_check_ast_count(0, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16))
|
2024-04-04 00:31:24 +08:00
|
|
|
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16).numpy(), [0, 1, 1, 1, 1, 0])
|
2024-05-08 12:14:29 +08:00
|
|
|
_check_ast_count(0, Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16))
|
|
|
|
np.testing.assert_equal(Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16).numpy(), [0, 65535, 65535, 65535, 65535, 0])
|
|
|
|
# not folded
|
2024-04-04 00:31:24 +08:00
|
|
|
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64))
|
|
|
|
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64).numpy(), [0, 1, 1, 1, 1, 0])
|
|
|
|
|
2024-04-04 02:39:28 +08:00
|
|
|
class TestReduceOpsConstFolding(unittest.TestCase):
|
|
|
|
def test_const_sum(self):
|
|
|
|
_check_ast_count(0, Tensor.ones(4, 5, 6).sum())
|
|
|
|
np.testing.assert_equal(Tensor.ones(4, 5, 6).sum().numpy(), 4 * 5 * 6)
|
|
|
|
_check_ast_count(0, Tensor.ones(4, 5, 6).sum(axis=0))
|
|
|
|
np.testing.assert_equal(Tensor.ones(4, 5, 6).sum(axis=0).numpy(), np.full((5, 6), 4))
|
|
|
|
_check_ast_count(0, Tensor(4).sum())
|
|
|
|
np.testing.assert_equal(Tensor(4).sum().numpy(), 4)
|
|
|
|
|
|
|
|
def test_padded_const_sum(self):
|
|
|
|
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).sum())
|
|
|
|
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).sum().numpy(), 4)
|
|
|
|
|
|
|
|
# NOTE: cannot just count the non-padded area because some UnaryOps f do not have f(0) = 0.
|
|
|
|
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).exp().sum())
|
|
|
|
np.testing.assert_allclose(Tensor.ones(4).pad(((1, 1),)).exp().sum().numpy(), 4 * math.e + 2)
|
|
|
|
|
2024-08-23 22:06:32 +08:00
|
|
|
def test_const_prod(self):
|
|
|
|
_check_ast_count(0, Tensor.full((2, 3), fill_value=2).prod())
|
|
|
|
np.testing.assert_equal(Tensor.full((2, 3), fill_value=2).prod().numpy(), 2**(2*3))
|
|
|
|
_check_ast_count(0, Tensor.full((4, 5, 6), fill_value=2).prod(axis=0))
|
|
|
|
np.testing.assert_equal(Tensor.full((4, 5, 6), fill_value=2).prod(axis=0).numpy(), np.full((5, 6), 2**4))
|
|
|
|
_check_ast_count(0, Tensor(4).prod())
|
|
|
|
np.testing.assert_equal(Tensor(4).prod().numpy(), 4)
|
|
|
|
|
2024-04-04 02:39:28 +08:00
|
|
|
def test_const_max(self):
|
|
|
|
_check_ast_count(0, Tensor.ones(4, 5, 6).max())
|
|
|
|
np.testing.assert_equal(Tensor.ones(4, 5, 6).max().numpy(), 1)
|
|
|
|
_check_ast_count(0, Tensor(4).max())
|
|
|
|
np.testing.assert_equal(Tensor(4).max().numpy(), 4)
|
|
|
|
|
2024-04-29 23:40:45 +08:00
|
|
|
def test_sum_output_dtype(self):
|
|
|
|
# sum output dtype can be different from input
|
|
|
|
for dt in dtypes.fields().values():
|
|
|
|
if is_dtype_supported(dt):
|
|
|
|
t = Tensor.ones(16, dtype=dt).reshape(4, 4)
|
|
|
|
assert t.sum().dtype == t.contiguous().sum().dtype
|
|
|
|
|
2024-04-02 04:53:43 +08:00
|
|
|
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
|
|
|
|
class TestMultiConstFolding(unittest.TestCase):
|
|
|
|
def test_multi_const_folding_literal(self):
|
|
|
|
ds = tuple(f"{Device.DEFAULT}:{i}" for i in range(4))
|
|
|
|
t = Tensor.arange(16).float().realize().to(ds)
|
|
|
|
|
|
|
|
# non const folding case creates one ast on each shard
|
|
|
|
_check_ast_count(4, t + 1)
|
|
|
|
_check_ast_count(4, 1 + t)
|
|
|
|
_check_ast_count(4, t * 2)
|
|
|
|
_check_ast_count(4, 2 * t)
|
|
|
|
|
|
|
|
# const folded
|
|
|
|
_check_ast_count(0, t + 0)
|
|
|
|
_check_ast_count(0, 0 + t)
|
|
|
|
_check_ast_count(0, t * 0)
|
|
|
|
_check_ast_count(0, 0 * t)
|
|
|
|
_check_ast_count(0, t * 1)
|
|
|
|
_check_ast_count(0, 1 * t)
|
|
|
|
np.testing.assert_equal((t + 0).numpy(), np.arange(16))
|
|
|
|
np.testing.assert_equal((t * 0).numpy(), [0] * 16)
|
|
|
|
np.testing.assert_equal((t * 1).numpy(), np.arange(16))
|
|
|
|
|
2024-04-03 08:52:05 +08:00
|
|
|
_check_ast_count(0, t ** 0)
|
|
|
|
_check_ast_count(0, t ** 1)
|
|
|
|
_check_ast_count(0, 1 ** t)
|
|
|
|
|
2024-04-02 04:53:43 +08:00
|
|
|
def test_multi_const_folding_tensor(self):
|
|
|
|
ds = tuple(f"{Device.DEFAULT}:{i}" for i in range(4))
|
|
|
|
t = Tensor.arange(16).float().realize().to(ds)
|
|
|
|
zero = Tensor.zeros(16).realize().to(ds)
|
|
|
|
one = Tensor.ones(16).realize().to(ds)
|
|
|
|
|
|
|
|
# const folded
|
|
|
|
_check_ast_count(0, t + zero)
|
|
|
|
_check_ast_count(0, zero + t)
|
|
|
|
_check_ast_count(0, t * zero)
|
|
|
|
_check_ast_count(0, zero * t)
|
|
|
|
_check_ast_count(0, t * one)
|
|
|
|
_check_ast_count(0, one * t)
|
|
|
|
np.testing.assert_equal((t + zero).numpy(), np.arange(16))
|
|
|
|
np.testing.assert_equal((t * zero).numpy(), [0] * 16)
|
|
|
|
np.testing.assert_equal((t * one).numpy(), np.arange(16))
|
|
|
|
|
|
|
|
@unittest.expectedFailure
|
|
|
|
def test_multi_todo_pow(self):
|
|
|
|
ds = tuple(f"{Device.DEFAULT}:{i}" for i in range(4))
|
|
|
|
t = Tensor.arange(16).float().realize().to(ds)
|
2024-04-03 08:52:05 +08:00
|
|
|
zero = Tensor.zeros(16).realize().to(ds)
|
|
|
|
one = Tensor.ones(16).realize().to(ds)
|
2024-04-02 04:53:43 +08:00
|
|
|
|
|
|
|
# TODO: fix pow folding
|
2024-04-03 08:52:05 +08:00
|
|
|
_check_ast_count(0, t ** zero)
|
|
|
|
_check_ast_count(0, t ** one)
|
|
|
|
_check_ast_count(0, one ** t)
|
2024-04-02 07:22:14 +08:00
|
|
|
|
2024-04-03 04:45:58 +08:00
|
|
|
class TestTautologicalCompare(unittest.TestCase):
|
|
|
|
# without const folding, these would have triggered -Wtautological-compare in clang
|
|
|
|
def test_lt_false(self):
|
|
|
|
# bool < False is always false
|
|
|
|
np.testing.assert_equal((Tensor([True, False]) < False).numpy(), [False, False])
|
|
|
|
|
|
|
|
def test_true_lt(self):
|
|
|
|
# True < bool is always false
|
|
|
|
np.testing.assert_equal((True < Tensor([True, False])).numpy(), [False, False])
|
|
|
|
|
|
|
|
def test_truth_table(self):
|
|
|
|
np.testing.assert_equal((Tensor(False) < Tensor(False)).numpy(), False)
|
|
|
|
np.testing.assert_equal((Tensor(False) < Tensor(True)).numpy(), True)
|
|
|
|
np.testing.assert_equal((Tensor(True) < Tensor(False)).numpy(), False)
|
|
|
|
np.testing.assert_equal((Tensor(True) < Tensor(True)).numpy(), False)
|
|
|
|
|
|
|
|
@unittest.skip("not implemented yet")
|
|
|
|
def test_a_eq_a(self):
|
|
|
|
# self eq is always true for int or bool
|
|
|
|
a = Tensor([1, 2, 3])
|
|
|
|
np.testing.assert_equal((a == a).numpy(), [True, True, True])
|
|
|
|
|
|
|
|
# not true for nan
|
|
|
|
a = Tensor([math.nan, 1.0, 2.0])
|
|
|
|
np.testing.assert_equal((a == a).numpy(), [False, True, True])
|
|
|
|
|
|
|
|
@unittest.skip("not implemented yet")
|
|
|
|
def test_a_ne_a(self):
|
|
|
|
# self not eq is always false for int or bool
|
|
|
|
a = Tensor([1, 2, 3])
|
|
|
|
np.testing.assert_equal((a != a).numpy(), [False, False, False])
|
|
|
|
|
|
|
|
# not true for nan
|
|
|
|
a = Tensor([math.nan, 1.0, 2.0])
|
|
|
|
np.testing.assert_equal((a != a).numpy(), [True, False, False])
|
|
|
|
|
2024-04-02 07:22:14 +08:00
|
|
|
if __name__ == '__main__':
|
2024-06-25 07:40:37 +08:00
|
|
|
unittest.main()
|