mirror of https://github.com/commaai/tinygrad.git
add RING_ALLREDUCE_THRESHOLD (#5835)
* add RING_ALLREDUCE_THRESHOLD * becnhmark * fixes * fix n_gpus * unused import * remove debug=2
This commit is contained in:
parent
431749dc21
commit
f768935be8
|
@ -5,7 +5,7 @@ from tinygrad.ops import ReduceOps
|
|||
from tinygrad.multi import MultiLazyBuffer, all_reduce
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from tinygrad.helpers import getenv, Context, RING
|
||||
from tinygrad.helpers import getenv, Context, RING, DEBUG
|
||||
from typing import List, Union
|
||||
|
||||
def realize(x: Union[LazyBuffer, List[LazyBuffer]]):
|
||||
|
@ -20,16 +20,15 @@ def test(devs: List[str], N: int, iters:int = 10):
|
|||
|
||||
secs, gflops, gbs = 0, 0, 0
|
||||
for i in range(-2, iters):
|
||||
GlobalCounters.reset()
|
||||
lbs = [Tensor.full((N,), float(1+i), device=d).contiguous().lazydata for i,d in enumerate(devs)]
|
||||
realize(lbs)
|
||||
GlobalCounters.reset()
|
||||
start = time.time()
|
||||
realize(_jitted(ReduceOps.SUM, Tensor(MultiLazyBuffer(lbs, 0), device=devs)).lazydata.lbs)
|
||||
end = time.time()
|
||||
if i < 0:
|
||||
# First time is slow due to kernel compilation
|
||||
continue
|
||||
i_secs = end-start
|
||||
if i < 0: continue # warm up jit
|
||||
i_secs = time.time() - start
|
||||
|
||||
if DEBUG >= 2: i_secs = GlobalCounters.time_sum_s
|
||||
i_gflops = GlobalCounters.global_ops/i_secs/10**9
|
||||
i_gbs = (N*4)/i_secs/10**9
|
||||
print(f"{'ring_allreduce' if RING >= 2 else 'naive_allreduce'} iter {i+1}/{iters}: {i_secs:.6f} sec {i_gflops:.2f} GFLOP/s {i_gbs:.2f} GB/s")
|
||||
|
@ -39,22 +38,34 @@ def test(devs: List[str], N: int, iters:int = 10):
|
|||
|
||||
return (gflops/iters, gbs/iters, secs/iters)
|
||||
|
||||
def run(sz, n_gpus=6, iters=10):
|
||||
dev = Device.DEFAULT
|
||||
devs = tuple([f"{dev}:{x}" for x in range(n_gpus)])
|
||||
N = sz // 4 # float32 is 4 bytes
|
||||
|
||||
with Context(RING=2):
|
||||
(ring_gflops, ring_gbs, ring_secs) = test(devs, N, iters=iters)
|
||||
with Context(RING=0):
|
||||
(naive_gflops, naive_gbs, naive_secs) = test(devs, N, iters=iters)
|
||||
return (ring_gflops, ring_gbs, ring_secs), (naive_gflops, naive_gbs, naive_secs)
|
||||
|
||||
def main():
|
||||
dev, n_gpus = Device.DEFAULT, getenv("GPUS", 6) # number of gpus
|
||||
devs = tuple([f"{dev}:{x}" for x in range(n_gpus)])
|
||||
n_gpus = getenv("GPUS", 6)
|
||||
|
||||
sz = getenv("SZ", 1000) * 10**6 # size of data on each gpu
|
||||
f32 = 4 # 4 bytes
|
||||
N = sz//f32
|
||||
|
||||
print(f"Using {sz/10**9:.2f} GB of numbers on each of {n_gpus} GPUs, {n_gpus*sz/10**9:.2f} GB total.")
|
||||
with Context(RING=2):
|
||||
(ring_gflops, ring_gbs, ring_secs) = test(devs, N)
|
||||
with Context(RING=0):
|
||||
(naive_gflops, naive_gbs, naive_secs) = test(devs, N)
|
||||
print(f"Ring:\n {ring_secs:.6f} seconds/iter\n {ring_gflops:.2f} GFLOP/s\n {ring_gbs:.2f} GB/s")
|
||||
print(f"Naive:\n {naive_secs:.6f} seconds/iter\n {naive_gflops:.2f} GFLOP/s\n {naive_gbs:.2f} GB/s")
|
||||
if getenv("BENCHMARK_SPLIT"):
|
||||
l, r = 0, 512
|
||||
while r - l > 1:
|
||||
m = (l + r) // 2
|
||||
(ring_gflops, ring_gbs, ring_secs), (naive_gflops, naive_gbs, naive_secs) = run(m * 1024 * 4, n_gpus=n_gpus, iters=100)
|
||||
if ring_secs > naive_secs: l = m
|
||||
else: r = m
|
||||
print("Better split", r * 1024, "elements")
|
||||
else:
|
||||
sz = getenv("SZ", 1000) * 10**6 # size of data on each gpu
|
||||
print(f"Using {sz/10**9:.2f} GB of numbers on each of {n_gpus} GPUs, {n_gpus*sz/10**9:.2f} GB total.")
|
||||
(ring_gflops, ring_gbs, ring_secs), (naive_gflops, naive_gbs, naive_secs) = run(sz)
|
||||
print(f"Ring:\n {ring_secs:.6f} seconds/iter\n {ring_gflops:.2f} GFLOP/s\n {ring_gbs:.2f} GB/s")
|
||||
print(f"Naive:\n {naive_secs:.6f} seconds/iter\n {naive_gflops:.2f} GFLOP/s\n {naive_gbs:.2f} GB/s")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,7 +1,7 @@
|
|||
from __future__ import annotations
|
||||
from typing import Optional, Union, Any, Tuple, List, Dict
|
||||
import functools, itertools, operator
|
||||
from tinygrad.helpers import all_same, all_int, dedup, prod, DEBUG, RING
|
||||
from tinygrad.helpers import all_same, all_int, dedup, prod, DEBUG, RING, getenv
|
||||
from tinygrad.dtype import DType, ConstType
|
||||
from tinygrad.ops import BinaryOps, MetaOps, UnaryOps, TernaryOps, ReduceOps
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
|
@ -15,7 +15,7 @@ def all_reduce(op: ReduceOps, lbs: List[LazyBuffer]) -> List[LazyBuffer]:
|
|||
n_lbs, dim = len(lbs), prod(lbs[0].shape)
|
||||
# Ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically)
|
||||
# so just fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks.
|
||||
use_ring = (RING >= 2 or (n_lbs > 2 and dim > 256_000 and RING >= 1))
|
||||
use_ring = (RING >= 2 or (n_lbs > 2 and dim > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and RING >= 1))
|
||||
if DEBUG >= 2: print(f"{'RING ALLREDUCE' if use_ring else 'NAIVE ALLREDUCE'} {n_lbs}x{dim} | {lbs[0].dtype}")
|
||||
if not use_ring:
|
||||
return [functools.reduce(lambda x,y: x.e(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
|
||||
|
|
Loading…
Reference in New Issue