merge uops with ops (#6111)

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
qazal 2024-08-17 06:17:57 +08:00 committed by GitHub
parent 379d080e74
commit 28c75bf2a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
41 changed files with 372 additions and 405 deletions

View File

@ -39,8 +39,7 @@ DEVICE = "CLANG" # NOTE: you can change this!
import struct import struct
from tinygrad.dtype import PtrDType, dtypes from tinygrad.dtype import PtrDType, dtypes
from tinygrad.device import Buffer, Device from tinygrad.device import Buffer, Device
from tinygrad.ops import BinaryOps, MetaOps from tinygrad.ops import BinaryOps, MetaOps, UOp, UOps
from tinygrad.codegen.uops import UOp, UOps
from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.shapetracker import ShapeTracker
# allocate some buffers + load in values # allocate some buffers + load in values

View File

@ -4,7 +4,7 @@ from extra.mcts_search import mcts_search
from examples.mlperf.helpers import get_mlperf_bert_model from examples.mlperf.helpers import get_mlperf_bert_model
from tinygrad import Tensor, Device, dtypes, nn from tinygrad import Tensor, Device, dtypes, nn
from tinygrad.codegen.kernel import Kernel from tinygrad.codegen.kernel import Kernel
from tinygrad.codegen.uops import UOps from tinygrad.ops import UOps
from tinygrad.device import Compiled from tinygrad.device import Compiled
from tinygrad.engine.schedule import create_schedule from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.search import time_linearizer, beam_search, bufs_from_lin from tinygrad.engine.search import time_linearizer, beam_search, bufs_from_lin

View File

@ -18,7 +18,7 @@ from tinygrad.device import Buffer
from tinygrad.helpers import partition, Context, fetch, getenv, DEBUG, tqdm from tinygrad.helpers import partition, Context, fetch, getenv, DEBUG, tqdm
from tinygrad.engine.realize import run_schedule, lower_schedule, ExecItem, CompiledRunner, memory_planner from tinygrad.engine.realize import run_schedule, lower_schedule, ExecItem, CompiledRunner, memory_planner
from tinygrad.engine.schedule import ScheduleItem, create_schedule from tinygrad.engine.schedule import ScheduleItem, create_schedule
from tinygrad.codegen.uops import UOps from tinygrad.ops import UOps
from tinygrad.tensor import _to_np_dtype from tinygrad.tensor import _to_np_dtype
Device.DEFAULT = "GPU" Device.DEFAULT = "GPU"

View File

@ -3,9 +3,8 @@ from typing import Dict, Union, Tuple, Any, List
import functools, hashlib import functools, hashlib
from enum import Enum, auto from enum import Enum, auto
from dataclasses import dataclass from dataclasses import dataclass
from tinygrad.codegen.uops import UOp, UOps
from tinygrad.helpers import dedup, pretty_print, prod from tinygrad.helpers import dedup, pretty_print, prod
from tinygrad.ops import ReduceOps, UnaryOps, BinaryOps, TernaryOps, reduce_st from tinygrad.ops import ReduceOps, UnaryOps, BinaryOps, TernaryOps, reduce_st, UOp, UOps
from tinygrad.dtype import ImageDType, PtrDType, dtypes, DType, ConstType from tinygrad.dtype import ImageDType, PtrDType, dtypes, DType, ConstType
from tinygrad.shape.symbolic import Variable, sint from tinygrad.shape.symbolic import Variable, sint
from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.shapetracker import ShapeTracker

View File

@ -1,7 +1,7 @@
from extra.models.resnet import ResNet50 from extra.models.resnet import ResNet50
from tinygrad import Tensor from tinygrad import Tensor
from tinygrad.codegen.uops import UOps
from tinygrad.helpers import Profiling, Timing, getenv from tinygrad.helpers import Profiling, Timing, getenv
from tinygrad.ops import UOps
from tinygrad.codegen.kernel import Kernel from tinygrad.codegen.kernel import Kernel
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -7,12 +7,11 @@ from extra.optimization.helpers import load_worlds, ast_str_to_lin, kern_str_to_
from tinygrad import Tensor, Device, dtypes from tinygrad import Tensor, Device, dtypes
from tinygrad.tensor import _to_np_dtype from tinygrad.tensor import _to_np_dtype
from tinygrad.codegen.kernel import Kernel from tinygrad.codegen.kernel import Kernel
from tinygrad.codegen.uops import UOp, UOps
from tinygrad.codegen.kernel import Opt, OptOps from tinygrad.codegen.kernel import Opt, OptOps
from tinygrad.engine.search import get_kernel_actions, bufs_from_lin from tinygrad.engine.search import get_kernel_actions, bufs_from_lin
from tinygrad.engine.realize import CompiledRunner from tinygrad.engine.realize import CompiledRunner
from tinygrad.helpers import getenv, from_mv, prod, colored, Context, DEBUG from tinygrad.helpers import getenv, from_mv, prod, colored, Context, DEBUG
from tinygrad.ops import UnaryOps from tinygrad.ops import UnaryOps, UOp, UOps
from test.helpers import is_dtype_supported from test.helpers import is_dtype_supported
def tuplize_uops(uops:List[UOp]) -> Tuple: def tuplize_uops(uops:List[UOp]) -> Tuple:

View File

@ -3,7 +3,7 @@ from collections import defaultdict
import numpy as np import numpy as np
from dataclasses import replace from dataclasses import replace
from typing import DefaultDict, Dict, List, Tuple from typing import DefaultDict, Dict, List, Tuple
from tinygrad.codegen.uops import END_FOR_UOP, UOp from tinygrad.ops import END_FOR_UOP, UOp
from tinygrad.codegen.uopgraph import UOpGraph from tinygrad.codegen.uopgraph import UOpGraph
from tinygrad.device import Buffer, Device from tinygrad.device import Buffer, Device
from tinygrad.engine.realize import CompiledRunner from tinygrad.engine.realize import CompiledRunner

View File

@ -2,7 +2,7 @@ import sys, unittest
from typing import Optional, Set, Tuple from typing import Optional, Set, Tuple
import numpy as np import numpy as np
from tinygrad import Tensor, Device, dtypes from tinygrad import Tensor, Device, dtypes
from tinygrad.codegen.uops import UOp from tinygrad.ops import UOp
from tinygrad.tensor import _to_np_dtype from tinygrad.tensor import _to_np_dtype
from tinygrad.engine.realize import Runner from tinygrad.engine.realize import Runner
from tinygrad.dtype import DType from tinygrad.dtype import DType

View File

@ -1,6 +1,6 @@
import unittest, math import unittest, math
from tinygrad import Tensor, Device, dtypes from tinygrad import Tensor, Device, dtypes
from tinygrad.codegen.uops import UOps from tinygrad.ops import UOps
from tinygrad.engine.schedule import create_schedule from tinygrad.engine.schedule import create_schedule
from tinygrad.helpers import CI from tinygrad.helpers import CI
import numpy as np import numpy as np

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
import unittest import unittest
from tinygrad.codegen.uops import UOps from tinygrad.ops import UOps
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.nn import Conv2d from tinygrad.nn import Conv2d
from tinygrad.engine.schedule import create_schedule from tinygrad.engine.schedule import create_schedule

View File

@ -4,12 +4,11 @@ from tinygrad import Tensor, dtypes, Device
import operator import operator
import numpy as np import numpy as np
from hypothesis import given, strategies as strat, settings from hypothesis import given, strategies as strat, settings
from tinygrad.codegen.uops import UOps
from tinygrad.dtype import DType from tinygrad.dtype import DType
from tinygrad.helpers import CI, getenv from tinygrad.helpers import CI, getenv
from tinygrad.engine.schedule import create_schedule from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import run_schedule from tinygrad.engine.realize import run_schedule
from tinygrad.ops import UnaryOps from tinygrad.ops import UnaryOps, UOps
from tinygrad.tensor import _to_np_dtype from tinygrad.tensor import _to_np_dtype
from test.helpers import is_dtype_supported from test.helpers import is_dtype_supported

View File

@ -2,7 +2,7 @@
import numpy as np import numpy as np
import unittest import unittest
from tinygrad import Tensor, Device, dtypes from tinygrad import Tensor, Device, dtypes
from tinygrad.codegen.uops import UOps from tinygrad.ops import UOps
from tinygrad.lazy import LazyBuffer, MetaOps from tinygrad.lazy import LazyBuffer, MetaOps
from tinygrad.engine.schedule import create_schedule from tinygrad.engine.schedule import create_schedule

View File

@ -5,7 +5,7 @@ from dataclasses import replace
from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError, Kernel from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError, Kernel
from tinygrad.codegen.lowerer import get_grouped_dims from tinygrad.codegen.lowerer import get_grouped_dims
from tinygrad.codegen.uops import UOp, UOps from tinygrad.ops import UOp, UOps
from tinygrad.device import Device, Buffer from tinygrad.device import Device, Buffer
from extra.ops import BinaryOps, BufferOps, MemBuffer, ConstBuffer, LazyOp, MetaOps, TernaryOps, ReduceOps, UnaryOps, to_uop from extra.ops import BinaryOps, BufferOps, MemBuffer, ConstBuffer, LazyOp, MetaOps, TernaryOps, ReduceOps, UnaryOps, to_uop
from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.shapetracker import ShapeTracker

View File

@ -4,9 +4,9 @@
import unittest import unittest
from tinygrad import Device, dtypes from tinygrad import Device, dtypes
from tinygrad.codegen.uops import UOps from tinygrad.ops import UOps
from tinygrad.helpers import getenv from tinygrad.helpers import getenv
from extra.ops import LazyOp, BinaryOps, UnaryOps, ReduceOps, TernaryOps, BufferOps, MemBuffer, ConstBuffer, MetaOps # noqa: F401 # pylint: disable=unused-import from extra.ops import LazyOp, BinaryOps, UnaryOps, ReduceOps, TernaryOps, BufferOps, MemBuffer, ConstBuffer, MetaOps
from tinygrad.shape.shapetracker import ShapeTracker, View from tinygrad.shape.shapetracker import ShapeTracker, View
from tinygrad.engine.search import Opt, OptOps from tinygrad.engine.search import Opt, OptOps
from tinygrad.codegen.kernel import Kernel from tinygrad.codegen.kernel import Kernel

View File

@ -3,7 +3,7 @@ import unittest, random
import numpy as np import numpy as np
from tinygrad.codegen.kernel import KernelOptError from tinygrad.codegen.kernel import KernelOptError
from tinygrad.codegen.kernel import Kernel from tinygrad.codegen.kernel import Kernel
from tinygrad.codegen.uops import UOps from tinygrad.ops import UOps
from tinygrad.engine.search import Opt, OptOps from tinygrad.engine.search import Opt, OptOps
from tinygrad import Device, dtypes, Tensor from tinygrad import Device, dtypes, Tensor
from tinygrad.helpers import CI from tinygrad.helpers import CI

View File

@ -1,8 +1,7 @@
import unittest, functools, random import unittest, functools, random
from typing import List from typing import List
from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes
from tinygrad.codegen.uops import UOps from tinygrad.ops import MetaOps, ReduceOps, BinaryOps, UOps
from tinygrad.ops import MetaOps, ReduceOps, BinaryOps
from tinygrad.helpers import CI, getenv, prod, Context from tinygrad.helpers import CI, getenv, prod, Context
from tinygrad.nn.state import get_parameters, get_state_dict from tinygrad.nn.state import get_parameters, get_state_dict
from tinygrad.engine.schedule import create_schedule from tinygrad.engine.schedule import create_schedule

View File

@ -3,7 +3,7 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from tinygrad import Tensor, Device, TinyJit from tinygrad import Tensor, Device, TinyJit
from tinygrad.codegen.uops import UOps from tinygrad.ops import UOps
from tinygrad.helpers import CI, Context from tinygrad.helpers import CI, Context
from tinygrad.nn import Conv1d, ConvTranspose1d, Conv2d, ConvTranspose2d, Linear, Embedding from tinygrad.nn import Conv1d, ConvTranspose1d, Conv2d, ConvTranspose2d, Linear, Embedding
from tinygrad.nn import BatchNorm, LayerNorm, LayerNorm2d, GroupNorm, InstanceNorm, RMSNorm, LSTMCell from tinygrad.nn import BatchNorm, LayerNorm, LayerNorm2d, GroupNorm, InstanceNorm, RMSNorm, LSTMCell

View File

@ -1,8 +1,7 @@
import unittest, itertools import unittest, itertools
from test.helpers import TestUOps from test.helpers import TestUOps
from tinygrad.dtype import dtypes from tinygrad.dtype import dtypes
from tinygrad.ops import BinaryOps, TernaryOps, ReduceOps, UnaryOps # noqa: F401 from tinygrad.ops import UOps, UOp, PatternMatcher, UPat, BinaryOps, TernaryOps, ReduceOps, UnaryOps # noqa: F401
from tinygrad.codegen.uops import UOps, UOp, PatternMatcher, UPat
from tinygrad.codegen.uopgraph import constant_folder from tinygrad.codegen.uopgraph import constant_folder
class TestPatternMatcher(TestUOps): class TestPatternMatcher(TestUOps):

View File

@ -1,13 +1,12 @@
import unittest import unittest
from typing import List, cast from typing import List, cast
import numpy as np import numpy as np
from tinygrad.codegen.uops import UOp, UOps
from tinygrad.device import Buffer, Device from tinygrad.device import Buffer, Device
from tinygrad.dtype import PtrDType, DType, dtypes from tinygrad.dtype import PtrDType, DType, dtypes
from tinygrad.engine.realize import CompiledRunner from tinygrad.engine.realize import CompiledRunner
from tinygrad.helpers import dedup, flatten from tinygrad.helpers import dedup, flatten
from tinygrad.renderer.cstyle import CStyleLanguage from tinygrad.renderer.cstyle import CStyleLanguage
from tinygrad.ops import BinaryOps from tinygrad.ops import BinaryOps, UOp, UOps
from tinygrad.renderer import Program from tinygrad.renderer import Program
from tinygrad.tensor import Tensor, _to_np_dtype from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.lazy import LazyBuffer from tinygrad.lazy import LazyBuffer

View File

@ -8,10 +8,9 @@ from typing import List, Optional, Union, cast
from tinygrad import nn, dtypes from tinygrad import nn, dtypes
from tinygrad.device import Device from tinygrad.device import Device
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.ops import BinaryOps, MetaOps, UnaryOps from tinygrad.ops import BinaryOps, MetaOps, UnaryOps, UOps, verify_ast
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, FUSE_CONV_BW, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, FUSE_CONV_BW, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP
from tinygrad.codegen.kernel import Kernel from tinygrad.codegen.kernel import Kernel
from tinygrad.codegen.uops import UOps, verify_ast
from tinygrad.engine.schedule import create_schedule from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import run_schedule from tinygrad.engine.realize import run_schedule
from test.helpers import is_dtype_supported, Context from test.helpers import is_dtype_supported, Context

View File

@ -2,7 +2,7 @@ import unittest
from tinygrad.codegen.kernel import Opt, OptOps from tinygrad.codegen.kernel import Opt, OptOps
from tinygrad.codegen.kernel import Kernel from tinygrad.codegen.kernel import Kernel
from tinygrad.codegen.uops import UOps from tinygrad.ops import UOps
from tinygrad.engine.schedule import create_schedule from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.search import time_linearizer, bufs_from_lin, actions, beam_search from tinygrad.engine.search import time_linearizer, bufs_from_lin, actions, beam_search
from tinygrad.device import Device, Buffer from tinygrad.device import Device, Buffer

View File

@ -3,8 +3,7 @@ from test.helpers import TestUOps
from tinygrad import dtypes, Variable from tinygrad import dtypes, Variable
from tinygrad.dtype import PtrDType from tinygrad.dtype import PtrDType
from tinygrad.helpers import DEBUG from tinygrad.helpers import DEBUG
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, ReduceOps from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, ReduceOps, UOps, UOp, NOp, PatternMatcher
from tinygrad.codegen.uops import UOps, UOp, NOp, PatternMatcher
from tinygrad.codegen.uopgraph import UOpGraph, graph_rewrite, expander, reducer, constant_folder, float4_folding from tinygrad.codegen.uopgraph import UOpGraph, graph_rewrite, expander, reducer, constant_folder, float4_folding
simple_pm = PatternMatcher([ simple_pm = PatternMatcher([

View File

@ -5,11 +5,10 @@ from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.helpers import CI, DEBUG, getenv, Context from tinygrad.helpers import CI, DEBUG, getenv, Context
from tinygrad.dtype import dtypes, DType, PtrDType from tinygrad.dtype import dtypes, DType, PtrDType
from tinygrad.device import Buffer, Device from tinygrad.device import Buffer, Device
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps, KernelInfo, exec_alu # noqa F401 from tinygrad.ops import UOps, NOp, UOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, KernelInfo, exec_alu # noqa F401
from tinygrad.renderer import Program from tinygrad.renderer import Program
from tinygrad.engine.schedule import create_schedule from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
from tinygrad.codegen.uops import UOps, NOp, UOp
from tinygrad.codegen.uopgraph import UOpGraph from tinygrad.codegen.uopgraph import UOpGraph
from test.helpers import is_dtype_supported, TestUOps as TestEqUOps from test.helpers import is_dtype_supported, TestUOps as TestEqUOps

View File

@ -3,9 +3,8 @@ from tinygrad import Tensor
from tinygrad.helpers import getenv from tinygrad.helpers import getenv
from tinygrad.engine.schedule import create_schedule from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import lower_schedule_item from tinygrad.engine.realize import lower_schedule_item
from tinygrad.codegen.uops import flops_mem, UOps, UOp
from tinygrad.codegen.uopgraph import UOpGraph from tinygrad.codegen.uopgraph import UOpGraph
from tinygrad.ops import BinaryOps, TernaryOps from tinygrad.ops import BinaryOps, TernaryOps, flops_mem, UOps, UOp
from tinygrad.dtype import dtypes from tinygrad.dtype import dtypes
from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError

View File

@ -1,6 +1,6 @@
import unittest import unittest
from tinygrad import Tensor, GlobalCounters from tinygrad import Tensor, GlobalCounters
from tinygrad.codegen.uops import UOps from tinygrad.ops import UOps
from tinygrad.helpers import Timing, CI, Profiling, WINO, DEBUG, getenv from tinygrad.helpers import Timing, CI, Profiling, WINO, DEBUG, getenv
from tinygrad.codegen.kernel import Kernel from tinygrad.codegen.kernel import Kernel
from tinygrad.engine.schedule import create_schedule from tinygrad.engine.schedule import create_schedule

View File

@ -9,9 +9,8 @@ from typing import Tuple
from tinygrad.helpers import DEBUG from tinygrad.helpers import DEBUG
from tinygrad.dtype import dtypes, PtrDType, ConstType from tinygrad.dtype import dtypes, PtrDType, ConstType
from tinygrad.codegen.uops import UOp, UOps
from tinygrad.codegen.uopgraph import UOpGraph from tinygrad.codegen.uopgraph import UOpGraph
from tinygrad.ops import BinaryOps from tinygrad.ops import BinaryOps, UOp, UOps
import functools import functools
def render(self) -> Tuple[str, ConstType, ConstType]: def render(self) -> Tuple[str, ConstType, ConstType]:

View File

@ -4,8 +4,7 @@ from dataclasses import dataclass, replace
from collections import defaultdict from collections import defaultdict
from typing import Literal, Optional, List, Tuple, Union, cast, Dict, Final, DefaultDict from typing import Literal, Optional, List, Tuple, Union, cast, Dict, Final, DefaultDict
from tinygrad.codegen.uops import BUFFER_UOPS, UOp, UOps, verify_ast from tinygrad.ops import BinaryOps, ReduceOps, UNSAFE_PAD_OPS, KernelInfo, BUFFER_UOPS, UOp, UOps, verify_ast
from tinygrad.ops import BinaryOps, ReduceOps, UNSAFE_PAD_OPS, KernelInfo
from tinygrad.device import Device from tinygrad.device import Device
from tinygrad.renderer import Renderer, TensorCore, Program from tinygrad.renderer import Renderer, TensorCore, Program
from tinygrad.dtype import DType, ImageDType, PtrDType from tinygrad.dtype import DType, ImageDType, PtrDType

View File

@ -5,8 +5,7 @@ import functools
from tinygrad.shape.shapetracker import ShapeTracker, View from tinygrad.shape.shapetracker import ShapeTracker, View
from tinygrad.shape.symbolic import sint from tinygrad.shape.symbolic import sint
from tinygrad.dtype import dtypes, DType from tinygrad.dtype import dtypes, DType
from tinygrad.ops import ReduceOps, KernelInfo, BinaryOps from tinygrad.ops import ReduceOps, KernelInfo, BinaryOps, BUFFER_UOPS, UOp, UOps
from tinygrad.codegen.uops import BUFFER_UOPS, UOp, UOps
from tinygrad.renderer import Renderer from tinygrad.renderer import Renderer
from tinygrad.helpers import all_int, get_contraction, prod, partition, flatten from tinygrad.helpers import all_int, get_contraction, prod, partition, flatten

View File

@ -1,7 +1,7 @@
import math, functools import math, functools
from typing import Tuple, List from typing import Tuple, List
from tinygrad.dtype import dtypes, DType from tinygrad.dtype import dtypes, DType
from tinygrad.codegen.uops import UOp from tinygrad.ops import UOp
TRANSCENDENTAL_SUPPORTED_DTYPES = {dtypes.float16, dtypes.float32, dtypes.float64} TRANSCENDENTAL_SUPPORTED_DTYPES = {dtypes.float16, dtypes.float32, dtypes.float64}

View File

@ -3,9 +3,8 @@ from typing import Iterator, Optional, Tuple, Dict, List, Set, Union, cast, TYPE
import functools, itertools, heapq, math, operator import functools, itertools, heapq, math, operator
from collections import defaultdict from collections import defaultdict
from tinygrad.dtype import dtypes, PtrDType, ImageDType, DType from tinygrad.dtype import dtypes, PtrDType, ImageDType, DType
from tinygrad.ops import UnaryOps, BinaryOps, exec_alu from tinygrad.ops import UnaryOps, BinaryOps, exec_alu, UOp, NOp, UOps, UPat, PatternMatcher, END_FOR_UOP, type_verify, print_uops
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, prod, CI, all_same, partition from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, prod, CI, all_same, partition
from tinygrad.codegen.uops import UOp, NOp, UOps, UPat, PatternMatcher, END_FOR_UOP, type_verify, print_uops
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
if TYPE_CHECKING: from tinygrad.renderer import Renderer if TYPE_CHECKING: from tinygrad.renderer import Renderer

View File

@ -1,338 +0,0 @@
from __future__ import annotations
from typing import Optional, Tuple, Any, Set, cast, List, Union, DefaultDict, Callable, Dict
import functools, itertools, math, hashlib
from collections import defaultdict
from enum import Enum, auto
from dataclasses import dataclass
from tinygrad.dtype import ConstType, dtypes, DType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import sint, Variable
from tinygrad.ops import ReduceOps, UnaryOps, BinaryOps, TernaryOps, exec_alu, reduce_st
from tinygrad.helpers import merge_dicts, prod, pretty_print, dedup
# the order of these UOps controls the order of the toposort
class UOps(Enum):
# ops that aren't rendered
SINK = auto(); EXT = auto(); EXPAND = auto(); CONTRACT = auto(); ST_IDX = auto(); ST_VALID = auto() # noqa: E702
DEFINE_GLOBAL = auto(); DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # noqa: E702
CONST = auto(); SPECIAL = auto() # noqa: E702
NOOP = auto(); GEP = auto() # noqa: E702
# math ops
CAST = auto(); BITCAST = auto(); VECTORIZE = auto() # noqa: E702
ALU = auto(); REDUCE = auto(); REDUCE_AXIS = auto(); WMMA = auto() # noqa: E702
# memory/assignment ops
LOAD = auto(); STORE = auto(); PHI = auto() # noqa: E702
# control flow ops
BARRIER = auto(); IF = auto(); RANGE = auto() # noqa: E702
# these two are not graph nodes
ENDRANGE = auto(); ENDIF = auto() # noqa: E702
BUFFER_UOPS = {UOps.LOAD, UOps.STORE, UOps.CONST}
END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.PHI, UOps.ENDRANGE)}
@dataclass(frozen=True, eq=False)
class UOp:
op: UOps
dtype: Optional[DType] = None
src: Tuple[UOp, ...] = tuple()
arg: Any = None
def commutative(self) -> bool:
return (self.op is UOps.ALU and \
self.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR})
@functools.cached_property
def cmp_tuple(self):
# NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX
return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg.expr) if self.op is not UOps.ALU else \
self.arg.value, self.dtype, self.src)
def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple
@functools.cached_property
def key(self) -> bytes:
return hashlib.sha256(functools.reduce(lambda x,y: x+y, [s.key for s in self.src], str((self.op, self.dtype, self.arg)).encode())).digest()
def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))")
def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is UOps.REDUCE_AXIS else self.arg
# *** uop syntactic sugar
def ufix(self, x): return self.const(x) if not isinstance(x, UOp) else x
def cast(self, dtype=None): return type(self)(UOps.CAST, dtype, (self,))
def bitcast(self, dtype=None): return type(self)(UOps.BITCAST, dtype, (self,))
def gep(self, i:int): return type(self)(UOps.GEP, self.dtype.scalar() if self.dtype is not None else None, (self,), i)
def __neg__(self): return self.alu(UnaryOps.NEG)
def __add__(self, x): return self.alu(BinaryOps.ADD, self.ufix(x))
def __radd__(self, x): return self.alu(BinaryOps.ADD, self.ufix(x))
def __sub__(self, x): return self.alu(BinaryOps.ADD, self.ufix(-x))
def __mul__(self, x): return self.alu(BinaryOps.MUL, self.ufix(x))
def __rmul__(self, x): return self.ufix(x).alu(BinaryOps.MUL, self)
def __floordiv__(self, x): return self.alu(BinaryOps.IDIV, self.ufix(x))
def __truediv__(self, x): return self.alu(BinaryOps.MUL, self.ufix(x).alu(UnaryOps.RECIP))
def __mod__(self, x): return self.alu(BinaryOps.MOD, self.ufix(x))
def __xor__(self, x): return self.alu(BinaryOps.XOR, self.ufix(x))
def __and__(self, x): return self.alu(BinaryOps.AND, self.ufix(x))
def __or__(self, x): return self.alu(BinaryOps.OR, self.ufix(x))
def ne(self, x): return self.alu(BinaryOps.CMPNE, self.ufix(x))
def eq(self, x): return -self.ne(x)
def lt(self, x): return self.alu(BinaryOps.CMPLT, self.ufix(x))
def ge(self, x): return (-self).lt(-x+1)
def max(self, x): return self.alu(BinaryOps.MAX, x)
def min(self, x): return -(-self).max(-x)
def where(self, x, y): return self.alu(TernaryOps.WHERE, x, y)
def recip(self): return self.alu(UnaryOps.RECIP)
def const(self:Union[UOp, DType, None], b:ConstType|Variable): return UOp._const(self.dtype if isinstance(self, UOp) else self, b)
def sconst(self:Union[UOp, DType, None], b:ConstType|Variable):
return UOp._const(cast(DType, self.dtype if isinstance(self, UOp) else self).scalar() if self is not None else self, b)
@staticmethod
@functools.lru_cache(maxsize=None)
def _const(dtype:Optional[DType], b:ConstType|Variable):
# TODO: fix dtype of b.max after Variable is just an UOp
if isinstance(b, Variable): return UOp(UOps.DEFINE_VAR, dtype, (UOp.const(dtypes.int, b.min), UOp.const(dtypes.int, cast(int,b.max))), b)
if dtype is not None and dtype != (sdtype := dtype.scalar()):
return UOp(UOps.VECTORIZE, dtype, src=tuple(UOp(UOps.CONST, sdtype, arg=dtypes.as_const(b, sdtype)) for _ in range(dtype.count)))
return UOp(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b)
def alu(self, arg, *src:UOp):
return type(self)(UOps.ALU, dtypes.bool if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else (self, *src)[-1].dtype, (self,)+src, arg)
@staticmethod
def load(*src:UOp, dtype:Optional[DType]=None, **kwargs): return type(src[0])(UOps.LOAD, dtype, tuple(src)+tuple(kwargs.values()))
@staticmethod
def store(*src:UOp, **kwargs): return type((src:=(*src, *kwargs.values()))[0])(UOps.STORE, None, src)
@functools.cached_property
def parents(self) -> Dict[UOp, None]: return merge_dicts([{x:None for x in self.src}]+[x.parents for x in self.src])
@property # parents with self
def sparents(self) -> Dict[UOp, None]: return {**self.parents, self:None}
@staticmethod
def from_st(st:ShapeTracker) -> Tuple[UOp, UOp]: return UOp(UOps.ST_IDX, dtypes.pyint, (), st), UOp(UOps.ST_VALID, dtypes.bool, (), st)
@functools.cached_property
def full_shape(self) -> Tuple[sint, ...]:
if self.op in {UOps.ST_IDX, UOps.ST_VALID}: return self.arg.shape
# NOTE: UOps.DEFINE_GLOBAL and UOps.DEFINE_LOCAL don't have shape
return tuple(max(x) for x in zip(*[x.full_shape for x in self.src if x.op not in {UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL}]))
# TODO: these two should merge
def vars(self) -> Set[UOp]: return set([x for x in self.sparents if x.op is UOps.DEFINE_VAR])
def variables(self) -> List[Variable]:
st_vars: List[Set[Variable]] = [x.src[-1].arg.vars() for x in self.sparents if x.op in BUFFER_UOPS]
return sorted(set.union(*st_vars, set([x.arg for x in self.sparents if x.op is UOps.DEFINE_VAR])), key=lambda v: v.expr)
def const_factor(self) -> int:
"""largest known int that divides self"""
if self.op is UOps.CONST: return self.arg
if self.op is UOps.ALU:
if self.arg is BinaryOps.ADD: return math.gcd(self.src[0].const_factor(), self.src[0].const_factor())
if self.arg is BinaryOps.MUL: return self.src[0].arg if self.src[0].op is UOps.CONST else self.src[1].arg if self.src[1].op is UOps.CONST else 1
return 1
def divides(self, v) -> Optional[UOp]:
if v==1: return self
if self.op is UOps.CONST: return self.const(self.arg//v) if self.arg%v == 0 else None
if self.op is UOps.ALU:
if self.arg is BinaryOps.ADD: return d0+d1 if (d0:=self.src[0].divides(v)) is not None and (d1:=self.src[1].divides(v)) is not None else None
if self.arg is BinaryOps.MUL:
if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1]
if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1
return None # generic None if we aren't sure
@functools.cached_property
def vmin(self) -> UOp: return x if (x:=self._min_max[0]) is not None and not math.isnan(x.arg) else self.sconst(dtypes.min(cast(DType, self.dtype)))
@functools.cached_property
def vmax(self) -> UOp: return x if (x:=self._min_max[1]) is not None and not math.isnan(x.arg) else self.sconst(dtypes.max(cast(DType, self.dtype)))
@functools.cached_property
def _min_max(self) -> Tuple[Optional[UOp], Optional[UOp]]:
# NOTE: returned UOp is assumed to be CONST
if self.op is UOps.DEFINE_VAR and self.src: return self.src[0], self.src[1] if isinstance(self.src[1].arg, int) else None
if self.op is UOps.RANGE: return self.src[0], self.const(self.src[1].arg-1) if isinstance(self.src[1].arg, int) else None
# TODO: UOps.SPECIAL is UOps.DEFINE_VAR
if self.op is UOps.SPECIAL: return self.const(0), self.const(self.arg[1]-1) if isinstance(self.arg[1], int) else None
if self.op is UOps.CONST: return self, self
if self.op is UOps.ALU and cast(DType, self.dtype).count == 1:
s0,s1 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(2)]
if self.arg is UnaryOps.NEG and self.dtype != dtypes.bool and not dtypes.is_unsigned(cast(DType, self.dtype)):
return self.sconst(-s0.vmax.arg), self.sconst(-s0.vmin.arg)
if self.arg is BinaryOps.ADD: return self.sconst(s0.vmin.arg+s1.vmin.arg), self.sconst(s0.vmax.arg+s1.vmax.arg)
if self.arg is BinaryOps.MUL and (s0.vmin.arg >= 0 or s1.vmin.arg >= 0):
# handle at lease one is non-negative
Lmin, Lmax = (s0.vmin.arg, s0.vmax.arg) if s1.vmin.arg >= 0 else (s0.vmax.arg, s0.vmin.arg)
Rmin, Rmax = (s1.vmin.arg, s1.vmax.arg) if s0.vmin.arg >= 0 else (s1.vmax.arg, s1.vmin.arg)
assert math.isnan(Lmax*Rmax) or math.isnan(Lmin*Rmin) or Lmax*Rmax >= Lmin*Rmin, f"{Lmax=}, {Lmin=}, {Rmax=}, {Rmin=}"
return self.sconst(Lmin*Rmin), self.sconst(Lmax*Rmax)
if self.arg is BinaryOps.MOD and s1.op is UOps.CONST and s1.arg > 0: return self.sconst(0), self.sconst(s1.arg-1)
if self.arg is BinaryOps.IDIV and s1.op is UOps.CONST:
if s1.arg > 0: return self.sconst(s0.vmin.arg//s1.arg), self.sconst(s0.vmax.arg//s1.arg)
if s1.arg < 0: return self.sconst(-(s0.vmax.arg//-s1.arg)), self.sconst(-(s0.vmin.arg//-s1.arg))
if self.arg is BinaryOps.MAX: return self.sconst(max(s0.vmin.arg, s1.vmin.arg)), self.sconst(max(s0.vmax.arg, s1.vmax.arg))
if self.arg is BinaryOps.CMPLT: return (UOp.sconst(dtypes.bool, True), UOp.sconst(dtypes.bool, True)) if s0.vmax.arg < s1.vmin.arg else \
(UOp.sconst(dtypes.bool, False), UOp.sconst(dtypes.bool, False)) if s0.vmin.arg >= s1.vmax.arg else (None, None)
return None, None
@dataclass(frozen=True, repr=False) # reuse repr from UOp
class NOp(UOp):
name:Optional[str] = None
src:Tuple[NOp, ...] = tuple()
allow_any_len:bool = False
@staticmethod
def var(name:Optional[str]=None, dtype:Optional[DType]=None): return NOp(UOps.NOOP, dtype=dtype, name=name)
@staticmethod
def cvar(name:Optional[str]=None, dtype:Optional[DType]=None): return NOp(UOps.CONST, dtype=dtype, name=name)
def const(self:Union[UOp, DType, None], b:ConstType|Variable): return NOp((x:=UOp.const(self, b)).op, x.dtype, x.src, x.arg)
def compile(self: NOp, name:Optional[str]=None) -> UPat:
return UPat(name=self.name, dtype=self.dtype) if self.op is UOps.NOOP else UPat(self.op, self.arg, (list if self.commutative()
else tuple)(src.compile() for src in self.src) or None, self.name or name, self.dtype, self.allow_any_len)
class UPat:
def __init__(self, op:Optional[Union[UOps, Set[UOps]]]=None, arg:Any=None, src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None,
name:Optional[str]=None, dtype:Optional[Union[DType, Set[DType]]]=None, allow_any_len:bool=False):
self.op: Optional[Tuple[UOps, ...]] = None if op is None else (tuple(op) if isinstance(op, set) else (op,))
self.dtype: Optional[Tuple[DType, ...]] = None if dtype is None else (tuple(dtype) if isinstance(dtype, set) else (dtype,))
self.arg, self.name = arg, name
self.src: Any = None
# try all permutations if it's a list
if isinstance(src, list): self.src = list(itertools.permutations(src))
# only one if it's a tuple
elif isinstance(src, tuple): self.src = [src]
# repeat if it's a UPat
elif isinstance(src, UPat): self.src = [itertools.repeat(src)]
self.allowed_len: int = 0 if allow_any_len or isinstance(src, UPat) or src is None else len(src)
def __repr__(self):
def rep(x):
form = "UPat(%s, %s, name=%s, dtype=%s, allow_any_len=%s, src=%s)"
return form % (None if x.op is None else ('(%s)'%', '.join(map(str, x.op))), x.arg, repr(x.name),
set(x.dtype) if x.dtype else None, x.allowed_len == 0, "[%s]" if x.src and len(x.src)>1 else "(%s)")
return pretty_print(self, rep, srcfn=lambda x:None if x.src is None else [next(x.src[0])] if isinstance(x.src[0], itertools.repeat) else x.src[0])
def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> List[Dict[str, UOp]]:
if (pat.name is not None and store.setdefault(pat.name, uop) is not uop) or \
(pat.dtype is not None and uop.dtype not in pat.dtype) or \
(pat.arg is not None and pat.arg != uop.arg) or \
(pat.op is not None and uop.op not in pat.op): return []
if pat.src is None: return [store]
res: List[Dict[str, UOp]] = []
for vp in pat.src:
if pat.allowed_len != 0 and len(uop.src) != pat.allowed_len: return []
new_stores = [store.copy()]
for uu, vv in zip(uop.src, vp): new_stores = [rstore for nstore in new_stores for rstore in _match(uu, vv, nstore)]
res.extend(new_stores)
return res
class PatternMatcher:
def __init__(self, patterns:List[Tuple[Union[UPat, NOp], Callable]]):
self.patterns = patterns
self.pdict: DefaultDict[Tuple[UOps, Any], List[Tuple[UPat, Callable]]] = defaultdict(list)
# uop is required, arg is optional
for p,fxn in self.patterns:
if isinstance(p, NOp): p = p.compile()
assert p.op is not None
for uop in p.op: self.pdict[(uop, p.arg)].append((p, fxn))
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns)
def rewrite(self, uop:UOp) -> Optional[UOp]:
for p,fxn in itertools.chain(self.pdict[(uop.op, uop.arg)], self.pdict[(uop.op, None)]):
if (matches := _match(uop, p, {})) and (ret:=fxn(**matches[0])) is not None: return ret # NOTE: if it returns None, we keep trying to match
return None
def type_verify(uops):
for u in uops:
uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype
if uop in {UOps.CONST, UOps.DEFINE_ACC}:
if uop is UOps.CONST:
assert dtype is not None and dtype == dtype.scalar(), f"consts must be scalar, got {dtype}"
assert type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
if uop is UOps.DEFINE_ACC: assert dtype is not None and src[0].dtype == dtype, f"dtype mismatch {src[0].dtype=} != {dtype=}"
if uop in {UOps.CAST, UOps.BITCAST, UOps.VECTORIZE}: assert arg is None and dtype is not None # type is the output type, not an arg
if uop is UOps.CAST: assert dtype.count == 1 and len(src) == 1
if uop is UOps.VECTORIZE:
assert dtype.count > 1 and len(src) == dtype.count, f"dtype vectorization mismatch {dtype.count=} != {len(src)=}"
assert all(dtype == x.dtype.vec(len(src)) for x in src), f"{dtype=} must be {src[0].dtype.vec(len(src))}"
if uop is UOps.LOAD and len(src) > 3 and src[3].op is UOps.ALU: assert src[3].dtype == dtypes.bool and src[2].dtype == dtype
if uop is UOps.GEP: assert dtype == src[0].dtype.scalar(), f"GEP of {src[0].dtype=} should be {src[0].dtype.scalar()} != {dtype}"
if uop is UOps.STORE:
assert dtype is None, f"{uop} dtype must be None, got {dtype}"
if len(src) == 4: assert src[3].dtype == dtypes.bool, f"gate dtype mismatch {src[3].dtype} != {dtypes.bool}"
if uop is UOps.ALU:
if arg in UnaryOps: assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
elif arg in {BinaryOps.CMPLT, BinaryOps.CMPNE}:
assert dtype == (bd:=(dtypes.bool.vec(dtype.count) if dtype.count != 1 else dtypes.bool)), f"{arg} output dtype mismatch {dtype=} != {bd=}"
assert src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
elif arg is BinaryOps.IDIV:
assert dtypes.is_int(src[0].dtype) and dtypes.is_int(src[1].dtype), f"input dtype is not int {src[0].dtype=}, {src[1].dtype=}"
assert dtypes.is_int(dtype), f"output dtype is not int {dtype=}"
elif arg in {BinaryOps.SHL, BinaryOps.SHR}:
# the distance to shift isn't typechecked
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
elif arg in BinaryOps: assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
elif arg == TernaryOps.WHERE:
assert src[0].dtype == (bd:=(dtypes.bool.vec(dtype.count) if dtype.count != 1 else dtypes.bool)), \
f"{arg} selector dtype mismatch {src[0].dtype=} != {bd}"
assert dtype == src[1].dtype == src[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {src[1].dtype=} != {src[2].dtype=}"
def uop_alu_resolve(u:UOp) -> sint:
if u.op in {UOps.CONST, UOps.DEFINE_VAR}: return u.arg
if u.op is UOps.ALU: return exec_alu(u.arg, cast(DType,u.dtype), tuple(map(uop_alu_resolve, u.src)))
raise RuntimeError(f"ALU resolve fail @ {u.op}")
def print_uops(uops:List[UOp]):
for i,u in enumerate(uops):
formatted_parents = [uops.index(x) if x.op is not UOps.CONST else f"{x.arg}" for x in u.src]
print(f"{i:4d} {str(u.op):20s}: {str(u.dtype) if u.dtype is not None else '':25s} " f"{str(formatted_parents):32s} {u.arg}")
def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
flops: sint = 0
mem: sint = 0
mults: sint = 1
mult_stack: List[sint] = []
dont_count: Set[UOp] = set()
if ignore_indexing:
for u in uops:
if u.op is UOps.LOAD:
dont_count = dont_count.union(u.src[1].sparents)
if len(u.src) > 3: dont_count = dont_count.union(u.src[2].sparents)
elif u.op is UOps.STORE:
dont_count = dont_count.union(u.src[1].sparents)
if len(u.src) > 3: dont_count = dont_count.union(u.src[3].sparents)
elif u.op is UOps.IF:
dont_count = dont_count.union(u.src[0].sparents)
for u in uops:
if u.op is UOps.RANGE:
mult_stack.append(mults)
mults *= uop_alu_resolve(u.src[1] - u.src[0])
elif u.op is UOps.ENDRANGE:
mults = mult_stack.pop(-1)
elif u.op is UOps.SPECIAL:
mults *= u.arg[1] # NOTE: we don't push to the mult_stack here, you can't end these
elif u.op is UOps.LOAD:
assert u.dtype is not None
mem += u.dtype.itemsize * mults
elif u.op is UOps.STORE:
assert u.src[2].dtype is not None
mem += u.src[2].dtype.itemsize * mults
elif u.op is UOps.ALU and u not in dont_count:
assert u.dtype is not None
flops += (mults * (2 if u.arg == TernaryOps.MULACC else 1)) * u.dtype.count
elif u.op is UOps.WMMA and u not in dont_count:
assert u.arg[1] is not None
flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
return flops, mem
# the living definition of UOps.ST_IDX and UOps.ST_VALID
def verify_ast(ast:UOp) -> Dict[UOp, ShapeTracker]:
assert ast.op is UOps.SINK and all(x.op is UOps.STORE for x in ast.src), "must be SINK"
sts: Dict[UOp, ShapeTracker] = {}
def assert_valid(op:UOp, st:ShapeTracker):
if op in sts or op.op in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL}: return
# restore globals from the two stage reduce
if op.op is UOps.LOAD and op.src[0].op is UOps.DEFINE_LOCAL:
assert_valid(local_reduce:=op.src[1].src[2], op.src[-1].arg)
return sts.setdefault(op, sts[local_reduce])
for x in op.src: assert_valid(x, st)
# only reduceop is allowed to change shape, limited to turning n to 1
if op.op is UOps.REDUCE_AXIS: st = ShapeTracker.from_shape(reduce_st(sts[op.src[0]], op.arg[1][-1] if op.arg[0] is ReduceOps.WMMA else op.arg[1]))
else:
# movementops are pushed to the edges with ST_IDX, ST_VALID
# elementwise inherits shape
st = op.arg if op.op in {UOps.ST_IDX, UOps.ST_VALID} else sts[op.src[-1]]
for x in (op.src[1:] if op.op in BUFFER_UOPS else op.src):
if sts[x].shape != st.shape:
if prod(sts[x].shape) == prod(st.shape): raise AssertionError(f"found implicit reshape {x.op} {op.op} {sts[x].shape} != {st.shape}")
raise AssertionError(f"found implicit expand {x.op} {sts[x].shape} != {op.op} {st.shape} {prod(sts[x].shape)} != {prod(st.shape)}")
sts[op] = st
for out in ast.src: assert_valid(out, out.src[-1].arg)
shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])]
assert all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims), f"shapes must have either 1 or n in each dimension, {shape_dims}"
return sts

