mirror of https://github.com/commaai/tinygrad.git
touchup test_dtype.test_gradient_dtype (#3887)
add back bad merge from #3613 and add float.double and float.bfloat16 to test
This commit is contained in:
parent
fc11808a79
commit
2d3ce53348
|
@ -566,21 +566,25 @@ class TestAutoCastType(unittest.TestCase):
|
||||||
assert (Tensor([1, 2], dtype=dtypes.float16) / 2).dtype == dtypes.float16
|
assert (Tensor([1, 2], dtype=dtypes.float16) / 2).dtype == dtypes.float16
|
||||||
assert (Tensor([1, 2], dtype=dtypes.float16) / 2.0).dtype == dtypes.float16
|
assert (Tensor([1, 2], dtype=dtypes.float16) / 2.0).dtype == dtypes.float16
|
||||||
|
|
||||||
@unittest.skipIf(not is_dtype_supported(dtypes.float16), "need float16")
|
|
||||||
def test_gradient_dtype(self):
|
def test_gradient_dtype(self):
|
||||||
for default_dtype in [dtypes.float16, dtypes.float32]:
|
|
||||||
old_default_float = dtypes.default_float
|
old_default_float = dtypes.default_float
|
||||||
try:
|
|
||||||
|
for default_dtype in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
|
||||||
|
if not is_dtype_supported(default_dtype): continue
|
||||||
dtypes.default_float = default_dtype
|
dtypes.default_float = default_dtype
|
||||||
for datatype in [dtypes.float16, dtypes.float32]:
|
for dtype in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
|
||||||
a = Tensor([1, 2, 3], dtype=datatype, requires_grad=True)
|
if not is_dtype_supported(dtype): continue
|
||||||
|
if DEBUG >= 2:
|
||||||
|
print(f"testing {default_dtype=}, {dtype=}")
|
||||||
|
a = Tensor([1, 2, 3], dtype=dtype, requires_grad=True)
|
||||||
b = (a * 5).sum()
|
b = (a * 5).sum()
|
||||||
b.backward() # if there is dtype mismatch, lazy should assert
|
b.backward() # if there is dtype mismatch, lazy should assert
|
||||||
assert a.grad.dtype == a.dtype
|
assert a.grad.dtype == a.dtype
|
||||||
np.testing.assert_allclose(a.grad.numpy(), Tensor([5, 5, 5], dtype=datatype).numpy())
|
np.testing.assert_allclose(a.grad.numpy(), [5, 5, 5])
|
||||||
finally:
|
|
||||||
dtypes.default_float = old_default_float
|
dtypes.default_float = old_default_float
|
||||||
|
|
||||||
|
class TestImplicitFunctionTypeChange(unittest.TestCase):
|
||||||
def test_functions(self):
|
def test_functions(self):
|
||||||
result = []
|
result = []
|
||||||
for func in [
|
for func in [
|
||||||
|
|
Loading…
Reference in New Issue