fix function.py sum backward without downcast_half (#4353)

without downcast_half, sum output dtype can be different from input dtype. cast back to input dtype in function.py
This commit is contained in:
chenyu 2024-04-29 17:53:02 -04:00 committed by GitHub
parent 18c61ce077
commit 93abcd3113
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 3 deletions

View File

@ -623,8 +623,10 @@ class TestAutoCastType(unittest.TestCase):
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_mean_half_precision(self):
t = Tensor([60000, 60000, 60000], dtype=dtypes.half)
t = Tensor([60000, 60000, 60000], dtype=dtypes.half, requires_grad=True)
np.testing.assert_allclose(t.mean().numpy(), 60000)
t.square().mean().backward()
np.testing.assert_allclose(t.grad.numpy(), [60000 * 2 / 3] * 3)
class TestImplicitFunctionTypeChange(unittest.TestCase):
def test_functions(self):

View File

@ -147,10 +147,11 @@ class Where(Function):
class Sum(Function):
def forward(self, x:LazyBuffer, axis:Tuple[int, ...], acc_dtype:Optional[DType]=None, downcast_half:bool=True) -> LazyBuffer:
self.input_shape = x.shape
self.input_shape, self.input_dtype = x.shape, x.dtype
return x.r(ReduceOps.SUM, axis, acc_dtype, downcast_half)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.expand(self.input_shape)
# if downcast_half is False, the forward output can have different dtype, and backward needs to cast back to input dtype
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.cast(self.input_dtype).expand(self.input_shape)
class Max(Function):
def forward(self, x:LazyBuffer, axis:Tuple[int, ...], acc_dtype:Optional[DType]=None, downcast_half:bool=True) -> LazyBuffer: