multitensor start (#2676)

* multitensor work

* early gen fixes the tests

* atol for flaky test
This commit is contained in:
George Hotz 2023-12-07 17:07:05 -08:00 committed by GitHub
parent 4b01839774
commit 4164d0ebbd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 118 additions and 37 deletions

View File

@ -37,7 +37,7 @@ Try a matmul. See how, despite the style, it is fused into one kernel with the p
```sh
DEBUG=3 python3 -c "from tinygrad import Tensor;
N = 1024; a, b = Tensor.rand(N, N), Tensor.rand(N, N);
c = (a.reshape(N, 1, N) * b.permute(1,0).reshape(1, N, N)).sum(axis=2);
c = (a.reshape(N, 1, N) * b.T.reshape(1, N, N)).sum(axis=2);
print((c.numpy() - (a.numpy() @ b.numpy())).mean())"
```

59
extra/multitensor.py Normal file
View File

@ -0,0 +1,59 @@
import numpy as np
from tinygrad import Tensor, Device, GlobalCounters
from tinygrad.helpers import Timing
d0, d1 = f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2"
N = 256
FLOPS = N*N*N*2
# LazyBuffer should make three fields lists: self.st (all must have the same shape), self.realized, and self.device
def explicit_shard_W_axis_1(X, W):
Xs = [X.to(d0), X.to(d1)]
Ws = [W[:, :N//2].to(d0), W[:, N//2:].to(d1)] # TODO: these shouldn't make copies on the original device
# pad them to form the correct size
Ws = [Ws[0].pad((None, (0,N//2))), Ws[1].pad((None, (N//2,0)))]
for x in Xs: assert x.shape == X.shape
for w in Ws: assert w.shape == W.shape
# TODO: it shouldn't be faster with these realize
for x in Xs+Ws: x.realize()
def lm(x:Tensor, w:Tensor):
# these are movement ops on the local device
x = x.reshape(N, 1, N).expand(N, N, N)
w = w.T.reshape(1, N, N).expand(N, N, N)
m = x*w
assert m.lazydata.st.views[0].mask is not None
ret = m.sum(2)
return ret
#Os = [lm(Xs[0], Ws[0]), lm(Xs[1], Ws[1])]
Os = [Xs[0] @ Ws[0], Xs[1] @ Ws[1]]
for x in Os: x.realize()
return Os[0].to(Device.DEFAULT) + Os[1].to(Device.DEFAULT)
#return Tensor.cat(*[x.to(Device.DEFAULT) for x in Os], dim=1) # TODO: someday we can remove this copy too
def matmul(X, W):
return explicit_shard_W_axis_1(X, W)
#return X@W
if __name__ == "__main__":
with Timing("init devices: "):
Device[d0], Device[d1]
with Timing("create tensors: "):
X = Tensor.kaiming_uniform(N, N).realize()
W = Tensor.kaiming_uniform(N, N).realize()
#with Timing("warmup: "):
# O = matmul(X, W).numpy()
GlobalCounters.reset()
print("******** multiply start")
with Timing("******** multiply done: ", lambda x: f" {FLOPS/x:.2f} GFLOPS"):
O = matmul(X, W).realize()
Device[Device.DEFAULT].synchronize()
with Timing("testing: "):
val = X.numpy() @ W.numpy()
np.testing.assert_allclose(val, O.numpy(), atol=1e-5)

View File

@ -166,7 +166,7 @@ class TestOpt(unittest.TestCase):
with CLCache(allowed=1):
d = a * b + c
d.realize()
np.testing.assert_allclose(d.numpy(), na*nb+nc, rtol=1e-5)
np.testing.assert_allclose(d.numpy(), na*nb+nc, rtol=1e-5, atol=1e-7)
def test_fold_reduce_elementwise(self):
img = Tensor.ones(32)

View File

@ -17,11 +17,12 @@ from examples.stable_diffusion import UNetModel
def helper_test(nm, gen, train, max_memory_allowed, max_kernels_allowed, all_jitted=False):
tms = []
for _ in range(4):
early_gen = [x.realize() if isinstance(x, Tensor) else x for x in gen()]
GlobalCounters.reset()
GlobalCounters.mem_used = 0
Device[Device.DEFAULT].synchronize()
st = time.perf_counter_ns()
train(*gen())
train(*early_gen)
Device[Device.DEFAULT].synchronize()
tms.append(time.perf_counter_ns() - st)

View File

@ -3,19 +3,27 @@ from tinygrad.tensor import Tensor
class TestMaskedShapeTracker(unittest.TestCase):
def test_mul_masked(self):
a = Tensor([1,1,1,1])
b = Tensor([1,1]).pad(((0,2),))
a = Tensor([1,1,1,1,1])
b = Tensor([1,1]).pad(((0,3),))
c = a*b
# TODO: make this true
assert c.shape == a.shape
#assert c.lazydata.st.views[0].mask is not None
ret = c.data()
assert ret.tolist() == [1.0, 1.0, 0.0, 0.0]
assert ret.tolist() == [1.0, 1.0, 0.0, 0.0, 0.0]
def test_mul_both_masked(self):
a = Tensor([1,1]).pad(((0,3),))
b = Tensor([1,1]).pad(((0,3),))
c = a*b
assert c.shape == a.shape
#assert c.lazydata.st.views[0].mask is not None
ret = c.data()
assert ret.tolist() == [1.0, 1.0, 0.0, 0.0, 0.0]
def test_add_masked(self):
a = Tensor([1,1]).pad(((0,2),))
b = Tensor([1,1]).pad(((0,2),))
c = a+b
# TODO: make this true
#assert c.lazydata.st.views[0].mask is not None
ret = c.data()
assert ret.tolist() == [2.0, 2.0, 0.0, 0.0]

View File

@ -48,7 +48,7 @@ class JITRunner:
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
raise NotImplementedError("override this")
def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Optional[Dict[Variable, int]], et: Optional[float], buf_count, jit=False, num_kernels=1, lra: Optional[Dict]=None):
def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Optional[Dict[Variable, int]], et: Optional[float], buf_count:int, jit=False, num_kernels=1, lra: Optional[Dict]=None, device:str=""):
if var_vals is None: var_vals = {}
op_estimate, mem_estimate = sym_infer(op_estimate, var_vals), sym_infer(mem_estimate, var_vals)
GlobalCounters.kernel_count += num_kernels
@ -56,7 +56,7 @@ def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Option
GlobalCounters.global_mem += mem_estimate
if et is not None: GlobalCounters.time_sum_s += et
if DEBUG >= 2:
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', ('magenta' if num_kernels == 1 else 'CYAN') if jit else None)} {name+' '*(37-ansilen(name))} arg {buf_count:3d} sz {str(lra.get('global_size', '') if lra else ''):18s} {str(lra.get('local_size', '') if lra else ''):12s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', ('magenta' if num_kernels == 1 else 'CYAN') if jit else None)} {name+' '*(37-ansilen(name))} arg {buf_count:3d} sz {str(lra.get('global_size', '') if lra else ''):18s} dev {device:7s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
# **************** Buffer / Allocator ****************
@ -89,32 +89,41 @@ class Buffer:
if self.size > 0: self.allocator.copyout(flat_mv(ret.data), self._buf)
return ret
def _internal_buffer_copy(dest, src):
if hasattr(dest.allocator, 'transfer') and type(dest.allocator) is type(src.allocator):
# fast path, used on HIP between GPUs
# NOTE: it's important we use the dest device here to ensure the transfer is ready
Device[src.device].synchronize() # TODO: async this
dest.allocator.transfer(dest._buf, src._buf, dest.size*dest.dtype.itemsize)
return
if getenv("FROM_BUFFER") and hasattr(dest.allocator, 'from_buffer') and hasattr(dest.allocator, 'transfer') and hasattr(src.allocator, 'as_buffer'):
# fast path, used on Metal in OS X Sonoma
# NOTE: this is *only* faster if the pages from disk are already loaded into memory
fb = dest.allocator.from_buffer(src.allocator.as_buffer(src._buf))
if fb:
dest.allocator.transfer(dest._buf, fb, dest.size*dest.dtype.itemsize)
return
if hasattr(dest.allocator, 'as_buffer'):
# fast(ish) path, uses readinto in diskbuffers
src.allocator.copyout(dest.allocator.as_buffer(dest._buf), src._buf)
elif hasattr(src.allocator, 'as_buffer'):
dest.allocator.copyin(dest._buf, src.allocator.as_buffer(src._buf))
else:
# slow path, allocates a CPU buffer
dest.copyin(src.toCPU().data)
class _BufferCopy(JITRunner):
# TODO: make wait work
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False):
dest, src = rawbufs
assert dest.size == src.size and dest.dtype == src.dtype, "buffer copy size/dtype mismatch"
if DEBUG >= 2: print(f"*** copy {dest.device} <- {src.device} size {dest.size:<16d} dtype {dest.dtype}")
if hasattr(dest.allocator, 'transfer') and type(dest.allocator) is type(src.allocator):
# fast path, used on HIP between GPUs
# NOTE: it's important we use the dest device here to ensure the transfer is ready
dest.allocator.transfer(dest._buf, src._buf, dest.size*dest.dtype.itemsize)
return
if getenv("FROM_BUFFER") and hasattr(dest.allocator, 'from_buffer') and hasattr(dest.allocator, 'transfer') and hasattr(src.allocator, 'as_buffer'):
# fast path, used on Metal in OS X Sonoma
# NOTE: this is *only* faster if the pages from disk are already loaded into memory
fb = dest.allocator.from_buffer(src.allocator.as_buffer(src._buf))
if fb:
dest.allocator.transfer(dest._buf, fb, dest.size*dest.dtype.itemsize)
return
if hasattr(dest.allocator, 'as_buffer'):
# fast(ish) path, uses readinto in diskbuffers
src.allocator.copyout(dest.allocator.as_buffer(dest._buf), src._buf)
elif hasattr(src.allocator, 'as_buffer'):
dest.allocator.copyin(dest._buf, src.allocator.as_buffer(src._buf))
else:
# slow path, allocates a CPU buffer
dest.copyin(src.toCPU().data)
st = time.perf_counter()
_internal_buffer_copy(dest, src)
et = None
if wait or DEBUG >= 2:
Device[dest.device].synchronize()
et = time.perf_counter() - st
update_stats(colored(f"copy {dest.device:7s} <- {src.device:7s}", "yellow"), 0, dest.size*dest.dtype.itemsize, {}, et, 2, jit, lra={"global_size": dest.size}, device=dest.device)
BufferCopy = _BufferCopy()
# TODO: size, dest, src are the same type. can we enforce this?
@ -167,7 +176,7 @@ class InterpretedASTRunner(JITRunner):
st = time.perf_counter()
rawbufs[0]._buf = self.fxn([x._buf for x in rawbufs], var_vals)
et = time.perf_counter() - st
update_stats(f"<interpreted {rawbufs[0].size}>", self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit)
update_stats(f"<interpreted {rawbufs[0].size}>", self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit, device=rawbufs[0].device)
return et
class Interpreted:
@ -257,7 +266,7 @@ class CompiledASTRunner(JITRunner):
if global_size: lra['global_size'] = global_size
if local_size and 'local_size' not in lra: lra['local_size'] = local_size
et = self.clprg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k] for k in self.vars), wait=wait or DEBUG>=2)
update_stats(self.display_name, self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit, lra=lra)
update_stats(self.display_name, self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit, lra=lra, device=rawbufs[0].device)
return et
class Compiled:

View File

@ -2,6 +2,7 @@ import os, atexit, functools
from collections import defaultdict
from typing import Dict, List
from tinygrad.ops import ScheduleItem, UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, BufferOps, TernaryOps, Op, OpType, LazyOp
from tinygrad.device import Device
from tinygrad.helpers import GRAPH, GRAPHPATH, DEBUG, GlobalCounters, getenv, dedup
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.shape.shapetracker import ShapeTracker
@ -21,7 +22,7 @@ if GRAPH:
G = nx.DiGraph()
def save_graph_exit():
for k,v in cnts.items(): print(k, v)
print("saving", G)
print("saving", G, f"to {GRAPHPATH}.svg")
nx.drawing.nx_pydot.write_dot(G, f'{GRAPHPATH}.dot')
# -Gnslimit=100 can make it finish, but you won't like results
os.system(f'dot -Tsvg {GRAPHPATH}.dot -o {GRAPHPATH}.svg')
@ -68,7 +69,7 @@ def log_schedule_item(si: ScheduleItem):
cnts[optype] += 1
if GRAPH:
assert si.out.base == si.out, "all outputs based"
top_colors = {LoadOps: '#FFFFa0', UnaryOps: "#c0c0c0", ReduceOps: "#8080ff", BinaryOps: "#c0c0c0", MovementOps: "#80ff80", TernaryOps: "#c0c0c0", BufferOps: '#FF8080'}
top_colors = {LoadOps: '#FFFFa0', UnaryOps: "#c0c0c0", ReduceOps: "#FFA0A0", BinaryOps: "#c0c0c0", MovementOps: "#80ff80", TernaryOps: "#c0c0c0", BufferOps: '#a0a0ff'}
# get inputs for shapetrackers
input_to_st = defaultdict(list)
@ -88,7 +89,7 @@ def log_schedule_item(si: ScheduleItem):
if nm(si.out) not in G.nodes: G.add_node(nm(si.out))
G.nodes[nm(si.out)]['label'] = (str(set(x.shape for x in si.inputs))+"\n"+str(si.out.shape) if optype == ReduceOps else str(si.out.shape))+str_dtype(si.out.dtype)+(f"\n{si.ast.op}" if si.ast.op in LoadOps else "")
G.nodes[nm(si.out)]['label'] = '"' + (str(set(x.shape for x in si.inputs))+"\n"+str(si.out.shape) if optype == ReduceOps else str(si.out.shape))+str_dtype(si.out.dtype)+(f"\n{si.ast.op}" if si.ast.op in LoadOps or optype is BufferOps else "")+(f"\n{si.out.device}" if si.out.device != Device.DEFAULT else "") + '"'
G.nodes[nm(si.out)]['fillcolor'] = top_colors[optype]
G.nodes[nm(si.out)]['color'] = 'black'
G.nodes[nm(si.out)]['style'] = 'filled'

View File

@ -55,6 +55,7 @@ class HIPAllocator(LRUAllocator):
check(hip.hipMemcpy(from_mv(dest), src, len(dest), hip.hipMemcpyDeviceToHost))
def transfer(self, dest:T, src:T, sz:int):
check(hip.hipSetDevice(self.device))
# TODO: hipMemcpyAsync, but you have to track the "src" buffer to not free it
check(hip.hipMemcpy(dest, src, sz, hip.hipMemcpyDeviceToDevice))
class HIPDevice(Compiled):
@ -65,4 +66,6 @@ class HIPDevice(Compiled):
from tinygrad.features.graph.hip import HIPGraph
super().__init__(MallocAllocator if MOCKHIP else HIPAllocator(self.device), LinearizerOptions(device="HIP"), HIPRenderer, compile_hip, functools.partial(HIPProgram, self.device), HIPGraph)
def synchronize(self): hip.hipDeviceSynchronize()
def synchronize(self):
check(hip.hipSetDevice(self.device))
check(hip.hipDeviceSynchronize())