touchup test_dtype.test_gradient_dtype (#3887)

add back bad merge from #3613 and add float.double and float.bfloat16 to test
This commit is contained in:
chenyu 2024-03-22 20:56:45 -04:00 committed by GitHub
parent fc11808a79
commit 2d3ce53348
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 17 additions and 13 deletions

View File

@ -566,21 +566,25 @@ class TestAutoCastType(unittest.TestCase):
assert (Tensor([1, 2], dtype=dtypes.float16) / 2).dtype == dtypes.float16 assert (Tensor([1, 2], dtype=dtypes.float16) / 2).dtype == dtypes.float16
assert (Tensor([1, 2], dtype=dtypes.float16) / 2.0).dtype == dtypes.float16 assert (Tensor([1, 2], dtype=dtypes.float16) / 2.0).dtype == dtypes.float16
@unittest.skipIf(not is_dtype_supported(dtypes.float16), "need float16")
def test_gradient_dtype(self): def test_gradient_dtype(self):
for default_dtype in [dtypes.float16, dtypes.float32]:
old_default_float = dtypes.default_float old_default_float = dtypes.default_float
try:
for default_dtype in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
if not is_dtype_supported(default_dtype): continue
dtypes.default_float = default_dtype dtypes.default_float = default_dtype
for datatype in [dtypes.float16, dtypes.float32]: for dtype in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
a = Tensor([1, 2, 3], dtype=datatype, requires_grad=True) if not is_dtype_supported(dtype): continue
if DEBUG >= 2:
print(f"testing {default_dtype=}, {dtype=}")
a = Tensor([1, 2, 3], dtype=dtype, requires_grad=True)
b = (a * 5).sum() b = (a * 5).sum()
b.backward() # if there is dtype mismatch, lazy should assert b.backward() # if there is dtype mismatch, lazy should assert
assert a.grad.dtype == a.dtype assert a.grad.dtype == a.dtype
np.testing.assert_allclose(a.grad.numpy(), Tensor([5, 5, 5], dtype=datatype).numpy()) np.testing.assert_allclose(a.grad.numpy(), [5, 5, 5])
finally:
dtypes.default_float = old_default_float dtypes.default_float = old_default_float
class TestImplicitFunctionTypeChange(unittest.TestCase):
def test_functions(self): def test_functions(self):
result = [] result = []
for func in [ for func in [