mirror of https://github.com/commaai/tinygrad.git
fix bug in gpu copy out
This commit is contained in:
parent
e87410c531
commit
c1a769b68b
|
@ -4,7 +4,7 @@ import numpy as np
|
|||
from PIL import Image
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
BASEDIR = "/home/batman/imagenet"
|
||||
BASEDIR = "/Users/kafka/fun/imagenet"
|
||||
train_files = open(os.path.join(BASEDIR, "train_files")).read().strip().split("\n")
|
||||
val_files = open(os.path.join(BASEDIR, "val_files")).read().strip().split("\n")
|
||||
ci = json.load(open(os.path.join(BASEDIR, "imagenet_class_index.json")))
|
||||
|
|
|
@ -349,9 +349,9 @@ class GPUBuffer(ExplicitExecAST):
|
|||
def fromCPU(x): return GPUBuffer(x.shape, backing=x.view(np.ndarray).astype(np.float32).ravel())
|
||||
|
||||
def toCPU(self) -> np.ndarray:
|
||||
cl_buf = self.contiguous()
|
||||
cl_buf = self.unary_op(UnaryOps.NOOP) if not self.st.contiguous or prod(self._base_shape) != prod(self.shape) else self
|
||||
cl_buf = cl_buf if isinstance(cl_buf._buf, CLBuffer) else self.movement_op(MovementOps.RESHAPE, list(self.shape)+[1]).unary_op(UnaryOps.NOOP)
|
||||
assert prod(cl_buf._base_shape) == prod(self.shape)
|
||||
assert prod(cl_buf._base_shape) == prod(self.shape), f"shape product mismatch {cl_buf._base_shape} vs {self.shape}"
|
||||
data = np.empty(self.shape, dtype=np.float32)
|
||||
cl_buf._buf.copyout(data)
|
||||
return data
|
||||
|
|
Loading…
Reference in New Issue