add cuda on cpu tests (#1020)

This commit is contained in:
cloud11665 2023-06-22 23:15:50 +02:00 committed by GitHub
parent e09219df0f
commit 2407690d82
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 76 additions and 11 deletions

View File

@ -226,3 +226,42 @@ jobs:
run: docker build -t tinygrad -f test/Dockerfile .
- name: Test Docker
run: docker run --rm tinygrad /usr/bin/env python3 -c "from tinygrad.tensor import Tensor; print(Tensor.eye(3).numpy())"
testcuda:
name: (emulated) cuda test
runs-on: ubuntu-22.04
steps:
- name: Checkout Code
uses: actions/checkout@v3
- name: Update packages
run: |
export DEBIAN_FRONTEND=noninteractive
sudo apt-get update -y
- name: Install packages
run: |
sudo apt-get install -y git g++ cmake ninja-build llvm-15-dev libz-dev libglew-dev flex bison libfl-dev libboost-thread-dev libboost-filesystem-dev nvidia-cuda-toolkit-gcc
- name: Clone gpuocelot repo
uses: actions/checkout@v3
with:
repository: gpuocelot/gpuocelot
ref: 19626fc00b6ee321638c3111074269c69050e091
path: ${{ github.workspace }}/gpuocelot
submodules: true
- name: Compile gpuocelot
run: |
cd ${{ github.workspace }}/gpuocelot/ocelot
mkdir build
cd build
cmake .. -Wno-dev -G Ninja -DOCELOT_BUILD_TOOLS=OFF
ninja
sudo ninja install
- name: Set up Python 3.8
uses: actions/setup-python@v4
with:
python-version: 3.8
cache: 'pip'
cache-dependency-path: setup.py
- name: Install tinygrad dependencies
run: pip install -e '.[testing, cuda]' --extra-index-url https://download.pytorch.org/whl/cpu
- name: Run pytest
run: FORWARD_ONLY=1 JIT=1 OPT=2 CUDA=1 CUDACPU=1 python -m pytest -s -v -n=auto test --ignore=test/external --ignore=test/models --ignore=test/test_speed_v_torch.py --ignore=test/test_specific_conv.py --ignore=test/test_net_speed.py --ignore=test/test_nn.py -k "not half"

View File

@ -98,6 +98,7 @@ class TestInt8Dtype(unittest.TestCase):
def test_int8_mul_upcast_int64(self): _test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [1,4,9,16])
def test_int8_matmul_upcast_int64(self): _test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.int64), dtypes.int64, [[1,2],[3,4]])
@unittest.skipIf(getenv("CUDA",0)==1, "cuda saturation works differently")
def test_int8_to_uint8_negative(self): _test_op(lambda: Tensor([-1, -2, -3, -4], dtype=dtypes.int8).cast(dtypes.uint8), dtypes.uint8, [255, 254, 253, 252])
def test_uint8_to_int8_overflow(self): _test_op(lambda: Tensor([255, 254, 253, 252], dtype=dtypes.uint8).cast(dtypes.int8), dtypes.int8, [-1, -2, -3, -4])

View File

@ -1,18 +1,45 @@
import subprocess
from typing import Optional
import time
import numpy as np
import pycuda.autoprimaryctx # type: ignore # pylint: disable=unused-import # noqa: F401
import pycuda.driver as cuda # type: ignore
from pycuda.compiler import compile as cuda_compile # type: ignore
from tinygrad.helpers import DEBUG, getenv, fromimport
from tinygrad.ops import Compiled
from tinygrad.runtime.lib import RawBufferCopyInOut
from tinygrad.runtime.lib import RawBufferCopyInOut, RawMallocBuffer
from tinygrad.codegen.cstyle import CStyleCodegen, CStyleLanguage
class RawCUDABuffer(RawBufferCopyInOut):
def __init__(self, size, dtype): super().__init__(size, dtype, cuda.mem_alloc(size * dtype.itemsize))
def _copyin(self, x:np.ndarray, stream:Optional[cuda.Stream]=None): cuda.memcpy_htod_async(self._buf, x, stream)
def _copyout(self, x:np.ndarray): cuda.memcpy_dtoh(x, self._buf)
if getenv("CUDACPU", 0) == 1:
import ctypes, ctypes.util
lib = ctypes.CDLL(ctypes.util.find_library("gpuocelot"))
lib.ptx_run.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int]
class cuda:
class module:
def __init__(self, src): self.src = src
def get_function(self, _): return self
def __call__(self, *args, block, grid): lib.ptx_run(self.src, len(args), (ctypes.c_void_p * len(args))(*[ctypes.cast(x, ctypes.c_void_p) for x in args]), *block, *grid)
module_from_buffer = lambda src: cuda.module(src) # pylint: disable=unnecessary-lambda # noqa: E731
class Event:
def __init__(self): pass
def record(self): self.start = time.perf_counter()
def time_till(self, other): return self.start - other.start
def synchronize(self): pass
class Context:
synchronize = lambda:0 # noqa: E731
CompileError = Exception
class context:
class device:
compute_capability = lambda: (3,5) # pylint: disable=unnecessary-lambda # noqa: E731
get_device = lambda: context.device # pylint: disable=unnecessary-lambda # noqa: E731
import pycuda.driver # type: ignore
pycuda.driver.Context = context
RawCUDABuffer = RawMallocBuffer
else:
import pycuda.autoprimaryctx # type: ignore # pylint: disable=unused-import # noqa: F401
import pycuda.driver as cuda # type: ignore
class RawCUDABuffer(RawBufferCopyInOut): # type: ignore
def __init__(self, size, dtype): super().__init__(size, dtype, cuda.mem_alloc(size * dtype.itemsize)) # type: ignore
def _copyin(self, x:np.ndarray, stream:Optional[cuda.Stream]=None): cuda.memcpy_htod_async(self._buf, x, stream) # type: ignore
def _copyout(self, x:np.ndarray): cuda.memcpy_dtoh(x, self._buf) # type: ignore
class CUDAProgram:
def __init__(self, name:str, prg:str, binary=False):
@ -22,7 +49,7 @@ class CUDAProgram:
f.write(cuda_compile(prg, target="cubin", no_extern_c=True))
sass = subprocess.check_output(['nvdisasm', '/tmp/cubin']).decode('utf-8')
print(sass)
if not binary: prg = cuda_compile(prg, target="ptx", no_extern_c=True).decode('utf-8')
if not binary: prg = cuda_compile(prg, target="ptx", no_extern_c=True, options=['-Wno-deprecated-gpu-targets']).decode('utf-8')
except cuda.CompileError as e:
if DEBUG >= 3: print("FAILED TO BUILD", prg)
raise e
@ -42,7 +69,7 @@ class CUDAProgram:
class CUDACodegen(CStyleCodegen):
lang = CStyleLanguage(
kernel_prefix = "__global__", smem_prefix = "__shared__ ", barrier = "__syncthreads();", float4 = "make_float4",
kernel_prefix = "typedef unsigned char uchar;\ntypedef unsigned int uint;\ntypedef unsigned long ulong;\n__global__", smem_prefix = "__shared__ ", barrier = "__syncthreads();", float4 = "make_float4",
gid = [f'blockIdx.{chr(120+i)}' for i in range(3)],
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)],
half_prekernel = """
@ -51,8 +78,6 @@ class CUDACodegen(CStyleCodegen):
half2 x, y;
__device__ __forceinline__ explicit operator float4() const {return make_float4(__half2float(x.x), __half2float(x.y), __half2float(y.x), __half2float(y.y)); }
};
typedef unsigned char uchar;
typedef long long int64;
""")
supports_float4_alu = False