fix mean of half tensor if sum is greater than hlaf.max (#4327)

sum of half does acc in float32 already, add an arg to not downcast to half and use that in mean
This commit is contained in:
chenyu 2024-04-28 18:04:54 -04:00 committed by GitHub
parent e027879475
commit c1d8d425eb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 20 additions and 13 deletions

View File

@ -621,6 +621,11 @@ class TestAutoCastType(unittest.TestCase):
t.reshape(2, 1).expand(2, 10001).max().backward() t.reshape(2, 1).expand(2, 10001).max().backward()
np.testing.assert_allclose(t.grad.numpy(), [1, 0]) np.testing.assert_allclose(t.grad.numpy(), [1, 0])
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_mean_half_precision(self):
t = Tensor([60000, 60000, 60000], dtype=dtypes.half)
np.testing.assert_allclose(t.mean().numpy(), 60000)
class TestImplicitFunctionTypeChange(unittest.TestCase): class TestImplicitFunctionTypeChange(unittest.TestCase):
def test_functions(self): def test_functions(self):
result = [] result = []

View File

@ -113,15 +113,15 @@ class MultiLazyBuffer:
def _shape_to_single_shard(self, shape:Tuple[sint, ...], lb:LazyBuffer) -> Tuple[sint, ...]: def _shape_to_single_shard(self, shape:Tuple[sint, ...], lb:LazyBuffer) -> Tuple[sint, ...]:
return tuple(lb.shape[self.axis] if a == self.axis else s for a,s in enumerate(shape)) return tuple(lb.shape[self.axis] if a == self.axis else s for a,s in enumerate(shape))
def r(self, op:ReduceOps, axis:Tuple[int, ...], acc_dt:Optional[DType]=None) -> MultiLazyBuffer: def r(self, op:ReduceOps, axis:Tuple[int, ...], acc_dt:Optional[DType]=None, downcast_half:bool=True) -> MultiLazyBuffer:
if self.axis is not None and self.axis in axis: if self.axis is not None and self.axis in axis:
# all-reduce on sharded axes # all-reduce on sharded axes
new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape)) new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape))
reduced_parts = [x.r(op, axis, acc_dt) if r else x.const(0, shape=new_shape) for x,r in zip(self.lbs, self.real)] reduced_parts = [x.r(op, axis, acc_dt, downcast_half) if r else x.const(0, shape=new_shape) for x,r in zip(self.lbs, self.real)]
if all(self.real): return MultiLazyBuffer(all_reduce(op, reduced_parts), None) if all(self.real): return MultiLazyBuffer(all_reduce(op, reduced_parts), None)
return MultiLazyBuffer(reduced_parts, None, self.real) return MultiLazyBuffer(reduced_parts, None, self.real)
# reduce on non sharded axes, piecewise is fine. if axis is None this is also correct # reduce on non sharded axes, piecewise is fine. if axis is None this is also correct
return MultiLazyBuffer([x.r(op, axis, acc_dt) for x in self.lbs], self.axis, self.real) return MultiLazyBuffer([x.r(op, axis, acc_dt, downcast_half) for x in self.lbs], self.axis, self.real)
# *** movement ops *** # *** movement ops ***

View File

@ -146,14 +146,14 @@ class Where(Function):
# ************* reduce ops ************* # ************* reduce ops *************
class Sum(Function): class Sum(Function):
def forward(self, x:LazyBuffer, axis:Tuple[int, ...], acc_dtype:Optional[DType]=None) -> LazyBuffer: def forward(self, x:LazyBuffer, axis:Tuple[int, ...], acc_dtype:Optional[DType]=None, downcast_half:bool=True) -> LazyBuffer:
self.input_shape = x.shape self.input_shape = x.shape
return x.r(ReduceOps.SUM, axis, acc_dtype) return x.r(ReduceOps.SUM, axis, acc_dtype, downcast_half)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.expand(self.input_shape) def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.expand(self.input_shape)
class Max(Function): class Max(Function):
def forward(self, x:LazyBuffer, axis:Tuple[int, ...], acc_dtype:Optional[DType]=None) -> LazyBuffer: def forward(self, x:LazyBuffer, axis:Tuple[int, ...], acc_dtype:Optional[DType]=None, downcast_half:bool=True) -> LazyBuffer:
self.x, self.ret, self.axis = x, x.r(ReduceOps.MAX, axis), axis self.x, self.ret, self.axis = x, x.r(ReduceOps.MAX, axis), axis
return self.ret return self.ret

View File

