minor cleanups from toonygrad (#6990)

This commit is contained in:
George Hotz 2024-10-11 14:19:10 +08:00 committed by GitHub
parent f50d0e0ee0
commit c08521e823
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 18 additions and 16 deletions

View File

@ -4,7 +4,7 @@ from tinygrad import Tensor, Device
from tinygrad.helpers import Profiling, Timing, getenv, BEAM, NOOPT, DEBUG, Context, ansilen
from tinygrad.ops import UOps
from tinygrad.codegen.kernel import Kernel
from tinygrad.codegen.lowerer import ast_to_uop
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index
from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite
from tinygrad.engine.search import beam_search, bufs_from_lin
@ -37,7 +37,7 @@ if __name__ == "__main__":
else: k.hand_coded_optimizations()
kernels.append(k)
with Timing("***** model lower in "): uops = [ast_to_uop(k.get_optimized_ast(), k.opts) for k in kernels]
with Timing("***** model lower in "): uops = [rewrite_shapetracker_with_index(k.get_optimized_ast(), k.opts) for k in kernels]
with Profiling(PROFILE, fn="/tmp/rewrite.prof"):
with Timing("***** model rewrite in "):
rewritten_uops = []

View File

@ -6,7 +6,7 @@ from tinygrad.dtype import PtrDType
from tinygrad.helpers import DEBUG
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, UOps, UOp, KernelInfo
from tinygrad.ops import UPat, PatternMatcher
from tinygrad.codegen.lowerer import ast_to_uop
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index
from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite, graph_rewrite, expander, reducer, sym, float4_folding
from tinygrad.shape.shapetracker import ShapeTracker, View
@ -50,7 +50,7 @@ class TestGraphRewriteEfficiency(unittest.TestCase):
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(
View(shape=(2, 4, 64, 8, 16, 16, 15, 3, 3, 4, 15), strides=(7200, 0, 230400, 900, 0, 14400, 15, 0, 0, 225, 1), offset=0,
mask=None, contiguous=False),)), src=()),)),)),)),)),)),))
lower_sink = ast_to_uop(sink, Device[Device.DEFAULT].renderer)
lower_sink = rewrite_shapetracker_with_index(sink, Device[Device.DEFAULT].renderer)
cnt = [0]
old_init = UOp.__init__
def uop_hook(self, *args, **kwargs):

View File

@ -16,7 +16,7 @@ from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.shape.view import strides_for_shape
from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite
from tinygrad.codegen.lowerer import ast_to_uop, get_contraction
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index, get_contraction
class OptOps(Enum):
TC = auto(); UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
@ -715,7 +715,7 @@ class Kernel:
print(self.applied_opts)
verify_ast(modified_ast)
self.uops:List[UOp] = linearize_uop(full_graph_rewrite(ast_to_uop(modified_ast, self.opts), self.opts))
self.uops:List[UOp] = linearize_uop(full_graph_rewrite(rewrite_shapetracker_with_index(modified_ast, self.opts), self.opts))
if DEBUG >= 5: print_uops(self.uops)
if getenv("GRAPHUOPS"):
from tinygrad.engine.graph import graph_uops

View File

@ -132,4 +132,4 @@ pm_lowerer = PatternMatcher([
(UPat((UOps.LOAD, UOps.STORE), src=(UPat(), UPat(UOps.VIEW)), allow_any_len=True, name="x"), lower_load_store),
])
def ast_to_uop(ast:UOp, opts:Renderer) -> UOp: return graph_rewrite(ast, pm_lowerer, ctx=get_index(ast, opts))
def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp: return graph_rewrite(ast, pm_lowerer, ctx=get_index(ast, opts))

View File