View File

@ -1,10 +1,9 @@
import os, atexit, functools, contextlib import os, atexit, functools, contextlib
from collections import defaultdict from collections import defaultdict
from typing import List, Any, DefaultDict from typing import List, Any, DefaultDict
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MetaOps, TernaryOps from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MetaOps, TernaryOps, UOps, UOp
from tinygrad.device import Device from tinygrad.device import Device
from tinygrad.helpers import GRAPHPATH, DEBUG, GlobalCounters from tinygrad.helpers import GRAPHPATH, DEBUG, GlobalCounters
from tinygrad.codegen.uops import UOps, UOp
from tinygrad.shape.symbolic import NumNode from tinygrad.shape.symbolic import NumNode
from tinygrad.lazy import LazyBuffer from tinygrad.lazy import LazyBuffer

View File

@ -3,8 +3,7 @@ import time, pprint
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, replace from dataclasses import dataclass, replace
from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, Context, TRACEMETA, dedup from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, Context, TRACEMETA, dedup
from tinygrad.codegen.uops import UOps, UOp from tinygrad.ops import MetaOps, UOps, UOp
from tinygrad.ops import MetaOps
from tinygrad.dtype import dtypes from tinygrad.dtype import dtypes
from tinygrad.device import Device, Buffer from tinygrad.device import Device, Buffer
from tinygrad.shape.symbolic import Variable, sym_infer, sint from tinygrad.shape.symbolic import Variable, sym_infer, sint

