mirror of https://github.com/commaai/tinygrad.git
fix opencl bug, no training on opencl
This commit is contained in:
parent
f93e297804
commit
5e96ed523a
|
@ -66,11 +66,12 @@ class OpenCLBuffer(GPUBuffer):
|
|||
@property
|
||||
def cl(self):
|
||||
if self._buf is None:
|
||||
if self.st.contiguous:
|
||||
if self._backing is not None:
|
||||
self._buf = CLBuffer(4*roundup(prod(self._backing.shape)))
|
||||
CL.enqueue_copy(self._buf.cl, self._backing, is_blocking=False)
|
||||
elif self.st.contiguous:
|
||||
self._buf = CLBuffer(4*roundup(prod(self.shape)))
|
||||
if self._backing is not None:
|
||||
CL.enqueue_copy(self._buf.cl, self._backing, is_blocking=False)
|
||||
#self._backing = None
|
||||
|
||||
if self._image is not None:
|
||||
self._buf = CLBuffer(4*roundup(prod(self._image.shape)*4))
|
||||
if self._backing is not None:
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
import os
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
import tinygrad.optim as optim
|
||||
from extra.training import train, evaluate
|
||||
from extra.utils import get_parameters
|
||||
|
@ -44,6 +44,7 @@ class TinyConvNet:
|
|||
x = x.reshape(shape=[x.shape[0], -1])
|
||||
return x.dot(self.l1).logsoftmax()
|
||||
|
||||
@unittest.skipUnless(getattr(Device, "OPENCL", None) is None or Device.DEFAULT != Device.OPENCL, "OOM on OpenCL")
|
||||
class TestMNIST(unittest.TestCase):
|
||||
def test_sgd_onestep(self):
|
||||
np.random.seed(1337)
|
||||
|
|
Loading…
Reference in New Issue