add cuda on cpu tests (#1020)

This commit is contained in:
cloud11665 2023-06-22 23:15:50 +02:00 committed by GitHub
parent e09219df0f
commit 2407690d82
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 76 additions and 11 deletions

View File

@ -226,3 +226,42 @@ jobs:
run: docker build -t tinygrad -f test/Dockerfile . run: docker build -t tinygrad -f test/Dockerfile .
- name: Test Docker - name: Test Docker
run: docker run --rm tinygrad /usr/bin/env python3 -c "from tinygrad.tensor import Tensor; print(Tensor.eye(3).numpy())" run: docker run --rm tinygrad /usr/bin/env python3 -c "from tinygrad.tensor import Tensor; print(Tensor.eye(3).numpy())"
testcuda:
name: (emulated) cuda test
runs-on: ubuntu-22.04
steps:
- name: Checkout Code
uses: actions/checkout@v3
- name: Update packages
run: |
export DEBIAN_FRONTEND=noninteractive
sudo apt-get update -y
- name: Install packages
run: |
sudo apt-get install -y git g++ cmake ninja-build llvm-15-dev libz-dev libglew-dev flex bison libfl-dev libboost-thread-dev libboost-filesystem-dev nvidia-cuda-toolkit-gcc
- name: Clone gpuocelot repo
uses: actions/checkout@v3
with:
repository: gpuocelot/gpuocelot
ref: 19626fc00b6ee321638c3111074269c69050e091
path: ${{ github.workspace }}/gpuocelot
submodules: true
- name: Compile gpuocelot
run: |
cd ${{ github.workspace }}/gpuocelot/ocelot
mkdir build
cd build
cmake .. -Wno-dev -G Ninja -DOCELOT_BUILD_TOOLS=OFF
ninja
sudo ninja install
- name: Set up Python 3.8
uses: actions/setup-python@v4
with:
python-version: 3.8
cache: 'pip'
cache-dependency-path: setup.py
- name: Install tinygrad dependencies
run: pip install -e '.[testing, cuda]' --extra-index-url https://download.pytorch.org/whl/cpu
- name: Run pytest
run: FORWARD_ONLY=1 JIT=1 OPT=2 CUDA=1 CUDACPU=1 python -m pytest -s -v -n=auto test --ignore=test/external --ignore=test/models --ignore=test/test_speed_v_torch.py --ignore=test/test_specific_conv.py --ignore=test/test_net_speed.py --ignore=test/test_nn.py -k "not half"

View File

@ -98,6 +98,7 @@ class TestInt8Dtype(unittest.TestCase):
def test_int8_mul_upcast_int64(self): _test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [1,4,9,16]) def test_int8_mul_upcast_int64(self): _test_mul_upcast(Tensor([1,2,3,4], dtype=dtypes.int8), Tensor([1,2,3,4], dtype=dtypes.int64), dtypes.int64, [1,4,9,16])
def test_int8_matmul_upcast_int64(self): _test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.int64), dtypes.int64, [[1,2],[3,4]]) def test_int8_matmul_upcast_int64(self): _test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.int64), dtypes.int64, [[1,2],[3,4]])
@unittest.skipIf(getenv("CUDA",0)==1, "cuda saturation works differently")
def test_int8_to_uint8_negative(self): _test_op(lambda: Tensor([-1, -2, -3, -4], dtype=dtypes.int8).cast(dtypes.uint8), dtypes.uint8, [255, 254, 253, 252]) def test_int8_to_uint8_negative(self): _test_op(lambda: Tensor([-1, -2, -3, -4], dtype=dtypes.int8).cast(dtypes.uint8), dtypes.uint8, [255, 254, 253, 252])
def test_uint8_to_int8_overflow(self): _test_op(lambda: Tensor([255, 254, 253, 252], dtype=dtypes.uint8).cast(dtypes.int8), dtypes.int8, [-1, -2, -3, -4]) def test_uint8_to_int8_overflow(self): _test_op(lambda: Tensor([255, 254, 253, 252], dtype=dtypes.uint8).cast(dtypes.int8), dtypes.int8, [-1, -2, -3, -4])

View File