View File

@ -2,8 +2,7 @@ import sys, pickle, atexit, importlib, contextlib
from collections import defaultdict, deque from collections import defaultdict, deque
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Tuple, List, Dict, Optional, Set, DefaultDict, get_args from typing import Tuple, List, Dict, Optional, Set, DefaultDict, get_args
from tinygrad.codegen.uops import UOp, UOps from tinygrad.ops import MetaOps, ReduceOps, UNSAFE_PAD_OPS, UnaryOps, reduce_st, UOp, UOps
from tinygrad.ops import MetaOps, ReduceOps, UNSAFE_PAD_OPS, UnaryOps, reduce_st
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, \ from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, \
GlobalCounters, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata GlobalCounters, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata

View File

@ -2,7 +2,7 @@ from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
import itertools, functools, random, math, time, multiprocessing, traceback, signal import itertools, functools, random, math, time, multiprocessing, traceback, signal
from collections import defaultdict from collections import defaultdict
from dataclasses import replace from dataclasses import replace
from tinygrad.codegen.uops import UOp, UOps from tinygrad.ops import UOp, UOps
from tinygrad.device import Device, Buffer, Compiler from tinygrad.device import Device, Buffer, Compiler
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
from tinygrad.dtype import DType, ImageDType from tinygrad.dtype import DType, ImageDType

