mirror of https://github.com/commaai/tinygrad.git
distributed collectives (#1519)
* feat: world * feat: tests * feat: no more backwards * feat: recv into * feat: whoops * feat: test in ci * feat: some debug logging * feat: workflow naming * feat: need to set pythonpath * feat: just send to same device * feat: allreduce * feat: test * feat: need contiguous * feat: test in ci * feat: exit with correct code * feat: don't need that * feat: opencl wait_for just doesn't work * feat: synchronize on out * feat: try? * feat: try again? * feat: add extra realizes * feat: print * feat: seed * feat: tol * feat: test ones and zeros * feat: remove print * feat: are you just flaky * feat: seperate scatter and gather? * feat: just try synchronizing * feat: remove print again * feat: bring back difference * feat: no sync * feat: revert that * feat: back to wait_for * fix: typo
This commit is contained in:
parent
2e85fce068
commit
29d5801387
|
@ -158,6 +158,7 @@ jobs:
|
|||
name: Test multigpu
|
||||
run: |
|
||||
PYTHONPATH="." python test/external/dist/test_world.py
|
||||
PYTHONPATH="." python test/external/dist/test_collectives.py
|
||||
|
||||
testmetalwebgpu:
|
||||
name: Metal and WebGPU Tests
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
from tinygrad.tensor import Tensor, Device
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
from extra.dist import world
|
||||
|
||||
def allreduce(t:Tensor, cache_id=None) -> Tensor:
|
||||
RANK, WORLD_SIZE = getenv("RANK"), getenv("WORLD_SIZE")
|
||||
cache_id = f"{RANK}-{cache_id}" if cache_id is not None else None
|
||||
|
||||
# flatten
|
||||
flattened = t.flatten()
|
||||
|
||||
# pad to evenly divide
|
||||
if flattened.shape[0] % WORLD_SIZE != 0:
|
||||
flattened = Tensor.cat(flattened, Tensor.zeros(WORLD_SIZE - (flattened.shape[0] % WORLD_SIZE)))
|
||||
|
||||
# chunk
|
||||
chunks = flattened.chunk(WORLD_SIZE, dim=0)
|
||||
reduced = chunks[RANK]
|
||||
|
||||
next_rank = (RANK + 1) % WORLD_SIZE
|
||||
prev_rank = ((RANK - 1) + WORLD_SIZE) % WORLD_SIZE
|
||||
|
||||
# scatter reduce
|
||||
current_chunk_index = RANK
|
||||
for i in range(WORLD_SIZE - 1):
|
||||
world.send(reduced, next_rank, cache_id=f"{cache_id}-{i}-s" if cache_id is not None else None)
|
||||
current_chunk_index = ((current_chunk_index - 1) + WORLD_SIZE) % WORLD_SIZE
|
||||
recv_buf = Tensor.empty(*reduced.shape)
|
||||
world.recv(recv_buf, prev_rank)
|
||||
reduced = chunks[current_chunk_index] + recv_buf
|
||||
|
||||
# gather
|
||||
chunks[current_chunk_index] = reduced
|
||||
current_chunk_index = (RANK + 1) % WORLD_SIZE
|
||||
for i in range(WORLD_SIZE - 1):
|
||||
world.send(reduced, next_rank, cache_id=f"{cache_id}-{i}-g" if cache_id is not None else None)
|
||||
current_chunk_index = ((current_chunk_index - 1) + WORLD_SIZE) % WORLD_SIZE
|
||||
recv_buf = Tensor.empty(*reduced.shape)
|
||||
world.recv(recv_buf, prev_rank)
|
||||
reduced = chunks[current_chunk_index] = recv_buf
|
||||
|
||||
return Tensor.cat(*chunks, dim=0).shrink(((0, t.numel()),)).reshape(*t.shape)
|
|
@ -59,7 +59,7 @@ def _send_lb(x:LazyBuffer, target_rank:int, cache_id:Optional[str]=None) -> None
|
|||
|
||||
# receive a lazybuffer from the target rank
|
||||
def _recv_lb(x:LazyBuffer, target_rank:int) -> LazyBuffer:
|
||||
_recv_rb(x.realize().realized, target_rank)
|
||||
_recv_rb(x.contiguous().realize().realized, target_rank)
|
||||
return x
|
||||
|
||||
class Send(Function):
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
from extra import dist
|
||||
from tinygrad.jit import TinyJit
|
||||
if __name__ == "__main__":
|
||||
dist.preinit()
|
||||
|
||||
from extra.dist import collectives
|
||||
from tinygrad.helpers import CI, getenv
|
||||
from tinygrad.tensor import Tensor
|
||||
import numpy as np
|
||||
|
||||
@TinyJit
|
||||
def allreduce_jit(t:Tensor, cache_id=None) -> Tensor:
|
||||
return collectives.allreduce(t, cache_id=cache_id).realize()
|
||||
|
||||
SIZE = 2048 if not CI else 2
|
||||
SIZE_2 = 255 if not CI else 3
|
||||
|
||||
def run():
|
||||
# set a deterministic seed so that both ranks generate the same random tensor
|
||||
Tensor.manual_seed(42)
|
||||
|
||||
rank = getenv("RANK")
|
||||
|
||||
# loop 3 times to make sure it works with the jit
|
||||
for _ in range(3):
|
||||
# create a tensor to send
|
||||
t = Tensor.zeros(SIZE, SIZE) if rank == 0 else Tensor.ones(SIZE, SIZE)
|
||||
t2 = allreduce_jit(t.contiguous().realize(), cache_id="test")
|
||||
assert np.allclose(np.ones((SIZE, SIZE)), t2.numpy())
|
||||
|
||||
# reset jit
|
||||
allreduce_jit.cnt = 0
|
||||
|
||||
# test uneven chunk sizes
|
||||
for _ in range(3):
|
||||
# create a tensor to send
|
||||
t = Tensor.ones(SIZE_2, SIZE_2, SIZE_2) if rank == 0 else Tensor.zeros(SIZE_2, SIZE_2, SIZE_2)
|
||||
t2 = allreduce_jit(t.contiguous().realize(), cache_id="test2")
|
||||
assert np.allclose(np.ones((SIZE_2, SIZE_2, SIZE_2)), t2.numpy())
|
||||
|
||||
print(f"rank {rank} passed")
|
||||
|
||||
if __name__ == "__main__":
|
||||
devices = ["gpu:0", "gpu:1" if not CI else "gpu:0"]
|
||||
world_size = len(devices)
|
||||
|
||||
dist.init_oob(world_size)
|
||||
|
||||
processes = []
|
||||
for rank, device in enumerate(devices):
|
||||
processes.append(dist.spawn(rank, device, fn=run, args=()))
|
||||
for p in processes: p.join()
|
||||
|
||||
# exit with error code if any of the processes failed
|
||||
for p in processes:
|
||||
if p.exitcode != 0: exit(p.exitcode)
|
|
@ -59,3 +59,7 @@ if __name__ == "__main__":
|
|||
for rank, device in enumerate(devices):
|
||||
processes.append(dist.spawn(rank, device, fn=run, args=()))
|
||||
for p in processes: p.join()
|
||||
|
||||
# exit with error code if any of the processes failed
|
||||
for p in processes:
|
||||
if p.exitcode != 0: exit(p.exitcode)
|
||||
|
|
|
@ -48,7 +48,7 @@ class CLBuffer(RawBufferCopyInOut):
|
|||
assert not self.dtype.name.startswith("image"), f"can't copyout images {self.dtype}"
|
||||
buf = cl.Buffer(CL.cl_ctxs[self._buf.device], cl.mem_flags.WRITE_ONLY | cl.mem_flags.USE_HOST_PTR, 0, hostbuf=x.data)
|
||||
mapped, event = cl.enqueue_map_buffer(CL.cl_queue[self._buf.device], buf, cl.map_flags.WRITE, 0, self.size, dtype=self.dtype.np, is_blocking=False)
|
||||
with mapped.base: cl.enqueue_copy(CL.cl_queue[self._buf.device], mapped, self._buf, is_blocking=True, wait_for=[event])
|
||||
with mapped.base: cl.enqueue_copy(CL.cl_queue[self._buf.device], mapped, self._buf, is_blocking=True, wait_for=[event] + ([self.event] if hasattr(self, "event") else []))
|
||||
|
||||
class CLProgram:
|
||||
def __init__(self, name:str, prg:str, binary=False, argdtypes=None, options=None):
|
||||
|
|
Loading…
Reference in New Issue