mirror of https://github.com/commaai/tinygrad.git
minor cleanups from toonygrad (#6990)
This commit is contained in:
parent
f50d0e0ee0
commit
c08521e823
|
@ -4,7 +4,7 @@ from tinygrad import Tensor, Device
|
|||
from tinygrad.helpers import Profiling, Timing, getenv, BEAM, NOOPT, DEBUG, Context, ansilen
|
||||
from tinygrad.ops import UOps
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.codegen.lowerer import ast_to_uop
|
||||
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index
|
||||
from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite
|
||||
from tinygrad.engine.search import beam_search, bufs_from_lin
|
||||
|
||||
|
@ -37,7 +37,7 @@ if __name__ == "__main__":
|
|||
else: k.hand_coded_optimizations()
|
||||
kernels.append(k)
|
||||
|
||||
with Timing("***** model lower in "): uops = [ast_to_uop(k.get_optimized_ast(), k.opts) for k in kernels]
|
||||
with Timing("***** model lower in "): uops = [rewrite_shapetracker_with_index(k.get_optimized_ast(), k.opts) for k in kernels]
|
||||
with Profiling(PROFILE, fn="/tmp/rewrite.prof"):
|
||||
with Timing("***** model rewrite in "):
|
||||
rewritten_uops = []
|
||||
|
|
|
@ -6,7 +6,7 @@ from tinygrad.dtype import PtrDType
|
|||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, UOps, UOp, KernelInfo
|
||||
from tinygrad.ops import UPat, PatternMatcher
|
||||
from tinygrad.codegen.lowerer import ast_to_uop
|
||||
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index
|
||||
from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite, graph_rewrite, expander, reducer, sym, float4_folding
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
|
||||
|
@ -50,7 +50,7 @@ class TestGraphRewriteEfficiency(unittest.TestCase):
|
|||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(
|
||||
View(shape=(2, 4, 64, 8, 16, 16, 15, 3, 3, 4, 15), strides=(7200, 0, 230400, 900, 0, 14400, 15, 0, 0, 225, 1), offset=0,
|
||||
mask=None, contiguous=False),)), src=()),)),)),)),)),)),))
|
||||
lower_sink = ast_to_uop(sink, Device[Device.DEFAULT].renderer)
|
||||
lower_sink = rewrite_shapetracker_with_index(sink, Device[Device.DEFAULT].renderer)
|
||||
cnt = [0]
|
||||
old_init = UOp.__init__
|
||||
def uop_hook(self, *args, **kwargs):
|
||||
|
|
|
@ -16,7 +16,7 @@ from tinygrad.shape.shapetracker import ShapeTracker
|
|||
from tinygrad.shape.symbolic import Variable, sint
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite
|
||||
from tinygrad.codegen.lowerer import ast_to_uop, get_contraction
|
||||
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index, get_contraction
|
||||
|
||||
class OptOps(Enum):
|
||||
TC = auto(); UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
|
||||
|
@ -715,7 +715,7 @@ class Kernel:
|
|||
print(self.applied_opts)
|
||||
verify_ast(modified_ast)
|
||||
|
||||
self.uops:List[UOp] = linearize_uop(full_graph_rewrite(ast_to_uop(modified_ast, self.opts), self.opts))
|
||||
self.uops:List[UOp] = linearize_uop(full_graph_rewrite(rewrite_shapetracker_with_index(modified_ast, self.opts), self.opts))
|
||||
if DEBUG >= 5: print_uops(self.uops)
|
||||
if getenv("GRAPHUOPS"):
|
||||
from tinygrad.engine.graph import graph_uops
|
||||
|
|
|
@ -132,4 +132,4 @@ pm_lowerer = PatternMatcher([
|
|||
(UPat((UOps.LOAD, UOps.STORE), src=(UPat(), UPat(UOps.VIEW)), allow_any_len=True, name="x"), lower_load_store),
|
||||
])
|
||||
|
||||
def ast_to_uop(ast:UOp, opts:Renderer) -> UOp: return graph_rewrite(ast, pm_lowerer, ctx=get_index(ast, opts))
|
||||
def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp: return graph_rewrite(ast, pm_lowerer, ctx=get_index(ast, opts))
|
||||
|
|
|
@ -3,8 +3,9 @@ from collections import defaultdict
|
|||
from typing import List, Any, DefaultDict
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MetaOps, TernaryOps, UOps, UOp
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.helpers import GRAPHPATH, DEBUG, GlobalCounters
|
||||
from tinygrad.helpers import GRAPHPATH, DEBUG, GlobalCounters, word_wrap
|
||||
from tinygrad.engine.lazy import LazyBuffer
|
||||
from tinygrad.viz.serve import uops_colors
|
||||
|
||||
with contextlib.suppress(ImportError): import networkx as nx
|
||||
|
||||
|
@ -70,12 +71,7 @@ def log_lazybuffer(lb:'LazyBuffer', scheduled=False):
|
|||
# realized but unseen?
|
||||
G.add_node(nm(lb), label=f'"{str(lb.base.realized)[5:-1].replace(" ", chr(10))}\nb:{nm(lb.realized)}"', style='filled', fillcolor="#f0c08080")
|
||||
|
||||
uops_colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.CONST: "#e0e0e0", UOps.VCONST: "#e0e0e0",
|
||||
UOps.DEFINE_GLOBAL: "#ffe0b0", UOps.DEFINE_LOCAL: "#ffe0d0", UOps.DEFINE_ACC: "#f0ffe0", UOps.REDUCE: "#C4A484",
|
||||
UOps.RANGE: "#c8a0e0", UOps.ASSIGN: "#e0ffc0", UOps.BARRIER: "#ff8080", UOps.IF: "#c8b0c0", UOps.SPECIAL: "#c0c0ff",
|
||||
UOps.WMMA: "#efefc0", UOps.VIEW: "#C8F9D4", UOps.REDUCE_AXIS: "#f58488"}
|
||||
graph_uops_cnt = 0
|
||||
def word_wrap(x, wrap=80): return x if len(x) <= wrap else (x[0:wrap] + "\n" + word_wrap(x[wrap:], wrap))
|
||||
def graph_uops(uops:List[UOp]):
|
||||
global graph_uops_cnt
|
||||
G = nx.DiGraph()
|
||||
|
|
|
@ -67,6 +67,7 @@ def get_child(obj, key):
|
|||
elif isinstance(obj, dict): obj = obj[k]
|
||||
else: obj = getattr(obj, k)
|
||||
return obj
|
||||
def word_wrap(x, wrap=80): return x if len(x) <= wrap else (x[0:wrap] + "\n" + word_wrap(x[wrap:], wrap))
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def to_function_name(s:str): return ''.join([c if c in (string.ascii_letters+string.digits+'_') else f'{ord(c):02X}' for c in ansistrip(s)])
|
||||
|
|
|
@ -99,6 +99,7 @@ def identity_element(op:BinaryOps, dt:DType): return dtypes.as_const({BinaryOps.
|
|||
class UOps(FastEnum):
|
||||
# uops that aren't rendered
|
||||
SINK = auto()
|
||||
CONTIGUOUS = auto()
|
||||
|
||||
# metaops
|
||||
CUSTOM = auto()
|
||||
|
|
|
@ -4,11 +4,15 @@ from http.server import HTTPServer, BaseHTTPRequestHandler
|
|||
from urllib.parse import parse_qs, urlparse
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any, Dict, List, Tuple, Optional
|
||||
from tinygrad.helpers import colored, getenv, to_function_name, tqdm
|
||||
from tinygrad.helpers import colored, getenv, to_function_name, tqdm, word_wrap
|
||||
from tinygrad.ops import TrackedRewriteContext, UOp, UOps, lines
|
||||
from tinygrad.engine.graph import word_wrap, uops_colors
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
|
||||
uops_colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.CONST: "#e0e0e0", UOps.VCONST: "#e0e0e0",
|
||||
UOps.DEFINE_GLOBAL: "#ffe0b0", UOps.DEFINE_LOCAL: "#ffe0d0", UOps.DEFINE_ACC: "#f0ffe0", UOps.REDUCE: "#C4A484",
|
||||
UOps.RANGE: "#c8a0e0", UOps.ASSIGN: "#e0ffc0", UOps.BARRIER: "#ff8080", UOps.IF: "#c8b0c0", UOps.SPECIAL: "#c0c0ff",
|
||||
UOps.WMMA: "#efefc0", UOps.VIEW: "#C8F9D4", UOps.REDUCE_AXIS: "#f58488"}
|
||||
|
||||
# ** API spec
|
||||
|
||||
@dataclass
|
||||
|
@ -112,7 +116,6 @@ class Handler(BaseHTTPRequestHandler):
|
|||
|
||||
# ** main loop
|
||||
|
||||
stop_reloader = threading.Event()
|
||||
def reloader():
|
||||
mtime = os.stat(__file__).st_mtime
|
||||
while not stop_reloader.is_set():
|
||||
|
@ -122,6 +125,7 @@ def reloader():
|
|||
time.sleep(0.1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
stop_reloader = threading.Event()
|
||||
multiprocessing.current_process().name = "VizProcess" # disallow opening of devices
|
||||
st = time.perf_counter()
|
||||
print("*** viz is starting")
|
||||
|
|
Loading…
Reference in New Issue