View File

@ -1,10 +1,12 @@
from __future__ import annotations from __future__ import annotations
from typing import Union, Tuple, Dict, Callable from collections import defaultdict
import math, operator, ctypes, struct from typing import Any, DefaultDict, List, Optional, Set, Union, Tuple, Dict, Callable, cast
import math, operator, ctypes, struct, functools, hashlib, itertools
from enum import Enum, auto from enum import Enum, auto
from dataclasses import dataclass from dataclasses import dataclass
from tinygrad.dtype import dtypes, DType from tinygrad.dtype import ConstType, dtypes, DType
from tinygrad.shape.symbolic import sint from tinygrad.helpers import dedup, merge_dicts, pretty_print, prod
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.shapetracker import ShapeTracker
# these are the llops your accelerator must implement, along with toCpu # these are the llops your accelerator must implement, along with toCpu
@ -73,3 +75,329 @@ truncate: Dict[DType, Callable] = {dtypes.bool: bool,
def exec_alu(op:Op, dtype:DType, operands): return truncate.get(dtype, lambda x: x)(python_alu[op](*operands)) def exec_alu(op:Op, dtype:DType, operands): return truncate.get(dtype, lambda x: x)(python_alu[op](*operands))
def reduce_st(st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[sint, ...]: return tuple(1 if i in axis else s for i,s in enumerate(st.shape)) def reduce_st(st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[sint, ...]: return tuple(1 if i in axis else s for i,s in enumerate(st.shape))
# the order of these UOps controls the order of the toposort
class UOps(Enum):
# ops that aren't rendered
SINK = auto(); EXT = auto(); EXPAND = auto(); CONTRACT = auto(); ST_IDX = auto(); ST_VALID = auto() # noqa: E702
DEFINE_GLOBAL = auto(); DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # noqa: E702
CONST = auto(); SPECIAL = auto() # noqa: E702
NOOP = auto(); GEP = auto() # noqa: E702
# math ops
CAST = auto(); BITCAST = auto(); VECTORIZE = auto() # noqa: E702
ALU = auto(); REDUCE = auto(); REDUCE_AXIS = auto(); WMMA = auto() # noqa: E702
# memory/assignment ops
LOAD = auto(); STORE = auto(); PHI = auto() # noqa: E702
# control flow ops
BARRIER = auto(); IF = auto(); RANGE = auto() # noqa: E702
# these two are not graph nodes
ENDRANGE = auto(); ENDIF = auto() # noqa: E702
BUFFER_UOPS = {UOps.LOAD, UOps.STORE, UOps.CONST}
END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.PHI, UOps.ENDRANGE)}
@dataclass(frozen=True, eq=False)
class UOp:
op: UOps
dtype: Optional[DType] = None
src: Tuple[UOp, ...] = tuple()
arg: Any = None
def commutative(self) -> bool:
return (self.op is UOps.ALU and \
self.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR})
@functools.cached_property
def cmp_tuple(self):
# NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX
return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg.expr) if self.op is not UOps.ALU else \
self.arg.value, self.dtype, self.src)
def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple
@functools.cached_property
def key(self) -> bytes:
return hashlib.sha256(functools.reduce(lambda x,y: x+y, [s.key for s in self.src], str((self.op, self.dtype, self.arg)).encode())).digest()
def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))")
def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is UOps.REDUCE_AXIS else self.arg
# *** uop syntactic sugar
def ufix(self, x): return self.const(x) if not isinstance(x, UOp) else x
def cast(self, dtype=None): return type(self)(UOps.CAST, dtype, (self,))
def bitcast(self, dtype=None): return type(self)(UOps.BITCAST, dtype, (self,))
def gep(self, i:int): return type(self)(UOps.GEP, self.dtype.scalar() if self.dtype is not None else None, (self,), i)
def __neg__(self): return self.alu(UnaryOps.NEG)
def __add__(self, x): return self.alu(BinaryOps.ADD, self.ufix(x))
def __radd__(self, x): return self.alu(BinaryOps.ADD, self.ufix(x))
def __sub__(self, x): return self.alu(BinaryOps.ADD, self.ufix(-x))
def __mul__(self, x): return self.alu(BinaryOps.MUL, self.ufix(x))
def __rmul__(self, x): return self.ufix(x).alu(BinaryOps.MUL, self)
def __floordiv__(self, x): return self.alu(BinaryOps.IDIV, self.ufix(x))
def __truediv__(self, x): return self.alu(BinaryOps.MUL, self.ufix(x).alu(UnaryOps.RECIP))
def __mod__(self, x): return self.alu(BinaryOps.MOD, self.ufix(x))
def __xor__(self, x): return self.alu(BinaryOps.XOR, self.ufix(x))
def __and__(self, x): return self.alu(BinaryOps.AND, self.ufix(x))
def __or__(self, x): return self.alu(BinaryOps.OR, self.ufix(x))
def ne(self, x): return self.alu(BinaryOps.CMPNE, self.ufix(x))
def eq(self, x): return -self.ne(x)
def lt(self, x): return self.alu(BinaryOps.CMPLT, self.ufix(x))
def ge(self, x): return (-self).lt(-x+1)
def max(self, x): return self.alu(BinaryOps.MAX, x)
def min(self, x): return -(-self).max(-x)
def where(self, x, y): return self.alu(TernaryOps.WHERE, x, y)
def recip(self): return self.alu(UnaryOps.RECIP)
def const(self:Union[UOp, DType, None], b:ConstType|Variable): return UOp._const(self.dtype if isinstance(self, UOp) else self, b)
def sconst(self:Union[UOp, DType, None], b:ConstType|Variable):
return UOp._const(cast(DType, self.dtype if isinstance(self, UOp) else self).scalar() if self is not None else self, b)
@staticmethod
@functools.lru_cache(maxsize=None)
def _const(dtype:Optional[DType], b:ConstType|Variable):
# TODO: fix dtype of b.max after Variable is just an UOp
if isinstance(b, Variable): return UOp(UOps.DEFINE_VAR, dtype, (UOp.const(dtypes.int, b.min), UOp.const(dtypes.int, cast(int,b.max))), b)
if dtype is not None and dtype != (sdtype := dtype.scalar()):
return UOp(UOps.VECTORIZE, dtype, src=tuple(UOp(UOps.CONST, sdtype, arg=dtypes.as_const(b, sdtype)) for _ in range(dtype.count)))
return UOp(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b)
def alu(self, arg, *src:UOp):
return type(self)(UOps.ALU, dtypes.bool if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else (self, *src)[-1].dtype, (self,)+src, arg)
@staticmethod
def load(*src:UOp, dtype:Optional[DType]=None, **kwargs): return type(src[0])(UOps.LOAD, dtype, tuple(src)+tuple(kwargs.values()))
@staticmethod
def store(*src:UOp, **kwargs): return type((src:=(*src, *kwargs.values()))[0])(UOps.STORE, None, src)
@functools.cached_property
def parents(self) -> Dict[UOp, None]: return merge_dicts([{x:None for x in self.src}]+[x.parents for x in self.src])
@property # parents with self
def sparents(self) -> Dict[UOp, None]: return {**self.parents, self:None}
@staticmethod
def from_st(st:ShapeTracker) -> Tuple[UOp, UOp]: return UOp(UOps.ST_IDX, dtypes.pyint, (), st), UOp(UOps.ST_VALID, dtypes.bool, (), st)
@functools.cached_property
def full_shape(self) -> Tuple[sint, ...]:
if self.op in {UOps.ST_IDX, UOps.ST_VALID}: return self.arg.shape
# NOTE: UOps.DEFINE_GLOBAL and UOps.DEFINE_LOCAL don't have shape
return tuple(max(x) for x in zip(*[x.full_shape for x in self.src if x.op not in {UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL}]))
def vars(self) -> Set[UOp]: return set([x for x in self.sparents if x.op is UOps.DEFINE_VAR])
def variables(self) -> List[Variable]:
st_vars: List[Set[Variable]] = [x.src[-1].arg.vars() for x in self.sparents if x.op in BUFFER_UOPS]
return sorted(set.union(*st_vars, set([x.arg for x in self.sparents if x.op is UOps.DEFINE_VAR])), key=lambda v: v.expr)
def const_factor(self) -> int:
"""largest known int that divides self"""
if self.op is UOps.CONST: return self.arg
if self.op is UOps.ALU:
if self.arg is BinaryOps.ADD: return math.gcd(self.src[0].const_factor(), self.src[0].const_factor())
if self.arg is BinaryOps.MUL: return self.src[0].arg if self.src[0].op is UOps.CONST else self.src[1].arg if self.src[1].op is UOps.CONST else 1
return 1
def divides(self, v) -> Optional[UOp]:
if v==1: return self
if self.op is UOps.CONST: return self.const(self.arg//v) if self.arg%v == 0 else None
if self.op is UOps.ALU:
if self.arg is BinaryOps.ADD: return d0+d1 if (d0:=self.src[0].divides(v)) is not None and (d1:=self.src[1].divides(v)) is not None else None
if self.arg is BinaryOps.MUL:
if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1]
if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1
return None # generic None if we aren't sure
@functools.cached_property
def vmin(self) -> UOp: return x if (x:=self._min_max[0]) is not None and not math.isnan(x.arg) else self.sconst(dtypes.min(cast(DType, self.dtype)))
@functools.cached_property
def vmax(self) -> UOp: return x if (x:=self._min_max[1]) is not None and not math.isnan(x.arg) else self.sconst(dtypes.max(cast(DType, self.dtype)))
@functools.cached_property
def _min_max(self) -> Tuple[Optional[UOp], Optional[UOp]]:
# NOTE: returned UOp is assumed to be CONST
if self.op is UOps.DEFINE_VAR and self.src: return self.src[0], self.src[1] if isinstance(self.src[1].arg, int) else None
if self.op is UOps.RANGE: return self.src[0], self.const(self.src[1].arg-1) if isinstance(self.src[1].arg, int) else None
# TODO: UOps.SPECIAL is UOps.DEFINE_VAR
if self.op is UOps.SPECIAL: return self.const(0), self.const(self.arg[1]-1) if isinstance(self.arg[1], int) else None
if self.op is UOps.CONST: return self, self
if self.op is UOps.ALU and cast(DType, self.dtype).count == 1:
s0,s1 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(2)]
if self.arg is UnaryOps.NEG and self.dtype != dtypes.bool and not dtypes.is_unsigned(cast(DType, self.dtype)):
return self.sconst(-s0.vmax.arg), self.sconst(-s0.vmin.arg)
if self.arg is BinaryOps.ADD: return self.sconst(s0.vmin.arg+s1.vmin.arg), self.sconst(s0.vmax.arg+s1.vmax.arg)
if self.arg is BinaryOps.MUL and (s0.vmin.arg >= 0 or s1.vmin.arg >= 0):
# handle at lease one is non-negative
Lmin, Lmax = (s0.vmin.arg, s0.vmax.arg) if s1.vmin.arg >= 0 else (s0.vmax.arg, s0.vmin.arg)
Rmin, Rmax = (s1.vmin.arg, s1.vmax.arg) if s0.vmin.arg >= 0 else (s1.vmax.arg, s1.vmin.arg)
assert math.isnan(Lmax*Rmax) or math.isnan(Lmin*Rmin) or Lmax*Rmax >= Lmin*Rmin, f"{Lmax=}, {Lmin=}, {Rmax=}, {Rmin=}"
return self.sconst(Lmin*Rmin), self.sconst(Lmax*Rmax)
if self.arg is BinaryOps.MOD and s1.op is UOps.CONST and s1.arg > 0: return self.sconst(0), self.sconst(s1.arg-1)
if self.arg is BinaryOps.IDIV and s1.op is UOps.CONST:
if s1.arg > 0: return self.sconst(s0.vmin.arg//s1.arg), self.sconst(s0.vmax.arg//s1.arg)
if s1.arg < 0: return self.sconst(-(s0.vmax.arg//-s1.arg)), self.sconst(-(s0.vmin.arg//-s1.arg))
if self.arg is BinaryOps.MAX: return self.sconst(max(s0.vmin.arg, s1.vmin.arg)), self.sconst(max(s0.vmax.arg, s1.vmax.arg))
if self.arg is BinaryOps.CMPLT: return (UOp.sconst(dtypes.bool, True), UOp.sconst(dtypes.bool, True)) if s0.vmax.arg < s1.vmin.arg else \
(UOp.sconst(dtypes.bool, False), UOp.sconst(dtypes.bool, False)) if s0.vmin.arg >= s1.vmax.arg else (None, None)
return None, None
@dataclass(frozen=True, repr=False) # reuse repr from UOp
class NOp(UOp):
name:Optional[str] = None
src:Tuple[NOp, ...] = tuple()
allow_any_len:bool = False
@staticmethod
def var(name:Optional[str]=None, dtype:Optional[DType]=None): return NOp(UOps.NOOP, dtype=dtype, name=name)
@staticmethod
def cvar(name:Optional[str]=None, dtype:Optional[DType]=None): return NOp(UOps.CONST, dtype=dtype, name=name)
def const(self:Union[UOp, DType, None], b:ConstType|Variable): return NOp((x:=UOp.const(self, b)).op, x.dtype, x.src, x.arg)
def compile(self: NOp, name:Optional[str]=None) -> UPat:
return UPat(name=self.name, dtype=self.dtype) if self.op is UOps.NOOP else UPat(self.op, self.arg, (list if self.commutative()
else tuple)(src.compile() for src in self.src) or None, self.name or name, self.dtype, self.allow_any_len)
class UPat:
def __init__(self, op:Optional[Union[UOps, Set[UOps]]]=None, arg:Any=None, src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None,
name:Optional[str]=None, dtype:Optional[Union[DType, Set[DType]]]=None, allow_any_len:bool=False):
self.op: Optional[Tuple[UOps, ...]] = None if op is None else (tuple(op) if isinstance(op, set) else (op,))
self.dtype: Optional[Tuple[DType, ...]] = None if dtype is None else (tuple(dtype) if isinstance(dtype, set) else (dtype,))
self.arg, self.name = arg, name
self.src: Any = None
# try all permutations if it's a list
if isinstance(src, list): self.src = list(itertools.permutations(src))
# only one if it's a tuple
elif isinstance(src, tuple): self.src = [src]
# repeat if it's a UPat
elif isinstance(src, UPat): self.src = [itertools.repeat(src)]
self.allowed_len: int = 0 if allow_any_len or isinstance(src, UPat) or src is None else len(src)
def __repr__(self):
def rep(x):
form = "UPat(%s, %s, name=%s, dtype=%s, allow_any_len=%s, src=%s)"
return form % (None if x.op is None else ('(%s)'%', '.join(map(str, x.op))), x.arg, repr(x.name),
set(x.dtype) if x.dtype else None, x.allowed_len == 0, "[%s]" if x.src and len(x.src)>1 else "(%s)")
return pretty_print(self, rep, srcfn=lambda x:None if x.src is None else [next(x.src[0])] if isinstance(x.src[0], itertools.repeat) else x.src[0])
def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> List[Dict[str, UOp]]:
if (pat.name is not None and store.setdefault(pat.name, uop) is not uop) or \
(pat.dtype is not None and uop.dtype not in pat.dtype) or \
(pat.arg is not None and pat.arg != uop.arg) or \
(pat.op is not None and uop.op not in pat.op): return []
if pat.src is None: return [store]
res: List[Dict[str, UOp]] = []
for vp in pat.src:
if pat.allowed_len != 0 and len(uop.src) != pat.allowed_len: return []
new_stores = [store.copy()]
for uu, vv in zip(uop.src, vp): new_stores = [rstore for nstore in new_stores for rstore in _match(uu, vv, nstore)]
res.extend(new_stores)
return res
class PatternMatcher:
def __init__(self, patterns:List[Tuple[Union[UPat, NOp], Callable]]):
self.patterns = patterns
self.pdict: DefaultDict[Tuple[UOps, Any], List[Tuple[UPat, Callable]]] = defaultdict(list)
# uop is required, arg is optional
for p,fxn in self.patterns:
if isinstance(p, NOp): p = p.compile()
assert p.op is not None
for uop in p.op: self.pdict[(uop, p.arg)].append((p, fxn))
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns)
def rewrite(self, uop:UOp) -> Optional[UOp]:
for p,fxn in itertools.chain(self.pdict[(uop.op, uop.arg)], self.pdict[(uop.op, None)]):
if (matches := _match(uop, p, {})) and (ret:=fxn(**matches[0])) is not None: return ret # NOTE: if it returns None, we keep trying to match
return None
def type_verify(uops):
for u in uops:
uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype
if uop in {UOps.CONST, UOps.DEFINE_ACC}:
if uop is UOps.CONST:
assert dtype is not None and dtype == dtype.scalar(), f"consts must be scalar, got {dtype}"
assert type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
if uop is UOps.DEFINE_ACC: assert dtype is not None and src[0].dtype == dtype, f"dtype mismatch {src[0].dtype=} != {dtype=}"
if uop in {UOps.CAST, UOps.BITCAST, UOps.VECTORIZE}: assert arg is None and dtype is not None # type is the output type, not an arg
if uop is UOps.CAST: assert dtype.count == 1 and len(src) == 1
if uop is UOps.VECTORIZE:
assert dtype.count > 1 and len(src) == dtype.count, f"dtype vectorization mismatch {dtype.count=} != {len(src)=}"
assert all(dtype == x.dtype.vec(len(src)) for x in src), f"{dtype=} must be {src[0].dtype.vec(len(src))}"
if uop is UOps.LOAD and len(src) > 3 and src[3].op is UOps.ALU: assert src[3].dtype == dtypes.bool and src[2].dtype == dtype
if uop is UOps.GEP: assert dtype == src[0].dtype.scalar(), f"GEP of {src[0].dtype=} should be {src[0].dtype.scalar()} != {dtype}"
if uop is UOps.STORE:
assert dtype is None, f"{uop} dtype must be None, got {dtype}"
if len(src) == 4: assert src[3].dtype == dtypes.bool, f"gate dtype mismatch {src[3].dtype} != {dtypes.bool}"
if uop is UOps.ALU:
if arg in UnaryOps: assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
elif arg in {BinaryOps.CMPLT, BinaryOps.CMPNE}:
assert dtype == (bd:=(dtypes.bool.vec(dtype.count) if dtype.count != 1 else dtypes.bool)), f"{arg} output dtype mismatch {dtype=} != {bd=}"
assert src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
elif arg is BinaryOps.IDIV:
assert dtypes.is_int(src[0].dtype) and dtypes.is_int(src[1].dtype), f"input dtype is not int {src[0].dtype=}, {src[1].dtype=}"
assert dtypes.is_int(dtype), f"output dtype is not int {dtype=}"
elif arg in {BinaryOps.SHL, BinaryOps.SHR}:
# the distance to shift isn't typechecked
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
elif arg in BinaryOps: assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
elif arg == TernaryOps.WHERE:
assert src[0].dtype == (bd:=(dtypes.bool.vec(dtype.count) if dtype.count != 1 else dtypes.bool)), \
f"{arg} selector dtype mismatch {src[0].dtype=} != {bd}"
assert dtype == src[1].dtype == src[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {src[1].dtype=} != {src[2].dtype=}"
def uop_alu_resolve(u:UOp) -> sint:
if u.op in {UOps.CONST, UOps.DEFINE_VAR}: return u.arg
if u.op is UOps.ALU: return exec_alu(u.arg, cast(DType,u.dtype), tuple(map(uop_alu_resolve, u.src)))
raise RuntimeError(f"ALU resolve fail @ {u.op}")
def print_uops(uops:List[UOp]):
for i,u in enumerate(uops):
formatted_parents = [uops.index(x) if x.op is not UOps.CONST else f"{x.arg}" for x in u.src]
print(f"{i:4d} {str(u.op):20s}: {str(u.dtype) if u.dtype is not None else '':25s} " f"{str(formatted_parents):32s} {u.arg}")
def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
flops: sint = 0
mem: sint = 0
mults: sint = 1
mult_stack: List[sint] = []
dont_count: Set[UOp] = set()
if ignore_indexing:
for u in uops:
if u.op is UOps.LOAD:
dont_count = dont_count.union(u.src[1].sparents)
if len(u.src) > 3: dont_count = dont_count.union(u.src[2].sparents)
elif u.op is UOps.STORE:
dont_count = dont_count.union(u.src[1].sparents)
if len(u.src) > 3: dont_count = dont_count.union(u.src[3].sparents)
elif u.op is UOps.IF:
dont_count = dont_count.union(u.src[0].sparents)
for u in uops:
if u.op is UOps.RANGE:
mult_stack.append(mults)
mults *= uop_alu_resolve(u.src[1] - u.src[0])
elif u.op is UOps.ENDRANGE:
mults = mult_stack.pop(-1)
elif u.op is UOps.SPECIAL:
mults *= u.arg[1] # NOTE: we don't push to the mult_stack here, you can't end these
elif u.op is UOps.LOAD:
assert u.dtype is not None
mem += u.dtype.itemsize * mults
elif u.op is UOps.STORE:
assert u.src[2].dtype is not None
mem += u.src[2].dtype.itemsize * mults
elif u.op is UOps.ALU and u not in dont_count:
assert u.dtype is not None
flops += (mults * (2 if u.arg == TernaryOps.MULACC else 1)) * u.dtype.count
elif u.op is UOps.WMMA and u not in dont_count:
assert u.arg[1] is not None
flops += 2 * prod(u.arg[1]) // 32 * mults
return flops, mem
# the living definition of UOps.ST_IDX and UOps.ST_VALID
def verify_ast(ast:UOp) -> Dict[UOp, ShapeTracker]:
assert ast.op is UOps.SINK and all(x.op is UOps.STORE for x in ast.src), "must be SINK"
sts: Dict[UOp, ShapeTracker] = {}
def assert_valid(op:UOp, st:ShapeTracker):
if op in sts or op.op in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL}: return
# restore globals from the two stage reduce
if op.op is UOps.LOAD and op.src[0].op is UOps.DEFINE_LOCAL:
assert_valid(local_reduce:=op.src[1].src[2], op.src[-1].arg)
return sts.setdefault(op, sts[local_reduce])
for x in op.src: assert_valid(x, st)
# only reduceop is allowed to change shape, limited to turning n to 1
if op.op is UOps.REDUCE_AXIS: st = ShapeTracker.from_shape(reduce_st(sts[op.src[0]], op.arg[1][-1] if op.arg[0] is ReduceOps.WMMA else op.arg[1]))
else:
# movementops are pushed to the edges with ST_IDX, ST_VALID
# elementwise inherits shape
st = op.arg if op.op in {UOps.ST_IDX, UOps.ST_VALID} else sts[op.src[-1]]
for x in (op.src[1:] if op.op in BUFFER_UOPS else op.src):
if sts[x].shape != st.shape:
if prod(sts[x].shape) == prod(st.shape): raise AssertionError(f"found implicit reshape {x.op} {op.op} {sts[x].shape} != {st.shape}")
raise AssertionError(f"found implicit expand {x.op} {sts[x].shape} != {op.op} {st.shape} {prod(sts[x].shape)} != {prod(st.shape)}")
sts[op] = st
for out in ast.src: assert_valid(out, out.src[-1].arg)
shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])]
assert all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims), f"shapes must have either 1 or n in each dimension, {shape_dims}"
return sts

View File

@ -2,8 +2,7 @@ from typing import Optional, List, Tuple, Dict, Callable, Any
import functools import functools
from dataclasses import dataclass, field from dataclasses import dataclass, field
from tinygrad.helpers import to_function_name, dedup from tinygrad.helpers import to_function_name, dedup
from tinygrad.codegen.uops import UOps, UOp, flops_mem from tinygrad.ops import Op, UOps, UOp, flops_mem
from tinygrad.ops import Op
from tinygrad.shape.symbolic import sym_infer, sint, Variable from tinygrad.shape.symbolic import sym_infer, sint, Variable
from tinygrad.dtype import DType from tinygrad.dtype import DType

View File

@ -1,9 +1,8 @@
from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable
import struct, math import struct, math
from collections import defaultdict from collections import defaultdict
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op, UOps, UOp, PatternMatcher, UPat
from tinygrad.dtype import dtypes, DType, PtrDType, ConstType from tinygrad.dtype import dtypes, DType, PtrDType, ConstType
from tinygrad.codegen.uops import UOps, UOp, PatternMatcher, UPat
from tinygrad.renderer import Renderer, TensorCore from tinygrad.renderer import Renderer, TensorCore
def render_val(x, dtype): def render_val(x, dtype):

View File

@ -1,10 +1,9 @@
from typing import Dict, List, Optional, Tuple, Union, DefaultDict, cast, Literal, Callable from typing import Dict, List, Optional, Tuple, Union, DefaultDict, cast, Literal, Callable
import os, math import os, math
from collections import defaultdict, Counter from collections import defaultdict, Counter
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOps, UOp
from tinygrad.helpers import strip_parens, getenv, prod, dedup from tinygrad.helpers import strip_parens, getenv, prod, dedup
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, ConstType from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, ConstType
from tinygrad.codegen.uops import UOps, UOp
from tinygrad.renderer import Renderer, TensorCore from tinygrad.renderer import Renderer, TensorCore
class CStyleLanguage(Renderer): class CStyleLanguage(Renderer):

