diff --git a/docs/abstractions2.py b/docs/abstractions2.py index c778b79c..a3e6b0f8 100644 --- a/docs/abstractions2.py +++ b/docs/abstractions2.py @@ -39,8 +39,7 @@ DEVICE = "CLANG" # NOTE: you can change this! import struct from tinygrad.dtype import PtrDType, dtypes from tinygrad.device import Buffer, Device -from tinygrad.ops import BinaryOps, MetaOps -from tinygrad.codegen.uops import UOp, UOps +from tinygrad.ops import BinaryOps, MetaOps, UOp, UOps from tinygrad.shape.shapetracker import ShapeTracker # allocate some buffers + load in values diff --git a/examples/handcode_opt.py b/examples/handcode_opt.py index ebd7182a..3d6ab580 100644 --- a/examples/handcode_opt.py +++ b/examples/handcode_opt.py @@ -4,7 +4,7 @@ from extra.mcts_search import mcts_search from examples.mlperf.helpers import get_mlperf_bert_model from tinygrad import Tensor, Device, dtypes, nn from tinygrad.codegen.kernel import Kernel -from tinygrad.codegen.uops import UOps +from tinygrad.ops import UOps from tinygrad.device import Compiled from tinygrad.engine.schedule import create_schedule from tinygrad.engine.search import time_linearizer, beam_search, bufs_from_lin diff --git a/examples/openpilot/compile2.py b/examples/openpilot/compile2.py index 01e9f69f..31615b3c 100644 --- a/examples/openpilot/compile2.py +++ b/examples/openpilot/compile2.py @@ -18,7 +18,7 @@ from tinygrad.device import Buffer from tinygrad.helpers import partition, Context, fetch, getenv, DEBUG, tqdm from tinygrad.engine.realize import run_schedule, lower_schedule, ExecItem, CompiledRunner, memory_planner from tinygrad.engine.schedule import ScheduleItem, create_schedule -from tinygrad.codegen.uops import UOps +from tinygrad.ops import UOps from tinygrad.tensor import _to_np_dtype Device.DEFAULT = "GPU" diff --git a/extra/ops.py b/extra/ops.py index c9f58a43..b9680634 100644 --- a/extra/ops.py +++ b/extra/ops.py @@ -3,9 +3,8 @@ from typing import Dict, Union, Tuple, Any, List import functools, hashlib from enum import Enum, auto from dataclasses import dataclass -from tinygrad.codegen.uops import UOp, UOps from tinygrad.helpers import dedup, pretty_print, prod -from tinygrad.ops import ReduceOps, UnaryOps, BinaryOps, TernaryOps, reduce_st +from tinygrad.ops import ReduceOps, UnaryOps, BinaryOps, TernaryOps, reduce_st, UOp, UOps from tinygrad.dtype import ImageDType, PtrDType, dtypes, DType, ConstType from tinygrad.shape.symbolic import Variable, sint from tinygrad.shape.shapetracker import ShapeTracker diff --git a/test/external/external_benchmark_schedule.py b/test/external/external_benchmark_schedule.py index 73c1a2fa..607669b3 100644 --- a/test/external/external_benchmark_schedule.py +++ b/test/external/external_benchmark_schedule.py @@ -1,7 +1,7 @@ from extra.models.resnet import ResNet50 from tinygrad import Tensor -from tinygrad.codegen.uops import UOps from tinygrad.helpers import Profiling, Timing, getenv +from tinygrad.ops import UOps from tinygrad.codegen.kernel import Kernel if __name__ == "__main__": diff --git a/test/external/fuzz_linearizer.py b/test/external/fuzz_linearizer.py index 2a869598..c4ed2d4f 100644 --- a/test/external/fuzz_linearizer.py +++ b/test/external/fuzz_linearizer.py @@ -7,12 +7,11 @@ from extra.optimization.helpers import load_worlds, ast_str_to_lin, kern_str_to_ from tinygrad import Tensor, Device, dtypes from tinygrad.tensor import _to_np_dtype from tinygrad.codegen.kernel import Kernel -from tinygrad.codegen.uops import UOp, UOps from tinygrad.codegen.kernel import Opt, OptOps from tinygrad.engine.search import get_kernel_actions, bufs_from_lin from tinygrad.engine.realize import CompiledRunner from tinygrad.helpers import getenv, from_mv, prod, colored, Context, DEBUG -from tinygrad.ops import UnaryOps +from tinygrad.ops import UnaryOps, UOp, UOps from test.helpers import is_dtype_supported def tuplize_uops(uops:List[UOp]) -> Tuple: diff --git a/test/external/fuzz_uops.py b/test/external/fuzz_uops.py index 4daeb6cc..9e780bed 100644 --- a/test/external/fuzz_uops.py +++ b/test/external/fuzz_uops.py @@ -3,7 +3,7 @@ from collections import defaultdict import numpy as np from dataclasses import replace from typing import DefaultDict, Dict, List, Tuple -from tinygrad.codegen.uops import END_FOR_UOP, UOp +from tinygrad.ops import END_FOR_UOP, UOp from tinygrad.codegen.uopgraph import UOpGraph from tinygrad.device import Buffer, Device from tinygrad.engine.realize import CompiledRunner diff --git a/test/helpers.py b/test/helpers.py index b90b9844..9d1b051d 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -2,7 +2,7 @@ import sys, unittest from typing import Optional, Set, Tuple import numpy as np from tinygrad import Tensor, Device, dtypes -from tinygrad.codegen.uops import UOp +from tinygrad.ops import UOp from tinygrad.tensor import _to_np_dtype from tinygrad.engine.realize import Runner from tinygrad.dtype import DType diff --git a/test/test_const_folding.py b/test/test_const_folding.py index 0925f813..30451c20 100644 --- a/test/test_const_folding.py +++ b/test/test_const_folding.py @@ -1,6 +1,6 @@ import unittest, math from tinygrad import Tensor, Device, dtypes -from tinygrad.codegen.uops import UOps +from tinygrad.ops import UOps from tinygrad.engine.schedule import create_schedule from tinygrad.helpers import CI import numpy as np diff --git a/test/test_conv_shapetracker.py b/test/test_conv_shapetracker.py index f1eff3da..159c6d05 100644 --- a/test/test_conv_shapetracker.py +++ b/test/test_conv_shapetracker.py @@ -1,6 +1,6 @@ #!/usr/bin/env python import unittest -from tinygrad.codegen.uops import UOps +from tinygrad.ops import UOps from tinygrad.tensor import Tensor from tinygrad.nn import Conv2d from tinygrad.engine.schedule import create_schedule diff --git a/test/test_dtype_alu.py b/test/test_dtype_alu.py index 533347d9..0f3d2062 100644 --- a/test/test_dtype_alu.py +++ b/test/test_dtype_alu.py @@ -4,12 +4,11 @@ from tinygrad import Tensor, dtypes, Device import operator import numpy as np from hypothesis import given, strategies as strat, settings -from tinygrad.codegen.uops import UOps from tinygrad.dtype import DType from tinygrad.helpers import CI, getenv from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import run_schedule -from tinygrad.ops import UnaryOps +from tinygrad.ops import UnaryOps, UOps from tinygrad.tensor import _to_np_dtype from test.helpers import is_dtype_supported diff --git a/test/test_lazybuffer.py b/test/test_lazybuffer.py index 4c7b72dd..2d025d33 100644 --- a/test/test_lazybuffer.py +++ b/test/test_lazybuffer.py @@ -2,7 +2,7 @@ import numpy as np import unittest from tinygrad import Tensor, Device, dtypes -from tinygrad.codegen.uops import UOps +from tinygrad.ops import UOps from tinygrad.lazy import LazyBuffer, MetaOps from tinygrad.engine.schedule import create_schedule diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 71aa1ae4..d1d6eb71 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -5,7 +5,7 @@ from dataclasses import replace from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError, Kernel from tinygrad.codegen.lowerer import get_grouped_dims -from tinygrad.codegen.uops import UOp, UOps +from tinygrad.ops import UOp, UOps from tinygrad.device import Device, Buffer from extra.ops import BinaryOps, BufferOps, MemBuffer, ConstBuffer, LazyOp, MetaOps, TernaryOps, ReduceOps, UnaryOps, to_uop from tinygrad.shape.shapetracker import ShapeTracker diff --git a/test/test_linearizer_dumb.py b/test/test_linearizer_dumb.py index 636bcbcb..f2ead072 100644 --- a/test/test_linearizer_dumb.py +++ b/test/test_linearizer_dumb.py @@ -4,9 +4,9 @@ import unittest from tinygrad import Device, dtypes -from tinygrad.codegen.uops import UOps +from tinygrad.ops import UOps from tinygrad.helpers import getenv -from extra.ops import LazyOp, BinaryOps, UnaryOps, ReduceOps, TernaryOps, BufferOps, MemBuffer, ConstBuffer, MetaOps # noqa: F401 # pylint: disable=unused-import +from extra.ops import LazyOp, BinaryOps, UnaryOps, ReduceOps, TernaryOps, BufferOps, MemBuffer, ConstBuffer, MetaOps from tinygrad.shape.shapetracker import ShapeTracker, View from tinygrad.engine.search import Opt, OptOps from tinygrad.codegen.kernel import Kernel diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index 81680b71..58393ffc 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -3,7 +3,7 @@ import unittest, random import numpy as np from tinygrad.codegen.kernel import KernelOptError from tinygrad.codegen.kernel import Kernel -from tinygrad.codegen.uops import UOps +from tinygrad.ops import UOps from tinygrad.engine.search import Opt, OptOps from tinygrad import Device, dtypes, Tensor from tinygrad.helpers import CI diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 3bf668d2..a1d0859b 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -1,8 +1,7 @@ import unittest, functools, random from typing import List from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes -from tinygrad.codegen.uops import UOps -from tinygrad.ops import MetaOps, ReduceOps, BinaryOps +from tinygrad.ops import MetaOps, ReduceOps, BinaryOps, UOps from tinygrad.helpers import CI, getenv, prod, Context from tinygrad.nn.state import get_parameters, get_state_dict from tinygrad.engine.schedule import create_schedule diff --git a/test/test_nn.py b/test/test_nn.py index 6b9b4550..a82a5e20 100755 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -3,7 +3,7 @@ import unittest import numpy as np import torch from tinygrad import Tensor, Device, TinyJit -from tinygrad.codegen.uops import UOps +from tinygrad.ops import UOps from tinygrad.helpers import CI, Context from tinygrad.nn import Conv1d, ConvTranspose1d, Conv2d, ConvTranspose2d, Linear, Embedding from tinygrad.nn import BatchNorm, LayerNorm, LayerNorm2d, GroupNorm, InstanceNorm, RMSNorm, LSTMCell diff --git a/test/test_pattern_matcher.py b/test/test_pattern_matcher.py index 216cf38c..845801f6 100644 --- a/test/test_pattern_matcher.py +++ b/test/test_pattern_matcher.py @@ -1,8 +1,7 @@ import unittest, itertools from test.helpers import TestUOps from tinygrad.dtype import dtypes -from tinygrad.ops import BinaryOps, TernaryOps, ReduceOps, UnaryOps # noqa: F401 -from tinygrad.codegen.uops import UOps, UOp, PatternMatcher, UPat +from tinygrad.ops import UOps, UOp, PatternMatcher, UPat, BinaryOps, TernaryOps, ReduceOps, UnaryOps # noqa: F401 from tinygrad.codegen.uopgraph import constant_folder class TestPatternMatcher(TestUOps): diff --git a/test/test_renderer_failures.py b/test/test_renderer_failures.py index 04044eb9..76be7195 100644 --- a/test/test_renderer_failures.py +++ b/test/test_renderer_failures.py @@ -1,13 +1,12 @@ import unittest from typing import List, cast import numpy as np -from tinygrad.codegen.uops import UOp, UOps from tinygrad.device import Buffer, Device from tinygrad.dtype import PtrDType, DType, dtypes from tinygrad.engine.realize import CompiledRunner from tinygrad.helpers import dedup, flatten from tinygrad.renderer.cstyle import CStyleLanguage -from tinygrad.ops import BinaryOps +from tinygrad.ops import BinaryOps, UOp, UOps from tinygrad.renderer import Program from tinygrad.tensor import Tensor, _to_np_dtype from tinygrad.lazy import LazyBuffer diff --git a/test/test_schedule.py b/test/test_schedule.py index 289a468c..e989f69d 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -8,10 +8,9 @@ from typing import List, Optional, Union, cast from tinygrad import nn, dtypes from tinygrad.device import Device from tinygrad.tensor import Tensor -from tinygrad.ops import BinaryOps, MetaOps, UnaryOps +from tinygrad.ops import BinaryOps, MetaOps, UnaryOps, UOps, verify_ast from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, FUSE_CONV_BW, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP from tinygrad.codegen.kernel import Kernel -from tinygrad.codegen.uops import UOps, verify_ast from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import run_schedule from test.helpers import is_dtype_supported, Context diff --git a/test/test_search.py b/test/test_search.py index c13e5dcf..cb946fd5 100644 --- a/test/test_search.py +++ b/test/test_search.py @@ -2,7 +2,7 @@ import unittest from tinygrad.codegen.kernel import Opt, OptOps from tinygrad.codegen.kernel import Kernel -from tinygrad.codegen.uops import UOps +from tinygrad.ops import UOps from tinygrad.engine.schedule import create_schedule from tinygrad.engine.search import time_linearizer, bufs_from_lin, actions, beam_search from tinygrad.device import Device, Buffer diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 0989f3e2..fb8bfe18 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -3,8 +3,7 @@ from test.helpers import TestUOps from tinygrad import dtypes, Variable from tinygrad.dtype import PtrDType from tinygrad.helpers import DEBUG -from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, ReduceOps -from tinygrad.codegen.uops import UOps, UOp, NOp, PatternMatcher +from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, ReduceOps, UOps, UOp, NOp, PatternMatcher from tinygrad.codegen.uopgraph import UOpGraph, graph_rewrite, expander, reducer, constant_folder, float4_folding simple_pm = PatternMatcher([ diff --git a/test/test_uops.py b/test/test_uops.py index 5e70a0ef..7ceba9a8 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -5,11 +5,10 @@ from tinygrad.tensor import Tensor, _to_np_dtype from tinygrad.helpers import CI, DEBUG, getenv, Context from tinygrad.dtype import dtypes, DType, PtrDType from tinygrad.device import Buffer, Device -from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps, KernelInfo, exec_alu # noqa F401 +from tinygrad.ops import UOps, NOp, UOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, KernelInfo, exec_alu # noqa F401 from tinygrad.renderer import Program from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel -from tinygrad.codegen.uops import UOps, NOp, UOp from tinygrad.codegen.uopgraph import UOpGraph from test.helpers import is_dtype_supported, TestUOps as TestEqUOps diff --git a/test/test_uops_stats.py b/test/test_uops_stats.py index 42aed760..00f44dcd 100644 --- a/test/test_uops_stats.py +++ b/test/test_uops_stats.py @@ -3,9 +3,8 @@ from tinygrad import Tensor from tinygrad.helpers import getenv from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import lower_schedule_item -from tinygrad.codegen.uops import flops_mem, UOps, UOp from tinygrad.codegen.uopgraph import UOpGraph -from tinygrad.ops import BinaryOps, TernaryOps +from tinygrad.ops import BinaryOps, TernaryOps, flops_mem, UOps, UOp from tinygrad.dtype import dtypes from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError diff --git a/test/test_winograd.py b/test/test_winograd.py index 22ebb81e..f0b4dc22 100644 --- a/test/test_winograd.py +++ b/test/test_winograd.py @@ -1,6 +1,6 @@ import unittest from tinygrad import Tensor, GlobalCounters -from tinygrad.codegen.uops import UOps +from tinygrad.ops import UOps from tinygrad.helpers import Timing, CI, Profiling, WINO, DEBUG, getenv from tinygrad.codegen.kernel import Kernel from tinygrad.engine.schedule import create_schedule diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 55ba5aaa..91010627 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -9,9 +9,8 @@ from typing import Tuple from tinygrad.helpers import DEBUG from tinygrad.dtype import dtypes, PtrDType, ConstType -from tinygrad.codegen.uops import UOp, UOps from tinygrad.codegen.uopgraph import UOpGraph -from tinygrad.ops import BinaryOps +from tinygrad.ops import BinaryOps, UOp, UOps import functools def render(self) -> Tuple[str, ConstType, ConstType]: diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 019e73c6..895767eb 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -4,8 +4,7 @@ from dataclasses import dataclass, replace from collections import defaultdict from typing import Literal, Optional, List, Tuple, Union, cast, Dict, Final, DefaultDict -from tinygrad.codegen.uops import BUFFER_UOPS, UOp, UOps, verify_ast -from tinygrad.ops import BinaryOps, ReduceOps, UNSAFE_PAD_OPS, KernelInfo +from tinygrad.ops import BinaryOps, ReduceOps, UNSAFE_PAD_OPS, KernelInfo, BUFFER_UOPS, UOp, UOps, verify_ast from tinygrad.device import Device from tinygrad.renderer import Renderer, TensorCore, Program from tinygrad.dtype import DType, ImageDType, PtrDType diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index fda9a283..c265fb38 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -5,8 +5,7 @@ import functools from tinygrad.shape.shapetracker import ShapeTracker, View from tinygrad.shape.symbolic import sint from tinygrad.dtype import dtypes, DType -from tinygrad.ops import ReduceOps, KernelInfo, BinaryOps -from tinygrad.codegen.uops import BUFFER_UOPS, UOp, UOps +from tinygrad.ops import ReduceOps, KernelInfo, BinaryOps, BUFFER_UOPS, UOp, UOps from tinygrad.renderer import Renderer from tinygrad.helpers import all_int, get_contraction, prod, partition, flatten diff --git a/tinygrad/codegen/transcendental.py b/tinygrad/codegen/transcendental.py index 455cdea4..95ebb7c8 100644 --- a/tinygrad/codegen/transcendental.py +++ b/tinygrad/codegen/transcendental.py @@ -1,7 +1,7 @@ import math, functools from typing import Tuple, List from tinygrad.dtype import dtypes, DType -from tinygrad.codegen.uops import UOp +from tinygrad.ops import UOp TRANSCENDENTAL_SUPPORTED_DTYPES = {dtypes.float16, dtypes.float32, dtypes.float64} diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 34e14681..b1c2b525 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -3,9 +3,8 @@ from typing import Iterator, Optional, Tuple, Dict, List, Set, Union, cast, TYPE import functools, itertools, heapq, math, operator from collections import defaultdict from tinygrad.dtype import dtypes, PtrDType, ImageDType, DType -from tinygrad.ops import UnaryOps, BinaryOps, exec_alu +from tinygrad.ops import UnaryOps, BinaryOps, exec_alu, UOp, NOp, UOps, UPat, PatternMatcher, END_FOR_UOP, type_verify, print_uops from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, prod, CI, all_same, partition -from tinygrad.codegen.uops import UOp, NOp, UOps, UPat, PatternMatcher, END_FOR_UOP, type_verify, print_uops from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES if TYPE_CHECKING: from tinygrad.renderer import Renderer diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py deleted file mode 100644 index 038c46cd..00000000 --- a/tinygrad/codegen/uops.py +++ /dev/null @@ -1,338 +0,0 @@ -from __future__ import annotations -from typing import Optional, Tuple, Any, Set, cast, List, Union, DefaultDict, Callable, Dict -import functools, itertools, math, hashlib -from collections import defaultdict -from enum import Enum, auto -from dataclasses import dataclass -from tinygrad.dtype import ConstType, dtypes, DType -from tinygrad.shape.shapetracker import ShapeTracker -from tinygrad.shape.symbolic import sint, Variable -from tinygrad.ops import ReduceOps, UnaryOps, BinaryOps, TernaryOps, exec_alu, reduce_st -from tinygrad.helpers import merge_dicts, prod, pretty_print, dedup - -# the order of these UOps controls the order of the toposort -class UOps(Enum): - # ops that aren't rendered - SINK = auto(); EXT = auto(); EXPAND = auto(); CONTRACT = auto(); ST_IDX = auto(); ST_VALID = auto() # noqa: E702 - DEFINE_GLOBAL = auto(); DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # noqa: E702 - CONST = auto(); SPECIAL = auto() # noqa: E702 - NOOP = auto(); GEP = auto() # noqa: E702 - # math ops - CAST = auto(); BITCAST = auto(); VECTORIZE = auto() # noqa: E702 - ALU = auto(); REDUCE = auto(); REDUCE_AXIS = auto(); WMMA = auto() # noqa: E702 - # memory/assignment ops - LOAD = auto(); STORE = auto(); PHI = auto() # noqa: E702 - # control flow ops - BARRIER = auto(); IF = auto(); RANGE = auto() # noqa: E702 - # these two are not graph nodes - ENDRANGE = auto(); ENDIF = auto() # noqa: E702 - -BUFFER_UOPS = {UOps.LOAD, UOps.STORE, UOps.CONST} - -END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.PHI, UOps.ENDRANGE)} - -@dataclass(frozen=True, eq=False) -class UOp: - op: UOps - dtype: Optional[DType] = None - src: Tuple[UOp, ...] = tuple() - arg: Any = None - def commutative(self) -> bool: - return (self.op is UOps.ALU and \ - self.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR}) - @functools.cached_property - def cmp_tuple(self): - # NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX - return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg.expr) if self.op is not UOps.ALU else \ - self.arg.value, self.dtype, self.src) - def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple - @functools.cached_property - def key(self) -> bytes: - return hashlib.sha256(functools.reduce(lambda x,y: x+y, [s.key for s in self.src], str((self.op, self.dtype, self.arg)).encode())).digest() - def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))") - def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is UOps.REDUCE_AXIS else self.arg - # *** uop syntactic sugar - def ufix(self, x): return self.const(x) if not isinstance(x, UOp) else x - def cast(self, dtype=None): return type(self)(UOps.CAST, dtype, (self,)) - def bitcast(self, dtype=None): return type(self)(UOps.BITCAST, dtype, (self,)) - def gep(self, i:int): return type(self)(UOps.GEP, self.dtype.scalar() if self.dtype is not None else None, (self,), i) - def __neg__(self): return self.alu(UnaryOps.NEG) - def __add__(self, x): return self.alu(BinaryOps.ADD, self.ufix(x)) - def __radd__(self, x): return self.alu(BinaryOps.ADD, self.ufix(x)) - def __sub__(self, x): return self.alu(BinaryOps.ADD, self.ufix(-x)) - def __mul__(self, x): return self.alu(BinaryOps.MUL, self.ufix(x)) - def __rmul__(self, x): return self.ufix(x).alu(BinaryOps.MUL, self) - def __floordiv__(self, x): return self.alu(BinaryOps.IDIV, self.ufix(x)) - def __truediv__(self, x): return self.alu(BinaryOps.MUL, self.ufix(x).alu(UnaryOps.RECIP)) - def __mod__(self, x): return self.alu(BinaryOps.MOD, self.ufix(x)) - def __xor__(self, x): return self.alu(BinaryOps.XOR, self.ufix(x)) - def __and__(self, x): return self.alu(BinaryOps.AND, self.ufix(x)) - def __or__(self, x): return self.alu(BinaryOps.OR, self.ufix(x)) - def ne(self, x): return self.alu(BinaryOps.CMPNE, self.ufix(x)) - def eq(self, x): return -self.ne(x) - def lt(self, x): return self.alu(BinaryOps.CMPLT, self.ufix(x)) - def ge(self, x): return (-self).lt(-x+1) - def max(self, x): return self.alu(BinaryOps.MAX, x) - def min(self, x): return -(-self).max(-x) - def where(self, x, y): return self.alu(TernaryOps.WHERE, x, y) - def recip(self): return self.alu(UnaryOps.RECIP) - def const(self:Union[UOp, DType, None], b:ConstType|Variable): return UOp._const(self.dtype if isinstance(self, UOp) else self, b) - def sconst(self:Union[UOp, DType, None], b:ConstType|Variable): - return UOp._const(cast(DType, self.dtype if isinstance(self, UOp) else self).scalar() if self is not None else self, b) - @staticmethod - @functools.lru_cache(maxsize=None) - def _const(dtype:Optional[DType], b:ConstType|Variable): - # TODO: fix dtype of b.max after Variable is just an UOp - if isinstance(b, Variable): return UOp(UOps.DEFINE_VAR, dtype, (UOp.const(dtypes.int, b.min), UOp.const(dtypes.int, cast(int,b.max))), b) - if dtype is not None and dtype != (sdtype := dtype.scalar()): - return UOp(UOps.VECTORIZE, dtype, src=tuple(UOp(UOps.CONST, sdtype, arg=dtypes.as_const(b, sdtype)) for _ in range(dtype.count))) - return UOp(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) - def alu(self, arg, *src:UOp): - return type(self)(UOps.ALU, dtypes.bool if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else (self, *src)[-1].dtype, (self,)+src, arg) - @staticmethod - def load(*src:UOp, dtype:Optional[DType]=None, **kwargs): return type(src[0])(UOps.LOAD, dtype, tuple(src)+tuple(kwargs.values())) - @staticmethod - def store(*src:UOp, **kwargs): return type((src:=(*src, *kwargs.values()))[0])(UOps.STORE, None, src) - @functools.cached_property - def parents(self) -> Dict[UOp, None]: return merge_dicts([{x:None for x in self.src}]+[x.parents for x in self.src]) - @property # parents with self - def sparents(self) -> Dict[UOp, None]: return {**self.parents, self:None} - @staticmethod - def from_st(st:ShapeTracker) -> Tuple[UOp, UOp]: return UOp(UOps.ST_IDX, dtypes.pyint, (), st), UOp(UOps.ST_VALID, dtypes.bool, (), st) - @functools.cached_property - def full_shape(self) -> Tuple[sint, ...]: - if self.op in {UOps.ST_IDX, UOps.ST_VALID}: return self.arg.shape - # NOTE: UOps.DEFINE_GLOBAL and UOps.DEFINE_LOCAL don't have shape - return tuple(max(x) for x in zip(*[x.full_shape for x in self.src if x.op not in {UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL}])) - # TODO: these two should merge - def vars(self) -> Set[UOp]: return set([x for x in self.sparents if x.op is UOps.DEFINE_VAR]) - def variables(self) -> List[Variable]: - st_vars: List[Set[Variable]] = [x.src[-1].arg.vars() for x in self.sparents if x.op in BUFFER_UOPS] - return sorted(set.union(*st_vars, set([x.arg for x in self.sparents if x.op is UOps.DEFINE_VAR])), key=lambda v: v.expr) - def const_factor(self) -> int: - """largest known int that divides self""" - if self.op is UOps.CONST: return self.arg - if self.op is UOps.ALU: - if self.arg is BinaryOps.ADD: return math.gcd(self.src[0].const_factor(), self.src[0].const_factor()) - if self.arg is BinaryOps.MUL: return self.src[0].arg if self.src[0].op is UOps.CONST else self.src[1].arg if self.src[1].op is UOps.CONST else 1 - return 1 - def divides(self, v) -> Optional[UOp]: - if v==1: return self - if self.op is UOps.CONST: return self.const(self.arg//v) if self.arg%v == 0 else None - if self.op is UOps.ALU: - if self.arg is BinaryOps.ADD: return d0+d1 if (d0:=self.src[0].divides(v)) is not None and (d1:=self.src[1].divides(v)) is not None else None - if self.arg is BinaryOps.MUL: - if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1] - if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1 - return None # generic None if we aren't sure - @functools.cached_property - def vmin(self) -> UOp: return x if (x:=self._min_max[0]) is not None and not math.isnan(x.arg) else self.sconst(dtypes.min(cast(DType, self.dtype))) - @functools.cached_property - def vmax(self) -> UOp: return x if (x:=self._min_max[1]) is not None and not math.isnan(x.arg) else self.sconst(dtypes.max(cast(DType, self.dtype))) - @functools.cached_property - def _min_max(self) -> Tuple[Optional[UOp], Optional[UOp]]: - # NOTE: returned UOp is assumed to be CONST - if self.op is UOps.DEFINE_VAR and self.src: return self.src[0], self.src[1] if isinstance(self.src[1].arg, int) else None - if self.op is UOps.RANGE: return self.src[0], self.const(self.src[1].arg-1) if isinstance(self.src[1].arg, int) else None - # TODO: UOps.SPECIAL is UOps.DEFINE_VAR - if self.op is UOps.SPECIAL: return self.const(0), self.const(self.arg[1]-1) if isinstance(self.arg[1], int) else None - if self.op is UOps.CONST: return self, self - if self.op is UOps.ALU and cast(DType, self.dtype).count == 1: - s0,s1 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(2)] - if self.arg is UnaryOps.NEG and self.dtype != dtypes.bool and not dtypes.is_unsigned(cast(DType, self.dtype)): - return self.sconst(-s0.vmax.arg), self.sconst(-s0.vmin.arg) - if self.arg is BinaryOps.ADD: return self.sconst(s0.vmin.arg+s1.vmin.arg), self.sconst(s0.vmax.arg+s1.vmax.arg) - if self.arg is BinaryOps.MUL and (s0.vmin.arg >= 0 or s1.vmin.arg >= 0): - # handle at lease one is non-negative - Lmin, Lmax = (s0.vmin.arg, s0.vmax.arg) if s1.vmin.arg >= 0 else (s0.vmax.arg, s0.vmin.arg) - Rmin, Rmax = (s1.vmin.arg, s1.vmax.arg) if s0.vmin.arg >= 0 else (s1.vmax.arg, s1.vmin.arg) - assert math.isnan(Lmax*Rmax) or math.isnan(Lmin*Rmin) or Lmax*Rmax >= Lmin*Rmin, f"{Lmax=}, {Lmin=}, {Rmax=}, {Rmin=}" - return self.sconst(Lmin*Rmin), self.sconst(Lmax*Rmax) - if self.arg is BinaryOps.MOD and s1.op is UOps.CONST and s1.arg > 0: return self.sconst(0), self.sconst(s1.arg-1) - if self.arg is BinaryOps.IDIV and s1.op is UOps.CONST: - if s1.arg > 0: return self.sconst(s0.vmin.arg//s1.arg), self.sconst(s0.vmax.arg//s1.arg) - if s1.arg < 0: return self.sconst(-(s0.vmax.arg//-s1.arg)), self.sconst(-(s0.vmin.arg//-s1.arg)) - if self.arg is BinaryOps.MAX: return self.sconst(max(s0.vmin.arg, s1.vmin.arg)), self.sconst(max(s0.vmax.arg, s1.vmax.arg)) - if self.arg is BinaryOps.CMPLT: return (UOp.sconst(dtypes.bool, True), UOp.sconst(dtypes.bool, True)) if s0.vmax.arg < s1.vmin.arg else \ - (UOp.sconst(dtypes.bool, False), UOp.sconst(dtypes.bool, False)) if s0.vmin.arg >= s1.vmax.arg else (None, None) - return None, None - -@dataclass(frozen=True, repr=False) # reuse repr from UOp -class NOp(UOp): - name:Optional[str] = None - src:Tuple[NOp, ...] = tuple() - allow_any_len:bool = False - @staticmethod - def var(name:Optional[str]=None, dtype:Optional[DType]=None): return NOp(UOps.NOOP, dtype=dtype, name=name) - @staticmethod - def cvar(name:Optional[str]=None, dtype:Optional[DType]=None): return NOp(UOps.CONST, dtype=dtype, name=name) - def const(self:Union[UOp, DType, None], b:ConstType|Variable): return NOp((x:=UOp.const(self, b)).op, x.dtype, x.src, x.arg) - - def compile(self: NOp, name:Optional[str]=None) -> UPat: - return UPat(name=self.name, dtype=self.dtype) if self.op is UOps.NOOP else UPat(self.op, self.arg, (list if self.commutative() - else tuple)(src.compile() for src in self.src) or None, self.name or name, self.dtype, self.allow_any_len) - -class UPat: - def __init__(self, op:Optional[Union[UOps, Set[UOps]]]=None, arg:Any=None, src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None, - name:Optional[str]=None, dtype:Optional[Union[DType, Set[DType]]]=None, allow_any_len:bool=False): - self.op: Optional[Tuple[UOps, ...]] = None if op is None else (tuple(op) if isinstance(op, set) else (op,)) - self.dtype: Optional[Tuple[DType, ...]] = None if dtype is None else (tuple(dtype) if isinstance(dtype, set) else (dtype,)) - self.arg, self.name = arg, name - self.src: Any = None - # try all permutations if it's a list - if isinstance(src, list): self.src = list(itertools.permutations(src)) - # only one if it's a tuple - elif isinstance(src, tuple): self.src = [src] - # repeat if it's a UPat - elif isinstance(src, UPat): self.src = [itertools.repeat(src)] - - self.allowed_len: int = 0 if allow_any_len or isinstance(src, UPat) or src is None else len(src) - - def __repr__(self): - def rep(x): - form = "UPat(%s, %s, name=%s, dtype=%s, allow_any_len=%s, src=%s)" - return form % (None if x.op is None else ('(%s)'%', '.join(map(str, x.op))), x.arg, repr(x.name), - set(x.dtype) if x.dtype else None, x.allowed_len == 0, "[%s]" if x.src and len(x.src)>1 else "(%s)") - return pretty_print(self, rep, srcfn=lambda x:None if x.src is None else [next(x.src[0])] if isinstance(x.src[0], itertools.repeat) else x.src[0]) - -def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> List[Dict[str, UOp]]: - if (pat.name is not None and store.setdefault(pat.name, uop) is not uop) or \ - (pat.dtype is not None and uop.dtype not in pat.dtype) or \ - (pat.arg is not None and pat.arg != uop.arg) or \ - (pat.op is not None and uop.op not in pat.op): return [] - if pat.src is None: return [store] - res: List[Dict[str, UOp]] = [] - for vp in pat.src: - if pat.allowed_len != 0 and len(uop.src) != pat.allowed_len: return [] - new_stores = [store.copy()] - for uu, vv in zip(uop.src, vp): new_stores = [rstore for nstore in new_stores for rstore in _match(uu, vv, nstore)] - res.extend(new_stores) - return res - -class PatternMatcher: - def __init__(self, patterns:List[Tuple[Union[UPat, NOp], Callable]]): - self.patterns = patterns - self.pdict: DefaultDict[Tuple[UOps, Any], List[Tuple[UPat, Callable]]] = defaultdict(list) - # uop is required, arg is optional - for p,fxn in self.patterns: - if isinstance(p, NOp): p = p.compile() - assert p.op is not None - for uop in p.op: self.pdict[(uop, p.arg)].append((p, fxn)) - - @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none - def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns) - - def rewrite(self, uop:UOp) -> Optional[UOp]: - for p,fxn in itertools.chain(self.pdict[(uop.op, uop.arg)], self.pdict[(uop.op, None)]): - if (matches := _match(uop, p, {})) and (ret:=fxn(**matches[0])) is not None: return ret # NOTE: if it returns None, we keep trying to match - return None - -def type_verify(uops): - for u in uops: - uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype - if uop in {UOps.CONST, UOps.DEFINE_ACC}: - if uop is UOps.CONST: - assert dtype is not None and dtype == dtype.scalar(), f"consts must be scalar, got {dtype}" - assert type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}" - if uop is UOps.DEFINE_ACC: assert dtype is not None and src[0].dtype == dtype, f"dtype mismatch {src[0].dtype=} != {dtype=}" - if uop in {UOps.CAST, UOps.BITCAST, UOps.VECTORIZE}: assert arg is None and dtype is not None # type is the output type, not an arg - if uop is UOps.CAST: assert dtype.count == 1 and len(src) == 1 - if uop is UOps.VECTORIZE: - assert dtype.count > 1 and len(src) == dtype.count, f"dtype vectorization mismatch {dtype.count=} != {len(src)=}" - assert all(dtype == x.dtype.vec(len(src)) for x in src), f"{dtype=} must be {src[0].dtype.vec(len(src))}" - if uop is UOps.LOAD and len(src) > 3 and src[3].op is UOps.ALU: assert src[3].dtype == dtypes.bool and src[2].dtype == dtype - if uop is UOps.GEP: assert dtype == src[0].dtype.scalar(), f"GEP of {src[0].dtype=} should be {src[0].dtype.scalar()} != {dtype}" - if uop is UOps.STORE: - assert dtype is None, f"{uop} dtype must be None, got {dtype}" - if len(src) == 4: assert src[3].dtype == dtypes.bool, f"gate dtype mismatch {src[3].dtype} != {dtypes.bool}" - if uop is UOps.ALU: - if arg in UnaryOps: assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}" - elif arg in {BinaryOps.CMPLT, BinaryOps.CMPNE}: - assert dtype == (bd:=(dtypes.bool.vec(dtype.count) if dtype.count != 1 else dtypes.bool)), f"{arg} output dtype mismatch {dtype=} != {bd=}" - assert src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}" - elif arg is BinaryOps.IDIV: - assert dtypes.is_int(src[0].dtype) and dtypes.is_int(src[1].dtype), f"input dtype is not int {src[0].dtype=}, {src[1].dtype=}" - assert dtypes.is_int(dtype), f"output dtype is not int {dtype=}" - elif arg in {BinaryOps.SHL, BinaryOps.SHR}: - # the distance to shift isn't typechecked - assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}" - elif arg in BinaryOps: assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}" - elif arg == TernaryOps.WHERE: - assert src[0].dtype == (bd:=(dtypes.bool.vec(dtype.count) if dtype.count != 1 else dtypes.bool)), \ - f"{arg} selector dtype mismatch {src[0].dtype=} != {bd}" - assert dtype == src[1].dtype == src[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {src[1].dtype=} != {src[2].dtype=}" - -def uop_alu_resolve(u:UOp) -> sint: - if u.op in {UOps.CONST, UOps.DEFINE_VAR}: return u.arg - if u.op is UOps.ALU: return exec_alu(u.arg, cast(DType,u.dtype), tuple(map(uop_alu_resolve, u.src))) - raise RuntimeError(f"ALU resolve fail @ {u.op}") - -def print_uops(uops:List[UOp]): - for i,u in enumerate(uops): - formatted_parents = [uops.index(x) if x.op is not UOps.CONST else f"{x.arg}" for x in u.src] - print(f"{i:4d} {str(u.op):20s}: {str(u.dtype) if u.dtype is not None else '':25s} " f"{str(formatted_parents):32s} {u.arg}") - -def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]: - flops: sint = 0 - mem: sint = 0 - mults: sint = 1 - mult_stack: List[sint] = [] - dont_count: Set[UOp] = set() - if ignore_indexing: - for u in uops: - if u.op is UOps.LOAD: - dont_count = dont_count.union(u.src[1].sparents) - if len(u.src) > 3: dont_count = dont_count.union(u.src[2].sparents) - elif u.op is UOps.STORE: - dont_count = dont_count.union(u.src[1].sparents) - if len(u.src) > 3: dont_count = dont_count.union(u.src[3].sparents) - elif u.op is UOps.IF: - dont_count = dont_count.union(u.src[0].sparents) - for u in uops: - if u.op is UOps.RANGE: - mult_stack.append(mults) - mults *= uop_alu_resolve(u.src[1] - u.src[0]) - elif u.op is UOps.ENDRANGE: - mults = mult_stack.pop(-1) - elif u.op is UOps.SPECIAL: - mults *= u.arg[1] # NOTE: we don't push to the mult_stack here, you can't end these - elif u.op is UOps.LOAD: - assert u.dtype is not None - mem += u.dtype.itemsize * mults - elif u.op is UOps.STORE: - assert u.src[2].dtype is not None - mem += u.src[2].dtype.itemsize * mults - elif u.op is UOps.ALU and u not in dont_count: - assert u.dtype is not None - flops += (mults * (2 if u.arg == TernaryOps.MULACC else 1)) * u.dtype.count - elif u.op is UOps.WMMA and u not in dont_count: - assert u.arg[1] is not None - flops += 2 * prod(u.arg[1]) // u.arg[5] * mults - return flops, mem - -# the living definition of UOps.ST_IDX and UOps.ST_VALID -def verify_ast(ast:UOp) -> Dict[UOp, ShapeTracker]: - assert ast.op is UOps.SINK and all(x.op is UOps.STORE for x in ast.src), "must be SINK" - sts: Dict[UOp, ShapeTracker] = {} - def assert_valid(op:UOp, st:ShapeTracker): - if op in sts or op.op in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL}: return - # restore globals from the two stage reduce - if op.op is UOps.LOAD and op.src[0].op is UOps.DEFINE_LOCAL: - assert_valid(local_reduce:=op.src[1].src[2], op.src[-1].arg) - return sts.setdefault(op, sts[local_reduce]) - for x in op.src: assert_valid(x, st) - # only reduceop is allowed to change shape, limited to turning n to 1 - if op.op is UOps.REDUCE_AXIS: st = ShapeTracker.from_shape(reduce_st(sts[op.src[0]], op.arg[1][-1] if op.arg[0] is ReduceOps.WMMA else op.arg[1])) - else: - # movementops are pushed to the edges with ST_IDX, ST_VALID - # elementwise inherits shape - st = op.arg if op.op in {UOps.ST_IDX, UOps.ST_VALID} else sts[op.src[-1]] - for x in (op.src[1:] if op.op in BUFFER_UOPS else op.src): - if sts[x].shape != st.shape: - if prod(sts[x].shape) == prod(st.shape): raise AssertionError(f"found implicit reshape {x.op} {op.op} {sts[x].shape} != {st.shape}") - raise AssertionError(f"found implicit expand {x.op} {sts[x].shape} != {op.op} {st.shape} {prod(sts[x].shape)} != {prod(st.shape)}") - sts[op] = st - for out in ast.src: assert_valid(out, out.src[-1].arg) - shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])] - assert all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims), f"shapes must have either 1 or n in each dimension, {shape_dims}" - return sts diff --git a/tinygrad/engine/graph.py b/tinygrad/engine/graph.py index d0aa4bca..9fc94288 100644 --- a/tinygrad/engine/graph.py +++ b/tinygrad/engine/graph.py @@ -1,10 +1,9 @@ import os, atexit, functools, contextlib from collections import defaultdict from typing import List, Any, DefaultDict -from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MetaOps, TernaryOps +from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MetaOps, TernaryOps, UOps, UOp from tinygrad.device import Device from tinygrad.helpers import GRAPHPATH, DEBUG, GlobalCounters -from tinygrad.codegen.uops import UOps, UOp from tinygrad.shape.symbolic import NumNode from tinygrad.lazy import LazyBuffer diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index f82ab508..0a49a2ea 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -3,8 +3,7 @@ import time, pprint from collections import defaultdict from dataclasses import dataclass, replace from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, Context, TRACEMETA, dedup -from tinygrad.codegen.uops import UOps, UOp -from tinygrad.ops import MetaOps +from tinygrad.ops import MetaOps, UOps, UOp from tinygrad.dtype import dtypes from tinygrad.device import Device, Buffer from tinygrad.shape.symbolic import Variable, sym_infer, sint diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index e3df45db..d681238a 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -2,8 +2,7 @@ import sys, pickle, atexit, importlib, contextlib from collections import defaultdict, deque from dataclasses import dataclass, field from typing import Tuple, List, Dict, Optional, Set, DefaultDict, get_args -from tinygrad.codegen.uops import UOp, UOps -from tinygrad.ops import MetaOps, ReduceOps, UNSAFE_PAD_OPS, UnaryOps, reduce_st +from tinygrad.ops import MetaOps, ReduceOps, UNSAFE_PAD_OPS, UnaryOps, reduce_st, UOp, UOps from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, \ GlobalCounters, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata diff --git a/tinygrad/engine/search.py b/tinygrad/engine/search.py index 1aff8a2c..8d0201dc 100644 --- a/tinygrad/engine/search.py +++ b/tinygrad/engine/search.py @@ -2,7 +2,7 @@ from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable import itertools, functools, random, math, time, multiprocessing, traceback, signal from collections import defaultdict from dataclasses import replace -from tinygrad.codegen.uops import UOp, UOps +from tinygrad.ops import UOp, UOps from tinygrad.device import Device, Buffer, Compiler from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name from tinygrad.dtype import DType, ImageDType diff --git a/tinygrad/ops.py b/tinygrad/ops.py index fccd6f9c..2cbf1d3e 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -1,10 +1,12 @@ from __future__ import annotations -from typing import Union, Tuple, Dict, Callable -import math, operator, ctypes, struct +from collections import defaultdict +from typing import Any, DefaultDict, List, Optional, Set, Union, Tuple, Dict, Callable, cast +import math, operator, ctypes, struct, functools, hashlib, itertools from enum import Enum, auto from dataclasses import dataclass -from tinygrad.dtype import dtypes, DType -from tinygrad.shape.symbolic import sint +from tinygrad.dtype import ConstType, dtypes, DType +from tinygrad.helpers import dedup, merge_dicts, pretty_print, prod +from tinygrad.shape.symbolic import Variable, sint from tinygrad.shape.shapetracker import ShapeTracker # these are the llops your accelerator must implement, along with toCpu @@ -73,3 +75,329 @@ truncate: Dict[DType, Callable] = {dtypes.bool: bool, def exec_alu(op:Op, dtype:DType, operands): return truncate.get(dtype, lambda x: x)(python_alu[op](*operands)) def reduce_st(st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[sint, ...]: return tuple(1 if i in axis else s for i,s in enumerate(st.shape)) + +# the order of these UOps controls the order of the toposort +class UOps(Enum): + # ops that aren't rendered + SINK = auto(); EXT = auto(); EXPAND = auto(); CONTRACT = auto(); ST_IDX = auto(); ST_VALID = auto() # noqa: E702 + DEFINE_GLOBAL = auto(); DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # noqa: E702 + CONST = auto(); SPECIAL = auto() # noqa: E702 + NOOP = auto(); GEP = auto() # noqa: E702 + # math ops + CAST = auto(); BITCAST = auto(); VECTORIZE = auto() # noqa: E702 + ALU = auto(); REDUCE = auto(); REDUCE_AXIS = auto(); WMMA = auto() # noqa: E702 + # memory/assignment ops + LOAD = auto(); STORE = auto(); PHI = auto() # noqa: E702 + # control flow ops + BARRIER = auto(); IF = auto(); RANGE = auto() # noqa: E702 + # these two are not graph nodes + ENDRANGE = auto(); ENDIF = auto() # noqa: E702 + +BUFFER_UOPS = {UOps.LOAD, UOps.STORE, UOps.CONST} + +END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.PHI, UOps.ENDRANGE)} + +@dataclass(frozen=True, eq=False) +class UOp: + op: UOps + dtype: Optional[DType] = None + src: Tuple[UOp, ...] = tuple() + arg: Any = None + def commutative(self) -> bool: + return (self.op is UOps.ALU and \ + self.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR}) + @functools.cached_property + def cmp_tuple(self): + # NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX + return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg.expr) if self.op is not UOps.ALU else \ + self.arg.value, self.dtype, self.src) + def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple + @functools.cached_property + def key(self) -> bytes: + return hashlib.sha256(functools.reduce(lambda x,y: x+y, [s.key for s in self.src], str((self.op, self.dtype, self.arg)).encode())).digest() + def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))") + def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is UOps.REDUCE_AXIS else self.arg + # *** uop syntactic sugar + def ufix(self, x): return self.const(x) if not isinstance(x, UOp) else x + def cast(self, dtype=None): return type(self)(UOps.CAST, dtype, (self,)) + def bitcast(self, dtype=None): return type(self)(UOps.BITCAST, dtype, (self,)) + def gep(self, i:int): return type(self)(UOps.GEP, self.dtype.scalar() if self.dtype is not None else None, (self,), i) + def __neg__(self): return self.alu(UnaryOps.NEG) + def __add__(self, x): return self.alu(BinaryOps.ADD, self.ufix(x)) + def __radd__(self, x): return self.alu(BinaryOps.ADD, self.ufix(x)) + def __sub__(self, x): return self.alu(BinaryOps.ADD, self.ufix(-x)) + def __mul__(self, x): return self.alu(BinaryOps.MUL, self.ufix(x)) + def __rmul__(self, x): return self.ufix(x).alu(BinaryOps.MUL, self) + def __floordiv__(self, x): return self.alu(BinaryOps.IDIV, self.ufix(x)) + def __truediv__(self, x): return self.alu(BinaryOps.MUL, self.ufix(x).alu(UnaryOps.RECIP)) + def __mod__(self, x): return self.alu(BinaryOps.MOD, self.ufix(x)) + def __xor__(self, x): return self.alu(BinaryOps.XOR, self.ufix(x)) + def __and__(self, x): return self.alu(BinaryOps.AND, self.ufix(x)) + def __or__(self, x): return self.alu(BinaryOps.OR, self.ufix(x)) + def ne(self, x): return self.alu(BinaryOps.CMPNE, self.ufix(x)) + def eq(self, x): return -self.ne(x) + def lt(self, x): return self.alu(BinaryOps.CMPLT, self.ufix(x)) + def ge(self, x): return (-self).lt(-x+1) + def max(self, x): return self.alu(BinaryOps.MAX, x) + def min(self, x): return -(-self).max(-x) + def where(self, x, y): return self.alu(TernaryOps.WHERE, x, y) + def recip(self): return self.alu(UnaryOps.RECIP) + def const(self:Union[UOp, DType, None], b:ConstType|Variable): return UOp._const(self.dtype if isinstance(self, UOp) else self, b) + def sconst(self:Union[UOp, DType, None], b:ConstType|Variable): + return UOp._const(cast(DType, self.dtype if isinstance(self, UOp) else self).scalar() if self is not None else self, b) + @staticmethod + @functools.lru_cache(maxsize=None) + def _const(dtype:Optional[DType], b:ConstType|Variable): + # TODO: fix dtype of b.max after Variable is just an UOp + if isinstance(b, Variable): return UOp(UOps.DEFINE_VAR, dtype, (UOp.const(dtypes.int, b.min), UOp.const(dtypes.int, cast(int,b.max))), b) + if dtype is not None and dtype != (sdtype := dtype.scalar()): + return UOp(UOps.VECTORIZE, dtype, src=tuple(UOp(UOps.CONST, sdtype, arg=dtypes.as_const(b, sdtype)) for _ in range(dtype.count))) + return UOp(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) + def alu(self, arg, *src:UOp): + return type(self)(UOps.ALU, dtypes.bool if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else (self, *src)[-1].dtype, (self,)+src, arg) + @staticmethod + def load(*src:UOp, dtype:Optional[DType]=None, **kwargs): return type(src[0])(UOps.LOAD, dtype, tuple(src)+tuple(kwargs.values())) + @staticmethod + def store(*src:UOp, **kwargs): return type((src:=(*src, *kwargs.values()))[0])(UOps.STORE, None, src) + @functools.cached_property + def parents(self) -> Dict[UOp, None]: return merge_dicts([{x:None for x in self.src}]+[x.parents for x in self.src]) + @property # parents with self + def sparents(self) -> Dict[UOp, None]: return {**self.parents, self:None} + @staticmethod + def from_st(st:ShapeTracker) -> Tuple[UOp, UOp]: return UOp(UOps.ST_IDX, dtypes.pyint, (), st), UOp(UOps.ST_VALID, dtypes.bool, (), st) + @functools.cached_property + def full_shape(self) -> Tuple[sint, ...]: + if self.op in {UOps.ST_IDX, UOps.ST_VALID}: return self.arg.shape + # NOTE: UOps.DEFINE_GLOBAL and UOps.DEFINE_LOCAL don't have shape + return tuple(max(x) for x in zip(*[x.full_shape for x in self.src if x.op not in {UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL}])) + def vars(self) -> Set[UOp]: return set([x for x in self.sparents if x.op is UOps.DEFINE_VAR]) + def variables(self) -> List[Variable]: + st_vars: List[Set[Variable]] = [x.src[-1].arg.vars() for x in self.sparents if x.op in BUFFER_UOPS] + return sorted(set.union(*st_vars, set([x.arg for x in self.sparents if x.op is UOps.DEFINE_VAR])), key=lambda v: v.expr) + def const_factor(self) -> int: + """largest known int that divides self""" + if self.op is UOps.CONST: return self.arg + if self.op is UOps.ALU: + if self.arg is BinaryOps.ADD: return math.gcd(self.src[0].const_factor(), self.src[0].const_factor()) + if self.arg is BinaryOps.MUL: return self.src[0].arg if self.src[0].op is UOps.CONST else self.src[1].arg if self.src[1].op is UOps.CONST else 1 + return 1 + def divides(self, v) -> Optional[UOp]: + if v==1: return self + if self.op is UOps.CONST: return self.const(self.arg//v) if self.arg%v == 0 else None + if self.op is UOps.ALU: + if self.arg is BinaryOps.ADD: return d0+d1 if (d0:=self.src[0].divides(v)) is not None and (d1:=self.src[1].divides(v)) is not None else None + if self.arg is BinaryOps.MUL: + if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1] + if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1 + return None # generic None if we aren't sure + @functools.cached_property + def vmin(self) -> UOp: return x if (x:=self._min_max[0]) is not None and not math.isnan(x.arg) else self.sconst(dtypes.min(cast(DType, self.dtype))) + @functools.cached_property + def vmax(self) -> UOp: return x if (x:=self._min_max[1]) is not None and not math.isnan(x.arg) else self.sconst(dtypes.max(cast(DType, self.dtype))) + @functools.cached_property + def _min_max(self) -> Tuple[Optional[UOp], Optional[UOp]]: + # NOTE: returned UOp is assumed to be CONST + if self.op is UOps.DEFINE_VAR and self.src: return self.src[0], self.src[1] if isinstance(self.src[1].arg, int) else None + if self.op is UOps.RANGE: return self.src[0], self.const(self.src[1].arg-1) if isinstance(self.src[1].arg, int) else None + # TODO: UOps.SPECIAL is UOps.DEFINE_VAR + if self.op is UOps.SPECIAL: return self.const(0), self.const(self.arg[1]-1) if isinstance(self.arg[1], int) else None + if self.op is UOps.CONST: return self, self + if self.op is UOps.ALU and cast(DType, self.dtype).count == 1: + s0,s1 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(2)] + if self.arg is UnaryOps.NEG and self.dtype != dtypes.bool and not dtypes.is_unsigned(cast(DType, self.dtype)): + return self.sconst(-s0.vmax.arg), self.sconst(-s0.vmin.arg) + if self.arg is BinaryOps.ADD: return self.sconst(s0.vmin.arg+s1.vmin.arg), self.sconst(s0.vmax.arg+s1.vmax.arg) + if self.arg is BinaryOps.MUL and (s0.vmin.arg >= 0 or s1.vmin.arg >= 0): + # handle at lease one is non-negative + Lmin, Lmax = (s0.vmin.arg, s0.vmax.arg) if s1.vmin.arg >= 0 else (s0.vmax.arg, s0.vmin.arg) + Rmin, Rmax = (s1.vmin.arg, s1.vmax.arg) if s0.vmin.arg >= 0 else (s1.vmax.arg, s1.vmin.arg) + assert math.isnan(Lmax*Rmax) or math.isnan(Lmin*Rmin) or Lmax*Rmax >= Lmin*Rmin, f"{Lmax=}, {Lmin=}, {Rmax=}, {Rmin=}" + return self.sconst(Lmin*Rmin), self.sconst(Lmax*Rmax) + if self.arg is BinaryOps.MOD and s1.op is UOps.CONST and s1.arg > 0: return self.sconst(0), self.sconst(s1.arg-1) + if self.arg is BinaryOps.IDIV and s1.op is UOps.CONST: + if s1.arg > 0: return self.sconst(s0.vmin.arg//s1.arg), self.sconst(s0.vmax.arg//s1.arg) + if s1.arg < 0: return self.sconst(-(s0.vmax.arg//-s1.arg)), self.sconst(-(s0.vmin.arg//-s1.arg)) + if self.arg is BinaryOps.MAX: return self.sconst(max(s0.vmin.arg, s1.vmin.arg)), self.sconst(max(s0.vmax.arg, s1.vmax.arg)) + if self.arg is BinaryOps.CMPLT: return (UOp.sconst(dtypes.bool, True), UOp.sconst(dtypes.bool, True)) if s0.vmax.arg < s1.vmin.arg else \ + (UOp.sconst(dtypes.bool, False), UOp.sconst(dtypes.bool, False)) if s0.vmin.arg >= s1.vmax.arg else (None, None) + return None, None + +@dataclass(frozen=True, repr=False) # reuse repr from UOp +class NOp(UOp): + name:Optional[str] = None + src:Tuple[NOp, ...] = tuple() + allow_any_len:bool = False + @staticmethod + def var(name:Optional[str]=None, dtype:Optional[DType]=None): return NOp(UOps.NOOP, dtype=dtype, name=name) + @staticmethod + def cvar(name:Optional[str]=None, dtype:Optional[DType]=None): return NOp(UOps.CONST, dtype=dtype, name=name) + def const(self:Union[UOp, DType, None], b:ConstType|Variable): return NOp((x:=UOp.const(self, b)).op, x.dtype, x.src, x.arg) + + def compile(self: NOp, name:Optional[str]=None) -> UPat: + return UPat(name=self.name, dtype=self.dtype) if self.op is UOps.NOOP else UPat(self.op, self.arg, (list if self.commutative() + else tuple)(src.compile() for src in self.src) or None, self.name or name, self.dtype, self.allow_any_len) + +class UPat: + def __init__(self, op:Optional[Union[UOps, Set[UOps]]]=None, arg:Any=None, src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None, + name:Optional[str]=None, dtype:Optional[Union[DType, Set[DType]]]=None, allow_any_len:bool=False): + self.op: Optional[Tuple[UOps, ...]] = None if op is None else (tuple(op) if isinstance(op, set) else (op,)) + self.dtype: Optional[Tuple[DType, ...]] = None if dtype is None else (tuple(dtype) if isinstance(dtype, set) else (dtype,)) + self.arg, self.name = arg, name + self.src: Any = None + # try all permutations if it's a list + if isinstance(src, list): self.src = list(itertools.permutations(src)) + # only one if it's a tuple + elif isinstance(src, tuple): self.src = [src] + # repeat if it's a UPat + elif isinstance(src, UPat): self.src = [itertools.repeat(src)] + + self.allowed_len: int = 0 if allow_any_len or isinstance(src, UPat) or src is None else len(src) + + def __repr__(self): + def rep(x): + form = "UPat(%s, %s, name=%s, dtype=%s, allow_any_len=%s, src=%s)" + return form % (None if x.op is None else ('(%s)'%', '.join(map(str, x.op))), x.arg, repr(x.name), + set(x.dtype) if x.dtype else None, x.allowed_len == 0, "[%s]" if x.src and len(x.src)>1 else "(%s)") + return pretty_print(self, rep, srcfn=lambda x:None if x.src is None else [next(x.src[0])] if isinstance(x.src[0], itertools.repeat) else x.src[0]) + +def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> List[Dict[str, UOp]]: + if (pat.name is not None and store.setdefault(pat.name, uop) is not uop) or \ + (pat.dtype is not None and uop.dtype not in pat.dtype) or \ + (pat.arg is not None and pat.arg != uop.arg) or \ + (pat.op is not None and uop.op not in pat.op): return [] + if pat.src is None: return [store] + res: List[Dict[str, UOp]] = [] + for vp in pat.src: + if pat.allowed_len != 0 and len(uop.src) != pat.allowed_len: return [] + new_stores = [store.copy()] + for uu, vv in zip(uop.src, vp): new_stores = [rstore for nstore in new_stores for rstore in _match(uu, vv, nstore)] + res.extend(new_stores) + return res + +class PatternMatcher: + def __init__(self, patterns:List[Tuple[Union[UPat, NOp], Callable]]): + self.patterns = patterns + self.pdict: DefaultDict[Tuple[UOps, Any], List[Tuple[UPat, Callable]]] = defaultdict(list) + # uop is required, arg is optional + for p,fxn in self.patterns: + if isinstance(p, NOp): p = p.compile() + assert p.op is not None + for uop in p.op: self.pdict[(uop, p.arg)].append((p, fxn)) + + @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none + def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns) + + def rewrite(self, uop:UOp) -> Optional[UOp]: + for p,fxn in itertools.chain(self.pdict[(uop.op, uop.arg)], self.pdict[(uop.op, None)]): + if (matches := _match(uop, p, {})) and (ret:=fxn(**matches[0])) is not None: return ret # NOTE: if it returns None, we keep trying to match + return None + +def type_verify(uops): + for u in uops: + uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype + if uop in {UOps.CONST, UOps.DEFINE_ACC}: + if uop is UOps.CONST: + assert dtype is not None and dtype == dtype.scalar(), f"consts must be scalar, got {dtype}" + assert type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}" + if uop is UOps.DEFINE_ACC: assert dtype is not None and src[0].dtype == dtype, f"dtype mismatch {src[0].dtype=} != {dtype=}" + if uop in {UOps.CAST, UOps.BITCAST, UOps.VECTORIZE}: assert arg is None and dtype is not None # type is the output type, not an arg + if uop is UOps.CAST: assert dtype.count == 1 and len(src) == 1 + if uop is UOps.VECTORIZE: + assert dtype.count > 1 and len(src) == dtype.count, f"dtype vectorization mismatch {dtype.count=} != {len(src)=}" + assert all(dtype == x.dtype.vec(len(src)) for x in src), f"{dtype=} must be {src[0].dtype.vec(len(src))}" + if uop is UOps.LOAD and len(src) > 3 and src[3].op is UOps.ALU: assert src[3].dtype == dtypes.bool and src[2].dtype == dtype + if uop is UOps.GEP: assert dtype == src[0].dtype.scalar(), f"GEP of {src[0].dtype=} should be {src[0].dtype.scalar()} != {dtype}" + if uop is UOps.STORE: + assert dtype is None, f"{uop} dtype must be None, got {dtype}" + if len(src) == 4: assert src[3].dtype == dtypes.bool, f"gate dtype mismatch {src[3].dtype} != {dtypes.bool}" + if uop is UOps.ALU: + if arg in UnaryOps: assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}" + elif arg in {BinaryOps.CMPLT, BinaryOps.CMPNE}: + assert dtype == (bd:=(dtypes.bool.vec(dtype.count) if dtype.count != 1 else dtypes.bool)), f"{arg} output dtype mismatch {dtype=} != {bd=}" + assert src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}" + elif arg is BinaryOps.IDIV: + assert dtypes.is_int(src[0].dtype) and dtypes.is_int(src[1].dtype), f"input dtype is not int {src[0].dtype=}, {src[1].dtype=}" + assert dtypes.is_int(dtype), f"output dtype is not int {dtype=}" + elif arg in {BinaryOps.SHL, BinaryOps.SHR}: + # the distance to shift isn't typechecked + assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}" + elif arg in BinaryOps: assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}" + elif arg == TernaryOps.WHERE: + assert src[0].dtype == (bd:=(dtypes.bool.vec(dtype.count) if dtype.count != 1 else dtypes.bool)), \ + f"{arg} selector dtype mismatch {src[0].dtype=} != {bd}" + assert dtype == src[1].dtype == src[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {src[1].dtype=} != {src[2].dtype=}" + +def uop_alu_resolve(u:UOp) -> sint: + if u.op in {UOps.CONST, UOps.DEFINE_VAR}: return u.arg + if u.op is UOps.ALU: return exec_alu(u.arg, cast(DType,u.dtype), tuple(map(uop_alu_resolve, u.src))) + raise RuntimeError(f"ALU resolve fail @ {u.op}") + +def print_uops(uops:List[UOp]): + for i,u in enumerate(uops): + formatted_parents = [uops.index(x) if x.op is not UOps.CONST else f"{x.arg}" for x in u.src] + print(f"{i:4d} {str(u.op):20s}: {str(u.dtype) if u.dtype is not None else '':25s} " f"{str(formatted_parents):32s} {u.arg}") + +def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]: + flops: sint = 0 + mem: sint = 0 + mults: sint = 1 + mult_stack: List[sint] = [] + dont_count: Set[UOp] = set() + if ignore_indexing: + for u in uops: + if u.op is UOps.LOAD: + dont_count = dont_count.union(u.src[1].sparents) + if len(u.src) > 3: dont_count = dont_count.union(u.src[2].sparents) + elif u.op is UOps.STORE: + dont_count = dont_count.union(u.src[1].sparents) + if len(u.src) > 3: dont_count = dont_count.union(u.src[3].sparents) + elif u.op is UOps.IF: + dont_count = dont_count.union(u.src[0].sparents) + for u in uops: + if u.op is UOps.RANGE: + mult_stack.append(mults) + mults *= uop_alu_resolve(u.src[1] - u.src[0]) + elif u.op is UOps.ENDRANGE: + mults = mult_stack.pop(-1) + elif u.op is UOps.SPECIAL: + mults *= u.arg[1] # NOTE: we don't push to the mult_stack here, you can't end these + elif u.op is UOps.LOAD: + assert u.dtype is not None + mem += u.dtype.itemsize * mults + elif u.op is UOps.STORE: + assert u.src[2].dtype is not None + mem += u.src[2].dtype.itemsize * mults + elif u.op is UOps.ALU and u not in dont_count: + assert u.dtype is not None + flops += (mults * (2 if u.arg == TernaryOps.MULACC else 1)) * u.dtype.count + elif u.op is UOps.WMMA and u not in dont_count: + assert u.arg[1] is not None + flops += 2 * prod(u.arg[1]) // 32 * mults + return flops, mem + +# the living definition of UOps.ST_IDX and UOps.ST_VALID +def verify_ast(ast:UOp) -> Dict[UOp, ShapeTracker]: + assert ast.op is UOps.SINK and all(x.op is UOps.STORE for x in ast.src), "must be SINK" + sts: Dict[UOp, ShapeTracker] = {} + def assert_valid(op:UOp, st:ShapeTracker): + if op in sts or op.op in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL}: return + # restore globals from the two stage reduce + if op.op is UOps.LOAD and op.src[0].op is UOps.DEFINE_LOCAL: + assert_valid(local_reduce:=op.src[1].src[2], op.src[-1].arg) + return sts.setdefault(op, sts[local_reduce]) + for x in op.src: assert_valid(x, st) + # only reduceop is allowed to change shape, limited to turning n to 1 + if op.op is UOps.REDUCE_AXIS: st = ShapeTracker.from_shape(reduce_st(sts[op.src[0]], op.arg[1][-1] if op.arg[0] is ReduceOps.WMMA else op.arg[1])) + else: + # movementops are pushed to the edges with ST_IDX, ST_VALID + # elementwise inherits shape + st = op.arg if op.op in {UOps.ST_IDX, UOps.ST_VALID} else sts[op.src[-1]] + for x in (op.src[1:] if op.op in BUFFER_UOPS else op.src): + if sts[x].shape != st.shape: + if prod(sts[x].shape) == prod(st.shape): raise AssertionError(f"found implicit reshape {x.op} {op.op} {sts[x].shape} != {st.shape}") + raise AssertionError(f"found implicit expand {x.op} {sts[x].shape} != {op.op} {st.shape} {prod(sts[x].shape)} != {prod(st.shape)}") + sts[op] = st + for out in ast.src: assert_valid(out, out.src[-1].arg) + shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])] + assert all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims), f"shapes must have either 1 or n in each dimension, {shape_dims}" + return sts diff --git a/tinygrad/renderer/__init__.py b/tinygrad/renderer/__init__.py index d622dc01..87c85698 100644 --- a/tinygrad/renderer/__init__.py +++ b/tinygrad/renderer/__init__.py @@ -2,8 +2,7 @@ from typing import Optional, List, Tuple, Dict, Callable, Any import functools from dataclasses import dataclass, field from tinygrad.helpers import to_function_name, dedup -from tinygrad.codegen.uops import UOps, UOp, flops_mem -from tinygrad.ops import Op +from tinygrad.ops import Op, UOps, UOp, flops_mem from tinygrad.shape.symbolic import sym_infer, sint, Variable from tinygrad.dtype import DType diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index a0363f80..2ccac7f0 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -1,9 +1,8 @@ from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable import struct, math from collections import defaultdict -from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op +from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op, UOps, UOp, PatternMatcher, UPat from tinygrad.dtype import dtypes, DType, PtrDType, ConstType -from tinygrad.codegen.uops import UOps, UOp, PatternMatcher, UPat from tinygrad.renderer import Renderer, TensorCore def render_val(x, dtype): diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 322ec8fe..7a3746be 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -1,10 +1,9 @@ from typing import Dict, List, Optional, Tuple, Union, DefaultDict, cast, Literal, Callable import os, math from collections import defaultdict, Counter -from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps +from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOps, UOp from tinygrad.helpers import strip_parens, getenv, prod, dedup from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, ConstType -from tinygrad.codegen.uops import UOps, UOp from tinygrad.renderer import Renderer, TensorCore class CStyleLanguage(Renderer): diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 4f95fb68..ecc27349 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -1,8 +1,7 @@ from typing import Dict, Callable, Any, List, Optional from llvmlite import ir from tinygrad.dtype import DType, PtrDType, dtypes -from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps -from tinygrad.codegen.uops import UOps, UOp +from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps, UOps, UOp from tinygrad.renderer import Renderer MFLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index d0598a67..f20c2e0e 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -7,8 +7,7 @@ import pickle, base64, itertools, time, struct from tinygrad.dtype import DType, dtypes, ImageDType from tinygrad.helpers import all_same, getenv, flatten from tinygrad.device import Compiled, Compiler, Allocator -from tinygrad.codegen.uops import UOps, UOp -from tinygrad.ops import BinaryOps, TernaryOps, exec_alu, truncate +from tinygrad.ops import BinaryOps, TernaryOps, exec_alu, truncate, UOps, UOp from tinygrad.renderer import Renderer from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, AMDRenderer, IntelRenderer