fix jitted dist (#1955)

This commit is contained in:
nimlgen 2023-10-02 18:45:13 +03:00 committed by GitHub
parent 35ac60775b
commit e1f2c2cc19
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 6 deletions

View File

@ -2,6 +2,7 @@
import unittest
from tinygrad.runtime.lib import RawBuffer, LRUAllocator
from tinygrad.helpers import dtypes
from tinygrad.ops import ASTRunner
from tinygrad.jit import CacheCollector
from weakref import ref
@ -33,7 +34,7 @@ def anybuf(size, dtype):
return FakeBuffer(size, dtype)
def add_to_cache(bufs):
CacheCollector.add(None, bufs, None)
CacheCollector.add(ASTRunner("", None), bufs, None)
return bufs[0]
def add_to_cache_refed(bufs):

View File

@ -3,7 +3,7 @@ from weakref import ref
from collections import defaultdict
import functools, itertools
from tinygrad.helpers import DEBUG, DType, merge_dicts, ImageDType
from tinygrad.ops import RawBuffer, Device, BasicBatchExecutor
from tinygrad.ops import RawBuffer, Device, BasicBatchExecutor, ASTRunner
from tinygrad.tensor import Tensor
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable
@ -81,7 +81,7 @@ class _CacheCollector:
if self.cache is None: return
# Substitute output buffers with placeholders to find the most optimal reusage.
if ref(rawbufs[0]) not in self.placeholders: self.placeholders[ref(rawbufs[0])] = _CacheCollector._Placeholder(rawbufs[0])
cached_rawbufs = [self.placeholders.get(ref(buf), buf) if isinstance(buf, RawBuffer) and ref(buf) not in self.circular_signatures else buf for buf in rawbufs]
cached_rawbufs = [self.placeholders.get(ref(buf), buf) if isinstance(prg, ASTRunner) and isinstance(buf, RawBuffer) and ref(buf) not in self.circular_signatures else buf for buf in rawbufs]
self.cache.append((prg, cached_rawbufs, var_vals))
def finish(self):
if self.cache is None: return []
@ -101,9 +101,7 @@ class _CacheCollector:
query_list = sorted([(buf.size*buf.dtype.itemsize, buf_usage_bounds[buf][0], buf_usage_bounds[buf][1], buf) for buf in buf_usage_bounds.keys()], key=lambda x: x[0], reverse=True)
for _, start, end, buf in query_list:
pool_idx = next((i for i,(with_buf, usages) in enumerate(rawbuf_pool) if self._can_substitute(buf, with_buf) and self._no_intersect(start,end,usages)), -1)
if pool_idx == -1:
rawbuf_pool.append((buf.alloc_rawbuf(), []))
pool_idx = len(rawbuf_pool) - 1
if pool_idx == -1: rawbuf_pool.append((buf.alloc_rawbuf(), []))
buf_map[buf] = rawbuf_pool[pool_idx][0]
rawbuf_pool[pool_idx][1].append((start, end))