View File

@ -1,8 +1,7 @@
from typing import Dict, Callable, Any, List, Optional from typing import Dict, Callable, Any, List, Optional
from llvmlite import ir from llvmlite import ir
from tinygrad.dtype import DType, PtrDType, dtypes from tinygrad.dtype import DType, PtrDType, dtypes
from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps, UOps, UOp
from tinygrad.codegen.uops import UOps, UOp
from tinygrad.renderer import Renderer from tinygrad.renderer import Renderer
MFLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf MFLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf

View File

@ -7,8 +7,7 @@ import pickle, base64, itertools, time, struct
from tinygrad.dtype import DType, dtypes, ImageDType from tinygrad.dtype import DType, dtypes, ImageDType
from tinygrad.helpers import all_same, getenv, flatten from tinygrad.helpers import all_same, getenv, flatten
from tinygrad.device import Compiled, Compiler, Allocator from tinygrad.device import Compiled, Compiler, Allocator
from tinygrad.codegen.uops import UOps, UOp from tinygrad.ops import BinaryOps, TernaryOps, exec_alu, truncate, UOps, UOp
from tinygrad.ops import BinaryOps, TernaryOps, exec_alu, truncate
from tinygrad.renderer import Renderer from tinygrad.renderer import Renderer
from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, AMDRenderer, IntelRenderer from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, AMDRenderer, IntelRenderer