Revert "Revert "Fix ShapeTracker mismatch in LazyBuffer.fromCPU (#1156)" (#1181)" + add test

This reverts commit a374b62bfe.
This commit is contained in:
George Hotz 2023-07-07 18:35:05 -07:00
parent 2952b8e7a8
commit 0ad99038ef
2 changed files with 28 additions and 2 deletions

26
test/test_lazybuffer.py Normal file
View File

@ -0,0 +1,26 @@
#!/usr/bin/env python
import numpy as np
import unittest
from tinygrad.lazy import LazyBuffer
class TestLazyBuffer(unittest.TestCase):
def test_fromcpu_buffer_sharing(self):
a = np.arange(8)
assert LazyBuffer.fromCPU(a).realized._buf is a
def test_fromcpu_shape_tracker(self):
def helper(a: np.ndarray):
print(a.shape, a.strides, a.flags.c_contiguous)
b = LazyBuffer.fromCPU(a).realize()
assert b.st.contiguous == a.flags.c_contiguous
assert b.st.shape == a.shape
np.testing.assert_equal(a, b.toCPU())
for ndims in range(1, 4):
a = np.random.randn(*(4,)*ndims).astype(np.float32)
for stride in [-2, 1, 2]:
for start in [0, 1]:
helper(a[(slice(start, None, stride),)*ndims])
if __name__ == "__main__":
unittest.main()

View File

@ -8,7 +8,7 @@ import numpy as np
from tinygrad.helpers import GRAPH, DEBUG, prod, getenv, DType, dtypes, flatten, ImageDType, LightWeakSet, LightWeakValueDictionary from tinygrad.helpers import GRAPH, DEBUG, prod, getenv, DType, dtypes, flatten, ImageDType, LightWeakSet, LightWeakValueDictionary
from tinygrad.runtime.ops_cpu import RawNumpyBuffer from tinygrad.runtime.ops_cpu import RawNumpyBuffer
from tinygrad.runtime.ops_disk import RawDiskBuffer from tinygrad.runtime.ops_disk import RawDiskBuffer
from tinygrad.shape.shapetracker import MovementOps, ShapeTracker, get_contraction from tinygrad.shape.shapetracker import MovementOps, ShapeTracker, View, get_contraction
from tinygrad.ops import Compiled, Interpreted, UnaryOps, BinaryOps, ReduceOps, LoadOps, OpType, LazyOp from tinygrad.ops import Compiled, Interpreted, UnaryOps, BinaryOps, ReduceOps, LoadOps, OpType, LazyOp
from tinygrad.runtime.lib import RawBufferMapped, RawConst, RawBuffer from tinygrad.runtime.lib import RawBufferMapped, RawConst, RawBuffer
@ -176,7 +176,7 @@ class LazyBuffer:
@staticmethod @staticmethod
def fromCPU(x: np.ndarray) -> LazyBuffer: def fromCPU(x: np.ndarray) -> LazyBuffer:
return LazyBuffer("CPU", ShapeTracker(x.shape), LoadOps, LazyOp(LoadOps.EMPTY, (), None), dtypes.from_np(x.dtype), RawNumpyBuffer.fromCPU(x)) return LazyBuffer("CPU", ShapeTracker(x.shape, [View(x.shape, tuple(st//x.itemsize for st in x.strides))]), LoadOps, LazyOp(LoadOps.EMPTY, (), None), dtypes.from_np(x.dtype), RawNumpyBuffer.fromCPU(x))
# create a constant with the shape and dtype of self # create a constant with the shape and dtype of self
def const_like(self, val) -> LazyBuffer: def const_like(self, val) -> LazyBuffer: