mirror of https://github.com/commaai/tinygrad.git
delete SAVE_SCHEDULE=1 [pr] (#7087)
This commit is contained in:
parent
3169cb386d
commit
390171d686
|
@ -60,7 +60,7 @@ class BenchmarkResnetTrain(unittest.TestCase):
|
|||
return f"{name} x{(bs, cin, xy, xy)}", [layer], cin, xy
|
||||
def _test_layer(self, name, layer, cin, xy):
|
||||
optim = SGD(get_parameters(layer), bs / 128 * 1.0) # need sgd for some params but not consequential for benchmarking
|
||||
with Context(SAVE_SCHEDULE=0): Tensor.realize(*[t.assign(t.detach().contiguous()) for t in get_parameters(optim)])
|
||||
with Context(TRACK_MATCH_STATS=0): Tensor.realize(*[t.assign(t.detach().contiguous()) for t in get_parameters(optim)])
|
||||
|
||||
JITCNT = getenv("JITCNT", 1)
|
||||
Tensor.training = True
|
||||
|
@ -81,7 +81,7 @@ class BenchmarkResnetTrain(unittest.TestCase):
|
|||
best_tm = None
|
||||
flops, mem_used, mem, kernels = None, None, None, None
|
||||
for i in range(CNT):
|
||||
with Context(SAVE_SCHEDULE=0): x = Tensor.randn(bs, cin, xy, xy, requires_grad=True).realize()
|
||||
with Context(TRACK_MATCH_STATS=0): x = Tensor.randn(bs, cin, xy, xy, requires_grad=True).realize()
|
||||
GlobalCounters.reset()
|
||||
|
||||
st = time.perf_counter()
|
||||
|
|
|
@ -38,7 +38,7 @@ class BenchmarkBertTrain(unittest.TestCase):
|
|||
|
||||
def _test_layer(self, name, layer, input_shapes):
|
||||
optim = LAMB(get_parameters(layer))
|
||||
with Context(SAVE_SCHEDULE=0): Tensor.realize(*[t.assign(t.detach().contiguous()) for t in get_parameters(optim)])
|
||||
with Context(TRACK_MATCH_STATS=0): Tensor.realize(*[t.assign(t.detach().contiguous()) for t in get_parameters(optim)])
|
||||
|
||||
JITCNT = getenv("JITCNT", 1)
|
||||
Tensor.training = True
|
||||
|
@ -59,7 +59,7 @@ class BenchmarkBertTrain(unittest.TestCase):
|
|||
best_tm = None
|
||||
flops, mem_used, mem, kernels = None, None, None, None
|
||||
for _ in range(CNT):
|
||||
with Context(SAVE_SCHEDULE=0): inputs = [Tensor.randn(*shape, requires_grad=False).realize() for shape in input_shapes]
|
||||
with Context(TRACK_MATCH_STATS=0): inputs = [Tensor.randn(*shape, requires_grad=False).realize() for shape in input_shapes]
|
||||
GlobalCounters.reset()
|
||||
|
||||
st = time.perf_counter()
|
||||
|
|
|
@ -1587,7 +1587,7 @@ class TestIndexing(unittest.TestCase):
|
|||
X = Tensor.randn(2,3,4,4).numpy()
|
||||
with Context(FUSE_ARANGE=1):
|
||||
compare = Tensor(X).interpolate(size=(2, 2), mode="linear").numpy()
|
||||
with Context(FUSE_ARANGE=0, SAVE_SCHEDULE=1):
|
||||
with Context(FUSE_ARANGE=0, TRACK_MATCH_STATS=0):
|
||||
ref = Tensor(X).interpolate(size=(2, 2), mode="linear").numpy()
|
||||
np.testing.assert_allclose(ref, compare, atol=1e-5, rtol=1e-6)
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ from dataclasses import dataclass, replace
|
|||
from collections import defaultdict
|
||||
from typing import List, Optional, Dict, Tuple, Any, Iterator
|
||||
import multiprocessing, importlib, inspect, functools, pathlib, os, ctypes, contextlib
|
||||
from tinygrad.helpers import SAVE_SCHEDULE, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv
|
||||
from tinygrad.helpers import getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv
|
||||
from tinygrad.dtype import DType, ImageDType
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
|
@ -94,7 +94,7 @@ class Buffer:
|
|||
if self._base is not None:
|
||||
return self.__class__, (self.device, self.size, self.dtype, None, None, None, 0, self.base, self.offset, hasattr(self, '_buf'))
|
||||
if self.device == "NPY": return self.__class__, (self.device, self.size, self.dtype, self._buf, self.options, None, self.lb_refcount)
|
||||
if self.is_allocated() and not SAVE_SCHEDULE:
|
||||
if self.is_allocated():
|
||||
buf = bytearray(self.nbytes)
|
||||
self.copyout(memoryview(buf))
|
||||
return self.__class__, (self.device, self.size, self.dtype, None, self.options, buf, self.lb_refcount)
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
import sys, pickle, atexit
|
||||
import sys, atexit
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Tuple, List, Dict, Optional, DefaultDict, cast
|
||||
from tinygrad.ops import BUFFER_UOPS, UNSAFE_PAD_OPS, MetaOps, ReduceOps, UnaryOps, UOp, UOps, PatternMatcher, UPat, Variable, resolve, \
|
||||
graph_rewrite, track_rewrites, sint
|
||||
from tinygrad.helpers import DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, Metadata, all_same, \
|
||||
colored, diskcache_put, prod, dedup, all_int, merge_dicts, getenv, unwrap
|
||||
from tinygrad.helpers import DEBUG, MULTIOUTPUT, FUSE_CONV_BW, FUSE_ARANGE, Metadata, all_same, colored, diskcache_put, prod, dedup, all_int, \
|
||||
merge_dicts, getenv, unwrap
|
||||
from tinygrad.dtype import ImageDType, dtypes
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View, strides_for_shape
|
||||
|
@ -274,7 +274,6 @@ def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuff
|
|||
for tr in group: _recursive_group(tr, tr.st, tr, children, realizes, reduce_for_op, descendants, cache={})
|
||||
return merge_dicts([group, {} if any(tr in group for tr in descendants) else descendants])
|
||||
|
||||
SCHEDULES: List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]], DefaultDict[LBScheduleItem, int]]] = []
|
||||
def _graph_schedule(outs:List[LazyBuffer]) -> \
|
||||
Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]], # this is the graph
|
||||
DefaultDict[LBScheduleItem, int], # this is the in-degree of the graph
|
||||
|
@ -396,13 +395,6 @@ def _graph_schedule(outs:List[LazyBuffer]) -> \
|
|||
for assign in parents_assigns:
|
||||
graph[lsi].append(assign)
|
||||
in_degree[assign] += 1
|
||||
|
||||
if SAVE_SCHEDULE:
|
||||
def _save():
|
||||
print(f"saving {len(SCHEDULES)} schedule graphs to", fp:=getenv("SAVE_SCHEDULE_PATH", "schedule.pkl"))
|
||||
with open(fp, "wb") as f: pickle.dump(SCHEDULES, f)
|
||||
if len(SCHEDULES) == 0: atexit.register(_save)
|
||||
SCHEDULES.append((graph, in_degree))
|
||||
return graph, in_degree, var_vals
|
||||
|
||||
# *** DAG ordering: breadth first search ***
|
||||
|
|
|
@ -100,11 +100,10 @@ class ContextVar:
|
|||
|
||||
DEBUG, IMAGE, BEAM, NOOPT, JIT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0), ContextVar("JIT", 1)
|
||||
WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1)
|
||||
SAVE_SCHEDULE, RING = ContextVar("SAVE_SCHEDULE", 0), ContextVar("RING", 1)
|
||||
MULTIOUTPUT, PROFILE, PROFILEPATH = ContextVar("MULTIOUTPUT", 1), ContextVar("PROFILE", 0), ContextVar("PROFILEPATH", temp("tinygrad_profile.json"))
|
||||
USE_TC, TC_OPT, AMX, TRANSCENDENTAL = ContextVar("TC", 1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0), ContextVar("TRANSCENDENTAL", 1)
|
||||
FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_BW", 0)
|
||||
SPLIT_REDUCEOP, NO_MEMORY_PLANNER = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0)
|
||||
SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Metadata:
|
||||
|
|
Loading…
Reference in New Issue