@ -3,8 +3,9 @@ from collections import defaultdict
from typing import List, Any, DefaultDict
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MetaOps, TernaryOps, UOps, UOp
from tinygrad.device import Device
from tinygrad.helpers import GRAPHPATH, DEBUG, GlobalCounters
from tinygrad.helpers import GRAPHPATH, DEBUG, GlobalCounters, word_wrap
from tinygrad.engine.lazy import LazyBuffer
from tinygrad.viz.serve import uops_colors
with contextlib.suppress(ImportError): import networkx as nx
@ -70,12 +71,7 @@ def log_lazybuffer(lb:'LazyBuffer', scheduled=False):
# realized but unseen?
G.add_node(nm(lb), label=f'"{str(lb.base.realized)[5:-1].replace(" ", chr(10))}\nb:{nm(lb.realized)}"', style='filled', fillcolor="#f0c08080")
uops_colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.CONST: "#e0e0e0", UOps.VCONST: "#e0e0e0",
UOps.DEFINE_GLOBAL: "#ffe0b0", UOps.DEFINE_LOCAL: "#ffe0d0", UOps.DEFINE_ACC: "#f0ffe0", UOps.REDUCE: "#C4A484",
UOps.RANGE: "#c8a0e0", UOps.ASSIGN: "#e0ffc0", UOps.BARRIER: "#ff8080", UOps.IF: "#c8b0c0", UOps.SPECIAL: "#c0c0ff",
UOps.WMMA: "#efefc0", UOps.VIEW: "#C8F9D4", UOps.REDUCE_AXIS: "#f58488"}
graph_uops_cnt = 0
def word_wrap(x, wrap=80): return x if len(x) <= wrap else (x[0:wrap] + "\n" + word_wrap(x[wrap:], wrap))
def graph_uops(uops:List[UOp]):
global graph_uops_cnt
G = nx.DiGraph()

View File

@ -67,6 +67,7 @@ def get_child(obj, key):
elif isinstance(obj, dict): obj = obj[k]
else: obj = getattr(obj, k)
return obj
def word_wrap(x, wrap=80): return x if len(x) <= wrap else (x[0:wrap] + "\n" + word_wrap(x[wrap:], wrap))
@functools.lru_cache(maxsize=None)
def to_function_name(s:str): return ''.join([c if c in (string.ascii_letters+string.digits+'_') else f'{ord(c):02X}' for c in ansistrip(s)])

View File

@ -99,6 +99,7 @@ def identity_element(op:BinaryOps, dt:DType): return dtypes.as_const({BinaryOps.
class UOps(FastEnum):
# uops that aren't rendered
SINK = auto()
CONTIGUOUS = auto()
# metaops
CUSTOM = auto()

View File

@ -4,11 +4,15 @@ from http.server import HTTPServer, BaseHTTPRequestHandler
from urllib.parse import parse_qs, urlparse
from dataclasses import asdict, dataclass
from typing import Any, Dict, List, Tuple, Optional
from tinygrad.helpers import colored, getenv, to_function_name, tqdm
from tinygrad.helpers import colored, getenv, to_function_name, tqdm, word_wrap
from tinygrad.ops import TrackedRewriteContext, UOp, UOps, lines
from tinygrad.engine.graph import word_wrap, uops_colors
from tinygrad.codegen.kernel import Kernel
uops_colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.CONST: "#e0e0e0", UOps.VCONST: "#e0e0e0",
UOps.DEFINE_GLOBAL: "#ffe0b0", UOps.DEFINE_LOCAL: "#ffe0d0", UOps.DEFINE_ACC: "#f0ffe0", UOps.REDUCE: "#C4A484",
UOps.RANGE: "#c8a0e0", UOps.ASSIGN: "#e0ffc0", UOps.BARRIER: "#ff8080", UOps.IF: "#c8b0c0", UOps.SPECIAL: "#c0c0ff",
UOps.WMMA: "#efefc0", UOps.VIEW: "#C8F9D4", UOps.REDUCE_AXIS: "#f58488"}
# ** API spec
@dataclass
@ -112,7 +116,6 @@ class Handler(BaseHTTPRequestHandler):
# ** main loop
stop_reloader = threading.Event()
def reloader():
mtime = os.stat(__file__).st_mtime
while not stop_reloader.is_set():
@ -122,6 +125,7 @@ def reloader():
time.sleep(0.1)
if __name__ == "__main__":
stop_reloader = threading.Event()
multiprocessing.current_process().name = "VizProcess" # disallow opening of devices
st = time.perf_counter()
print("*** viz is starting")