mirror of https://github.com/commaai/tinygrad.git
merge uops with ops (#6111)
Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
parent
379d080e74
commit
28c75bf2a6
|
@ -39,8 +39,7 @@ DEVICE = "CLANG" # NOTE: you can change this!
|
|||
import struct
|
||||
from tinygrad.dtype import PtrDType, dtypes
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.ops import BinaryOps, MetaOps
|
||||
from tinygrad.codegen.uops import UOp, UOps
|
||||
from tinygrad.ops import BinaryOps, MetaOps, UOp, UOps
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
|
||||
# allocate some buffers + load in values
|
||||
|
|
|
@ -4,7 +4,7 @@ from extra.mcts_search import mcts_search
|
|||
from examples.mlperf.helpers import get_mlperf_bert_model
|
||||
from tinygrad import Tensor, Device, dtypes, nn
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.codegen.uops import UOps
|
||||
from tinygrad.ops import UOps
|
||||
from tinygrad.device import Compiled
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.search import time_linearizer, beam_search, bufs_from_lin
|
||||
|
|
|
@ -18,7 +18,7 @@ from tinygrad.device import Buffer
|
|||
from tinygrad.helpers import partition, Context, fetch, getenv, DEBUG, tqdm
|
||||
from tinygrad.engine.realize import run_schedule, lower_schedule, ExecItem, CompiledRunner, memory_planner
|
||||
from tinygrad.engine.schedule import ScheduleItem, create_schedule
|
||||
from tinygrad.codegen.uops import UOps
|
||||
from tinygrad.ops import UOps
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
Device.DEFAULT = "GPU"
|
||||
|
||||
|
|
|
@ -3,9 +3,8 @@ from typing import Dict, Union, Tuple, Any, List
|
|||
import functools, hashlib
|
||||
from enum import Enum, auto
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.codegen.uops import UOp, UOps
|
||||
from tinygrad.helpers import dedup, pretty_print, prod
|
||||
from tinygrad.ops import ReduceOps, UnaryOps, BinaryOps, TernaryOps, reduce_st
|
||||
from tinygrad.ops import ReduceOps, UnaryOps, BinaryOps, TernaryOps, reduce_st, UOp, UOps
|
||||
from tinygrad.dtype import ImageDType, PtrDType, dtypes, DType, ConstType
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from extra.models.resnet import ResNet50
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.codegen.uops import UOps
|
||||
from tinygrad.helpers import Profiling, Timing, getenv
|
||||
from tinygrad.ops import UOps
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -7,12 +7,11 @@ from extra.optimization.helpers import load_worlds, ast_str_to_lin, kern_str_to_
|
|||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.codegen.uops import UOp, UOps
|
||||
from tinygrad.codegen.kernel import Opt, OptOps
|
||||
from tinygrad.engine.search import get_kernel_actions, bufs_from_lin
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.helpers import getenv, from_mv, prod, colored, Context, DEBUG
|
||||
from tinygrad.ops import UnaryOps
|
||||
from tinygrad.ops import UnaryOps, UOp, UOps
|
||||
from test.helpers import is_dtype_supported
|
||||
|
||||
def tuplize_uops(uops:List[UOp]) -> Tuple:
|
||||
|
|
|
@ -3,7 +3,7 @@ from collections import defaultdict
|
|||
import numpy as np
|
||||
from dataclasses import replace
|
||||
from typing import DefaultDict, Dict, List, Tuple
|
||||
from tinygrad.codegen.uops import END_FOR_UOP, UOp
|
||||
from tinygrad.ops import END_FOR_UOP, UOp
|
||||
from tinygrad.codegen.uopgraph import UOpGraph
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
|
|
|
@ -2,7 +2,7 @@ import sys, unittest
|
|||
from typing import Optional, Set, Tuple
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.codegen.uops import UOp
|
||||
from tinygrad.ops import UOp
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from tinygrad.engine.realize import Runner
|
||||
from tinygrad.dtype import DType
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import unittest, math
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.codegen.uops import UOps
|
||||
from tinygrad.ops import UOps
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.helpers import CI
|
||||
import numpy as np
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/usr/bin/env python
|
||||
import unittest
|
||||
from tinygrad.codegen.uops import UOps
|
||||
from tinygrad.ops import UOps
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import Conv2d
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
|
|
|
@ -4,12 +4,11 @@ from tinygrad import Tensor, dtypes, Device
|
|||
import operator
|
||||
import numpy as np
|
||||
from hypothesis import given, strategies as strat, settings
|
||||
from tinygrad.codegen.uops import UOps
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.helpers import CI, getenv
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from tinygrad.ops import UnaryOps
|
||||
from tinygrad.ops import UnaryOps, UOps
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from test.helpers import is_dtype_supported
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
import numpy as np
|
||||
import unittest
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.codegen.uops import UOps
|
||||
from tinygrad.ops import UOps
|
||||
from tinygrad.lazy import LazyBuffer, MetaOps
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ from dataclasses import replace
|
|||
|
||||
from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError, Kernel
|
||||
from tinygrad.codegen.lowerer import get_grouped_dims
|
||||
from tinygrad.codegen.uops import UOp, UOps
|
||||
from tinygrad.ops import UOp, UOps
|
||||
from tinygrad.device import Device, Buffer
|
||||
from extra.ops import BinaryOps, BufferOps, MemBuffer, ConstBuffer, LazyOp, MetaOps, TernaryOps, ReduceOps, UnaryOps, to_uop
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
|
|
|
@ -4,9 +4,9 @@
|
|||
|
||||
import unittest
|
||||
from tinygrad import Device, dtypes
|
||||
from tinygrad.codegen.uops import UOps
|
||||
from tinygrad.ops import UOps
|
||||
from tinygrad.helpers import getenv
|
||||
from extra.ops import LazyOp, BinaryOps, UnaryOps, ReduceOps, TernaryOps, BufferOps, MemBuffer, ConstBuffer, MetaOps # noqa: F401 # pylint: disable=unused-import
|
||||
from extra.ops import LazyOp, BinaryOps, UnaryOps, ReduceOps, TernaryOps, BufferOps, MemBuffer, ConstBuffer, MetaOps
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
from tinygrad.engine.search import Opt, OptOps
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
|
|
|
@ -3,7 +3,7 @@ import unittest, random
|
|||
import numpy as np
|
||||
from tinygrad.codegen.kernel import KernelOptError
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.codegen.uops import UOps
|
||||
from tinygrad.ops import UOps
|
||||
from tinygrad.engine.search import Opt, OptOps
|
||||
from tinygrad import Device, dtypes, Tensor
|
||||
from tinygrad.helpers import CI
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
import unittest, functools, random
|
||||
from typing import List
|
||||
from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes
|
||||
from tinygrad.codegen.uops import UOps
|
||||
from tinygrad.ops import MetaOps, ReduceOps, BinaryOps
|
||||
from tinygrad.ops import MetaOps, ReduceOps, BinaryOps, UOps
|
||||
from tinygrad.helpers import CI, getenv, prod, Context
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
|
|
|
@ -3,7 +3,7 @@ import unittest
|
|||
import numpy as np
|
||||
import torch
|
||||
from tinygrad import Tensor, Device, TinyJit
|
||||
from tinygrad.codegen.uops import UOps
|
||||
from tinygrad.ops import UOps
|
||||
from tinygrad.helpers import CI, Context
|
||||
from tinygrad.nn import Conv1d, ConvTranspose1d, Conv2d, ConvTranspose2d, Linear, Embedding
|
||||
from tinygrad.nn import BatchNorm, LayerNorm, LayerNorm2d, GroupNorm, InstanceNorm, RMSNorm, LSTMCell
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
import unittest, itertools
|
||||
from test.helpers import TestUOps
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, ReduceOps, UnaryOps # noqa: F401
|
||||
from tinygrad.codegen.uops import UOps, UOp, PatternMatcher, UPat
|
||||
from tinygrad.ops import UOps, UOp, PatternMatcher, UPat, BinaryOps, TernaryOps, ReduceOps, UnaryOps # noqa: F401
|
||||
from tinygrad.codegen.uopgraph import constant_folder
|
||||
|
||||
class TestPatternMatcher(TestUOps):
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
import unittest
|
||||
from typing import List, cast
|
||||
import numpy as np
|
||||
from tinygrad.codegen.uops import UOp, UOps
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.dtype import PtrDType, DType, dtypes
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.helpers import dedup, flatten
|
||||
from tinygrad.renderer.cstyle import CStyleLanguage
|
||||
from tinygrad.ops import BinaryOps
|
||||
from tinygrad.ops import BinaryOps, UOp, UOps
|
||||
from tinygrad.renderer import Program
|
||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
|
|
|
@ -8,10 +8,9 @@ from typing import List, Optional, Union, cast
|
|||
from tinygrad import nn, dtypes
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import BinaryOps, MetaOps, UnaryOps
|
||||
from tinygrad.ops import BinaryOps, MetaOps, UnaryOps, UOps, verify_ast
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, FUSE_CONV_BW, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.codegen.uops import UOps, verify_ast
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from test.helpers import is_dtype_supported, Context
|
||||
|
|
|
@ -2,7 +2,7 @@ import unittest
|
|||
|
||||
from tinygrad.codegen.kernel import Opt, OptOps
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.codegen.uops import UOps
|
||||
from tinygrad.ops import UOps
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.search import time_linearizer, bufs_from_lin, actions, beam_search
|
||||
from tinygrad.device import Device, Buffer
|
||||
|
|
|
@ -3,8 +3,7 @@ from test.helpers import TestUOps
|
|||
from tinygrad import dtypes, Variable
|
||||
from tinygrad.dtype import PtrDType
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, ReduceOps
|
||||
from tinygrad.codegen.uops import UOps, UOp, NOp, PatternMatcher
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, ReduceOps, UOps, UOp, NOp, PatternMatcher
|
||||
from tinygrad.codegen.uopgraph import UOpGraph, graph_rewrite, expander, reducer, constant_folder, float4_folding
|
||||
|
||||
simple_pm = PatternMatcher([
|
||||
|
|
|
@ -5,11 +5,10 @@ from tinygrad.tensor import Tensor, _to_np_dtype
|
|||
from tinygrad.helpers import CI, DEBUG, getenv, Context
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps, KernelInfo, exec_alu # noqa F401
|
||||
from tinygrad.ops import UOps, NOp, UOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, KernelInfo, exec_alu # noqa F401
|
||||
from tinygrad.renderer import Program
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
|
||||
from tinygrad.codegen.uops import UOps, NOp, UOp
|
||||
from tinygrad.codegen.uopgraph import UOpGraph
|
||||
from test.helpers import is_dtype_supported, TestUOps as TestEqUOps
|
||||
|
||||
|
|
|
@ -3,9 +3,8 @@ from tinygrad import Tensor
|
|||
from tinygrad.helpers import getenv
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import lower_schedule_item
|
||||
from tinygrad.codegen.uops import flops_mem, UOps, UOp
|
||||
from tinygrad.codegen.uopgraph import UOpGraph
|
||||
from tinygrad.ops import BinaryOps, TernaryOps
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, flops_mem, UOps, UOp
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import unittest
|
||||
from tinygrad import Tensor, GlobalCounters
|
||||
from tinygrad.codegen.uops import UOps
|
||||
from tinygrad.ops import UOps
|
||||
from tinygrad.helpers import Timing, CI, Profiling, WINO, DEBUG, getenv
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
|
|
|
@ -9,9 +9,8 @@ from typing import Tuple
|
|||
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.dtype import dtypes, PtrDType, ConstType
|
||||
from tinygrad.codegen.uops import UOp, UOps
|
||||
from tinygrad.codegen.uopgraph import UOpGraph
|
||||
from tinygrad.ops import BinaryOps
|
||||
from tinygrad.ops import BinaryOps, UOp, UOps
|
||||
import functools
|
||||
|
||||
def render(self) -> Tuple[str, ConstType, ConstType]:
|
||||
|
|
|
@ -4,8 +4,7 @@ from dataclasses import dataclass, replace
|
|||
from collections import defaultdict
|
||||
from typing import Literal, Optional, List, Tuple, Union, cast, Dict, Final, DefaultDict
|
||||
|
||||
from tinygrad.codegen.uops import BUFFER_UOPS, UOp, UOps, verify_ast
|
||||
from tinygrad.ops import BinaryOps, ReduceOps, UNSAFE_PAD_OPS, KernelInfo
|
||||
from tinygrad.ops import BinaryOps, ReduceOps, UNSAFE_PAD_OPS, KernelInfo, BUFFER_UOPS, UOp, UOps, verify_ast
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.renderer import Renderer, TensorCore, Program
|
||||
from tinygrad.dtype import DType, ImageDType, PtrDType
|
||||
|
|
|
@ -5,8 +5,7 @@ import functools
|
|||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
from tinygrad.shape.symbolic import sint
|
||||
from tinygrad.dtype import dtypes, DType
|
||||
from tinygrad.ops import ReduceOps, KernelInfo, BinaryOps
|
||||
from tinygrad.codegen.uops import BUFFER_UOPS, UOp, UOps
|
||||
from tinygrad.ops import ReduceOps, KernelInfo, BinaryOps, BUFFER_UOPS, UOp, UOps
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.helpers import all_int, get_contraction, prod, partition, flatten
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import math, functools
|
||||
from typing import Tuple, List
|
||||
from tinygrad.dtype import dtypes, DType
|
||||
from tinygrad.codegen.uops import UOp
|
||||
from tinygrad.ops import UOp
|
||||
|
||||
TRANSCENDENTAL_SUPPORTED_DTYPES = {dtypes.float16, dtypes.float32, dtypes.float64}
|
||||
|
||||
|
|
|
@ -3,9 +3,8 @@ from typing import Iterator, Optional, Tuple, Dict, List, Set, Union, cast, TYPE
|
|||
import functools, itertools, heapq, math, operator
|
||||
from collections import defaultdict
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, DType
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, exec_alu
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, exec_alu, UOp, NOp, UOps, UPat, PatternMatcher, END_FOR_UOP, type_verify, print_uops
|
||||
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, prod, CI, all_same, partition
|
||||
from tinygrad.codegen.uops import UOp, NOp, UOps, UPat, PatternMatcher, END_FOR_UOP, type_verify, print_uops
|
||||
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
if TYPE_CHECKING: from tinygrad.renderer import Renderer
|
||||
|
||||
|
|
|
@ -1,338 +0,0 @@
|
|||
from __future__ import annotations
|
||||
from typing import Optional, Tuple, Any, Set, cast, List, Union, DefaultDict, Callable, Dict
|
||||
import functools, itertools, math, hashlib
|
||||
from collections import defaultdict
|
||||
from enum import Enum, auto
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.dtype import ConstType, dtypes, DType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import sint, Variable
|
||||
from tinygrad.ops import ReduceOps, UnaryOps, BinaryOps, TernaryOps, exec_alu, reduce_st
|
||||
from tinygrad.helpers import merge_dicts, prod, pretty_print, dedup
|
||||
|
||||
# the order of these UOps controls the order of the toposort
|
||||
class UOps(Enum):
|
||||
# ops that aren't rendered
|
||||
SINK = auto(); EXT = auto(); EXPAND = auto(); CONTRACT = auto(); ST_IDX = auto(); ST_VALID = auto() # noqa: E702
|
||||
DEFINE_GLOBAL = auto(); DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # noqa: E702
|
||||
CONST = auto(); SPECIAL = auto() # noqa: E702
|
||||
NOOP = auto(); GEP = auto() # noqa: E702
|
||||
# math ops
|
||||
CAST = auto(); BITCAST = auto(); VECTORIZE = auto() # noqa: E702
|
||||
ALU = auto(); REDUCE = auto(); REDUCE_AXIS = auto(); WMMA = auto() # noqa: E702
|
||||
# memory/assignment ops
|
||||
LOAD = auto(); STORE = auto(); PHI = auto() # noqa: E702
|
||||
# control flow ops
|
||||
BARRIER = auto(); IF = auto(); RANGE = auto() # noqa: E702
|
||||
# these two are not graph nodes
|
||||
ENDRANGE = auto(); ENDIF = auto() # noqa: E702
|
||||
|
||||
BUFFER_UOPS = {UOps.LOAD, UOps.STORE, UOps.CONST}
|
||||
|
||||
END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.PHI, UOps.ENDRANGE)}
|
||||
|
||||
@dataclass(frozen=True, eq=False)
|
||||
class UOp:
|
||||
op: UOps
|
||||
dtype: Optional[DType] = None
|
||||
src: Tuple[UOp, ...] = tuple()
|
||||
arg: Any = None
|
||||
def commutative(self) -> bool:
|
||||
return (self.op is UOps.ALU and \
|
||||
self.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR})
|
||||
@functools.cached_property
|
||||
def cmp_tuple(self):
|
||||
# NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX
|
||||
return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg.expr) if self.op is not UOps.ALU else \
|
||||
self.arg.value, self.dtype, self.src)
|
||||
def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple
|
||||
@functools.cached_property
|
||||
def key(self) -> bytes:
|
||||
return hashlib.sha256(functools.reduce(lambda x,y: x+y, [s.key for s in self.src], str((self.op, self.dtype, self.arg)).encode())).digest()
|
||||
def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))")
|
||||
def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is UOps.REDUCE_AXIS else self.arg
|
||||
# *** uop syntactic sugar
|
||||
def ufix(self, x): return self.const(x) if not isinstance(x, UOp) else x
|
||||
def cast(self, dtype=None): return type(self)(UOps.CAST, dtype, (self,))
|
||||
def bitcast(self, dtype=None): return type(self)(UOps.BITCAST, dtype, (self,))
|
||||
def gep(self, i:int): return type(self)(UOps.GEP, self.dtype.scalar() if self.dtype is not None else None, (self,), i)
|
||||
def __neg__(self): return self.alu(UnaryOps.NEG)
|
||||
def __add__(self, x): return self.alu(BinaryOps.ADD, self.ufix(x))
|
||||
def __radd__(self, x): return self.alu(BinaryOps.ADD, self.ufix(x))
|
||||
def __sub__(self, x): return self.alu(BinaryOps.ADD, self.ufix(-x))
|
||||
def __mul__(self, x): return self.alu(BinaryOps.MUL, self.ufix(x))
|
||||
def __rmul__(self, x): return self.ufix(x).alu(BinaryOps.MUL, self)
|
||||
def __floordiv__(self, x): return self.alu(BinaryOps.IDIV, self.ufix(x))
|
||||
def __truediv__(self, x): return self.alu(BinaryOps.MUL, self.ufix(x).alu(UnaryOps.RECIP))
|
||||
def __mod__(self, x): return self.alu(BinaryOps.MOD, self.ufix(x))
|
||||
def __xor__(self, x): return self.alu(BinaryOps.XOR, self.ufix(x))
|
||||
def __and__(self, x): return self.alu(BinaryOps.AND, self.ufix(x))
|
||||
def __or__(self, x): return self.alu(BinaryOps.OR, self.ufix(x))
|
||||
def ne(self, x): return self.alu(BinaryOps.CMPNE, self.ufix(x))
|
||||
def eq(self, x): return -self.ne(x)
|
||||
def lt(self, x): return self.alu(BinaryOps.CMPLT, self.ufix(x))
|
||||
def ge(self, x): return (-self).lt(-x+1)
|
||||
def max(self, x): return self.alu(BinaryOps.MAX, x)
|
||||
def min(self, x): return -(-self).max(-x)
|
||||
def where(self, x, y): return self.alu(TernaryOps.WHERE, x, y)
|
||||
def recip(self): return self.alu(UnaryOps.RECIP)
|
||||
def const(self:Union[UOp, DType, None], b:ConstType|Variable): return UOp._const(self.dtype if isinstance(self, UOp) else self, b)
|
||||
def sconst(self:Union[UOp, DType, None], b:ConstType|Variable):
|
||||
return UOp._const(cast(DType, self.dtype if isinstance(self, UOp) else self).scalar() if self is not None else self, b)
|
||||
@staticmethod
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _const(dtype:Optional[DType], b:ConstType|Variable):
|
||||
# TODO: fix dtype of b.max after Variable is just an UOp
|
||||
if isinstance(b, Variable): return UOp(UOps.DEFINE_VAR, dtype, (UOp.const(dtypes.int, b.min), UOp.const(dtypes.int, cast(int,b.max))), b)
|
||||
if dtype is not None and dtype != (sdtype := dtype.scalar()):
|
||||
return UOp(UOps.VECTORIZE, dtype, src=tuple(UOp(UOps.CONST, sdtype, arg=dtypes.as_const(b, sdtype)) for _ in range(dtype.count)))
|
||||
return UOp(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b)
|
||||
def alu(self, arg, *src:UOp):
|
||||
return type(self)(UOps.ALU, dtypes.bool if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else (self, *src)[-1].dtype, (self,)+src, arg)
|
||||
@staticmethod
|
||||
def load(*src:UOp, dtype:Optional[DType]=None, **kwargs): return type(src[0])(UOps.LOAD, dtype, tuple(src)+tuple(kwargs.values()))
|
||||
@staticmethod
|
||||
def store(*src:UOp, **kwargs): return type((src:=(*src, *kwargs.values()))[0])(UOps.STORE, None, src)
|
||||
@functools.cached_property
|
||||
def parents(self) -> Dict[UOp, None]: return merge_dicts([{x:None for x in self.src}]+[x.parents for x in self.src])
|
||||
@property # parents with self
|
||||
def sparents(self) -> Dict[UOp, None]: return {**self.parents, self:None}
|
||||
@staticmethod
|
||||
def from_st(st:ShapeTracker) -> Tuple[UOp, UOp]: return UOp(UOps.ST_IDX, dtypes.pyint, (), st), UOp(UOps.ST_VALID, dtypes.bool, (), st)
|
||||
@functools.cached_property
|
||||
def full_shape(self) -> Tuple[sint, ...]:
|
||||
if self.op in {UOps.ST_IDX, UOps.ST_VALID}: return self.arg.shape
|
||||
# NOTE: UOps.DEFINE_GLOBAL and UOps.DEFINE_LOCAL don't have shape
|
||||
return tuple(max(x) for x in zip(*[x.full_shape for x in self.src if x.op not in {UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL}]))
|
||||
# TODO: these two should merge
|
||||
def vars(self) -> Set[UOp]: return set([x for x in self.sparents if x.op is UOps.DEFINE_VAR])
|
||||
def variables(self) -> List[Variable]:
|
||||
st_vars: List[Set[Variable]] = [x.src[-1].arg.vars() for x in self.sparents if x.op in BUFFER_UOPS]
|
||||
return sorted(set.union(*st_vars, set([x.arg for x in self.sparents if x.op is UOps.DEFINE_VAR])), key=lambda v: v.expr)
|
||||
def const_factor(self) -> int:
|
||||
"""largest known int that divides self"""
|
||||
if self.op is UOps.CONST: return self.arg
|
||||
if self.op is UOps.ALU:
|
||||
if self.arg is BinaryOps.ADD: return math.gcd(self.src[0].const_factor(), self.src[0].const_factor())
|
||||
if self.arg is BinaryOps.MUL: return self.src[0].arg if self.src[0].op is UOps.CONST else self.src[1].arg if self.src[1].op is UOps.CONST else 1
|
||||
return 1
|
||||
def divides(self, v) -> Optional[UOp]:
|
||||
if v==1: return self
|
||||
if self.op is UOps.CONST: return self.const(self.arg//v) if self.arg%v == 0 else None
|
||||
if self.op is UOps.ALU:
|
||||
if self.arg is BinaryOps.ADD: return d0+d1 if (d0:=self.src[0].divides(v)) is not None and (d1:=self.src[1].divides(v)) is not None else None
|
||||
if self.arg is BinaryOps.MUL:
|
||||
if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1]
|
||||
if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1
|
||||
return None # generic None if we aren't sure
|
||||
@functools.cached_property
|
||||
def vmin(self) -> UOp: return x if (x:=self._min_max[0]) is not None and not math.isnan(x.arg) else self.sconst(dtypes.min(cast(DType, self.dtype)))
|
||||
@functools.cached_property
|
||||
def vmax(self) -> UOp: return x if (x:=self._min_max[1]) is not None and not math.isnan(x.arg) else self.sconst(dtypes.max(cast(DType, self.dtype)))
|
||||
@functools.cached_property
|
||||
def _min_max(self) -> Tuple[Optional[UOp], Optional[UOp]]:
|
||||
# NOTE: returned UOp is assumed to be CONST
|
||||
if self.op is UOps.DEFINE_VAR and self.src: return self.src[0], self.src[1] if isinstance(self.src[1].arg, int) else None
|
||||
if self.op is UOps.RANGE: return self.src[0], self.const(self.src[1].arg-1) if isinstance(self.src[1].arg, int) else None
|
||||
# TODO: UOps.SPECIAL is UOps.DEFINE_VAR
|
||||
if self.op is UOps.SPECIAL: return self.const(0), self.const(self.arg[1]-1) if isinstance(self.arg[1], int) else None
|
||||
if self.op is UOps.CONST: return self, self
|
||||
if self.op is UOps.ALU and cast(DType, self.dtype).count == 1:
|
||||
s0,s1 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(2)]
|
||||
if self.arg is UnaryOps.NEG and self.dtype != dtypes.bool and not dtypes.is_unsigned(cast(DType, self.dtype)):
|
||||
return self.sconst(-s0.vmax.arg), self.sconst(-s0.vmin.arg)
|
||||
if self.arg is BinaryOps.ADD: return self.sconst(s0.vmin.arg+s1.vmin.arg), self.sconst(s0.vmax.arg+s1.vmax.arg)
|
||||
if self.arg is BinaryOps.MUL and (s0.vmin.arg >= 0 or s1.vmin.arg >= 0):
|
||||
# handle at lease one is non-negative
|
||||
Lmin, Lmax = (s0.vmin.arg, s0.vmax.arg) if s1.vmin.arg >= 0 else (s0.vmax.arg, s0.vmin.arg)
|
||||
Rmin, Rmax = (s1.vmin.arg, s1.vmax.arg) if s0.vmin.arg >= 0 else (s1.vmax.arg, s1.vmin.arg)
|
||||
assert math.isnan(Lmax*Rmax) or math.isnan(Lmin*Rmin) or Lmax*Rmax >= Lmin*Rmin, f"{Lmax=}, {Lmin=}, {Rmax=}, {Rmin=}"
|
||||
return self.sconst(Lmin*Rmin), self.sconst(Lmax*Rmax)
|
||||
if self.arg is BinaryOps.MOD and s1.op is UOps.CONST and s1.arg > 0: return self.sconst(0), self.sconst(s1.arg-1)
|
||||
if self.arg is BinaryOps.IDIV and s1.op is UOps.CONST:
|
||||
if s1.arg > 0: return self.sconst(s0.vmin.arg//s1.arg), self.sconst(s0.vmax.arg//s1.arg)
|
||||
if s1.arg < 0: return self.sconst(-(s0.vmax.arg//-s1.arg)), self.sconst(-(s0.vmin.arg//-s1.arg))
|
||||
if self.arg is BinaryOps.MAX: return self.sconst(max(s0.vmin.arg, s1.vmin.arg)), self.sconst(max(s0.vmax.arg, s1.vmax.arg))
|
||||
if self.arg is BinaryOps.CMPLT: return (UOp.sconst(dtypes.bool, True), UOp.sconst(dtypes.bool, True)) if s0.vmax.arg < s1.vmin.arg else \
|
||||
(UOp.sconst(dtypes.bool, False), UOp.sconst(dtypes.bool, False)) if s0.vmin.arg >= s1.vmax.arg else (None, None)
|
||||
return None, None
|
||||
|
||||
@dataclass(frozen=True, repr=False) # reuse repr from UOp
|
||||
class NOp(UOp):
|
||||
name:Optional[str] = None
|
||||
src:Tuple[NOp, ...] = tuple()
|
||||
allow_any_len:bool = False
|
||||
@staticmethod
|
||||
def var(name:Optional[str]=None, dtype:Optional[DType]=None): return NOp(UOps.NOOP, dtype=dtype, name=name)
|
||||
@staticmethod
|
||||
def cvar(name:Optional[str]=None, dtype:Optional[DType]=None): return NOp(UOps.CONST, dtype=dtype, name=name)
|
||||
def const(self:Union[UOp, DType, None], b:ConstType|Variable): return NOp((x:=UOp.const(self, b)).op, x.dtype, x.src, x.arg)
|
||||
|
||||
def compile(self: NOp, name:Optional[str]=None) -> UPat:
|
||||
return UPat(name=self.name, dtype=self.dtype) if self.op is UOps.NOOP else UPat(self.op, self.arg, (list if self.commutative()
|
||||
else tuple)(src.compile() for src in self.src) or None, self.name or name, self.dtype, self.allow_any_len)
|
||||
|
||||
class UPat:
|
||||
def __init__(self, op:Optional[Union[UOps, Set[UOps]]]=None, arg:Any=None, src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None,
|
||||
name:Optional[str]=None, dtype:Optional[Union[DType, Set[DType]]]=None, allow_any_len:bool=False):
|
||||
self.op: Optional[Tuple[UOps, ...]] = None if op is None else (tuple(op) if isinstance(op, set) else (op,))
|
||||
self.dtype: Optional[Tuple[DType, ...]] = None if dtype is None else (tuple(dtype) if isinstance(dtype, set) else (dtype,))
|
||||
self.arg, self.name = arg, name
|
||||
self.src: Any = None
|
||||
# try all permutations if it's a list
|
||||
if isinstance(src, list): self.src = list(itertools.permutations(src))
|
||||
# only one if it's a tuple
|
||||
elif isinstance(src, tuple): self.src = [src]
|
||||
# repeat if it's a UPat
|
||||
elif isinstance(src, UPat): self.src = [itertools.repeat(src)]
|
||||
|
||||
self.allowed_len: int = 0 if allow_any_len or isinstance(src, UPat) or src is None else len(src)
|
||||
|
||||
def __repr__(self):
|
||||
def rep(x):
|
||||
form = "UPat(%s, %s, name=%s, dtype=%s, allow_any_len=%s, src=%s)"
|
||||
return form % (None if x.op is None else ('(%s)'%', '.join(map(str, x.op))), x.arg, repr(x.name),
|
||||
set(x.dtype) if x.dtype else None, x.allowed_len == 0, "[%s]" if x.src and len(x.src)>1 else "(%s)")
|
||||
return pretty_print(self, rep, srcfn=lambda x:None if x.src is None else [next(x.src[0])] if isinstance(x.src[0], itertools.repeat) else x.src[0])
|
||||
|
||||
def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> List[Dict[str, UOp]]:
|
||||
if (pat.name is not None and store.setdefault(pat.name, uop) is not uop) or \
|
||||
(pat.dtype is not None and uop.dtype not in pat.dtype) or \
|
||||
(pat.arg is not None and pat.arg != uop.arg) or \
|
||||
(pat.op is not None and uop.op not in pat.op): return []
|
||||
if pat.src is None: return [store]
|
||||
res: List[Dict[str, UOp]] = []
|
||||
for vp in pat.src:
|
||||
if pat.allowed_len != 0 and len(uop.src) != pat.allowed_len: return []
|
||||
new_stores = [store.copy()]
|
||||
for uu, vv in zip(uop.src, vp): new_stores = [rstore for nstore in new_stores for rstore in _match(uu, vv, nstore)]
|
||||
res.extend(new_stores)
|
||||
return res
|
||||
|
||||
class PatternMatcher:
|
||||
def __init__(self, patterns:List[Tuple[Union[UPat, NOp], Callable]]):
|
||||
self.patterns = patterns
|
||||
self.pdict: DefaultDict[Tuple[UOps, Any], List[Tuple[UPat, Callable]]] = defaultdict(list)
|
||||
# uop is required, arg is optional
|
||||
for p,fxn in self.patterns:
|
||||
if isinstance(p, NOp): p = p.compile()
|
||||
assert p.op is not None
|
||||
for uop in p.op: self.pdict[(uop, p.arg)].append((p, fxn))
|
||||
|
||||
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
||||
def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns)
|
||||
|
||||
def rewrite(self, uop:UOp) -> Optional[UOp]:
|
||||
for p,fxn in itertools.chain(self.pdict[(uop.op, uop.arg)], self.pdict[(uop.op, None)]):
|
||||
if (matches := _match(uop, p, {})) and (ret:=fxn(**matches[0])) is not None: return ret # NOTE: if it returns None, we keep trying to match
|
||||
return None
|
||||
|
||||
def type_verify(uops):
|
||||
for u in uops:
|
||||
uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype
|
||||
if uop in {UOps.CONST, UOps.DEFINE_ACC}:
|
||||
if uop is UOps.CONST:
|
||||
assert dtype is not None and dtype == dtype.scalar(), f"consts must be scalar, got {dtype}"
|
||||
assert type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
|
||||
if uop is UOps.DEFINE_ACC: assert dtype is not None and src[0].dtype == dtype, f"dtype mismatch {src[0].dtype=} != {dtype=}"
|
||||
if uop in {UOps.CAST, UOps.BITCAST, UOps.VECTORIZE}: assert arg is None and dtype is not None # type is the output type, not an arg
|
||||
if uop is UOps.CAST: assert dtype.count == 1 and len(src) == 1
|
||||
if uop is UOps.VECTORIZE:
|
||||
assert dtype.count > 1 and len(src) == dtype.count, f"dtype vectorization mismatch {dtype.count=} != {len(src)=}"
|
||||
assert all(dtype == x.dtype.vec(len(src)) for x in src), f"{dtype=} must be {src[0].dtype.vec(len(src))}"
|
||||
if uop is UOps.LOAD and len(src) > 3 and src[3].op is UOps.ALU: assert src[3].dtype == dtypes.bool and src[2].dtype == dtype
|
||||
if uop is UOps.GEP: assert dtype == src[0].dtype.scalar(), f"GEP of {src[0].dtype=} should be {src[0].dtype.scalar()} != {dtype}"
|
||||
if uop is UOps.STORE:
|
||||
assert dtype is None, f"{uop} dtype must be None, got {dtype}"
|
||||
if len(src) == 4: assert src[3].dtype == dtypes.bool, f"gate dtype mismatch {src[3].dtype} != {dtypes.bool}"
|
||||
if uop is UOps.ALU:
|
||||
if arg in UnaryOps: assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
|
||||
elif arg in {BinaryOps.CMPLT, BinaryOps.CMPNE}:
|
||||
assert dtype == (bd:=(dtypes.bool.vec(dtype.count) if dtype.count != 1 else dtypes.bool)), f"{arg} output dtype mismatch {dtype=} != {bd=}"
|
||||
assert src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
|
||||
elif arg is BinaryOps.IDIV:
|
||||
assert dtypes.is_int(src[0].dtype) and dtypes.is_int(src[1].dtype), f"input dtype is not int {src[0].dtype=}, {src[1].dtype=}"
|
||||
assert dtypes.is_int(dtype), f"output dtype is not int {dtype=}"
|
||||
elif arg in {BinaryOps.SHL, BinaryOps.SHR}:
|
||||
# the distance to shift isn't typechecked
|
||||
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
|
||||
elif arg in BinaryOps: assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
|
||||
elif arg == TernaryOps.WHERE:
|
||||
assert src[0].dtype == (bd:=(dtypes.bool.vec(dtype.count) if dtype.count != 1 else dtypes.bool)), \
|
||||
f"{arg} selector dtype mismatch {src[0].dtype=} != {bd}"
|
||||
assert dtype == src[1].dtype == src[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {src[1].dtype=} != {src[2].dtype=}"
|
||||
|
||||
def uop_alu_resolve(u:UOp) -> sint:
|
||||
if u.op in {UOps.CONST, UOps.DEFINE_VAR}: return u.arg
|
||||
if u.op is UOps.ALU: return exec_alu(u.arg, cast(DType,u.dtype), tuple(map(uop_alu_resolve, u.src)))
|
||||
raise RuntimeError(f"ALU resolve fail @ {u.op}")
|
||||
|
||||
def print_uops(uops:List[UOp]):
|
||||
for i,u in enumerate(uops):
|
||||
formatted_parents = [uops.index(x) if x.op is not UOps.CONST else f"{x.arg}" for x in u.src]
|
||||
print(f"{i:4d} {str(u.op):20s}: {str(u.dtype) if u.dtype is not None else '':25s} " f"{str(formatted_parents):32s} {u.arg}")
|
||||
|
||||
def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
|
||||
flops: sint = 0
|
||||
mem: sint = 0
|
||||
mults: sint = 1
|
||||
mult_stack: List[sint] = []
|
||||
dont_count: Set[UOp] = set()
|
||||
if ignore_indexing:
|
||||
for u in uops:
|
||||
if u.op is UOps.LOAD:
|
||||
dont_count = dont_count.union(u.src[1].sparents)
|
||||
if len(u.src) > 3: dont_count = dont_count.union(u.src[2].sparents)
|
||||
elif u.op is UOps.STORE:
|
||||
dont_count = dont_count.union(u.src[1].sparents)
|
||||
if len(u.src) > 3: dont_count = dont_count.union(u.src[3].sparents)
|
||||
elif u.op is UOps.IF:
|
||||
dont_count = dont_count.union(u.src[0].sparents)
|
||||
for u in uops:
|
||||
if u.op is UOps.RANGE:
|
||||
mult_stack.append(mults)
|
||||
mults *= uop_alu_resolve(u.src[1] - u.src[0])
|
||||
elif u.op is UOps.ENDRANGE:
|
||||
mults = mult_stack.pop(-1)
|
||||
elif u.op is UOps.SPECIAL:
|
||||
mults *= u.arg[1] # NOTE: we don't push to the mult_stack here, you can't end these
|
||||
elif u.op is UOps.LOAD:
|
||||
assert u.dtype is not None
|
||||
mem += u.dtype.itemsize * mults
|
||||
elif u.op is UOps.STORE:
|
||||
assert u.src[2].dtype is not None
|
||||
mem += u.src[2].dtype.itemsize * mults
|
||||
elif u.op is UOps.ALU and u not in dont_count:
|
||||
assert u.dtype is not None
|
||||
flops += (mults * (2 if u.arg == TernaryOps.MULACC else 1)) * u.dtype.count
|
||||
elif u.op is UOps.WMMA and u not in dont_count:
|
||||
assert u.arg[1] is not None
|
||||
flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
|
||||
return flops, mem
|
||||
|
||||
# the living definition of UOps.ST_IDX and UOps.ST_VALID
|
||||
def verify_ast(ast:UOp) -> Dict[UOp, ShapeTracker]:
|
||||
assert ast.op is UOps.SINK and all(x.op is UOps.STORE for x in ast.src), "must be SINK"
|
||||
sts: Dict[UOp, ShapeTracker] = {}
|
||||
def assert_valid(op:UOp, st:ShapeTracker):
|
||||
if op in sts or op.op in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL}: return
|
||||
# restore globals from the two stage reduce
|
||||
if op.op is UOps.LOAD and op.src[0].op is UOps.DEFINE_LOCAL:
|
||||
assert_valid(local_reduce:=op.src[1].src[2], op.src[-1].arg)
|
||||
return sts.setdefault(op, sts[local_reduce])
|
||||
for x in op.src: assert_valid(x, st)
|
||||
# only reduceop is allowed to change shape, limited to turning n to 1
|
||||
if op.op is UOps.REDUCE_AXIS: st = ShapeTracker.from_shape(reduce_st(sts[op.src[0]], op.arg[1][-1] if op.arg[0] is ReduceOps.WMMA else op.arg[1]))
|
||||
else:
|
||||
# movementops are pushed to the edges with ST_IDX, ST_VALID
|
||||
# elementwise inherits shape
|
||||
st = op.arg if op.op in {UOps.ST_IDX, UOps.ST_VALID} else sts[op.src[-1]]
|
||||
for x in (op.src[1:] if op.op in BUFFER_UOPS else op.src):
|
||||
if sts[x].shape != st.shape:
|
||||
if prod(sts[x].shape) == prod(st.shape): raise AssertionError(f"found implicit reshape {x.op} {op.op} {sts[x].shape} != {st.shape}")
|
||||
raise AssertionError(f"found implicit expand {x.op} {sts[x].shape} != {op.op} {st.shape} {prod(sts[x].shape)} != {prod(st.shape)}")
|
||||
sts[op] = st
|
||||
for out in ast.src: assert_valid(out, out.src[-1].arg)
|
||||
shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])]
|
||||
assert all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims), f"shapes must have either 1 or n in each dimension, {shape_dims}"
|
||||
return sts
|
|
@ -1,10 +1,9 @@
|
|||
import os, atexit, functools, contextlib
|
||||
from collections import defaultdict
|
||||
from typing import List, Any, DefaultDict
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MetaOps, TernaryOps
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MetaOps, TernaryOps, UOps, UOp
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.helpers import GRAPHPATH, DEBUG, GlobalCounters
|
||||
from tinygrad.codegen.uops import UOps, UOp
|
||||
from tinygrad.shape.symbolic import NumNode
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
|
||||
|
|
|
@ -3,8 +3,7 @@ import time, pprint
|
|||
from collections import defaultdict
|
||||
from dataclasses import dataclass, replace
|
||||
from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, Context, TRACEMETA, dedup
|
||||
from tinygrad.codegen.uops import UOps, UOp
|
||||
from tinygrad.ops import MetaOps
|
||||
from tinygrad.ops import MetaOps, UOps, UOp
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.shape.symbolic import Variable, sym_infer, sint
|
||||
|
|
|
@ -2,8 +2,7 @@ import sys, pickle, atexit, importlib, contextlib
|
|||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Tuple, List, Dict, Optional, Set, DefaultDict, get_args
|
||||
from tinygrad.codegen.uops import UOp, UOps
|
||||
from tinygrad.ops import MetaOps, ReduceOps, UNSAFE_PAD_OPS, UnaryOps, reduce_st
|
||||
from tinygrad.ops import MetaOps, ReduceOps, UNSAFE_PAD_OPS, UnaryOps, reduce_st, UOp, UOps
|
||||
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
|
||||
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, \
|
||||
GlobalCounters, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata
|
||||
|
|
|
@ -2,7 +2,7 @@ from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
|
|||
import itertools, functools, random, math, time, multiprocessing, traceback, signal
|
||||
from collections import defaultdict
|
||||
from dataclasses import replace
|
||||
from tinygrad.codegen.uops import UOp, UOps
|
||||
from tinygrad.ops import UOp, UOps
|
||||
from tinygrad.device import Device, Buffer, Compiler
|
||||
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
|
||||
from tinygrad.dtype import DType, ImageDType
|
||||
|
|
336
tinygrad/ops.py
336
tinygrad/ops.py
|
@ -1,10 +1,12 @@
|
|||
from __future__ import annotations
|
||||
from typing import Union, Tuple, Dict, Callable
|
||||
import math, operator, ctypes, struct
|
||||
from collections import defaultdict
|
||||
from typing import Any, DefaultDict, List, Optional, Set, Union, Tuple, Dict, Callable, cast
|
||||
import math, operator, ctypes, struct, functools, hashlib, itertools
|
||||
from enum import Enum, auto
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.dtype import dtypes, DType
|
||||
from tinygrad.shape.symbolic import sint
|
||||
from tinygrad.dtype import ConstType, dtypes, DType
|
||||
from tinygrad.helpers import dedup, merge_dicts, pretty_print, prod
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
|
||||
# these are the llops your accelerator must implement, along with toCpu
|
||||
|
@ -73,3 +75,329 @@ truncate: Dict[DType, Callable] = {dtypes.bool: bool,
|
|||
def exec_alu(op:Op, dtype:DType, operands): return truncate.get(dtype, lambda x: x)(python_alu[op](*operands))
|
||||
|
||||
def reduce_st(st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[sint, ...]: return tuple(1 if i in axis else s for i,s in enumerate(st.shape))
|
||||
|
||||
# the order of these UOps controls the order of the toposort
|
||||
class UOps(Enum):
|
||||
# ops that aren't rendered
|
||||
SINK = auto(); EXT = auto(); EXPAND = auto(); CONTRACT = auto(); ST_IDX = auto(); ST_VALID = auto() # noqa: E702
|
||||
DEFINE_GLOBAL = auto(); DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # noqa: E702
|
||||
CONST = auto(); SPECIAL = auto() # noqa: E702
|
||||
NOOP = auto(); GEP = auto() # noqa: E702
|
||||
# math ops
|
||||
CAST = auto(); BITCAST = auto(); VECTORIZE = auto() # noqa: E702
|
||||
ALU = auto(); REDUCE = auto(); REDUCE_AXIS = auto(); WMMA = auto() # noqa: E702
|
||||
# memory/assignment ops
|
||||
LOAD = auto(); STORE = auto(); PHI = auto() # noqa: E702
|
||||
# control flow ops
|
||||
BARRIER = auto(); IF = auto(); RANGE = auto() # noqa: E702
|
||||
# these two are not graph nodes
|
||||
ENDRANGE = auto(); ENDIF = auto() # noqa: E702
|
||||
|
||||
BUFFER_UOPS = {UOps.LOAD, UOps.STORE, UOps.CONST}
|
||||
|
||||
END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.PHI, UOps.ENDRANGE)}
|
||||
|
||||
@dataclass(frozen=True, eq=False)
|
||||
class UOp:
|
||||
op: UOps
|
||||
dtype: Optional[DType] = None
|
||||
src: Tuple[UOp, ...] = tuple()
|
||||
arg: Any = None
|
||||
def commutative(self) -> bool:
|
||||
return (self.op is UOps.ALU and \
|
||||
self.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR})
|
||||
@functools.cached_property
|
||||
def cmp_tuple(self):
|
||||
# NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX
|
||||
return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg.expr) if self.op is not UOps.ALU else \
|
||||
self.arg.value, self.dtype, self.src)
|
||||
def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple
|
||||
@functools.cached_property
|
||||
def key(self) -> bytes:
|
||||
return hashlib.sha256(functools.reduce(lambda x,y: x+y, [s.key for s in self.src], str((self.op, self.dtype, self.arg)).encode())).digest()
|
||||
def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))")
|
||||
def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is UOps.REDUCE_AXIS else self.arg
|
||||
# *** uop syntactic sugar
|
||||
def ufix(self, x): return self.const(x) if not isinstance(x, UOp) else x
|
||||
def cast(self, dtype=None): return type(self)(UOps.CAST, dtype, (self,))
|
||||
def bitcast(self, dtype=None): return type(self)(UOps.BITCAST, dtype, (self,))
|
||||
def gep(self, i:int): return type(self)(UOps.GEP, self.dtype.scalar() if self.dtype is not None else None, (self,), i)
|
||||
def __neg__(self): return self.alu(UnaryOps.NEG)
|
||||
def __add__(self, x): return self.alu(BinaryOps.ADD, self.ufix(x))
|
||||
def __radd__(self, x): return self.alu(BinaryOps.ADD, self.ufix(x))
|
||||
def __sub__(self, x): return self.alu(BinaryOps.ADD, self.ufix(-x))
|
||||
def __mul__(self, x): return self.alu(BinaryOps.MUL, self.ufix(x))
|
||||
def __rmul__(self, x): return self.ufix(x).alu(BinaryOps.MUL, self)
|
||||
def __floordiv__(self, x): return self.alu(BinaryOps.IDIV, self.ufix(x))
|
||||
def __truediv__(self, x): return self.alu(BinaryOps.MUL, self.ufix(x).alu(UnaryOps.RECIP))
|
||||
def __mod__(self, x): return self.alu(BinaryOps.MOD, self.ufix(x))
|
||||
def __xor__(self, x): return self.alu(BinaryOps.XOR, self.ufix(x))
|
||||
def __and__(self, x): return self.alu(BinaryOps.AND, self.ufix(x))
|
||||
def __or__(self, x): return self.alu(BinaryOps.OR, self.ufix(x))
|
||||
def ne(self, x): return self.alu(BinaryOps.CMPNE, self.ufix(x))
|
||||
def eq(self, x): return -self.ne(x)
|
||||
def lt(self, x): return self.alu(BinaryOps.CMPLT, self.ufix(x))
|
||||
def ge(self, x): return (-self).lt(-x+1)
|
||||
def max(self, x): return self.alu(BinaryOps.MAX, x)
|
||||
def min(self, x): return -(-self).max(-x)
|
||||
def where(self, x, y): return self.alu(TernaryOps.WHERE, x, y)
|
||||
def recip(self): return self.alu(UnaryOps.RECIP)
|
||||
def const(self:Union[UOp, DType, None], b:ConstType|Variable): return UOp._const(self.dtype if isinstance(self, UOp) else self, b)
|
||||
def sconst(self:Union[UOp, DType, None], b:ConstType|Variable):
|
||||
return UOp._const(cast(DType, self.dtype if isinstance(self, UOp) else self).scalar() if self is not None else self, b)
|
||||
@staticmethod
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _const(dtype:Optional[DType], b:ConstType|Variable):
|
||||
# TODO: fix dtype of b.max after Variable is just an UOp
|
||||
if isinstance(b, Variable): return UOp(UOps.DEFINE_VAR, dtype, (UOp.const(dtypes.int, b.min), UOp.const(dtypes.int, cast(int,b.max))), b)
|
||||
if dtype is not None and dtype != (sdtype := dtype.scalar()):
|
||||
return UOp(UOps.VECTORIZE, dtype, src=tuple(UOp(UOps.CONST, sdtype, arg=dtypes.as_const(b, sdtype)) for _ in range(dtype.count)))
|
||||
return UOp(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b)
|
||||
def alu(self, arg, *src:UOp):
|
||||
return type(self)(UOps.ALU, dtypes.bool if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else (self, *src)[-1].dtype, (self,)+src, arg)
|
||||
@staticmethod
|
||||
def load(*src:UOp, dtype:Optional[DType]=None, **kwargs): return type(src[0])(UOps.LOAD, dtype, tuple(src)+tuple(kwargs.values()))
|
||||
@staticmethod
|
||||
def store(*src:UOp, **kwargs): return type((src:=(*src, *kwargs.values()))[0])(UOps.STORE, None, src)
|
||||
@functools.cached_property
|
||||
def parents(self) -> Dict[UOp, None]: return merge_dicts([{x:None for x in self.src}]+[x.parents for x in self.src])
|
||||
@property # parents with self
|
||||
def sparents(self) -> Dict[UOp, None]: return {**self.parents, self:None}
|
||||
@staticmethod
|
||||
def from_st(st:ShapeTracker) -> Tuple[UOp, UOp]: return UOp(UOps.ST_IDX, dtypes.pyint, (), st), UOp(UOps.ST_VALID, dtypes.bool, (), st)
|
||||
@functools.cached_property
|
||||
def full_shape(self) -> Tuple[sint, ...]:
|
||||
if self.op in {UOps.ST_IDX, UOps.ST_VALID}: return self.arg.shape
|
||||
# NOTE: UOps.DEFINE_GLOBAL and UOps.DEFINE_LOCAL don't have shape
|
||||
return tuple(max(x) for x in zip(*[x.full_shape for x in self.src if x.op not in {UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL}]))
|
||||
def vars(self) -> Set[UOp]: return set([x for x in self.sparents if x.op is UOps.DEFINE_VAR])
|
||||
def variables(self) -> List[Variable]:
|
||||
st_vars: List[Set[Variable]] = [x.src[-1].arg.vars() for x in self.sparents if x.op in BUFFER_UOPS]
|
||||
return sorted(set.union(*st_vars, set([x.arg for x in self.sparents if x.op is UOps.DEFINE_VAR])), key=lambda v: v.expr)
|
||||
def const_factor(self) -> int:
|
||||
"""largest known int that divides self"""
|
||||
if self.op is UOps.CONST: return self.arg
|
||||
if self.op is UOps.ALU:
|
||||
if self.arg is BinaryOps.ADD: return math.gcd(self.src[0].const_factor(), self.src[0].const_factor())
|
||||
if self.arg is BinaryOps.MUL: return self.src[0].arg if self.src[0].op is UOps.CONST else self.src[1].arg if self.src[1].op is UOps.CONST else 1
|
||||
return 1
|
||||
def divides(self, v) -> Optional[UOp]:
|
||||
if v==1: return self
|
||||
if self.op is UOps.CONST: return self.const(self.arg//v) if self.arg%v == 0 else None
|
||||
if self.op is UOps.ALU:
|
||||
if self.arg is BinaryOps.ADD: return d0+d1 if (d0:=self.src[0].divides(v)) is not None and (d1:=self.src[1].divides(v)) is not None else None
|
||||
if self.arg is BinaryOps.MUL:
|
||||
if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1]
|
||||
if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1
|
||||
return None # generic None if we aren't sure
|
||||
@functools.cached_property
|
||||
def vmin(self) -> UOp: return x if (x:=self._min_max[0]) is not None and not math.isnan(x.arg) else self.sconst(dtypes.min(cast(DType, self.dtype)))
|
||||
@functools.cached_property
|
||||
def vmax(self) -> UOp: return x if (x:=self._min_max[1]) is not None and not math.isnan(x.arg) else self.sconst(dtypes.max(cast(DType, self.dtype)))
|
||||
@functools.cached_property
|
||||
def _min_max(self) -> Tuple[Optional[UOp], Optional[UOp]]:
|
||||
# NOTE: returned UOp is assumed to be CONST
|
||||
if self.op is UOps.DEFINE_VAR and self.src: return self.src[0], self.src[1] if isinstance(self.src[1].arg, int) else None
|
||||
if self.op is UOps.RANGE: return self.src[0], self.const(self.src[1].arg-1) if isinstance(self.src[1].arg, int) else None
|
||||
# TODO: UOps.SPECIAL is UOps.DEFINE_VAR
|
||||
if self.op is UOps.SPECIAL: return self.const(0), self.const(self.arg[1]-1) if isinstance(self.arg[1], int) else None
|
||||
if self.op is UOps.CONST: return self, self
|
||||
if self.op is UOps.ALU and cast(DType, self.dtype).count == 1:
|
||||
s0,s1 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(2)]
|
||||
if self.arg is UnaryOps.NEG and self.dtype != dtypes.bool and not dtypes.is_unsigned(cast(DType, self.dtype)):
|
||||
return self.sconst(-s0.vmax.arg), self.sconst(-s0.vmin.arg)
|
||||
if self.arg is BinaryOps.ADD: return self.sconst(s0.vmin.arg+s1.vmin.arg), self.sconst(s0.vmax.arg+s1.vmax.arg)
|
||||
if self.arg is BinaryOps.MUL and (s0.vmin.arg >= 0 or s1.vmin.arg >= 0):
|
||||
# handle at lease one is non-negative
|
||||
Lmin, Lmax = (s0.vmin.arg, s0.vmax.arg) if s1.vmin.arg >= 0 else (s0.vmax.arg, s0.vmin.arg)
|
||||
Rmin, Rmax = (s1.vmin.arg, s1.vmax.arg) if s0.vmin.arg >= 0 else (s1.vmax.arg, s1.vmin.arg)
|
||||
assert math.isnan(Lmax*Rmax) or math.isnan(Lmin*Rmin) or Lmax*Rmax >= Lmin*Rmin, f"{Lmax=}, {Lmin=}, {Rmax=}, {Rmin=}"
|
||||
return self.sconst(Lmin*Rmin), self.sconst(Lmax*Rmax)
|
||||
if self.arg is BinaryOps.MOD and s1.op is UOps.CONST and s1.arg > 0: return self.sconst(0), self.sconst(s1.arg-1)
|
||||
if self.arg is BinaryOps.IDIV and s1.op is UOps.CONST:
|
||||
if s1.arg > 0: return self.sconst(s0.vmin.arg//s1.arg), self.sconst(s0.vmax.arg//s1.arg)
|
||||
if s1.arg < 0: return self.sconst(-(s0.vmax.arg//-s1.arg)), self.sconst(-(s0.vmin.arg//-s1.arg))
|
||||
if self.arg is BinaryOps.MAX: return self.sconst(max(s0.vmin.arg, s1.vmin.arg)), self.sconst(max(s0.vmax.arg, s1.vmax.arg))
|
||||
if self.arg is BinaryOps.CMPLT: return (UOp.sconst(dtypes.bool, True), UOp.sconst(dtypes.bool, True)) if s0.vmax.arg < s1.vmin.arg else \
|
||||
(UOp.sconst(dtypes.bool, False), UOp.sconst(dtypes.bool, False)) if s0.vmin.arg >= s1.vmax.arg else (None, None)
|
||||
return None, None
|
||||
|
||||
@dataclass(frozen=True, repr=False) # reuse repr from UOp
|
||||
class NOp(UOp):
|
||||
name:Optional[str] = None
|
||||
src:Tuple[NOp, ...] = tuple()
|
||||
allow_any_len:bool = False
|
||||
@staticmethod
|
||||
def var(name:Optional[str]=None, dtype:Optional[DType]=None): return NOp(UOps.NOOP, dtype=dtype, name=name)
|
||||
@staticmethod
|
||||
def cvar(name:Optional[str]=None, dtype:Optional[DType]=None): return NOp(UOps.CONST, dtype=dtype, name=name)
|
||||
def const(self:Union[UOp, DType, None], b:ConstType|Variable): return NOp((x:=UOp.const(self, b)).op, x.dtype, x.src, x.arg)
|
||||
|
||||
def compile(self: NOp, name:Optional[str]=None) -> UPat:
|
||||
return UPat(name=self.name, dtype=self.dtype) if self.op is UOps.NOOP else UPat(self.op, self.arg, (list if self.commutative()
|
||||
else tuple)(src.compile() for src in self.src) or None, self.name or name, self.dtype, self.allow_any_len)
|
||||
|
||||
class UPat:
|
||||
def __init__(self, op:Optional[Union[UOps, Set[UOps]]]=None, arg:Any=None, src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None,
|
||||
name:Optional[str]=None, dtype:Optional[Union[DType, Set[DType]]]=None, allow_any_len:bool=False):
|
||||
self.op: Optional[Tuple[UOps, ...]] = None if op is None else (tuple(op) if isinstance(op, set) else (op,))
|
||||
self.dtype: Optional[Tuple[DType, ...]] = None if dtype is None else (tuple(dtype) if isinstance(dtype, set) else (dtype,))
|
||||
self.arg, self.name = arg, name
|
||||
self.src: Any = None
|
||||
# try all permutations if it's a list
|
||||
if isinstance(src, list): self.src = list(itertools.permutations(src))
|
||||
# only one if it's a tuple
|
||||
elif isinstance(src, tuple): self.src = [src]
|
||||
# repeat if it's a UPat
|
||||
elif isinstance(src, UPat): self.src = [itertools.repeat(src)]
|
||||
|
||||
self.allowed_len: int = 0 if allow_any_len or isinstance(src, UPat) or src is None else len(src)
|
||||
|
||||
def __repr__(self):
|
||||
def rep(x):
|
||||
form = "UPat(%s, %s, name=%s, dtype=%s, allow_any_len=%s, src=%s)"
|
||||
return form % (None if x.op is None else ('(%s)'%', '.join(map(str, x.op))), x.arg, repr(x.name),
|
||||
set(x.dtype) if x.dtype else None, x.allowed_len == 0, "[%s]" if x.src and len(x.src)>1 else "(%s)")
|
||||
return pretty_print(self, rep, srcfn=lambda x:None if x.src is None else [next(x.src[0])] if isinstance(x.src[0], itertools.repeat) else x.src[0])
|
||||
|
||||
def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> List[Dict[str, UOp]]:
|
||||
if (pat.name is not None and store.setdefault(pat.name, uop) is not uop) or \
|
||||
(pat.dtype is not None and uop.dtype not in pat.dtype) or \
|
||||
(pat.arg is not None and pat.arg != uop.arg) or \
|
||||
(pat.op is not None and uop.op not in pat.op): return []
|
||||
if pat.src is None: return [store]
|
||||
res: List[Dict[str, UOp]] = []
|
||||
for vp in pat.src:
|
||||
if pat.allowed_len != 0 and len(uop.src) != pat.allowed_len: return []
|
||||
new_stores = [store.copy()]
|
||||
for uu, vv in zip(uop.src, vp): new_stores = [rstore for nstore in new_stores for rstore in _match(uu, vv, nstore)]
|
||||
res.extend(new_stores)
|
||||
return res
|
||||
|
||||
class PatternMatcher:
|
||||
def __init__(self, patterns:List[Tuple[Union[UPat, NOp], Callable]]):
|
||||
self.patterns = patterns
|
||||
self.pdict: DefaultDict[Tuple[UOps, Any], List[Tuple[UPat, Callable]]] = defaultdict(list)
|
||||
# uop is required, arg is optional
|
||||
for p,fxn in self.patterns:
|
||||
if isinstance(p, NOp): p = p.compile()
|
||||
assert p.op is not None
|
||||
for uop in p.op: self.pdict[(uop, p.arg)].append((p, fxn))
|
||||
|
||||
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
||||
def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns)
|
||||
|
||||
def rewrite(self, uop:UOp) -> Optional[UOp]:
|
||||
for p,fxn in itertools.chain(self.pdict[(uop.op, uop.arg)], self.pdict[(uop.op, None)]):
|
||||
if (matches := _match(uop, p, {})) and (ret:=fxn(**matches[0])) is not None: return ret # NOTE: if it returns None, we keep trying to match
|
||||
return None
|
||||
|
||||
def type_verify(uops):
|
||||
for u in uops:
|
||||
uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype
|
||||
if uop in {UOps.CONST, UOps.DEFINE_ACC}:
|
||||
if uop is UOps.CONST:
|
||||
assert dtype is not None and dtype == dtype.scalar(), f"consts must be scalar, got {dtype}"
|
||||
assert type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
|
||||
if uop is UOps.DEFINE_ACC: assert dtype is not None and src[0].dtype == dtype, f"dtype mismatch {src[0].dtype=} != {dtype=}"
|
||||
if uop in {UOps.CAST, UOps.BITCAST, UOps.VECTORIZE}: assert arg is None and dtype is not None # type is the output type, not an arg
|
||||
if uop is UOps.CAST: assert dtype.count == 1 and len(src) == 1
|
||||
if uop is UOps.VECTORIZE:
|
||||
assert dtype.count > 1 and len(src) == dtype.count, f"dtype vectorization mismatch {dtype.count=} != {len(src)=}"
|
||||
assert all(dtype == x.dtype.vec(len(src)) for x in src), f"{dtype=} must be {src[0].dtype.vec(len(src))}"
|
||||
if uop is UOps.LOAD and len(src) > 3 and src[3].op is UOps.ALU: assert src[3].dtype == dtypes.bool and src[2].dtype == dtype
|
||||
if uop is UOps.GEP: assert dtype == src[0].dtype.scalar(), f"GEP of {src[0].dtype=} should be {src[0].dtype.scalar()} != {dtype}"
|
||||
if uop is UOps.STORE:
|
||||
assert dtype is None, f"{uop} dtype must be None, got {dtype}"
|
||||
if len(src) == 4: assert src[3].dtype == dtypes.bool, f"gate dtype mismatch {src[3].dtype} != {dtypes.bool}"
|
||||
if uop is UOps.ALU:
|
||||
if arg in UnaryOps: assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
|
||||
elif arg in {BinaryOps.CMPLT, BinaryOps.CMPNE}:
|
||||
assert dtype == (bd:=(dtypes.bool.vec(dtype.count) if dtype.count != 1 else dtypes.bool)), f"{arg} output dtype mismatch {dtype=} != {bd=}"
|
||||
assert src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
|
||||
elif arg is BinaryOps.IDIV:
|
||||
assert dtypes.is_int(src[0].dtype) and dtypes.is_int(src[1].dtype), f"input dtype is not int {src[0].dtype=}, {src[1].dtype=}"
|
||||
assert dtypes.is_int(dtype), f"output dtype is not int {dtype=}"
|
||||
elif arg in {BinaryOps.SHL, BinaryOps.SHR}:
|
||||
# the distance to shift isn't typechecked
|
||||
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
|
||||
elif arg in BinaryOps: assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
|
||||
elif arg == TernaryOps.WHERE:
|
||||
assert src[0].dtype == (bd:=(dtypes.bool.vec(dtype.count) if dtype.count != 1 else dtypes.bool)), \
|
||||
f"{arg} selector dtype mismatch {src[0].dtype=} != {bd}"
|
||||
assert dtype == src[1].dtype == src[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {src[1].dtype=} != {src[2].dtype=}"
|
||||
|
||||
def uop_alu_resolve(u:UOp) -> sint:
|
||||
if u.op in {UOps.CONST, UOps.DEFINE_VAR}: return u.arg
|
||||
if u.op is UOps.ALU: return exec_alu(u.arg, cast(DType,u.dtype), tuple(map(uop_alu_resolve, u.src)))
|
||||
raise RuntimeError(f"ALU resolve fail @ {u.op}")
|
||||
|
||||
def print_uops(uops:List[UOp]):
|
||||
for i,u in enumerate(uops):
|
||||
formatted_parents = [uops.index(x) if x.op is not UOps.CONST else f"{x.arg}" for x in u.src]
|
||||
print(f"{i:4d} {str(u.op):20s}: {str(u.dtype) if u.dtype is not None else '':25s} " f"{str(formatted_parents):32s} {u.arg}")
|
||||
|
||||
def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
|
||||
flops: sint = 0
|
||||
mem: sint = 0
|
||||
mults: sint = 1
|
||||
mult_stack: List[sint] = []
|
||||
dont_count: Set[UOp] = set()
|
||||
if ignore_indexing:
|
||||
for u in uops:
|
||||
if u.op is UOps.LOAD:
|
||||
dont_count = dont_count.union(u.src[1].sparents)
|
||||
if len(u.src) > 3: dont_count = dont_count.union(u.src[2].sparents)
|
||||
elif u.op is UOps.STORE:
|
||||
dont_count = dont_count.union(u.src[1].sparents)
|
||||
if len(u.src) > 3: dont_count = dont_count.union(u.src[3].sparents)
|
||||
elif u.op is UOps.IF:
|
||||
dont_count = dont_count.union(u.src[0].sparents)
|
||||
for u in uops:
|
||||
if u.op is UOps.RANGE:
|
||||
mult_stack.append(mults)
|
||||
mults *= uop_alu_resolve(u.src[1] - u.src[0])
|
||||
elif u.op is UOps.ENDRANGE:
|
||||
mults = mult_stack.pop(-1)
|
||||
elif u.op is UOps.SPECIAL:
|
||||
mults *= u.arg[1] # NOTE: we don't push to the mult_stack here, you can't end these
|
||||
elif u.op is UOps.LOAD:
|
||||
assert u.dtype is not None
|
||||
mem += u.dtype.itemsize * mults
|
||||
elif u.op is UOps.STORE:
|
||||
assert u.src[2].dtype is not None
|
||||
mem += u.src[2].dtype.itemsize * mults
|
||||
elif u.op is UOps.ALU and u not in dont_count:
|
||||
assert u.dtype is not None
|
||||
flops += (mults * (2 if u.arg == TernaryOps.MULACC else 1)) * u.dtype.count
|
||||
elif u.op is UOps.WMMA and u not in dont_count:
|
||||
assert u.arg[1] is not None
|
||||
flops += 2 * prod(u.arg[1]) // 32 * mults
|
||||
return flops, mem
|
||||
|
||||
# the living definition of UOps.ST_IDX and UOps.ST_VALID
|
||||
def verify_ast(ast:UOp) -> Dict[UOp, ShapeTracker]:
|
||||
assert ast.op is UOps.SINK and all(x.op is UOps.STORE for x in ast.src), "must be SINK"
|
||||
sts: Dict[UOp, ShapeTracker] = {}
|
||||
def assert_valid(op:UOp, st:ShapeTracker):
|
||||
if op in sts or op.op in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL}: return
|
||||
# restore globals from the two stage reduce
|
||||
if op.op is UOps.LOAD and op.src[0].op is UOps.DEFINE_LOCAL:
|
||||
assert_valid(local_reduce:=op.src[1].src[2], op.src[-1].arg)
|
||||
return sts.setdefault(op, sts[local_reduce])
|
||||
for x in op.src: assert_valid(x, st)
|
||||
# only reduceop is allowed to change shape, limited to turning n to 1
|
||||
if op.op is UOps.REDUCE_AXIS: st = ShapeTracker.from_shape(reduce_st(sts[op.src[0]], op.arg[1][-1] if op.arg[0] is ReduceOps.WMMA else op.arg[1]))
|
||||
else:
|
||||
# movementops are pushed to the edges with ST_IDX, ST_VALID
|
||||
# elementwise inherits shape
|
||||
st = op.arg if op.op in {UOps.ST_IDX, UOps.ST_VALID} else sts[op.src[-1]]
|
||||
for x in (op.src[1:] if op.op in BUFFER_UOPS else op.src):
|
||||
if sts[x].shape != st.shape:
|
||||
if prod(sts[x].shape) == prod(st.shape): raise AssertionError(f"found implicit reshape {x.op} {op.op} {sts[x].shape} != {st.shape}")
|
||||
raise AssertionError(f"found implicit expand {x.op} {sts[x].shape} != {op.op} {st.shape} {prod(sts[x].shape)} != {prod(st.shape)}")
|
||||
sts[op] = st
|
||||
for out in ast.src: assert_valid(out, out.src[-1].arg)
|
||||
shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])]
|
||||
assert all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims), f"shapes must have either 1 or n in each dimension, {shape_dims}"
|
||||
return sts
|
||||
|
|
|
@ -2,8 +2,7 @@ from typing import Optional, List, Tuple, Dict, Callable, Any
|
|||
import functools
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.helpers import to_function_name, dedup
|
||||
from tinygrad.codegen.uops import UOps, UOp, flops_mem
|
||||
from tinygrad.ops import Op
|
||||
from tinygrad.ops import Op, UOps, UOp, flops_mem
|
||||
from tinygrad.shape.symbolic import sym_infer, sint, Variable
|
||||
from tinygrad.dtype import DType
|
||||
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable
|
||||
import struct, math
|
||||
from collections import defaultdict
|
||||
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op
|
||||
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op, UOps, UOp, PatternMatcher, UPat
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType, ConstType
|
||||
from tinygrad.codegen.uops import UOps, UOp, PatternMatcher, UPat
|
||||
from tinygrad.renderer import Renderer, TensorCore
|
||||
|
||||
def render_val(x, dtype):
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
from typing import Dict, List, Optional, Tuple, Union, DefaultDict, cast, Literal, Callable
|
||||
import os, math
|
||||
from collections import defaultdict, Counter
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOps, UOp
|
||||
from tinygrad.helpers import strip_parens, getenv, prod, dedup
|
||||
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, ConstType
|
||||
from tinygrad.codegen.uops import UOps, UOp
|
||||
from tinygrad.renderer import Renderer, TensorCore
|
||||
|
||||
class CStyleLanguage(Renderer):
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
from typing import Dict, Callable, Any, List, Optional
|
||||
from llvmlite import ir
|
||||
from tinygrad.dtype import DType, PtrDType, dtypes
|
||||
from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps
|
||||
from tinygrad.codegen.uops import UOps, UOp
|
||||
from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps, UOps, UOp
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
MFLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf
|
||||
|
|
|
@ -7,8 +7,7 @@ import pickle, base64, itertools, time, struct
|
|||
from tinygrad.dtype import DType, dtypes, ImageDType
|
||||
from tinygrad.helpers import all_same, getenv, flatten
|
||||
from tinygrad.device import Compiled, Compiler, Allocator
|
||||
from tinygrad.codegen.uops import UOps, UOp
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, exec_alu, truncate
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, exec_alu, truncate, UOps, UOp
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, AMDRenderer, IntelRenderer
|
||||
|
||||
|
|
Loading…
Reference in New Issue