mirror of https://github.com/commaai/tinygrad.git
touchup test_dtype.test_gradient_dtype (#3887)
add back bad merge from #3613 and add float.double and float.bfloat16 to test
This commit is contained in:
parent
fc11808a79
commit
2d3ce53348
|
@ -566,21 +566,25 @@ class TestAutoCastType(unittest.TestCase):
|
|||
assert (Tensor([1, 2], dtype=dtypes.float16) / 2).dtype == dtypes.float16
|
||||
assert (Tensor([1, 2], dtype=dtypes.float16) / 2.0).dtype == dtypes.float16
|
||||
|
||||
@unittest.skipIf(not is_dtype_supported(dtypes.float16), "need float16")
|
||||
def test_gradient_dtype(self):
|
||||
for default_dtype in [dtypes.float16, dtypes.float32]:
|
||||
old_default_float = dtypes.default_float
|
||||
try:
|
||||
dtypes.default_float = default_dtype
|
||||
for datatype in [dtypes.float16, dtypes.float32]:
|
||||
a = Tensor([1, 2, 3], dtype=datatype, requires_grad=True)
|
||||
b = (a * 5).sum()
|
||||
b.backward() # if there is dtype mismatch, lazy should assert
|
||||
assert a.grad.dtype == a.dtype
|
||||
np.testing.assert_allclose(a.grad.numpy(), Tensor([5, 5, 5], dtype=datatype).numpy())
|
||||
finally:
|
||||
dtypes.default_float = old_default_float
|
||||
old_default_float = dtypes.default_float
|
||||
|
||||
for default_dtype in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
|
||||
if not is_dtype_supported(default_dtype): continue
|
||||
dtypes.default_float = default_dtype
|
||||
for dtype in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
|
||||
if not is_dtype_supported(dtype): continue
|
||||
if DEBUG >= 2:
|
||||
print(f"testing {default_dtype=}, {dtype=}")
|
||||
a = Tensor([1, 2, 3], dtype=dtype, requires_grad=True)
|
||||
b = (a * 5).sum()
|
||||
b.backward() # if there is dtype mismatch, lazy should assert
|
||||
assert a.grad.dtype == a.dtype
|
||||
np.testing.assert_allclose(a.grad.numpy(), [5, 5, 5])
|
||||
|
||||
dtypes.default_float = old_default_float
|
||||
|
||||
class TestImplicitFunctionTypeChange(unittest.TestCase):
|
||||
def test_functions(self):
|
||||
result = []
|
||||
for func in [
|
||||
|
|
Loading…
Reference in New Issue