fix uneven shard with shrink and pad args on sharded axis (#5131)

it's incorrect to assume all first (len(device)-1) shards would have the same size. e.g. size 2 shard 4 -> (1, 1, 0, 0)
This commit is contained in:
chenyu 2024-06-24 16:55:50 -04:00 committed by GitHub
parent 18e70deec3
commit 7948b05738
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 14 additions and 6 deletions

View File

@ -433,6 +433,14 @@ class TestMultiTensor(unittest.TestCase):
np.testing.assert_equal(X.expand((4, 4, 257)).numpy(), np.tile(n, (1, 4, 1)))
np.testing.assert_equal(X.permute((0, 2, 1)).numpy(), np.transpose(n, (0, 2, 1)))
def test_uneven_multiple_zeros(self):
for data in ([1, 2, 3, 4], [1, 2, 3], [1, 2], [1], []):
for N in (1, 2, 3, 4):
devices = tuple(f"{Device.DEFAULT}:{i}" for i in range(N))
# make sure something is computed on each device
X = ((Tensor(data).shard(devices, axis=0) + 1).realize() - 1).realize()
np.testing.assert_equal(X.numpy(), data)
def test_bn_ast_on_devices(self):
t = Tensor.empty((16, 64, 112, 112)).shard(devices_4, axis=0)
bn = nn.BatchNorm2d(64)

View File

@ -45,7 +45,7 @@ def all_reduce(op: ReduceOps, lbs: List[LazyBuffer]) -> List[LazyBuffer]:
def to_sharded(lbs:List[LazyBuffer], axis:int) -> List[LazyBuffer]:
if DEBUG >= 3 and lbs[0].shape[axis] % len(lbs) != 0: print(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}")
sz = round_up(lbs[0].shape[axis], len(lbs)) // len(lbs)
return [lb.shrink(tuple((0,s) if a != axis else (sz*i,min(s,sz*(i+1))) for a,s in enumerate(lb.shape))) for i,lb in enumerate(lbs)]
return [lb.shrink(tuple((0,s) if a != axis else (min(s,sz*i),min(s,sz*(i+1))) for a,s in enumerate(lb.shape))) for i,lb in enumerate(lbs)]
class MultiLazyBuffer:
def __init__(self, lbs:List[LazyBuffer], axis:Optional[int], real:Optional[List[bool]]=None):
@ -81,11 +81,11 @@ class MultiLazyBuffer:
for lb in self.real_lbs:
if lb.device == device: return lb
return self.lbs[self.real.index(True)].copy_to_device(device)
sz = self.lbs[0].shape[self.axis]
llbs = []
for i,lb in enumerate([lb.copy_to_device(device) for lb in self.real_lbs]):
pad_arg = tuple((0,0) if a != self.axis else (sz*i, max(0, self.shape[self.axis]-sz*(i+1))) for a in range(len(lb.shape)))
llbs.append(lb.pad(pad_arg))
llbs:List[LazyBuffer] = []
for lb,real,(start,end) in zip(self.lbs, self.real, self.bounds):
if not real: continue
pad_arg = tuple((0,0) if a != self.axis else (start, self.bounds[-1][1]-end) for a in range(len(lb.shape)))
llbs.append(lb.copy_to_device(device).pad(pad_arg))
return functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), llbs)
# passthroughs