@ -1,18 +1,45 @@
import subprocess import subprocess
from typing import Optional from typing import Optional
import time
import numpy as np import numpy as np
import pycuda.autoprimaryctx # type: ignore # pylint: disable=unused-import # noqa: F401
import pycuda.driver as cuda # type: ignore
from pycuda.compiler import compile as cuda_compile # type: ignore from pycuda.compiler import compile as cuda_compile # type: ignore
from tinygrad.helpers import DEBUG, getenv, fromimport from tinygrad.helpers import DEBUG, getenv, fromimport
from tinygrad.ops import Compiled from tinygrad.ops import Compiled
from tinygrad.runtime.lib import RawBufferCopyInOut from tinygrad.runtime.lib import RawBufferCopyInOut, RawMallocBuffer
from tinygrad.codegen.cstyle import CStyleCodegen, CStyleLanguage from tinygrad.codegen.cstyle import CStyleCodegen, CStyleLanguage
class RawCUDABuffer(RawBufferCopyInOut): if getenv("CUDACPU", 0) == 1:
def __init__(self, size, dtype): super().__init__(size, dtype, cuda.mem_alloc(size * dtype.itemsize)) import ctypes, ctypes.util
def _copyin(self, x:np.ndarray, stream:Optional[cuda.Stream]=None): cuda.memcpy_htod_async(self._buf, x, stream) lib = ctypes.CDLL(ctypes.util.find_library("gpuocelot"))
def _copyout(self, x:np.ndarray): cuda.memcpy_dtoh(x, self._buf) lib.ptx_run.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int]
class cuda:
class module:
def __init__(self, src): self.src = src
def get_function(self, _): return self
def __call__(self, *args, block, grid): lib.ptx_run(self.src, len(args), (ctypes.c_void_p * len(args))(*[ctypes.cast(x, ctypes.c_void_p) for x in args]), *block, *grid)
module_from_buffer = lambda src: cuda.module(src) # pylint: disable=unnecessary-lambda # noqa: E731
class Event:
def __init__(self): pass
def record(self): self.start = time.perf_counter()
def time_till(self, other): return self.start - other.start
def synchronize(self): pass
class Context:
synchronize = lambda:0 # noqa: E731
CompileError = Exception
class context:
class device:
compute_capability = lambda: (3,5) # pylint: disable=unnecessary-lambda # noqa: E731
get_device = lambda: context.device # pylint: disable=unnecessary-lambda # noqa: E731
import pycuda.driver # type: ignore
pycuda.driver.Context = context
RawCUDABuffer = RawMallocBuffer
else:
import pycuda.autoprimaryctx # type: ignore # pylint: disable=unused-import # noqa: F401
import pycuda.driver as cuda # type: ignore
class RawCUDABuffer(RawBufferCopyInOut): # type: ignore
def __init__(self, size, dtype): super().__init__(size, dtype, cuda.mem_alloc(size * dtype.itemsize)) # type: ignore
def _copyin(self, x:np.ndarray, stream:Optional[cuda.Stream]=None): cuda.memcpy_htod_async(self._buf, x, stream) # type: ignore
def _copyout(self, x:np.ndarray): cuda.memcpy_dtoh(x, self._buf) # type: ignore
class CUDAProgram: class CUDAProgram:
def __init__(self, name:str, prg:str, binary=False): def __init__(self, name:str, prg:str, binary=False):
@ -22,7 +49,7 @@ class CUDAProgram:
f.write(cuda_compile(prg, target="cubin", no_extern_c=True)) f.write(cuda_compile(prg, target="cubin", no_extern_c=True))
sass = subprocess.check_output(['nvdisasm', '/tmp/cubin']).decode('utf-8') sass = subprocess.check_output(['nvdisasm', '/tmp/cubin']).decode('utf-8')
print(sass) print(sass)
if not binary: prg = cuda_compile(prg, target="ptx", no_extern_c=True).decode('utf-8') if not binary: prg = cuda_compile(prg, target="ptx", no_extern_c=True, options=['-Wno-deprecated-gpu-targets']).decode('utf-8')
except cuda.CompileError as e: except cuda.CompileError as e:
if DEBUG >= 3: print("FAILED TO BUILD", prg) if DEBUG >= 3: print("FAILED TO BUILD", prg)
raise e raise e
@ -42,7 +69,7 @@ class CUDAProgram:
class CUDACodegen(CStyleCodegen): class CUDACodegen(CStyleCodegen):
lang = CStyleLanguage( lang = CStyleLanguage(
kernel_prefix = "__global__", smem_prefix = "__shared__ ", barrier = "__syncthreads();", float4 = "make_float4", kernel_prefix = "typedef unsigned char uchar;\ntypedef unsigned int uint;\ntypedef unsigned long ulong;\n__global__", smem_prefix = "__shared__ ", barrier = "__syncthreads();", float4 = "make_float4",
gid = [f'blockIdx.{chr(120+i)}' for i in range(3)], gid = [f'blockIdx.{chr(120+i)}' for i in range(3)],
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)], lid = [f'threadIdx.{chr(120+i)}' for i in range(3)],
half_prekernel = """ half_prekernel = """
@ -51,8 +78,6 @@ class CUDACodegen(CStyleCodegen):
half2 x, y; half2 x, y;
__device__ __forceinline__ explicit operator float4() const {return make_float4(__half2float(x.x), __half2float(x.y), __half2float(y.x), __half2float(y.y)); } __device__ __forceinline__ explicit operator float4() const {return make_float4(__half2float(x.x), __half2float(x.y), __half2float(y.x), __half2float(y.y)); }
}; };
typedef unsigned char uchar;
typedef long long int64;
""") """)
supports_float4_alu = False supports_float4_alu = False