This commit is contained in:
uuuvn 2024-04-07 15:21:19 +03:00 committed by GitHub
parent bdbcac67f1
commit bb7567b365
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 3 additions and 1 deletions

View File

@ -40,6 +40,7 @@ class TestMultiTensor(unittest.TestCase):
assert lb.shape == (128,)
(X + X).realize()
@unittest.skipIf(Device.DEFAULT == "METAL", "metal multi-device is fake")
def test_sharded_memory(self):
mem_base = GlobalCounters.mem_used

View File

@ -71,7 +71,8 @@ class MetalAllocator(LRUAllocator):
ret = self.device.device.newBufferWithLength_options_(size, Metal.MTLResourceStorageModeShared)
if ret is None: raise MemoryError(f"Metal OOM while allocating {size=}")
return ret
def transfer(self, dest:Any, src:Any, sz:int, **kwargs):
def transfer(self, dest:Any, src:Any, sz:int, src_dev: MetalDevice, **kwargs):
src_dev.synchronize()
command_buffer = self.device.mtl_queue.commandBuffer()
encoder = command_buffer.blitCommandEncoder()
encoder.copyFromBuffer_sourceOffset_toBuffer_destinationOffset_size_(src, 0, dest, 0, sz)