@ -160,7 +160,7 @@ class LazyBuffer:
new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape)) new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape))
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), self.dtype, op, (axis, acc_dt), (self,)) return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), self.dtype, op, (axis, acc_dt), (self,))
def r(self, op:ReduceOps, axis:Tuple[int, ...], acc_dt:Optional[DType]=None) -> LazyBuffer: def r(self, op:ReduceOps, axis:Tuple[int, ...], acc_dt:Optional[DType]=None, downcast_half:bool=True) -> LazyBuffer:
new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape)) new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape))
# TODO: this logic should move to the scheduler # TODO: this logic should move to the scheduler
if self.size == 0 and 0 not in new_shape: return self.const({ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[op], new_shape) if self.size == 0 and 0 not in new_shape: return self.const({ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[op], new_shape)
@ -175,7 +175,7 @@ class LazyBuffer:
least_upper_dtype(self.dtype, dtypes.float) least_upper_dtype(self.dtype, dtypes.float)
if acc_dt is not None and acc_dt != self.dtype: if acc_dt is not None and acc_dt != self.dtype:
# cast back to float16 or bfloat16 to match torch / jax behavior # cast back to float16 or bfloat16 to match torch / jax behavior
return self.cast(acc_dt).r(op, axis, acc_dt).cast(self.dtype if self.dtype in [dtypes.float16, dtypes.bfloat16] else acc_dt) return self.cast(acc_dt).r(op, axis, acc_dt).cast(self.dtype if downcast_half and self.dtype in [dtypes.float16, dtypes.bfloat16] else acc_dt)
# TODO: can we split symbolic shape if the reduce axis is not symbolic? # TODO: can we split symbolic shape if the reduce axis is not symbolic?
if not getenv("SPLIT_REDUCEOP", 1) or not all_int(self.shape) or (0 in self.shape) or \ if not getenv("SPLIT_REDUCEOP", 1) or not all_int(self.shape) or (0 in self.shape) or \

View File

@ -907,21 +907,23 @@ class Tensor:
# ***** reduce ops ***** # ***** reduce ops *****
def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Tuple[int, ...]]]=None, keepdim=False, acc_dtype:Optional[DType]=None) -> Tensor: def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Tuple[int, ...]]]=None,
keepdim=False, acc_dtype:Optional[DType]=None, downcast_half:bool=True) -> Tensor:
axis_: Tuple[int, ...] = tuple(range(len(self.shape))) if axis is None else ((axis,) if isinstance(axis, int) else tuple(axis)) axis_: Tuple[int, ...] = tuple(range(len(self.shape))) if axis is None else ((axis,) if isinstance(axis, int) else tuple(axis))
axis_ = tuple(x if x >= 0 else x+len(self.shape) for x in axis_) axis_ = tuple(x if x >= 0 else x+len(self.shape) for x in axis_)
shape = tuple(s for i,s in enumerate(self.shape) if i not in axis_) shape = tuple(s for i,s in enumerate(self.shape) if i not in axis_)
ret = fxn.apply(self, axis=axis_, acc_dtype=acc_dtype) ret = fxn.apply(self, axis=axis_, acc_dtype=acc_dtype, downcast_half=downcast_half)
return ret if keepdim else ret.reshape(shape) return ret if keepdim else ret.reshape(shape)
def sum(self, axis=None, keepdim=False, acc_dtype:Optional[DType]=None): return self._reduce(F.Sum, axis, keepdim, acc_dtype) def sum(self, axis=None, keepdim=False, acc_dtype:Optional[DType]=None, downcast_half:bool=True):
return self._reduce(F.Sum, axis, keepdim, acc_dtype, downcast_half)
def max(self, axis=None, keepdim=False): return self._reduce(F.Max, axis, keepdim) def max(self, axis=None, keepdim=False): return self._reduce(F.Max, axis, keepdim)
def min(self, axis=None, keepdim=False): return -((-self).max(axis=axis, keepdim=keepdim)) def min(self, axis=None, keepdim=False): return -((-self).max(axis=axis, keepdim=keepdim))
def mean(self, axis=None, keepdim=False): def mean(self, axis=None, keepdim=False):
assert all_int(self.shape), "does not support symbolic shape" assert all_int(self.shape), "does not support symbolic shape"
out = self.sum(axis=axis, keepdim=keepdim) out = self.sum(axis=axis, keepdim=keepdim, downcast_half=False)
return out.div(prod(self.shape) / prod(out.shape)) if 0 not in out.shape else out return out.div(prod(self.shape) / prod(out.shape)).cast(self.dtype) if 0 not in out.shape else out.cast(self.dtype)
def var(self, axis=None, keepdim=False, correction=1): def var(self, axis=None, keepdim=False, correction=1):
assert all_int(self.shape), "does not support symbolic shape" assert all_int(self.shape), "does not support symbolic shape"
square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim) square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim)