From a71bb09ec38d2dd70b87deb807b6dcc3026024e7 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sat, 12 Oct 2024 18:44:44 +0800 Subject: [PATCH] remove symbolic file [pr] (#7012) --- examples/handcode_opt.py | 3 +-- extra/assembly/assembly.py | 2 +- extra/backends/hsa_graph.py | 2 +- extra/ops.py | 2 +- extra/optimization/extract_policynet.py | 2 +- extra/optimization/extract_sa_pairs.py | 2 +- extra/optimization/helpers.py | 2 +- extra/optimization/pretrain_valuenet.py | 2 +- extra/to_movement_ops.py | 2 +- test/external/fuzz_symbolic.py | 2 +- test/external/fuzz_uops.py | 2 +- test/helpers.py | 2 +- test/test_linearizer.py | 2 +- test/test_symbolic_shapetracker.py | 2 +- test/unit/test_helpers.py | 2 +- test/unit/test_shapetracker.py | 2 +- test/unit/test_uop_symbolic.py | 2 +- tinygrad/codegen/kernel.py | 3 +-- tinygrad/codegen/lowerer.py | 3 +-- tinygrad/engine/jit.py | 3 +-- tinygrad/engine/lazy.py | 3 +-- tinygrad/engine/realize.py | 8 ++++---- tinygrad/engine/schedule.py | 3 +-- tinygrad/engine/search.py | 6 ++---- tinygrad/function.py | 3 +-- tinygrad/ops.py | 20 ++++++++++++++------ tinygrad/renderer/__init__.py | 3 +-- tinygrad/runtime/graph/clang.py | 2 +- tinygrad/runtime/graph/cuda.py | 2 +- tinygrad/runtime/graph/hcq.py | 2 +- tinygrad/runtime/graph/metal.py | 2 +- tinygrad/shape/shapetracker.py | 3 +-- tinygrad/shape/symbolic.py | 19 ------------------- tinygrad/shape/view.py | 3 +-- tinygrad/tensor.py | 3 +-- 35 files changed, 51 insertions(+), 75 deletions(-) delete mode 100644 tinygrad/shape/symbolic.py diff --git a/examples/handcode_opt.py b/examples/handcode_opt.py index ec30733c..42df2b21 100644 --- a/examples/handcode_opt.py +++ b/examples/handcode_opt.py @@ -4,12 +4,11 @@ from extra.mcts_search import mcts_search from examples.mlperf.helpers import get_mlperf_bert_model from tinygrad import Tensor, Device, dtypes, nn from tinygrad.codegen.kernel import Kernel -from tinygrad.ops import UOps +from tinygrad.ops import UOps, sym_infer from tinygrad.device import Compiled from tinygrad.engine.schedule import create_schedule from tinygrad.engine.search import time_linearizer, beam_search, bufs_from_lin from tinygrad.helpers import DEBUG, ansilen, getenv, colored, TRACEMETA -from tinygrad.shape.symbolic import sym_infer def get_sched_resnet(): mdl = ResNet50() diff --git a/extra/assembly/assembly.py b/extra/assembly/assembly.py index f0349a8e..1133675d 100644 --- a/extra/assembly/assembly.py +++ b/extra/assembly/assembly.py @@ -3,7 +3,7 @@ from tinygrad.codegen.kernel import UOps, MemOp, UOp from tinygrad.ops import BinaryOps, UnaryOps from tinygrad.dtype import DType, dtypes from tinygrad.helpers import DEBUG -from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode +from tinygrad.ops import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode import functools import math from collections import defaultdict diff --git a/extra/backends/hsa_graph.py b/extra/backends/hsa_graph.py index 210b7a84..22476e9a 100644 --- a/extra/backends/hsa_graph.py +++ b/extra/backends/hsa_graph.py @@ -3,7 +3,7 @@ from typing import List, Any, Dict, cast, Optional, Tuple from tinygrad.helpers import init_c_var, round_up from tinygrad.device import Buffer, BufferOptions from tinygrad.device import Compiled, Device -from tinygrad.shape.symbolic import Variable +from tinygrad.ops import Variable from tinygrad.runtime.ops_hsa import HSADevice, PROFILE, Profiler from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner from tinygrad.engine.jit import MultiGraphRunner, GraphException diff --git a/extra/ops.py b/extra/ops.py index f612c7c4..76367a12 100644 --- a/extra/ops.py +++ b/extra/ops.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from tinygrad.helpers import dedup, prod from tinygrad.ops import ReduceOps, UnaryOps, BinaryOps, TernaryOps, UOp, UOps, pretty_print from tinygrad.dtype import ImageDType, PtrDType, dtypes, DType, ConstType -from tinygrad.shape.symbolic import Variable, sint +from tinygrad.ops import Variable, sint from tinygrad.shape.shapetracker import ShapeTracker # these ops are deleted after AST is UOp diff --git a/extra/optimization/extract_policynet.py b/extra/optimization/extract_policynet.py index 6e4f4c20..129208d4 100644 --- a/extra/optimization/extract_policynet.py +++ b/extra/optimization/extract_policynet.py @@ -15,7 +15,7 @@ from tinygrad.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, Buf from tinygrad.dtype import dtypes from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View -from tinygrad.shape.symbolic import Variable +from tinygrad.ops import Variable inf, nan = float('inf'), float('nan') from tinygrad.codegen.kernel import Opt, OptOps diff --git a/extra/optimization/extract_sa_pairs.py b/extra/optimization/extract_sa_pairs.py index 6b90a5a4..82f6eb00 100644 --- a/extra/optimization/extract_sa_pairs.py +++ b/extra/optimization/extract_sa_pairs.py @@ -8,7 +8,7 @@ from tinygrad.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, Buf from tinygrad.dtype import dtypes from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View -from tinygrad.shape.symbolic import Variable +from tinygrad.ops import Variable inf, nan = float('inf'), float('nan') from tinygrad.codegen.kernel import Opt, OptOps diff --git a/extra/optimization/helpers.py b/extra/optimization/helpers.py index 66a64094..6c5f8d99 100644 --- a/extra/optimization/helpers.py +++ b/extra/optimization/helpers.py @@ -7,7 +7,7 @@ from tinygrad.ops import UOp, UOps, KernelInfo from tinygrad.dtype import dtypes, PtrDType from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View -from tinygrad.shape.symbolic import NumNode +from tinygrad.ops import NumNode inf, nan = float('inf'), float('nan') # kernel unpacker diff --git a/extra/optimization/pretrain_valuenet.py b/extra/optimization/pretrain_valuenet.py index 6b49a30a..2216b24d 100644 --- a/extra/optimization/pretrain_valuenet.py +++ b/extra/optimization/pretrain_valuenet.py @@ -12,7 +12,7 @@ from tinygrad.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, Buf from tinygrad.dtype import dtypes from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View -from tinygrad.shape.symbolic import Variable +from tinygrad.ops import Variable inf, nan = float('inf'), float('nan') from tinygrad.codegen.kernel import Opt, OptOps diff --git a/extra/to_movement_ops.py b/extra/to_movement_ops.py index 35457997..af9da9b8 100644 --- a/extra/to_movement_ops.py +++ b/extra/to_movement_ops.py @@ -6,7 +6,7 @@ from extra.optimization.helpers import load_worlds, ast_str_to_ast from tinygrad.helpers import prod, tqdm from tinygrad.ops import UOp, UOps from tinygrad.shape.shapetracker import ShapeTracker -from tinygrad.shape.symbolic import sym_infer, Node +from tinygrad.ops import sym_infer, Node class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto(); AS_STRIDED = auto() # noqa: E702 diff --git a/test/external/fuzz_symbolic.py b/test/external/fuzz_symbolic.py index a3c505f9..834853b3 100644 --- a/test/external/fuzz_symbolic.py +++ b/test/external/fuzz_symbolic.py @@ -2,7 +2,7 @@ import itertools import random from tinygrad import Variable from tinygrad.helpers import DEBUG -from tinygrad.shape.symbolic import NumNode +from tinygrad.ops import NumNode random.seed(42) def add_v(expr, rng=None): diff --git a/test/external/fuzz_uops.py b/test/external/fuzz_uops.py index af9f8d7f..476b707b 100644 --- a/test/external/fuzz_uops.py +++ b/test/external/fuzz_uops.py @@ -7,7 +7,7 @@ from tinygrad.ops import END_FOR_UOP, UOp, print_uops from tinygrad.device import Buffer, Device from tinygrad.engine.realize import CompiledRunner from tinygrad.helpers import DEBUG, colored -from tinygrad.shape.symbolic import Variable +from tinygrad.ops import Variable from tinygrad.tensor import _to_np_dtype from test.external.fuzz_schedule import FUZZ_SCHEDULE_MAX_PATHS, find_all_toposorts diff --git a/test/helpers.py b/test/helpers.py index ab313d42..9fbed0c1 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -10,7 +10,7 @@ from tinygrad.engine.realize import Runner from tinygrad.dtype import ConstType, DType from tinygrad.nn.state import get_parameters from tinygrad.helpers import Context, CI, OSX, getenv -from tinygrad.shape.symbolic import sint +from tinygrad.ops import sint def derandomize_model(model): with Context(GRAPH=0): diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 8d32d934..c1b9f179 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -10,7 +10,7 @@ from tinygrad.ops import UOp, UOps, BinaryOps, TernaryOps, UnaryOps from tinygrad.device import Device, Buffer from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View -# from tinygrad.shape.symbolic import Variable +# from tinygrad.ops import Variable from tinygrad.tensor import Tensor, _to_np_dtype from tinygrad.engine.schedule import BUF_LIMIT, create_schedule from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner diff --git a/test/test_symbolic_shapetracker.py b/test/test_symbolic_shapetracker.py index f89583b4..859b0e0e 100644 --- a/test/test_symbolic_shapetracker.py +++ b/test/test_symbolic_shapetracker.py @@ -1,7 +1,7 @@ import unittest from tinygrad.shape.shapetracker import ShapeTracker, View from tinygrad import Variable -from tinygrad.shape.symbolic import NumNode +from tinygrad.ops import NumNode from tinygrad.tensor import Tensor class TestSymbolic(unittest.TestCase): diff --git a/test/unit/test_helpers.py b/test/unit/test_helpers.py index c8f4fa2b..4a8cb291 100644 --- a/test/unit/test_helpers.py +++ b/test/unit/test_helpers.py @@ -5,7 +5,7 @@ from tinygrad.helpers import Context, ContextVar from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv from tinygrad.tensor import get_shape from tinygrad.codegen.lowerer import get_contraction -from tinygrad.shape.symbolic import NumNode +from tinygrad.ops import NumNode import numpy as np VARIABLE = ContextVar("VARIABLE", 0) diff --git a/test/unit/test_shapetracker.py b/test/unit/test_shapetracker.py index d7082543..5bd6927b 100644 --- a/test/unit/test_shapetracker.py +++ b/test/unit/test_shapetracker.py @@ -5,7 +5,7 @@ from tinygrad.dtype import dtypes from tinygrad.helpers import prod from tinygrad.shape.shapetracker import ShapeTracker, View from tinygrad import Variable -from tinygrad.shape.symbolic import NumNode +from tinygrad.ops import NumNode from tinygrad.ops import UOp, UOps, graph_rewrite from tinygrad.codegen.uopgraph import sym from itertools import product diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index f280b040..5cf0043c 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -1,7 +1,7 @@ #!/usr/bin/env python import unittest, pickle from typing import Tuple -#from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, LtNode, ModNode, Node, sym_render, sym_infer, create_lt_node, create_ge_node +#from tinygrad.ops import MulNode, SumNode, Variable, NumNode, LtNode, ModNode, Node, sym_render, sym_infer, create_lt_node, create_ge_node # TODO: fix all the @unittest.expectedFailure diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 2eb5f473..5425f609 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -6,14 +6,13 @@ from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict from enum import Enum, auto from tinygrad.ops import UNSAFE_PAD_OPS, BUFFER_UOPS, BinaryOps, KernelInfo, UOp, UOps, PatternMatcher, print_uops, type_verify, resolve, \ - graph_rewrite, track_rewrites + graph_rewrite, track_rewrites, Variable, sint from tinygrad.device import Device from tinygrad.renderer import Renderer, TensorCore, Program from tinygrad.dtype import ImageDType, PtrDType from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put from tinygrad.helpers import DEBUG, TC_OPT, USE_TC, AMX from tinygrad.shape.shapetracker import ShapeTracker -from tinygrad.shape.symbolic import Variable, sint from tinygrad.shape.view import strides_for_shape from tinygrad.codegen.linearize import linearize_uop from tinygrad.codegen.uopgraph import full_graph_rewrite diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index da6915dc..96a1dbb3 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -5,9 +5,8 @@ from dataclasses import dataclass from typing import List, Tuple, cast, Optional from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import variable_to_uop -from tinygrad.shape.symbolic import sint from tinygrad.dtype import dtypes -from tinygrad.ops import KernelInfo, BinaryOps, UOp, UOps, graph_rewrite, PatternMatcher, UPat, resolve +from tinygrad.ops import KernelInfo, BinaryOps, UOp, UOps, graph_rewrite, PatternMatcher, UPat, resolve, sint from tinygrad.renderer import Renderer from tinygrad.helpers import all_int, prod, partition, flatten diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index b9b928e6..d2c458d0 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -6,9 +6,8 @@ from tinygrad.engine.lazy import LazyBuffer from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, GRAPH, BEAM, getenv, colored, JIT, dedup, partition from tinygrad.device import Buffer, Compiled, Device from tinygrad.dtype import DType -from tinygrad.ops import UOp, ssimplify +from tinygrad.ops import UOp, ssimplify, Variable, sint, sym_infer from tinygrad.shape.shapetracker import ShapeTracker -from tinygrad.shape.symbolic import Variable, sint, sym_infer from tinygrad.engine.realize import ExecItem, capturing, EmptyOp, ViewOp, BufferXfer, CompiledRunner, Runner, _internal_memory_planner from tinygrad.nn.state import get_parameters from dataclasses import dataclass diff --git a/tinygrad/engine/lazy.py b/tinygrad/engine/lazy.py index 3868dccf..6b632c64 100644 --- a/tinygrad/engine/lazy.py +++ b/tinygrad/engine/lazy.py @@ -3,8 +3,7 @@ from typing import Union, Optional, Any, Tuple, List, get_args from tinygrad.dtype import dtypes, DType, ConstType, to_dtype, ImageDType from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata, SPLIT_REDUCEOP from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu, REDUCE_ALU -from tinygrad.ops import identity_element, MathTrait, resolve, UOp -from tinygrad.shape.symbolic import sint +from tinygrad.ops import identity_element, MathTrait, resolve, UOp, sint from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.device import Buffer from weakref import ref, ReferenceType, WeakValueDictionary diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index 156c5a47..9c1808e5 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -4,10 +4,9 @@ from collections import defaultdict from dataclasses import dataclass, replace from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, Context, TRACEMETA, dedup from tinygrad.helpers import NO_MEMORY_PLANNER -from tinygrad.ops import UOps, UOp +from tinygrad.ops import UOps, UOp, Variable, sym_infer, sint from tinygrad.dtype import dtypes from tinygrad.device import Device, Buffer -from tinygrad.shape.symbolic import Variable, sym_infer, sint from tinygrad.renderer import Renderer, Program from tinygrad.codegen.kernel import Kernel from tinygrad.engine.schedule import ScheduleItem @@ -164,9 +163,10 @@ class ExecItem: prg: Runner bufs: List[Optional[Buffer]] metadata: Optional[Tuple[Metadata, ...]] = None - def run(self, var_vals:Optional[Dict[Variable, int]]=None, wait=False, jit=False, do_update_stats=True) -> Optional[float]: + def run(self, _var_vals:Optional[Dict[Variable, int]]=None, wait=False, jit=False, do_update_stats=True) -> Optional[float]: + var_vals = {} if _var_vals is None else _var_vals bufs = [cast(Buffer, x) for x in self.bufs] if jit else [cast(Buffer, x).ensure_allocated() for x in self.bufs] - et = self.prg(bufs, var_vals if var_vals is not None else {}, wait=wait or DEBUG >= 2) + et = self.prg(bufs, var_vals, wait=wait or DEBUG >= 2) if do_update_stats: GlobalCounters.kernel_count += 1 GlobalCounters.global_ops += (op_est:=sym_infer(self.prg.op_estimate, var_vals)) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 552c1abc..b4113bae 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -3,10 +3,9 @@ from collections import defaultdict, deque from dataclasses import dataclass from typing import Callable, Tuple, List, Dict, Optional, DefaultDict, cast from tinygrad.ops import REDUCE_ALU, UNSAFE_PAD_OPS, MetaOps, ReduceOps, TernaryOps, UnaryOps, UOp, UOps, PatternMatcher, UPat, resolve, \ - graph_rewrite, track_rewrites + graph_rewrite, track_rewrites, Variable, sint from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, GlobalCounters, Metadata, all_same, \ colored, diskcache_put, prod, dedup, all_int, merge_dicts, getenv, unwrap -from tinygrad.shape.symbolic import Variable, sint from tinygrad.dtype import ImageDType, dtypes from tinygrad.engine.lazy import LazyBuffer from tinygrad.shape.shapetracker import ShapeTracker diff --git a/tinygrad/engine/search.py b/tinygrad/engine/search.py index d7acd272..e727da0f 100644 --- a/tinygrad/engine/search.py +++ b/tinygrad/engine/search.py @@ -2,14 +2,12 @@ from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable import itertools, functools, random, math, time, multiprocessing, traceback, signal from collections import defaultdict from dataclasses import replace -from tinygrad.ops import UOp, UOps +from tinygrad.ops import UOp, UOps, Variable, sym_infer from tinygrad.device import Device, Buffer, Compiler from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name from tinygrad.dtype import ImageDType -from tinygrad.codegen.kernel import Kernel -from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError +from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError from tinygrad.tensor import Tensor -from tinygrad.shape.symbolic import Variable, sym_infer from tinygrad.engine.realize import CompiledRunner from tinygrad.renderer import Program diff --git a/tinygrad/function.py b/tinygrad/function.py index c31ba0a1..b27a2e6a 100644 --- a/tinygrad/function.py +++ b/tinygrad/function.py @@ -3,10 +3,9 @@ import math from typing import Tuple, Optional from tinygrad.helpers import argsort from tinygrad.dtype import dtypes, DType, sum_acc_dtype -from tinygrad.ops import ReduceOps, resolve +from tinygrad.ops import ReduceOps, resolve, sint from tinygrad.tensor import Function from tinygrad.engine.lazy import LazyBuffer -from tinygrad.shape.symbolic import sint class Contiguous(Function): def forward(self, x:LazyBuffer) -> LazyBuffer: return x.contiguous() diff --git a/tinygrad/ops.py b/tinygrad/ops.py index d780b304..37fd1622 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -7,7 +7,6 @@ from weakref import WeakValueDictionary from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate from tinygrad.helpers import ContextVar, prod, getenv, all_same, Context if TYPE_CHECKING: - from tinygrad.shape.symbolic import Variable, sint from tinygrad.shape.shapetracker import ShapeTracker # wrapper around IntEnum that preserves Enum.__str__ and makes auto() unique across all FastEnum subclasses @@ -154,7 +153,7 @@ COMMUTATIVE = {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, Bin END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.ASSIGN, UOps.ENDRANGE)} # With True as the default, this matches the old symbolic behavior -# python3 -c 'from tinygrad.shape.symbolic import Variable; print(bool(Variable("a", 1, 10) < 4))' -> True +# python3 -c 'from tinygrad.ops import Variable; print(bool(Variable("a", 1, 10) < 4))' -> True def resolve(x, default:bool=True): if not isinstance(x, UOp): return bool(x) assert x.dtype is dtypes.bool, "UOp in resolve must be bool" @@ -306,10 +305,8 @@ class UOp(MathTrait): st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFFER_UOPS] return sorted(set.union(*st_vars, [x.unbind()[0] if x.op is not UOps.DEFINE_VAR else x for x in self.vars()]), key=lambda v: v.arg) - # TODO: this is context rewrite - def substitute(self, dvars:Dict[UOp, UOp]): - if self in dvars: return dvars[self] - return self.replace(src=tuple(x.substitute(dvars) for x in self.src)) + def substitute(self, dvars:Dict[UOp, UOp]): return graph_rewrite(self, substitute, dvars) + @staticmethod def range(dtype:DType, start:ConstType|UOp, end:ConstType|UOp, idx:int): return UOp(UOps.RANGE, dtype=dtype, src=(UOp.const(dtype, start) if not isinstance(start, UOp) else start, @@ -967,6 +964,8 @@ symbolic_flat = symbolic+PatternMatcher([ ((UPat.var("x", dtypes.ints) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c), ]) +substitute = PatternMatcher([(UPat(tuple(UOps), name="x"), lambda ctx,x: ctx.get(x,None))]) + # for debug renderer = PatternMatcher([ (UPat(UOps.DEFINE_VAR, name="x"), lambda x: UOp(UOps.NOOP, arg=x.arg[0])), @@ -979,3 +978,12 @@ renderer = PatternMatcher([ (UPat(UOps.ALU, src=UPat(UOps.NOOP), arg=BinaryOps.CMPLT, name="x"), lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}<{x.src[1].arg})")), (UPat(UOps.ALU, src=UPat(UOps.NOOP), arg=BinaryOps.CMPNE, name="x"), lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}!={x.src[1].arg})")), ]) + +# *** what was symbolic.py *** + +sint = Union[int, UOp] +Variable = UOp + +def NumNode(val:int): return UOp.const(dtypes.int, val) +def sym_infer(uop: Union[UOp, int], var_vals: Dict[UOp, int]) -> int: + return int(uop.substitute({k:k.const_like(v) for k,v in var_vals.items()})) if isinstance(uop, UOp) else uop diff --git a/tinygrad/renderer/__init__.py b/tinygrad/renderer/__init__.py index 985b1604..5c1f976f 100644 --- a/tinygrad/renderer/__init__.py +++ b/tinygrad/renderer/__init__.py @@ -2,8 +2,7 @@ from typing import Optional, List, Tuple, Dict, Callable, Any import functools from dataclasses import dataclass, field from tinygrad.helpers import to_function_name, dedup, prod -from tinygrad.ops import Op, UOps, UOp, flops_mem -from tinygrad.shape.symbolic import sym_infer, sint, Variable +from tinygrad.ops import Op, UOps, UOp, flops_mem, sym_infer, sint, Variable from tinygrad.dtype import DType @dataclass(frozen=True) diff --git a/tinygrad/runtime/graph/clang.py b/tinygrad/runtime/graph/clang.py index 6c516f38..038ad034 100644 --- a/tinygrad/runtime/graph/clang.py +++ b/tinygrad/runtime/graph/clang.py @@ -4,7 +4,7 @@ from tinygrad.helpers import dedup, cpu_time_execution, DEBUG from tinygrad.engine.jit import GraphRunner, GraphException from tinygrad.device import Buffer, Device from tinygrad.engine.realize import ExecItem, CompiledRunner -from tinygrad.shape.symbolic import Variable +from tinygrad.ops import Variable from tinygrad.runtime.ops_clang import ClangProgram from tinygrad.renderer.cstyle import ClangRenderer render_dtype = ClangRenderer().render_dtype diff --git a/tinygrad/runtime/graph/cuda.py b/tinygrad/runtime/graph/cuda.py index 658041b8..03d14325 100644 --- a/tinygrad/runtime/graph/cuda.py +++ b/tinygrad/runtime/graph/cuda.py @@ -4,7 +4,7 @@ import tinygrad.runtime.autogen.cuda as cuda from tinygrad.helpers import init_c_var, dedup from tinygrad.device import Buffer, Device from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution -from tinygrad.shape.symbolic import Variable +from tinygrad.ops import Variable from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner from tinygrad.engine.jit import MultiGraphRunner, GraphException diff --git a/tinygrad/runtime/graph/hcq.py b/tinygrad/runtime/graph/hcq.py index f3e7448b..61f27072 100644 --- a/tinygrad/runtime/graph/hcq.py +++ b/tinygrad/runtime/graph/hcq.py @@ -3,7 +3,7 @@ from typing import List, Any, Dict, cast, Optional, Tuple, Set from tinygrad.helpers import round_up, PROFILE, memsize_to_str from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWCommandQueue, HWComputeQueue, HWCopyQueue, HCQArgsState from tinygrad.device import Buffer, BufferOptions, Compiled, Device -from tinygrad.shape.symbolic import Variable +from tinygrad.ops import Variable from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner from tinygrad.engine.jit import MultiGraphRunner diff --git a/tinygrad/runtime/graph/metal.py b/tinygrad/runtime/graph/metal.py index ff8df6d1..9b0ad0ae 100644 --- a/tinygrad/runtime/graph/metal.py +++ b/tinygrad/runtime/graph/metal.py @@ -5,7 +5,7 @@ from tinygrad.helpers import dedup, getenv from tinygrad.device import Buffer from tinygrad.engine.realize import ExecItem, CompiledRunner from tinygrad.engine.jit import GraphRunner, GraphException -from tinygrad.shape.symbolic import Variable +from tinygrad.ops import Variable from tinygrad.runtime.ops_metal import wait_check, msg, libobjc, to_struct, objc_instance,\ MTLResourceOptions, elapsed_time, objc_id diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index 3284858f..bea81181 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -3,10 +3,9 @@ from __future__ import annotations from dataclasses import dataclass from typing import Tuple, List, Optional, Dict, Set from tinygrad.helpers import merge_dicts, getenv -from tinygrad.shape.symbolic import Variable, sint from tinygrad.shape.view import View, strides_for_shape from tinygrad.dtype import dtypes -from tinygrad.ops import UOp, UOps, BinaryOps, graph_rewrite, _get_chain, symbolic_flat +from tinygrad.ops import UOp, UOps, BinaryOps, graph_rewrite, _get_chain, symbolic_flat, Variable, sint @dataclass(frozen=True) class ShapeTracker: diff --git a/tinygrad/shape/symbolic.py b/tinygrad/shape/symbolic.py deleted file mode 100644 index c24dade8..00000000 --- a/tinygrad/shape/symbolic.py +++ /dev/null @@ -1,19 +0,0 @@ -from __future__ import annotations -from typing import Union, Optional, Dict -from tinygrad.dtype import dtypes -from tinygrad.ops import UOp, UOps, exec_alu - -sint = Union[int, UOp] - -def NumNode(val:int): return UOp.const(dtypes.int, val) -Variable = UOp - -def sym_infer(uop: Union[UOp, int], var_vals: Optional[Dict[Variable, int]]) -> int: - if isinstance(uop, (int, float)): return uop # TODO: ugh, the float is a hack for qcom - if uop.op == UOps.CONST: return uop.arg - if uop.op == UOps.DEFINE_VAR and var_vals is not None: return var_vals[uop] - if uop.op == UOps.BIND: return uop.src[1].arg # bound variable returns bound value - if uop.op == UOps.ALU: - src_values = [sym_infer(src, var_vals) for src in uop.src] - return exec_alu(uop.arg, uop.dtype, src_values) - raise NotImplementedError(f"Unsupported UOp {uop.op}") diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index 99fbf7ff..eea20ea4 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -3,9 +3,8 @@ import functools, operator, itertools, math from dataclasses import dataclass from typing import Tuple, List, Optional, Dict, Set, cast, Union from tinygrad.dtype import dtypes -from tinygrad.ops import resolve, UOp +from tinygrad.ops import resolve, UOp, NumNode, Variable, sint, sym_infer from tinygrad.helpers import prod, all_int, argsort -from tinygrad.shape.symbolic import NumNode, Variable, sint, sym_infer @functools.lru_cache(maxsize=None) def canonicalize_strides(shape:Tuple[sint, ...], strides:Tuple[sint, ...]) -> Tuple[sint, ...]: diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 27b4718c..357866e4 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -9,9 +9,8 @@ from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, leas from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv from tinygrad.multi import MultiLazyBuffer -from tinygrad.ops import MetaOps, truncate, smax, resolve, UOp, UOps, BinaryOps +from tinygrad.ops import MetaOps, truncate, smax, resolve, UOp, UOps, BinaryOps, sint, Variable from tinygrad.device import Device, Buffer, BufferOptions -from tinygrad.shape.symbolic import sint, Variable from tinygrad.engine.lazy import LazyBuffer from tinygrad.engine.realize import run_schedule, memory_planner from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars