remove graph [pr] (#7085)

This commit is contained in:
George Hotz 2024-10-16 11:40:07 +08:00 committed by GitHub
parent 53586eac56
commit 3169cb386d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 22 additions and 175 deletions

View File

@ -40,9 +40,6 @@ METAL_XCODE | [1] | enable Metal using macOS Xcode SDK
CLANG | [1] | enable Clang backend
LLVM | [1] | enable LLVM backend
BEAM | [#] | number of beams in kernel beam search
GRAPH | [1] | create a graph of all operations (requires graphviz)
GRAPHUOPS | [1] | create a graph of uops (requires graphviz and saves at /tmp/uops.{svg,dot})
GRAPHPATH | [/path/to] | where to put the generated graph
DEFAULT_FLOAT | [HALF, ...]| specify the default float dtype (FLOAT32, HALF, BFLOAT16, FLOAT64, ...), default to FLOAT32
IMAGE | [1-2] | enable 2d specific optimizations
FLOAT16 | [1] | use float16 for images instead of float32

View File

@ -302,7 +302,4 @@ You can find a full list and their descriptions in [env_vars.md](env_vars.md).
### Visualizing the Computation Graph
It is possible to visualize the computation graph of a neural network using [graphviz](https://graphviz.org/).
This is easily done by running a single pass (forward or backward!) of the neural network with the environment variable `GRAPH` set to `1`.
The graph will be saved to `/tmp/net.svg` by default.
It is possible to visualize the computation graph of a neural network using VIZ=1.

View File

@ -4,7 +4,7 @@ if "NOOPT" not in os.environ: os.environ["NOOPT"] = "1"
from tinygrad import Device, nn, Tensor, dtypes, Variable
Device.DEFAULT = "CLANG"
from train_gpt2 import GPT, GPTConfig
from tinygrad.helpers import dedup, to_function_name, flatten, getenv, GRAPH, GlobalCounters, ansilen, to_function_name
from tinygrad.helpers import dedup, to_function_name, flatten, getenv, GlobalCounters, ansilen, to_function_name
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import get_kernel, run_schedule
from tinygrad.engine.memory import memory_planner
@ -27,7 +27,6 @@ if __name__ == "__main__":
optimizer = nn.optim.Adam(nn.state.get_parameters(model), lr=1e-4)
warmup_count = getenv("WARMUP", 3)
for i in range(warmup_count): # TODO: why does it take three and not two to stablize
if i == warmup_count-1: GRAPH.value = getenv("LATEGRAPH")
GlobalCounters.reset()
X = Tensor.empty(4, 64, dtype=dtypes.int).reshape(B, T)
Y = Tensor.empty(4, 64, dtype=dtypes.int).reshape(B, T)

View File

@ -1,6 +1,5 @@
from tinygrad import Tensor, dtypes, Device
from tinygrad.helpers import getenv, DEBUG
from tinygrad.engine.graph import print_globalcounters
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
from tinygrad.engine.realize import CompiledRunner, ExecItem
from dataclasses import replace
@ -36,4 +35,3 @@ if __name__ == "__main__":
prg = replace(prg, src=new_src)
ei = ExecItem(CompiledRunner(prg), [x.ensure_allocated() for x in si.bufs], si.metadata)
for i in range(5): ei.run(wait=True)
if DEBUG < 2: print_globalcounters()

View File

@ -1,4 +1,3 @@
# TODO: move the GRAPH and DEBUG stuff to here
import gc
from tinygrad.helpers import prod
from tinygrad.engine.lazy import LazyBuffer

View File

@ -148,7 +148,14 @@ def mcts_search(lin:Kernel, rawbufs:List[Buffer], amt:int) -> Kernel:
if DEBUG>=2: print()
if getenv("MCTSGRAPH"):
from tinygrad.engine.graph import nx, save_graph, GRAPHPATH
import networkx as nx
import os
GRAPHPATH = "/tmp/net"
def save_graph(G, fn, opt=""):
print("saving", G, f"to {fn}.svg")
nx.drawing.nx_pydot.write_dot(G, f'{fn}.dot')
os.system(f'dot {opt} -Tsvg {fn}.dot -o {fn}.svg')
G = nx.DiGraph()
def add_node(node:MCTSNode):
if node.n == 0: return

View File

@ -160,5 +160,5 @@ if __name__ == "__main__":
jmodel = TinyJit(model)
jmodel(Tensor.rand(1, 3, 224, 224)).realize()
GlobalCounters.reset()
with Context(GRAPH=1): jmodel(Tensor.rand(1, 3, 224, 224)).realize()
jmodel(Tensor.rand(1, 3, 224, 224)).realize()
for i in range(10): jmodel(Tensor.rand(1, 3, 224, 224))

View File

@ -50,10 +50,6 @@ if __name__ == "__main__":
if getenv("LINEARIZE", 1):
with Timing("***** model linearize in "): uops = [linearize_uop(u) for u in uops]
print(sum(len(u) for u in uops))
if getenv("GRAPHUOPS", 0):
for u in uops:
from tinygrad.engine.graph import graph_uops
graph_uops(u)
if getenv("SRC", 0):
renderer = Device[Device.DEFAULT].renderer
for k,u in zip(kernels, uops): print(renderer.render(k.name, u))

View File

@ -1,45 +0,0 @@
#!/usr/bin/env python
import unittest
from tinygrad.tensor import Tensor
from tinygrad.codegen.kernel import Kernel
from tinygrad.renderer.cstyle import OpenCLRenderer
from tinygrad.engine.graph import graph_uops
from tinygrad.engine.schedule import create_schedule
from tinygrad.nn import Conv2d
class TestUopsGraph(unittest.TestCase):
def test_matmul(self):
N = 1024
a = Tensor.rand(N,N)
b = Tensor.rand(N,N)
si = create_schedule([(a@b).lazydata])[-1]
lin = Kernel(si.ast)
lin.hand_coded_optimizations()
print(lin.colored_shape())
uops = lin.linearize().uops
graph_uops(uops)
for u in uops: print(u)
print(OpenCLRenderer("matmul", uops)[0])
def test_reduce(self):
a = Tensor.rand(1024*1024)
si = create_schedule([a.sum().lazydata])[-1]
lin = Kernel(si.ast)
lin.hand_coded_optimizations()
uops = lin.linearize().uops
graph_uops(uops)
#print(OpenCLRenderer("reduce", uops)[0])
def test_conv(self):
x = Tensor.rand(1,3,16,16)
c = Conv2d(3, 16, (3,3))
si = create_schedule([c(x).elu().lazydata])[-1]
lin = Kernel(si.ast)
lin.hand_coded_optimizations()
uops = lin.linearize().uops
graph_uops(uops)
print(lin.colored_shape())
print(OpenCLRenderer("conv", uops)[0])
if __name__ == '__main__':
unittest.main()

View File

@ -8,13 +8,12 @@ from tinygrad.tensor import _to_np_dtype
from tinygrad.engine.realize import Runner
from tinygrad.dtype import ConstType, DType
from tinygrad.nn.state import get_parameters
from tinygrad.helpers import Context, CI, OSX, getenv, colored
from tinygrad.helpers import CI, OSX, getenv, colored
def derandomize_model(model):
with Context(GRAPH=0):
for p in get_parameters(model):
p.lazydata = Tensor.empty(p.shape, device=p.device, dtype=p.dtype).lazydata
p.realize()
for p in get_parameters(model):
p.lazydata = Tensor.empty(p.shape, device=p.device, dtype=p.dtype).lazydata
p.realize()
def assert_jit_cache_len(fxn, expected_len):
if not fxn.jit_cache:

View File

@ -12,12 +12,12 @@ from tinygrad.dtype import DType, PtrDType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps, graph_rewrite
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
from tinygrad.codegen.kernel import Kernel, verify_ast
from tinygrad.engine.schedule import BUF_LIMIT, create_schedule, view_right, st_fixup, view_left
from tinygrad.engine.realize import CompiledRunner, run_schedule
from tinygrad.engine.lazy import LazyBuffer, view_supported_devices
from test.helpers import ast_const, is_dtype_supported, Context, timeit
from test.helpers import ast_const, is_dtype_supported, timeit
from extra.models.llama import precompute_freqs_cis
class KernelCountException(Exception): pass
@ -1587,7 +1587,7 @@ class TestIndexing(unittest.TestCase):
X = Tensor.randn(2,3,4,4).numpy()
with Context(FUSE_ARANGE=1):
compare = Tensor(X).interpolate(size=(2, 2), mode="linear").numpy()
with Context(FUSE_ARANGE=0, GRAPH=0, SAVE_SCHEDULE=1):
with Context(FUSE_ARANGE=0, SAVE_SCHEDULE=1):
ref = Tensor(X).interpolate(size=(2, 2), mode="linear").numpy()
np.testing.assert_allclose(ref, compare, atol=1e-5, rtol=1e-6)

View File

@ -63,10 +63,6 @@ class TestGraphRewriteEfficiency(unittest.TestCase):
et = time.perf_counter() - st
UOp.__init__ = old_init
print(f"rewrote in {et*1000:.2f} ms, from {len(lower_sink.sparents)} -> {len(new_sink.sparents)}, creating {cnt[0]} uops")
#from collections import Counter
#print(Counter(x.op for x in new_sink.sparents))
#from tinygrad.engine.graph import graph_uops
#graph_uops(linearize_uop(new_sink))
class TestGraphRewriteConst(unittest.TestCase):
def test_gep_const(self):

View File

@ -718,9 +718,6 @@ class Kernel:
self.uops:List[UOp] = linearize_uop(full_graph_rewrite(rewrite_shapetracker_with_index(modified_ast, self.opts), self.opts))
if DEBUG >= 5: print_uops(self.uops)
if getenv("GRAPHUOPS"):
from tinygrad.engine.graph import graph_uops
graph_uops(self.uops)
return self
def to_program(self, name_override:Optional[str]=None) -> Program:

View File

@ -1,85 +0,0 @@
import os, atexit, functools, contextlib
from collections import defaultdict
from typing import List, Any, DefaultDict
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MetaOps, TernaryOps, UOps, UOp
from tinygrad.device import Device
from tinygrad.helpers import GRAPHPATH, DEBUG, GlobalCounters, word_wrap
from tinygrad.engine.lazy import LazyBuffer
from tinygrad.viz.serve import uops_colors
with contextlib.suppress(ImportError): import networkx as nx
# **** debugging and graphing ****
def print_globalcounters():
if GlobalCounters.time_sum_s == 0: return
print(f"avg: {GlobalCounters.global_ops*1e-9/GlobalCounters.time_sum_s:8.2f} GFLOPS {GlobalCounters.global_mem*1e-9/GlobalCounters.time_sum_s:8.2f} GB/s", # noqa: E501
f"{' '*10}total: {GlobalCounters.kernel_count:5d} kernels {GlobalCounters.global_ops*1e-9:8.2f} GOPS {GlobalCounters.global_mem*1e-9:8.2f} GB {GlobalCounters.time_sum_s*1e3:8.2f} ms") # noqa: E501
if DEBUG >= 2: atexit.register(print_globalcounters)
def save_graph(G, fn, opt=""):
print("saving", G, f"to {fn}.svg")
nx.drawing.nx_pydot.write_dot(G, f'{fn}.dot')
os.system(f'dot {opt} -Tsvg {fn}.dot -o {fn}.svg')
G:Any = None
def init_graph():
global G
if G is not None: return
G = nx.DiGraph()
atexit.register(functools.partial(save_graph, G, GRAPHPATH)) # -Gnslimit=100 can make it finish, but you won't like results
counts: DefaultDict[type, int] = defaultdict(int)
def nm(x):
if not hasattr(x, 'node_id'):
setattr(x, 'node_id', counts[type(x)])
counts[type(x)] += 1
return x.node_id
def realized_lazybuffer(lb:'LazyBuffer', num):
init_graph()
G.nodes[nm(lb)]['style'] = '"filled,bold"'
G.nodes[nm(lb)]['fillcolor'] = G.nodes[nm(lb)]['fillcolor'][:-2]
G.nodes[nm(lb)]['label'] = '"' + G.nodes[nm(lb)]["label"].replace('"', '') + f'\nK:{num}"'
top_colors = {MetaOps: '#FFFFa0', UnaryOps: "#c0c0c0", ReduceOps: "#FFA0A0", BinaryOps: "#c0c0c0", TernaryOps: "#c0c0c0"}
def log_lazybuffer(lb:'LazyBuffer', scheduled=False):
init_graph()
if lb.base.realized is None and lb.base.op is MetaOps.CONST: return
if lb.base != lb:
offset = tuple(x.offset for x in lb.st.views if x.offset != 0)
label = f"{lb.st.shape}\n{lb.st.real_strides()}" + (f"\n{offset}" if len(offset) else "")
G.add_node(nm(lb), style='"filled,dashed"', fillcolor="#80ff8080", color="black", label=label)
G.add_edge(nm(lb.base), nm(lb), color='#00000060')
lb = lb.base
if lb.realized is None:
label_append = []
for idx,x in enumerate(lb.srcs):
if nm(x) not in G.nodes: log_lazybuffer(x)
if x.base.realized is None and x.base.op is MetaOps.CONST:
label_append.append(f"\nCONST{idx} {x.base.arg:g}")
else:
G.add_edge(nm(x), nm(lb), color='#a0a0a0')
label = '"' + \
(str(set(x.shape for x in lb.srcs))+"\n"+str(lb.shape) if lb.op in ReduceOps else str(lb.shape)) + \
(f"\n{lb.dtype.name}" if lb.dtype.name != "float" else "")+f"\n{lb.op}"+(f"\n{lb.arg}" if lb.op in {MetaOps.CONST, UnaryOps.CAST} else "") + \
(f"\n{lb.device[:15]}" if lb.device != Device.DEFAULT else "") + ''.join(label_append) + f'\n{lb.metadata}"'
G.add_node(nm(lb), style='"filled,dashed"', fillcolor=[v for k,v in top_colors.items() if lb.op in k][0] + "80", color="black", label=label)
if scheduled: G.nodes[nm(lb)]['shape'] = 'box'
else:
if nm(lb) not in G.nodes:
# realized but unseen?
G.add_node(nm(lb), label=f'"{str(lb.base.realized)[5:-1].replace(" ", chr(10))}\nb:{nm(lb.realized)}"', style='filled', fillcolor="#f0c08080")
graph_uops_cnt = 0
def graph_uops(uops:List[UOp]):
global graph_uops_cnt
G = nx.DiGraph()
for u in uops:
if u.op in {UOps.ENDRANGE, UOps.ENDIF}: continue
G.add_node(uops.index(u), label=f"{str(u.op)[5:]}{(' '+word_wrap(str(u.arg).replace(':', ''))) if u.arg is not None else ''}\n{str(u.dtype)}",
style="filled", fillcolor=uops_colors.get(u.op, "#ffffff"))
for v in u.src: G.add_edge(uops.index(v), uops.index(u))
save_graph(G, f'{GRAPHPATH}.{graph_uops_cnt}.uops', '-Grankdir=LR')
graph_uops_cnt += 1
def graph_uop(uop:UOp): return graph_uops(list(uop.sparents))

View File

@ -3,7 +3,7 @@ from typing import TypeVar, Generic, Callable, List, Tuple, Union, Dict, cast, O
import functools, itertools, collections
from tinygrad.tensor import Tensor
from tinygrad.engine.lazy import LazyBuffer
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, GRAPH, BEAM, getenv, colored, JIT, dedup, partition
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, dedup, partition
from tinygrad.device import Buffer, Compiled, Device
from tinygrad.dtype import DType
from tinygrad.ops import UOp, ssimplify, Variable, sint, sym_infer
@ -237,7 +237,7 @@ class TinyJit(Generic[ReturnType]):
self._jit_cache: List[ExecItem] = []
self._buffer_replace: WeakKeyDictionary[Buffer, Buffer] = WeakKeyDictionary()
# TODO: should we always disable the memory planner here? it must be off for prune
with Context(GRAPH=getenv("JITGRAPH", GRAPH.value), BEAM=getenv("JITBEAM", BEAM.value), NO_MEMORY_PLANNER=int(self.prune)):
with Context(BEAM=getenv("JITBEAM", BEAM.value), NO_MEMORY_PLANNER=int(self.prune)):
capturing.append(self)
try:
ret = self.fxn(*args, **kwargs)

View File

@ -4,7 +4,7 @@ from dataclasses import dataclass
from typing import Callable, Tuple, List, Dict, Optional, DefaultDict, cast
from tinygrad.ops import BUFFER_UOPS, UNSAFE_PAD_OPS, MetaOps, ReduceOps, UnaryOps, UOp, UOps, PatternMatcher, UPat, Variable, resolve, \
graph_rewrite, track_rewrites, sint
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, GlobalCounters, Metadata, all_same, \
from tinygrad.helpers import DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, Metadata, all_same, \
colored, diskcache_put, prod, dedup, all_int, merge_dicts, getenv, unwrap
from tinygrad.dtype import ImageDType, dtypes
from tinygrad.shape.shapetracker import ShapeTracker
@ -204,9 +204,6 @@ def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[La
"""recursively search the entire graph for all LazyBuffers, insert realizes after expands"""
if buf in allbufs: return None
if buf.base.realized is not None: return realizes.setdefault(buf.base)
if GRAPH:
from tinygrad.engine.graph import log_lazybuffer
log_lazybuffer(buf, scheduled)
# check if we need to realize views
if buf is not buf.base:
# fuse some pads
@ -414,13 +411,8 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
graph, in_degree, var_vals = _graph_schedule(outs)
queue = deque(lsi for lsi,deg in in_degree.items() if deg == 0)
schedule: List[ScheduleItem] = []
kernel_number = GlobalCounters.kernel_count
while queue:
lsi = queue.popleft()
if GRAPH:
kernel_number += 1
from tinygrad.engine.graph import realized_lazybuffer
for out in lsi.outputs: realized_lazybuffer(out, kernel_number)
for out in lsi.outputs: del out.srcs # can only schedule once
schedule.append(si:=ScheduleItem(lsi.ast, tuple(x.buffer for x in lsi.bufs if x.size != 0), lsi.metadata))
if (m:=BUF_LIMIT.get(device:=si.outputs[0].device)) and len(si.bufs) >= m:

View File

@ -100,7 +100,7 @@ class ContextVar:
DEBUG, IMAGE, BEAM, NOOPT, JIT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0), ContextVar("JIT", 1)
WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1)
GRAPH, GRAPHPATH, SAVE_SCHEDULE, RING = ContextVar("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net"), ContextVar("SAVE_SCHEDULE", 0), ContextVar("RING", 1)
SAVE_SCHEDULE, RING = ContextVar("SAVE_SCHEDULE", 0), ContextVar("RING", 1)
MULTIOUTPUT, PROFILE, PROFILEPATH = ContextVar("MULTIOUTPUT", 1), ContextVar("PROFILE", 0), ContextVar("PROFILEPATH", temp("tinygrad_profile.json"))
USE_TC, TC_OPT, AMX, TRANSCENDENTAL = ContextVar("TC", 1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0), ContextVar("TRANSCENDENTAL", 1)
FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_BW", 0)