fix abs of diff of uint (#4411)

This commit is contained in:
chenyu 2024-05-15 18:39:11 -04:00 committed by GitHub
parent 2119e0456d
commit 04f2327ca3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 8 additions and 1 deletions

View File

@ -660,8 +660,15 @@ class TestImplicitFunctionTypeChange(unittest.TestCase):
]:
t = func(Tensor([4.0, 3.0])).max() == func(Tensor([4.0, 3.0]))
result.append(t.numpy().sum())
assert all(result)
class TestTensorMethod(unittest.TestCase):
@given(strat.sampled_from(core_dtypes))
def test_abs_diff(self, dt):
if dt == dtypes.bool or not is_dtype_supported(dt): return
a, b = Tensor([2], dtype=dt), Tensor([1], dtype=dt)
ret = (a - b).abs()
np.testing.assert_allclose(ret.numpy(), np.abs(a.numpy()-b.numpy()))
if __name__ == '__main__':
unittest.main()