fix jitted dist (#1955)

This commit is contained in:
nimlgen 2023-10-02 18:45:13 +03:00 committed by GitHub
parent 35ac60775b
commit e1f2c2cc19
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 6 deletions

View File

@ -2,6 +2,7 @@
import unittest import unittest
from tinygrad.runtime.lib import RawBuffer, LRUAllocator from tinygrad.runtime.lib import RawBuffer, LRUAllocator
from tinygrad.helpers import dtypes from tinygrad.helpers import dtypes
from tinygrad.ops import ASTRunner
from tinygrad.jit import CacheCollector from tinygrad.jit import CacheCollector
from weakref import ref from weakref import ref
@ -33,7 +34,7 @@ def anybuf(size, dtype):
return FakeBuffer(size, dtype) return FakeBuffer(size, dtype)
def add_to_cache(bufs): def add_to_cache(bufs):
CacheCollector.add(None, bufs, None) CacheCollector.add(ASTRunner("", None), bufs, None)
return bufs[0] return bufs[0]
def add_to_cache_refed(bufs): def add_to_cache_refed(bufs):

View File

@ -3,7 +3,7 @@ from weakref import ref
from collections import defaultdict from collections import defaultdict
import functools, itertools import functools, itertools
from tinygrad.helpers import DEBUG, DType, merge_dicts, ImageDType from tinygrad.helpers import DEBUG, DType, merge_dicts, ImageDType
from tinygrad.ops import RawBuffer, Device, BasicBatchExecutor from tinygrad.ops import RawBuffer, Device, BasicBatchExecutor, ASTRunner
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable from tinygrad.shape.symbolic import Variable
@ -81,7 +81,7 @@ class _CacheCollector:
if self.cache is None: return if self.cache is None: return
# Substitute output buffers with placeholders to find the most optimal reusage. # Substitute output buffers with placeholders to find the most optimal reusage.
if ref(rawbufs[0]) not in self.placeholders: self.placeholders[ref(rawbufs[0])] = _CacheCollector._Placeholder(rawbufs[0]) if ref(rawbufs[0]) not in self.placeholders: self.placeholders[ref(rawbufs[0])] = _CacheCollector._Placeholder(rawbufs[0])
cached_rawbufs = [self.placeholders.get(ref(buf), buf) if isinstance(buf, RawBuffer) and ref(buf) not in self.circular_signatures else buf for buf in rawbufs] cached_rawbufs = [self.placeholders.get(ref(buf), buf) if isinstance(prg, ASTRunner) and isinstance(buf, RawBuffer) and ref(buf) not in self.circular_signatures else buf for buf in rawbufs]
self.cache.append((prg, cached_rawbufs, var_vals)) self.cache.append((prg, cached_rawbufs, var_vals))
def finish(self): def finish(self):
if self.cache is None: return [] if self.cache is None: return []
@ -101,9 +101,7 @@ class _CacheCollector:
query_list = sorted([(buf.size*buf.dtype.itemsize, buf_usage_bounds[buf][0], buf_usage_bounds[buf][1], buf) for buf in buf_usage_bounds.keys()], key=lambda x: x[0], reverse=True) query_list = sorted([(buf.size*buf.dtype.itemsize, buf_usage_bounds[buf][0], buf_usage_bounds[buf][1], buf) for buf in buf_usage_bounds.keys()], key=lambda x: x[0], reverse=True)
for _, start, end, buf in query_list: for _, start, end, buf in query_list:
pool_idx = next((i for i,(with_buf, usages) in enumerate(rawbuf_pool) if self._can_substitute(buf, with_buf) and self._no_intersect(start,end,usages)), -1) pool_idx = next((i for i,(with_buf, usages) in enumerate(rawbuf_pool) if self._can_substitute(buf, with_buf) and self._no_intersect(start,end,usages)), -1)
if pool_idx == -1: if pool_idx == -1: rawbuf_pool.append((buf.alloc_rawbuf(), []))
rawbuf_pool.append((buf.alloc_rawbuf(), []))
pool_idx = len(rawbuf_pool) - 1
buf_map[buf] = rawbuf_pool[pool_idx][0] buf_map[buf] = rawbuf_pool[pool_idx][0]
rawbuf_pool[pool_idx][1].append((start, end)) rawbuf_pool[pool_idx][1].append((start, end))