mirror of https://github.com/commaai/tinygrad.git
Fix cl import in the copy_speed test and cifar example (#2586)
* fix CL import * update test to only run on GPU * update hlb_cifar too
This commit is contained in:
parent
3226b3d96b
commit
ab2d4d8d29
|
@ -430,8 +430,8 @@ if __name__ == "__main__":
|
||||||
from tinygrad.runtime.ops_hip import HIP
|
from tinygrad.runtime.ops_hip import HIP
|
||||||
devices = [f"hip:{i}" for i in range(HIP.device_count)]
|
devices = [f"hip:{i}" for i in range(HIP.device_count)]
|
||||||
else:
|
else:
|
||||||
from tinygrad.runtime.ops_gpu import CL
|
from tinygrad.runtime.ops_gpu import CLDevice
|
||||||
devices = [f"gpu:{i}" for i in range(len(CL.devices))]
|
devices = [f"gpu:{i}" for i in range(len(CLDevice.device_ids))]
|
||||||
world_size = len(devices)
|
world_size = len(devices)
|
||||||
|
|
||||||
# ensure that the batch size is divisible by the number of devices
|
# ensure that the batch size is divisible by the number of devices
|
||||||
|
|
|
@ -51,9 +51,10 @@ class TestCopySpeed(unittest.TestCase):
|
||||||
t.to('cpu').realize()
|
t.to('cpu').realize()
|
||||||
|
|
||||||
@unittest.skipIf(CI, "CI doesn't have 6 GPUs")
|
@unittest.skipIf(CI, "CI doesn't have 6 GPUs")
|
||||||
|
@unittest.skipIf(Device.DEFAULT != "GPU", "only test this on GPU")
|
||||||
def testCopyCPUto6GPUs(self):
|
def testCopyCPUto6GPUs(self):
|
||||||
from tinygrad.runtime.ops_gpu import CL
|
from tinygrad.runtime.ops_gpu import CLDevice
|
||||||
if len(CL.devices) != 6: raise unittest.SkipTest("computer doesn't have 6 GPUs")
|
if len(CLDevice.device_ids) != 6: raise unittest.SkipTest("computer doesn't have 6 GPUs")
|
||||||
t = Tensor.rand(N, N, device="cpu").realize()
|
t = Tensor.rand(N, N, device="cpu").realize()
|
||||||
print(f"buffer: {t.nbytes()*1e-9:.2f} GB")
|
print(f"buffer: {t.nbytes()*1e-9:.2f} GB")
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
|
@ -64,4 +65,4 @@ class TestCopySpeed(unittest.TestCase):
|
||||||
Device["gpu"].synchronize()
|
Device["gpu"].synchronize()
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
Loading…
Reference in New Issue