mirror of https://github.com/commaai/tinygrad.git
fix uneven shard with shrink and pad args on sharded axis (#5131)
it's incorrect to assume all first (len(device)-1) shards would have the same size. e.g. size 2 shard 4 -> (1, 1, 0, 0)
This commit is contained in:
parent
18e70deec3
commit
7948b05738
|
@ -433,6 +433,14 @@ class TestMultiTensor(unittest.TestCase):
|
||||||
np.testing.assert_equal(X.expand((4, 4, 257)).numpy(), np.tile(n, (1, 4, 1)))
|
np.testing.assert_equal(X.expand((4, 4, 257)).numpy(), np.tile(n, (1, 4, 1)))
|
||||||
np.testing.assert_equal(X.permute((0, 2, 1)).numpy(), np.transpose(n, (0, 2, 1)))
|
np.testing.assert_equal(X.permute((0, 2, 1)).numpy(), np.transpose(n, (0, 2, 1)))
|
||||||
|
|
||||||
|
def test_uneven_multiple_zeros(self):
|
||||||
|
for data in ([1, 2, 3, 4], [1, 2, 3], [1, 2], [1], []):
|
||||||
|
for N in (1, 2, 3, 4):
|
||||||
|
devices = tuple(f"{Device.DEFAULT}:{i}" for i in range(N))
|
||||||
|
# make sure something is computed on each device
|
||||||
|
X = ((Tensor(data).shard(devices, axis=0) + 1).realize() - 1).realize()
|
||||||
|
np.testing.assert_equal(X.numpy(), data)
|
||||||
|
|
||||||
def test_bn_ast_on_devices(self):
|
def test_bn_ast_on_devices(self):
|
||||||
t = Tensor.empty((16, 64, 112, 112)).shard(devices_4, axis=0)
|
t = Tensor.empty((16, 64, 112, 112)).shard(devices_4, axis=0)
|
||||||
bn = nn.BatchNorm2d(64)
|
bn = nn.BatchNorm2d(64)
|
||||||
|
|
|
@ -45,7 +45,7 @@ def all_reduce(op: ReduceOps, lbs: List[LazyBuffer]) -> List[LazyBuffer]:
|
||||||
def to_sharded(lbs:List[LazyBuffer], axis:int) -> List[LazyBuffer]:
|
def to_sharded(lbs:List[LazyBuffer], axis:int) -> List[LazyBuffer]:
|
||||||
if DEBUG >= 3 and lbs[0].shape[axis] % len(lbs) != 0: print(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}")
|
if DEBUG >= 3 and lbs[0].shape[axis] % len(lbs) != 0: print(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}")
|
||||||
sz = round_up(lbs[0].shape[axis], len(lbs)) // len(lbs)
|
sz = round_up(lbs[0].shape[axis], len(lbs)) // len(lbs)
|
||||||
return [lb.shrink(tuple((0,s) if a != axis else (sz*i,min(s,sz*(i+1))) for a,s in enumerate(lb.shape))) for i,lb in enumerate(lbs)]
|
return [lb.shrink(tuple((0,s) if a != axis else (min(s,sz*i),min(s,sz*(i+1))) for a,s in enumerate(lb.shape))) for i,lb in enumerate(lbs)]
|
||||||
|
|
||||||
class MultiLazyBuffer:
|
class MultiLazyBuffer:
|
||||||
def __init__(self, lbs:List[LazyBuffer], axis:Optional[int], real:Optional[List[bool]]=None):
|
def __init__(self, lbs:List[LazyBuffer], axis:Optional[int], real:Optional[List[bool]]=None):
|
||||||
|
@ -81,11 +81,11 @@ class MultiLazyBuffer:
|
||||||
for lb in self.real_lbs:
|
for lb in self.real_lbs:
|
||||||
if lb.device == device: return lb
|
if lb.device == device: return lb
|
||||||
return self.lbs[self.real.index(True)].copy_to_device(device)
|
return self.lbs[self.real.index(True)].copy_to_device(device)
|
||||||
sz = self.lbs[0].shape[self.axis]
|
llbs:List[LazyBuffer] = []
|
||||||
llbs = []
|
for lb,real,(start,end) in zip(self.lbs, self.real, self.bounds):
|
||||||
for i,lb in enumerate([lb.copy_to_device(device) for lb in self.real_lbs]):
|
if not real: continue
|
||||||
pad_arg = tuple((0,0) if a != self.axis else (sz*i, max(0, self.shape[self.axis]-sz*(i+1))) for a in range(len(lb.shape)))
|
pad_arg = tuple((0,0) if a != self.axis else (start, self.bounds[-1][1]-end) for a in range(len(lb.shape)))
|
||||||
llbs.append(lb.pad(pad_arg))
|
llbs.append(lb.copy_to_device(device).pad(pad_arg))
|
||||||
return functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), llbs)
|
return functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), llbs)
|
||||||
|
|
||||||
# passthroughs
|
# passthroughs
|
||||||
|
|
Loading…
Reference in New Issue