mirror of https://github.com/commaai/tinygrad.git
remove symbolic file [pr] (#7012)
This commit is contained in:
parent
16271189ea
commit
a71bb09ec3
|
@ -4,12 +4,11 @@ from extra.mcts_search import mcts_search
|
|||
from examples.mlperf.helpers import get_mlperf_bert_model
|
||||
from tinygrad import Tensor, Device, dtypes, nn
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.ops import UOps
|
||||
from tinygrad.ops import UOps, sym_infer
|
||||
from tinygrad.device import Compiled
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.search import time_linearizer, beam_search, bufs_from_lin
|
||||
from tinygrad.helpers import DEBUG, ansilen, getenv, colored, TRACEMETA
|
||||
from tinygrad.shape.symbolic import sym_infer
|
||||
|
||||
def get_sched_resnet():
|
||||
mdl = ResNet50()
|
||||
|
|
|
@ -3,7 +3,7 @@ from tinygrad.codegen.kernel import UOps, MemOp, UOp
|
|||
from tinygrad.ops import BinaryOps, UnaryOps
|
||||
from tinygrad.dtype import DType, dtypes
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode
|
||||
from tinygrad.ops import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode
|
||||
import functools
|
||||
import math
|
||||
from collections import defaultdict
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import List, Any, Dict, cast, Optional, Tuple
|
|||
from tinygrad.helpers import init_c_var, round_up
|
||||
from tinygrad.device import Buffer, BufferOptions
|
||||
from tinygrad.device import Compiled, Device
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.ops import Variable
|
||||
from tinygrad.runtime.ops_hsa import HSADevice, PROFILE, Profiler
|
||||
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
|
||||
from tinygrad.engine.jit import MultiGraphRunner, GraphException
|
||||
|
|
|
@ -6,7 +6,7 @@ from dataclasses import dataclass
|
|||
from tinygrad.helpers import dedup, prod
|
||||
from tinygrad.ops import ReduceOps, UnaryOps, BinaryOps, TernaryOps, UOp, UOps, pretty_print
|
||||
from tinygrad.dtype import ImageDType, PtrDType, dtypes, DType, ConstType
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
from tinygrad.ops import Variable, sint
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
|
||||
# these ops are deleted after AST is UOp
|
||||
|
|
|
@ -15,7 +15,7 @@ from tinygrad.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, Buf
|
|||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.ops import Variable
|
||||
inf, nan = float('inf'), float('nan')
|
||||
from tinygrad.codegen.kernel import Opt, OptOps
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ from tinygrad.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, Buf
|
|||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.ops import Variable
|
||||
inf, nan = float('inf'), float('nan')
|
||||
from tinygrad.codegen.kernel import Opt, OptOps
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ from tinygrad.ops import UOp, UOps, KernelInfo
|
|||
from tinygrad.dtype import dtypes, PtrDType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.shape.symbolic import NumNode
|
||||
from tinygrad.ops import NumNode
|
||||
inf, nan = float('inf'), float('nan')
|
||||
|
||||
# kernel unpacker
|
||||
|
|
|
@ -12,7 +12,7 @@ from tinygrad.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, Buf
|
|||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.ops import Variable
|
||||
inf, nan = float('inf'), float('nan')
|
||||
from tinygrad.codegen.kernel import Opt, OptOps
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ from extra.optimization.helpers import load_worlds, ast_str_to_ast
|
|||
from tinygrad.helpers import prod, tqdm
|
||||
from tinygrad.ops import UOp, UOps
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import sym_infer, Node
|
||||
from tinygrad.ops import sym_infer, Node
|
||||
|
||||
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto(); AS_STRIDED = auto() # noqa: E702
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ import itertools
|
|||
import random
|
||||
from tinygrad import Variable
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.shape.symbolic import NumNode
|
||||
from tinygrad.ops import NumNode
|
||||
random.seed(42)
|
||||
|
||||
def add_v(expr, rng=None):
|
||||
|
|
|
@ -7,7 +7,7 @@ from tinygrad.ops import END_FOR_UOP, UOp, print_uops
|
|||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.helpers import DEBUG, colored
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.ops import Variable
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from test.external.fuzz_schedule import FUZZ_SCHEDULE_MAX_PATHS, find_all_toposorts
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ from tinygrad.engine.realize import Runner
|
|||
from tinygrad.dtype import ConstType, DType
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.helpers import Context, CI, OSX, getenv
|
||||
from tinygrad.shape.symbolic import sint
|
||||
from tinygrad.ops import sint
|
||||
|
||||
def derandomize_model(model):
|
||||
with Context(GRAPH=0):
|
||||
|
|
|
@ -10,7 +10,7 @@ from tinygrad.ops import UOp, UOps, BinaryOps, TernaryOps, UnaryOps
|
|||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
# from tinygrad.shape.symbolic import Variable
|
||||
# from tinygrad.ops import Variable
|
||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.engine.schedule import BUF_LIMIT, create_schedule
|
||||
from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import unittest
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
from tinygrad import Variable
|
||||
from tinygrad.shape.symbolic import NumNode
|
||||
from tinygrad.ops import NumNode
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
class TestSymbolic(unittest.TestCase):
|
||||
|
|
|
@ -5,7 +5,7 @@ from tinygrad.helpers import Context, ContextVar
|
|||
from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv
|
||||
from tinygrad.tensor import get_shape
|
||||
from tinygrad.codegen.lowerer import get_contraction
|
||||
from tinygrad.shape.symbolic import NumNode
|
||||
from tinygrad.ops import NumNode
|
||||
import numpy as np
|
||||
|
||||
VARIABLE = ContextVar("VARIABLE", 0)
|
||||
|
|
|
@ -5,7 +5,7 @@ from tinygrad.dtype import dtypes
|
|||
from tinygrad.helpers import prod
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
from tinygrad import Variable
|
||||
from tinygrad.shape.symbolic import NumNode
|
||||
from tinygrad.ops import NumNode
|
||||
from tinygrad.ops import UOp, UOps, graph_rewrite
|
||||
from tinygrad.codegen.uopgraph import sym
|
||||
from itertools import product
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#!/usr/bin/env python
|
||||
import unittest, pickle
|
||||
from typing import Tuple
|
||||
#from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, LtNode, ModNode, Node, sym_render, sym_infer, create_lt_node, create_ge_node
|
||||
#from tinygrad.ops import MulNode, SumNode, Variable, NumNode, LtNode, ModNode, Node, sym_render, sym_infer, create_lt_node, create_ge_node
|
||||
|
||||
# TODO: fix all the @unittest.expectedFailure
|
||||
|
||||
|
|
|
@ -6,14 +6,13 @@ from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict
|
|||
from enum import Enum, auto
|
||||
|
||||
from tinygrad.ops import UNSAFE_PAD_OPS, BUFFER_UOPS, BinaryOps, KernelInfo, UOp, UOps, PatternMatcher, print_uops, type_verify, resolve, \
|
||||
graph_rewrite, track_rewrites
|
||||
graph_rewrite, track_rewrites, Variable, sint
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.renderer import Renderer, TensorCore, Program
|
||||
from tinygrad.dtype import ImageDType, PtrDType
|
||||
from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put
|
||||
from tinygrad.helpers import DEBUG, TC_OPT, USE_TC, AMX
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
from tinygrad.codegen.linearize import linearize_uop
|
||||
from tinygrad.codegen.uopgraph import full_graph_rewrite
|
||||
|
|
|
@ -5,9 +5,8 @@ from dataclasses import dataclass
|
|||
from typing import List, Tuple, cast, Optional
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import variable_to_uop
|
||||
from tinygrad.shape.symbolic import sint
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import KernelInfo, BinaryOps, UOp, UOps, graph_rewrite, PatternMatcher, UPat, resolve
|
||||
from tinygrad.ops import KernelInfo, BinaryOps, UOp, UOps, graph_rewrite, PatternMatcher, UPat, resolve, sint
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.helpers import all_int, prod, partition, flatten
|
||||
|
||||
|
|
|
@ -6,9 +6,8 @@ from tinygrad.engine.lazy import LazyBuffer
|
|||
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, GRAPH, BEAM, getenv, colored, JIT, dedup, partition
|
||||
from tinygrad.device import Buffer, Compiled, Device
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.ops import UOp, ssimplify
|
||||
from tinygrad.ops import UOp, ssimplify, Variable, sint, sym_infer
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import Variable, sint, sym_infer
|
||||
from tinygrad.engine.realize import ExecItem, capturing, EmptyOp, ViewOp, BufferXfer, CompiledRunner, Runner, _internal_memory_planner
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from dataclasses import dataclass
|
||||
|
|
|
@ -3,8 +3,7 @@ from typing import Union, Optional, Any, Tuple, List, get_args
|
|||
from tinygrad.dtype import dtypes, DType, ConstType, to_dtype, ImageDType
|
||||
from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata, SPLIT_REDUCEOP
|
||||
from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu, REDUCE_ALU
|
||||
from tinygrad.ops import identity_element, MathTrait, resolve, UOp
|
||||
from tinygrad.shape.symbolic import sint
|
||||
from tinygrad.ops import identity_element, MathTrait, resolve, UOp, sint
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.device import Buffer
|
||||
from weakref import ref, ReferenceType, WeakValueDictionary
|
||||
|
|
|
@ -4,10 +4,9 @@ from collections import defaultdict
|
|||
from dataclasses import dataclass, replace
|
||||
from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, Context, TRACEMETA, dedup
|
||||
from tinygrad.helpers import NO_MEMORY_PLANNER
|
||||
from tinygrad.ops import UOps, UOp
|
||||
from tinygrad.ops import UOps, UOp, Variable, sym_infer, sint
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.shape.symbolic import Variable, sym_infer, sint
|
||||
from tinygrad.renderer import Renderer, Program
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.engine.schedule import ScheduleItem
|
||||
|
@ -164,9 +163,10 @@ class ExecItem:
|
|||
prg: Runner
|
||||
bufs: List[Optional[Buffer]]
|
||||
metadata: Optional[Tuple[Metadata, ...]] = None
|
||||
def run(self, var_vals:Optional[Dict[Variable, int]]=None, wait=False, jit=False, do_update_stats=True) -> Optional[float]:
|
||||
def run(self, _var_vals:Optional[Dict[Variable, int]]=None, wait=False, jit=False, do_update_stats=True) -> Optional[float]:
|
||||
var_vals = {} if _var_vals is None else _var_vals
|
||||
bufs = [cast(Buffer, x) for x in self.bufs] if jit else [cast(Buffer, x).ensure_allocated() for x in self.bufs]
|
||||
et = self.prg(bufs, var_vals if var_vals is not None else {}, wait=wait or DEBUG >= 2)
|
||||
et = self.prg(bufs, var_vals, wait=wait or DEBUG >= 2)
|
||||
if do_update_stats:
|
||||
GlobalCounters.kernel_count += 1
|
||||
GlobalCounters.global_ops += (op_est:=sym_infer(self.prg.op_estimate, var_vals))
|
||||
|
|
|
@ -3,10 +3,9 @@ from collections import defaultdict, deque
|
|||
from dataclasses import dataclass
|
||||
from typing import Callable, Tuple, List, Dict, Optional, DefaultDict, cast
|
||||
from tinygrad.ops import REDUCE_ALU, UNSAFE_PAD_OPS, MetaOps, ReduceOps, TernaryOps, UnaryOps, UOp, UOps, PatternMatcher, UPat, resolve, \
|
||||
graph_rewrite, track_rewrites
|
||||
graph_rewrite, track_rewrites, Variable, sint
|
||||
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, GlobalCounters, Metadata, all_same, \
|
||||
colored, diskcache_put, prod, dedup, all_int, merge_dicts, getenv, unwrap
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
from tinygrad.dtype import ImageDType, dtypes
|
||||
from tinygrad.engine.lazy import LazyBuffer
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
|
|
|
@ -2,14 +2,12 @@ from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
|
|||
import itertools, functools, random, math, time, multiprocessing, traceback, signal
|
||||
from collections import defaultdict
|
||||
from dataclasses import replace
|
||||
from tinygrad.ops import UOp, UOps
|
||||
from tinygrad.ops import UOp, UOps, Variable, sym_infer
|
||||
from tinygrad.device import Device, Buffer, Compiler
|
||||
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError
|
||||
from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.shape.symbolic import Variable, sym_infer
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.renderer import Program
|
||||
|
||||
|
|
|
@ -3,10 +3,9 @@ import math
|
|||
from typing import Tuple, Optional
|
||||
from tinygrad.helpers import argsort
|
||||
from tinygrad.dtype import dtypes, DType, sum_acc_dtype
|
||||
from tinygrad.ops import ReduceOps, resolve
|
||||
from tinygrad.ops import ReduceOps, resolve, sint
|
||||
from tinygrad.tensor import Function
|
||||
from tinygrad.engine.lazy import LazyBuffer
|
||||
from tinygrad.shape.symbolic import sint
|
||||
|
||||
class Contiguous(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer: return x.contiguous()
|
||||
|
|
|
@ -7,7 +7,6 @@ from weakref import WeakValueDictionary
|
|||
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate
|
||||
from tinygrad.helpers import ContextVar, prod, getenv, all_same, Context
|
||||
if TYPE_CHECKING:
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
|
||||
# wrapper around IntEnum that preserves Enum.__str__ and makes auto() unique across all FastEnum subclasses
|
||||
|
@ -154,7 +153,7 @@ COMMUTATIVE = {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, Bin
|
|||
END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.ASSIGN, UOps.ENDRANGE)}
|
||||
|
||||
# With True as the default, this matches the old symbolic behavior
|
||||
# python3 -c 'from tinygrad.shape.symbolic import Variable; print(bool(Variable("a", 1, 10) < 4))' -> True
|
||||
# python3 -c 'from tinygrad.ops import Variable; print(bool(Variable("a", 1, 10) < 4))' -> True
|
||||
def resolve(x, default:bool=True):
|
||||
if not isinstance(x, UOp): return bool(x)
|
||||
assert x.dtype is dtypes.bool, "UOp in resolve must be bool"
|
||||
|
@ -306,10 +305,8 @@ class UOp(MathTrait):
|
|||
st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFFER_UOPS]
|
||||
return sorted(set.union(*st_vars, [x.unbind()[0] if x.op is not UOps.DEFINE_VAR else x for x in self.vars()]), key=lambda v: v.arg)
|
||||
|
||||
# TODO: this is context rewrite
|
||||
def substitute(self, dvars:Dict[UOp, UOp]):
|
||||
if self in dvars: return dvars[self]
|
||||
return self.replace(src=tuple(x.substitute(dvars) for x in self.src))
|
||||
def substitute(self, dvars:Dict[UOp, UOp]): return graph_rewrite(self, substitute, dvars)
|
||||
|
||||
@staticmethod
|
||||
def range(dtype:DType, start:ConstType|UOp, end:ConstType|UOp, idx:int):
|
||||
return UOp(UOps.RANGE, dtype=dtype, src=(UOp.const(dtype, start) if not isinstance(start, UOp) else start,
|
||||
|
@ -967,6 +964,8 @@ symbolic_flat = symbolic+PatternMatcher([
|
|||
((UPat.var("x", dtypes.ints) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c),
|
||||
])
|
||||
|
||||
substitute = PatternMatcher([(UPat(tuple(UOps), name="x"), lambda ctx,x: ctx.get(x,None))])
|
||||
|
||||
# for debug
|
||||
renderer = PatternMatcher([
|
||||
(UPat(UOps.DEFINE_VAR, name="x"), lambda x: UOp(UOps.NOOP, arg=x.arg[0])),
|
||||
|
@ -979,3 +978,12 @@ renderer = PatternMatcher([
|
|||
(UPat(UOps.ALU, src=UPat(UOps.NOOP), arg=BinaryOps.CMPLT, name="x"), lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}<{x.src[1].arg})")),
|
||||
(UPat(UOps.ALU, src=UPat(UOps.NOOP), arg=BinaryOps.CMPNE, name="x"), lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}!={x.src[1].arg})")),
|
||||
])
|
||||
|
||||
# *** what was symbolic.py ***
|
||||
|
||||
sint = Union[int, UOp]
|
||||
Variable = UOp
|
||||
|
||||
def NumNode(val:int): return UOp.const(dtypes.int, val)
|
||||
def sym_infer(uop: Union[UOp, int], var_vals: Dict[UOp, int]) -> int:
|
||||
return int(uop.substitute({k:k.const_like(v) for k,v in var_vals.items()})) if isinstance(uop, UOp) else uop
|
||||
|
|
|
@ -2,8 +2,7 @@ from typing import Optional, List, Tuple, Dict, Callable, Any
|
|||
import functools
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.helpers import to_function_name, dedup, prod
|
||||
from tinygrad.ops import Op, UOps, UOp, flops_mem
|
||||
from tinygrad.shape.symbolic import sym_infer, sint, Variable
|
||||
from tinygrad.ops import Op, UOps, UOp, flops_mem, sym_infer, sint, Variable
|
||||
from tinygrad.dtype import DType
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
|
|
@ -4,7 +4,7 @@ from tinygrad.helpers import dedup, cpu_time_execution, DEBUG
|
|||
from tinygrad.engine.jit import GraphRunner, GraphException
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.engine.realize import ExecItem, CompiledRunner
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.ops import Variable
|
||||
from tinygrad.runtime.ops_clang import ClangProgram
|
||||
from tinygrad.renderer.cstyle import ClangRenderer
|
||||
render_dtype = ClangRenderer().render_dtype
|
||||
|
|
|
@ -4,7 +4,7 @@ import tinygrad.runtime.autogen.cuda as cuda
|
|||
from tinygrad.helpers import init_c_var, dedup
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.ops import Variable
|
||||
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
|
||||
from tinygrad.engine.jit import MultiGraphRunner, GraphException
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import List, Any, Dict, cast, Optional, Tuple, Set
|
|||
from tinygrad.helpers import round_up, PROFILE, memsize_to_str
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWCommandQueue, HWComputeQueue, HWCopyQueue, HCQArgsState
|
||||
from tinygrad.device import Buffer, BufferOptions, Compiled, Device
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.ops import Variable
|
||||
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
|
||||
from tinygrad.engine.jit import MultiGraphRunner
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ from tinygrad.helpers import dedup, getenv
|
|||
from tinygrad.device import Buffer
|
||||
from tinygrad.engine.realize import ExecItem, CompiledRunner
|
||||
from tinygrad.engine.jit import GraphRunner, GraphException
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.ops import Variable
|
||||
from tinygrad.runtime.ops_metal import wait_check, msg, libobjc, to_struct, objc_instance,\
|
||||
MTLResourceOptions, elapsed_time, objc_id
|
||||
|
||||
|
|
|
@ -3,10 +3,9 @@ from __future__ import annotations
|
|||
from dataclasses import dataclass
|
||||
from typing import Tuple, List, Optional, Dict, Set
|
||||
from tinygrad.helpers import merge_dicts, getenv
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
from tinygrad.shape.view import View, strides_for_shape
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import UOp, UOps, BinaryOps, graph_rewrite, _get_chain, symbolic_flat
|
||||
from tinygrad.ops import UOp, UOps, BinaryOps, graph_rewrite, _get_chain, symbolic_flat, Variable, sint
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ShapeTracker:
|
||||
|
|
|
@ -1,19 +0,0 @@
|
|||
from __future__ import annotations
|
||||
from typing import Union, Optional, Dict
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import UOp, UOps, exec_alu
|
||||
|
||||
sint = Union[int, UOp]
|
||||
|
||||
def NumNode(val:int): return UOp.const(dtypes.int, val)
|
||||
Variable = UOp
|
||||
|
||||
def sym_infer(uop: Union[UOp, int], var_vals: Optional[Dict[Variable, int]]) -> int:
|
||||
if isinstance(uop, (int, float)): return uop # TODO: ugh, the float is a hack for qcom
|
||||
if uop.op == UOps.CONST: return uop.arg
|
||||
if uop.op == UOps.DEFINE_VAR and var_vals is not None: return var_vals[uop]
|
||||
if uop.op == UOps.BIND: return uop.src[1].arg # bound variable returns bound value
|
||||
if uop.op == UOps.ALU:
|
||||
src_values = [sym_infer(src, var_vals) for src in uop.src]
|
||||
return exec_alu(uop.arg, uop.dtype, src_values)
|
||||
raise NotImplementedError(f"Unsupported UOp {uop.op}")
|
|
@ -3,9 +3,8 @@ import functools, operator, itertools, math
|
|||
from dataclasses import dataclass
|
||||
from typing import Tuple, List, Optional, Dict, Set, cast, Union
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import resolve, UOp
|
||||
from tinygrad.ops import resolve, UOp, NumNode, Variable, sint, sym_infer
|
||||
from tinygrad.helpers import prod, all_int, argsort
|
||||
from tinygrad.shape.symbolic import NumNode, Variable, sint, sym_infer
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def canonicalize_strides(shape:Tuple[sint, ...], strides:Tuple[sint, ...]) -> Tuple[sint, ...]:
|
||||
|
|
|
@ -9,9 +9,8 @@ from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, leas
|
|||
from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
|
||||
from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv
|
||||
from tinygrad.multi import MultiLazyBuffer
|
||||
from tinygrad.ops import MetaOps, truncate, smax, resolve, UOp, UOps, BinaryOps
|
||||
from tinygrad.ops import MetaOps, truncate, smax, resolve, UOp, UOps, BinaryOps, sint, Variable
|
||||
from tinygrad.device import Device, Buffer, BufferOptions
|
||||
from tinygrad.shape.symbolic import sint, Variable
|
||||
from tinygrad.engine.lazy import LazyBuffer
|
||||
from tinygrad.engine.realize import run_schedule, memory_planner
|
||||
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
|
||||
|
|
Loading…
Reference in New Issue