From 077567f62d1e7b9a54484dfb28a022777f3ce9ed Mon Sep 17 00:00:00 2001 From: Christopher Mauri Milan Date: Fri, 1 Dec 2023 18:51:38 -0800 Subject: [PATCH] Remove as_buffer for TORCH (#2554) * remove as_buffer for torch * enable torch zerocopy if on cpu * remove as_buffer even on torch:cpu --- test/test_zero_copy.py | 2 +- tinygrad/runtime/ops_torch.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/test/test_zero_copy.py b/test/test_zero_copy.py index 7b7cd2f6..265e9d81 100644 --- a/test/test_zero_copy.py +++ b/test/test_zero_copy.py @@ -14,7 +14,7 @@ def time_tensor_numpy(out:Tensor): N = 4096 class TestZeroCopy(unittest.TestCase): - @unittest.skipIf(Device.DEFAULT not in {"CLANG", "LLVM", "CPU", "TORCH", "METAL"}, "device isn't zero copy") + @unittest.skipIf(Device.DEFAULT not in {"CLANG", "LLVM", "CPU", "METAL"}, "device isn't zero copy") def test_zero_copy_from_default_to_cpu(self): demo = Tensor.rand(1).realize() t1 = time_tensor_numpy(demo) diff --git a/tinygrad/runtime/ops_torch.py b/tinygrad/runtime/ops_torch.py index 21f43612..75b64aa8 100644 --- a/tinygrad/runtime/ops_torch.py +++ b/tinygrad/runtime/ops_torch.py @@ -3,7 +3,7 @@ import numpy as np from typing import Dict, Callable from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, MovementOps, TernaryOps, ReduceOps, Op from tinygrad.device import Interpreted, Allocator -from tinygrad.helpers import getenv, dtypes, DType, flat_mv +from tinygrad.helpers import getenv, dtypes, DType from tinygrad.runtime.ops_cpu import einsum_mulacc, shape_to_axis device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu")) @@ -43,7 +43,6 @@ torch_fxn_for_op: Dict[Op, Callable] = { class TorchAllocator(Allocator): def _alloc(self, size:int, dtype:DType): return torch.empty([size], device=device, dtype=inverse_type_map[dtype]) - def as_buffer(self, src:torch.Tensor) -> memoryview: return flat_mv(np.require(src.numpy(), requirements='C').data) def copyin(self, dest:torch.Tensor, src:memoryview): dest.copy_(torch.frombuffer(src, dtype=dest.dtype)) def copyout(self, dest:memoryview, src:torch.Tensor): torch.frombuffer(dest, dtype=src.dtype).copy_(src.flatten())