mirror of https://github.com/commaai/tinygrad.git
Remove as_buffer for TORCH (#2554)
* remove as_buffer for torch * enable torch zerocopy if on cpu * remove as_buffer even on torch:cpu
This commit is contained in:
parent
05a5357dd9
commit
077567f62d
|
@ -14,7 +14,7 @@ def time_tensor_numpy(out:Tensor):
|
|||
|
||||
N = 4096
|
||||
class TestZeroCopy(unittest.TestCase):
|
||||
@unittest.skipIf(Device.DEFAULT not in {"CLANG", "LLVM", "CPU", "TORCH", "METAL"}, "device isn't zero copy")
|
||||
@unittest.skipIf(Device.DEFAULT not in {"CLANG", "LLVM", "CPU", "METAL"}, "device isn't zero copy")
|
||||
def test_zero_copy_from_default_to_cpu(self):
|
||||
demo = Tensor.rand(1).realize()
|
||||
t1 = time_tensor_numpy(demo)
|
||||
|
|
|
@ -3,7 +3,7 @@ import numpy as np
|
|||
from typing import Dict, Callable
|
||||
from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, MovementOps, TernaryOps, ReduceOps, Op
|
||||
from tinygrad.device import Interpreted, Allocator
|
||||
from tinygrad.helpers import getenv, dtypes, DType, flat_mv
|
||||
from tinygrad.helpers import getenv, dtypes, DType
|
||||
from tinygrad.runtime.ops_cpu import einsum_mulacc, shape_to_axis
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))
|
||||
|
@ -43,7 +43,6 @@ torch_fxn_for_op: Dict[Op, Callable] = {
|
|||
|
||||
class TorchAllocator(Allocator):
|
||||
def _alloc(self, size:int, dtype:DType): return torch.empty([size], device=device, dtype=inverse_type_map[dtype])
|
||||
def as_buffer(self, src:torch.Tensor) -> memoryview: return flat_mv(np.require(src.numpy(), requirements='C').data)
|
||||
def copyin(self, dest:torch.Tensor, src:memoryview): dest.copy_(torch.frombuffer(src, dtype=dest.dtype))
|
||||
def copyout(self, dest:memoryview, src:torch.Tensor): torch.frombuffer(dest, dtype=src.dtype).copy_(src.flatten())
|
||||
|
||||
|
|
Loading…
Reference in New Issue