mirror of https://github.com/commaai/tinygrad.git
fix jitted dist (#1955)
This commit is contained in:
parent
35ac60775b
commit
e1f2c2cc19
|
@ -2,6 +2,7 @@
|
|||
import unittest
|
||||
from tinygrad.runtime.lib import RawBuffer, LRUAllocator
|
||||
from tinygrad.helpers import dtypes
|
||||
from tinygrad.ops import ASTRunner
|
||||
from tinygrad.jit import CacheCollector
|
||||
from weakref import ref
|
||||
|
||||
|
@ -33,7 +34,7 @@ def anybuf(size, dtype):
|
|||
return FakeBuffer(size, dtype)
|
||||
|
||||
def add_to_cache(bufs):
|
||||
CacheCollector.add(None, bufs, None)
|
||||
CacheCollector.add(ASTRunner("", None), bufs, None)
|
||||
return bufs[0]
|
||||
|
||||
def add_to_cache_refed(bufs):
|
||||
|
|
|
@ -3,7 +3,7 @@ from weakref import ref
|
|||
from collections import defaultdict
|
||||
import functools, itertools
|
||||
from tinygrad.helpers import DEBUG, DType, merge_dicts, ImageDType
|
||||
from tinygrad.ops import RawBuffer, Device, BasicBatchExecutor
|
||||
from tinygrad.ops import RawBuffer, Device, BasicBatchExecutor, ASTRunner
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
|
@ -81,7 +81,7 @@ class _CacheCollector:
|
|||
if self.cache is None: return
|
||||
# Substitute output buffers with placeholders to find the most optimal reusage.
|
||||
if ref(rawbufs[0]) not in self.placeholders: self.placeholders[ref(rawbufs[0])] = _CacheCollector._Placeholder(rawbufs[0])
|
||||
cached_rawbufs = [self.placeholders.get(ref(buf), buf) if isinstance(buf, RawBuffer) and ref(buf) not in self.circular_signatures else buf for buf in rawbufs]
|
||||
cached_rawbufs = [self.placeholders.get(ref(buf), buf) if isinstance(prg, ASTRunner) and isinstance(buf, RawBuffer) and ref(buf) not in self.circular_signatures else buf for buf in rawbufs]
|
||||
self.cache.append((prg, cached_rawbufs, var_vals))
|
||||
def finish(self):
|
||||
if self.cache is None: return []
|
||||
|
@ -101,9 +101,7 @@ class _CacheCollector:
|
|||
query_list = sorted([(buf.size*buf.dtype.itemsize, buf_usage_bounds[buf][0], buf_usage_bounds[buf][1], buf) for buf in buf_usage_bounds.keys()], key=lambda x: x[0], reverse=True)
|
||||
for _, start, end, buf in query_list:
|
||||
pool_idx = next((i for i,(with_buf, usages) in enumerate(rawbuf_pool) if self._can_substitute(buf, with_buf) and self._no_intersect(start,end,usages)), -1)
|
||||
if pool_idx == -1:
|
||||
rawbuf_pool.append((buf.alloc_rawbuf(), []))
|
||||
pool_idx = len(rawbuf_pool) - 1
|
||||
if pool_idx == -1: rawbuf_pool.append((buf.alloc_rawbuf(), []))
|
||||
buf_map[buf] = rawbuf_pool[pool_idx][0]
|
||||
rawbuf_pool[pool_idx][1].append((start, end))
|
||||
|
||||
|
|
Loading…
Reference in New Issue