mirror of https://github.com/commaai/tinygrad.git
fix function.py sum backward without downcast_half (#4353)
without downcast_half, sum output dtype can be different from input dtype. cast back to input dtype in function.py
This commit is contained in:
parent
18c61ce077
commit
93abcd3113
|
@ -623,8 +623,10 @@ class TestAutoCastType(unittest.TestCase):
|
|||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
def test_mean_half_precision(self):
|
||||
t = Tensor([60000, 60000, 60000], dtype=dtypes.half)
|
||||
t = Tensor([60000, 60000, 60000], dtype=dtypes.half, requires_grad=True)
|
||||
np.testing.assert_allclose(t.mean().numpy(), 60000)
|
||||
t.square().mean().backward()
|
||||
np.testing.assert_allclose(t.grad.numpy(), [60000 * 2 / 3] * 3)
|
||||
|
||||
class TestImplicitFunctionTypeChange(unittest.TestCase):
|
||||
def test_functions(self):
|
||||
|
|
|
@ -147,10 +147,11 @@ class Where(Function):
|
|||
|
||||
class Sum(Function):
|
||||
def forward(self, x:LazyBuffer, axis:Tuple[int, ...], acc_dtype:Optional[DType]=None, downcast_half:bool=True) -> LazyBuffer:
|
||||
self.input_shape = x.shape
|
||||
self.input_shape, self.input_dtype = x.shape, x.dtype
|
||||
return x.r(ReduceOps.SUM, axis, acc_dtype, downcast_half)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.expand(self.input_shape)
|
||||
# if downcast_half is False, the forward output can have different dtype, and backward needs to cast back to input dtype
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.cast(self.input_dtype).expand(self.input_shape)
|
||||
|
||||
class Max(Function):
|
||||
def forward(self, x:LazyBuffer, axis:Tuple[int, ...], acc_dtype:Optional[DType]=None, downcast_half:bool=True) -> LazyBuffer:
|
||||
|
|
Loading…
Reference in New Issue