mirror of https://github.com/commaai/tinygrad.git
Fix cl import in the copy_speed test and cifar example (#2586)
* fix CL import * update test to only run on GPU * update hlb_cifar too
This commit is contained in:
parent
3226b3d96b
commit
ab2d4d8d29
|
@ -430,8 +430,8 @@ if __name__ == "__main__":
|
|||
from tinygrad.runtime.ops_hip import HIP
|
||||
devices = [f"hip:{i}" for i in range(HIP.device_count)]
|
||||
else:
|
||||
from tinygrad.runtime.ops_gpu import CL
|
||||
devices = [f"gpu:{i}" for i in range(len(CL.devices))]
|
||||
from tinygrad.runtime.ops_gpu import CLDevice
|
||||
devices = [f"gpu:{i}" for i in range(len(CLDevice.device_ids))]
|
||||
world_size = len(devices)
|
||||
|
||||
# ensure that the batch size is divisible by the number of devices
|
||||
|
|
|
@ -51,9 +51,10 @@ class TestCopySpeed(unittest.TestCase):
|
|||
t.to('cpu').realize()
|
||||
|
||||
@unittest.skipIf(CI, "CI doesn't have 6 GPUs")
|
||||
@unittest.skipIf(Device.DEFAULT != "GPU", "only test this on GPU")
|
||||
def testCopyCPUto6GPUs(self):
|
||||
from tinygrad.runtime.ops_gpu import CL
|
||||
if len(CL.devices) != 6: raise unittest.SkipTest("computer doesn't have 6 GPUs")
|
||||
from tinygrad.runtime.ops_gpu import CLDevice
|
||||
if len(CLDevice.device_ids) != 6: raise unittest.SkipTest("computer doesn't have 6 GPUs")
|
||||
t = Tensor.rand(N, N, device="cpu").realize()
|
||||
print(f"buffer: {t.nbytes()*1e-9:.2f} GB")
|
||||
for _ in range(3):
|
||||
|
@ -64,4 +65,4 @@ class TestCopySpeed(unittest.TestCase):
|
|||
Device["gpu"].synchronize()
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
|
Loading…
Reference in New Issue