mirror of https://github.com/commaai/tinygrad.git
fix bug in gpu copy out
This commit is contained in:
parent
e87410c531
commit
c1a769b68b
|
@ -4,7 +4,7 @@ import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
|
|
||||||
BASEDIR = "/home/batman/imagenet"
|
BASEDIR = "/Users/kafka/fun/imagenet"
|
||||||
train_files = open(os.path.join(BASEDIR, "train_files")).read().strip().split("\n")
|
train_files = open(os.path.join(BASEDIR, "train_files")).read().strip().split("\n")
|
||||||
val_files = open(os.path.join(BASEDIR, "val_files")).read().strip().split("\n")
|
val_files = open(os.path.join(BASEDIR, "val_files")).read().strip().split("\n")
|
||||||
ci = json.load(open(os.path.join(BASEDIR, "imagenet_class_index.json")))
|
ci = json.load(open(os.path.join(BASEDIR, "imagenet_class_index.json")))
|
||||||
|
|
|
@ -349,9 +349,9 @@ class GPUBuffer(ExplicitExecAST):
|
||||||
def fromCPU(x): return GPUBuffer(x.shape, backing=x.view(np.ndarray).astype(np.float32).ravel())
|
def fromCPU(x): return GPUBuffer(x.shape, backing=x.view(np.ndarray).astype(np.float32).ravel())
|
||||||
|
|
||||||
def toCPU(self) -> np.ndarray:
|
def toCPU(self) -> np.ndarray:
|
||||||
cl_buf = self.contiguous()
|
cl_buf = self.unary_op(UnaryOps.NOOP) if not self.st.contiguous or prod(self._base_shape) != prod(self.shape) else self
|
||||||
cl_buf = cl_buf if isinstance(cl_buf._buf, CLBuffer) else self.movement_op(MovementOps.RESHAPE, list(self.shape)+[1]).unary_op(UnaryOps.NOOP)
|
cl_buf = cl_buf if isinstance(cl_buf._buf, CLBuffer) else self.movement_op(MovementOps.RESHAPE, list(self.shape)+[1]).unary_op(UnaryOps.NOOP)
|
||||||
assert prod(cl_buf._base_shape) == prod(self.shape)
|
assert prod(cl_buf._base_shape) == prod(self.shape), f"shape product mismatch {cl_buf._base_shape} vs {self.shape}"
|
||||||
data = np.empty(self.shape, dtype=np.float32)
|
data = np.empty(self.shape, dtype=np.float32)
|
||||||
cl_buf._buf.copyout(data)
|
cl_buf._buf.copyout(data)
|
||||||
return data
|
return data
|
||||||
|
|
Loading…
Reference in New Issue