mirror of https://github.com/commaai/tinygrad.git
fix jitted dist (#1955)
This commit is contained in:
parent
35ac60775b
commit
e1f2c2cc19
|
@ -2,6 +2,7 @@
|
||||||
import unittest
|
import unittest
|
||||||
from tinygrad.runtime.lib import RawBuffer, LRUAllocator
|
from tinygrad.runtime.lib import RawBuffer, LRUAllocator
|
||||||
from tinygrad.helpers import dtypes
|
from tinygrad.helpers import dtypes
|
||||||
|
from tinygrad.ops import ASTRunner
|
||||||
from tinygrad.jit import CacheCollector
|
from tinygrad.jit import CacheCollector
|
||||||
from weakref import ref
|
from weakref import ref
|
||||||
|
|
||||||
|
@ -33,7 +34,7 @@ def anybuf(size, dtype):
|
||||||
return FakeBuffer(size, dtype)
|
return FakeBuffer(size, dtype)
|
||||||
|
|
||||||
def add_to_cache(bufs):
|
def add_to_cache(bufs):
|
||||||
CacheCollector.add(None, bufs, None)
|
CacheCollector.add(ASTRunner("", None), bufs, None)
|
||||||
return bufs[0]
|
return bufs[0]
|
||||||
|
|
||||||
def add_to_cache_refed(bufs):
|
def add_to_cache_refed(bufs):
|
||||||
|
|
|
@ -3,7 +3,7 @@ from weakref import ref
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import functools, itertools
|
import functools, itertools
|
||||||
from tinygrad.helpers import DEBUG, DType, merge_dicts, ImageDType
|
from tinygrad.helpers import DEBUG, DType, merge_dicts, ImageDType
|
||||||
from tinygrad.ops import RawBuffer, Device, BasicBatchExecutor
|
from tinygrad.ops import RawBuffer, Device, BasicBatchExecutor, ASTRunner
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.shape.shapetracker import ShapeTracker
|
from tinygrad.shape.shapetracker import ShapeTracker
|
||||||
from tinygrad.shape.symbolic import Variable
|
from tinygrad.shape.symbolic import Variable
|
||||||
|
@ -81,7 +81,7 @@ class _CacheCollector:
|
||||||
if self.cache is None: return
|
if self.cache is None: return
|
||||||
# Substitute output buffers with placeholders to find the most optimal reusage.
|
# Substitute output buffers with placeholders to find the most optimal reusage.
|
||||||
if ref(rawbufs[0]) not in self.placeholders: self.placeholders[ref(rawbufs[0])] = _CacheCollector._Placeholder(rawbufs[0])
|
if ref(rawbufs[0]) not in self.placeholders: self.placeholders[ref(rawbufs[0])] = _CacheCollector._Placeholder(rawbufs[0])
|
||||||
cached_rawbufs = [self.placeholders.get(ref(buf), buf) if isinstance(buf, RawBuffer) and ref(buf) not in self.circular_signatures else buf for buf in rawbufs]
|
cached_rawbufs = [self.placeholders.get(ref(buf), buf) if isinstance(prg, ASTRunner) and isinstance(buf, RawBuffer) and ref(buf) not in self.circular_signatures else buf for buf in rawbufs]
|
||||||
self.cache.append((prg, cached_rawbufs, var_vals))
|
self.cache.append((prg, cached_rawbufs, var_vals))
|
||||||
def finish(self):
|
def finish(self):
|
||||||
if self.cache is None: return []
|
if self.cache is None: return []
|
||||||
|
@ -101,9 +101,7 @@ class _CacheCollector:
|
||||||
query_list = sorted([(buf.size*buf.dtype.itemsize, buf_usage_bounds[buf][0], buf_usage_bounds[buf][1], buf) for buf in buf_usage_bounds.keys()], key=lambda x: x[0], reverse=True)
|
query_list = sorted([(buf.size*buf.dtype.itemsize, buf_usage_bounds[buf][0], buf_usage_bounds[buf][1], buf) for buf in buf_usage_bounds.keys()], key=lambda x: x[0], reverse=True)
|
||||||
for _, start, end, buf in query_list:
|
for _, start, end, buf in query_list:
|
||||||
pool_idx = next((i for i,(with_buf, usages) in enumerate(rawbuf_pool) if self._can_substitute(buf, with_buf) and self._no_intersect(start,end,usages)), -1)
|
pool_idx = next((i for i,(with_buf, usages) in enumerate(rawbuf_pool) if self._can_substitute(buf, with_buf) and self._no_intersect(start,end,usages)), -1)
|
||||||
if pool_idx == -1:
|
if pool_idx == -1: rawbuf_pool.append((buf.alloc_rawbuf(), []))
|
||||||
rawbuf_pool.append((buf.alloc_rawbuf(), []))
|
|
||||||
pool_idx = len(rawbuf_pool) - 1
|
|
||||||
buf_map[buf] = rawbuf_pool[pool_idx][0]
|
buf_map[buf] = rawbuf_pool[pool_idx][0]
|
||||||
rawbuf_pool[pool_idx][1].append((start, end))
|
rawbuf_pool[pool_idx][1].append((start, end))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue