add openpilot tests to tinygrad

This commit is contained in:
George Hotz 2022-08-21 12:03:37 -07:00
parent 7f15779942
commit a8734df030
6 changed files with 32 additions and 32 deletions

View File

@ -103,7 +103,7 @@ jobs:
run: OPT=1 GPU=1 python -m pytest -s -v
testopencl:
name: OpenCL Tests
name: OpenCL (openpilot) Test
runs-on: ubuntu-20.04
if: ${{ false }}
@ -125,7 +125,10 @@ jobs:
- name: Install Dependencies
run: pip install -e '.[gpu,testing]'
- name: Run Pytest (default)
run: OPENCL=1 python -m pytest -s -v
run: |
UNSAFE_FLOAT4=1 DEBUGCL=1 python3 openpilot/compile.py
FLOAT16=1 UNSAFE_FLOAT4=1 DEBUGCL=1 python3 openpilot/compile.py
python3 openpilot/run_thneed.py /tmp/output.thneed
testmypy:
name: Mypy Tests

View File

@ -1,9 +1,9 @@
//PREFIX
__kernel void image_conv(
write_only image2d_t output,
read_only image2d_t input,
read_only image2d_t weights,
write_only image2d_t output
read_only image2d_t weights
#ifndef NOARGS
,short numPackedInputChannelsForGroup,
short totalNumPackedInputChannels,

View File

@ -1,10 +1,10 @@
//PREFIX
__kernel void matmul(
write_only image2d_t output,
__local float *outputScratch,
read_only image1d_t input,
read_only image2d_t weights,
write_only image2d_t output
read_only image2d_t weights
//ARGS
) {

View File

@ -3,7 +3,7 @@
from __future__ import annotations
import os
from tinygrad.llops.ops_gpu import GPUBuffer, CL, CLProgram, CLBuffer
from tinygrad.ops import ProcessingOps
from tinygrad.ops import ProcessingOps, ReduceOps
from tinygrad.helpers import prod, ConvArgs
from typing import List, Tuple, Optional, Dict, Set
import numpy as np
@ -80,8 +80,8 @@ class OpenCLBuffer(GPUBuffer):
#print(f"converting {self.shape} back to buffer, image shape is {self._image.shape}")
CLProgram("from_image", """
__kernel void from_image(
read_only image2d_t in,
__global float4 *out) {
__global float4 *out,
read_only image2d_t in) {
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
int2 l;
l.y = get_global_id(1);
@ -89,7 +89,7 @@ class OpenCLBuffer(GPUBuffer):
int W = get_image_width(in);
out[l.y*W + l.x] = read_imagef(in, smp, l);
}
""")(self._image.shape, None, self._image, self._buf.cl)
""")(self._image.shape, None, self._buf.cl, self._image)
self._image = None
return self._buf.cl
@ -105,15 +105,15 @@ class OpenCLBuffer(GPUBuffer):
#print(f"converting {self.shape} to image with shape {self._image.shape}")
CLProgram("to_image", """
__kernel void to_image(
__global const float4 *in,
write_only image2d_t out) {
write_only image2d_t out,
__global const float4 *in) {
int2 l;
l.y = get_global_id(1);
l.x = get_global_id(0);
int W = get_image_width(out);
write_imagef(out, l, in[l.y*W + l.x]);
}
""")(self._image.shape, None, self._buf.cl, self._image)
""")(self._image.shape, None, self._image, self._buf.cl)
self._buf = None
return self._image
@ -123,12 +123,11 @@ class OpenCLBuffer(GPUBuffer):
return type(x)(C.out_shape)._processing_op([("input", x.contiguous_op()), ("weight", w.contiguous_op())], "acc", C)
seen = set()
def _processing_op(ret, bufs: List[Tuple[str, OpenCLBuffer]]=[], code:str="acc", C=None, start="0.0", reduce_shape=None, earlybufs:Set[str]=set(), earlycode:str="acc"):
def _processing_op(ret, bufs: List[Tuple[str, OpenCLBuffer]]=[], code:str="acc", C=None, op=ReduceOps.SUM, reduce_shape=None, earlybufs:Set[str]=set(), earlycode:str="acc"):
if C is None or earlycode != "acc":
# TODO: handle an opencl conv without the conv part
return super()._processing_op(bufs, code, C, start, reduce_shape, earlybufs, earlycode)
return super()._processing_op(bufs, code, C, op, reduce_shape, earlybufs, earlycode)
assert earlycode == "acc"
assert start == "0.0"
x = [x for x in bufs if x[0] == "input"][0][1]
w = [x for x in bufs if x[0] == "weight"][0][1]
@ -228,7 +227,7 @@ class OpenCLBuffer(GPUBuffer):
local_work_size = [4, global_work_size[1], lw]
#print(global_work_size, local_work_size)
conv_prg(global_work_size, local_work_size, cl.LocalMemory(4 * local_work_size[0] * local_work_size[1] * lw), x.image, w.image, ret.image, *conv_args, *[buf.image if 'image2d_t' in typ else buf.cl for typ, (_, buf) in zip(ewtypes, ewbufs)])
conv_prg(global_work_size, local_work_size, ret.image, cl.LocalMemory(4 * local_work_size[0] * local_work_size[1] * lw), x.image, w.image, *conv_args, *[buf.image if 'image2d_t' in typ else buf.cl for typ, (_, buf) in zip(ewtypes, ewbufs)])
return ret
# this option is unused
@ -259,7 +258,7 @@ class OpenCLBuffer(GPUBuffer):
argdtypes=tuple([None, None, None] + [np.int16]*len(conv_args) + [None]*len(ewbufs))
)
global_work_size = [C.cout//4, (C.ox+NUM_OUTPUTS-1)//NUM_OUTPUTS, C.bs*C.oy]
conv_prg(global_work_size, None, x.image, w.image, ret.image, *conv_args, *[buf.image if 'image2d_t' in typ else buf.cl for typ, (_, buf) in zip(ewtypes, ewbufs)])
conv_prg(global_work_size, None, ret.image, x.image, w.image, *conv_args, *[buf.image if 'image2d_t' in typ else buf.cl for typ, (_, buf) in zip(ewtypes, ewbufs)])
return ret
GPUBuffer = OpenCLBuffer

View File

@ -178,18 +178,13 @@ def compile(input, output_fn):
saved_binaries = set()
kernels_to_save = set()
kernels_to_not_save = set()
kernels_to_not_save = set(inputs)
import pyopencl as cl
for self, args in local_cl_cache:
for i,a in enumerate(args[2:]):
access_qualifer = self.clprg.get_arg_info(i, cl.kernel_arg_info.ACCESS_QUALIFIER)
type_qualifer = self.clprg.get_arg_info(i, cl.kernel_arg_info.TYPE_QUALIFIER)
type_name = self.clprg.get_arg_info(i, cl.kernel_arg_info.TYPE_NAME)
if cl.kernel_arg_access_qualifier.READ_ONLY == access_qualifer or cl.kernel_arg_type_qualifier.CONST == type_qualifer:
kernels_to_save.add(a)
else:
# this is written to at some point, we don't have to save it
kernels_to_not_save.add(a)
# output is always the first parameter
kernels_to_not_save.add(args[2])
for a in args[3:]:
kernels_to_save.add(a)
kernels_to_save -= kernels_to_not_save
gobj = 0
@ -259,7 +254,7 @@ def compile(input, output_fn):
})
if needs_load:
data = np.empty(size//2, dtype=np.float32)
data = np.empty(size//(2 if FLOAT16 else 4), dtype=np.float32)
CL.enqueue_copy(data, buf.cl, is_blocking=True)
if FLOAT16: data = data.astype(np.float16)
weights.append(data.tobytes())

View File

@ -11,8 +11,8 @@ THNEED_KERNELS = "../../selfdrive/modeld/thneed/kernels/"
def load_thneed_model(fn="model.thneed", float32=False, replace=None):
import pyopencl as cl
platform = [x for x in cl.get_platforms()]
assert len(platform) == 1
ctx = cl.Context(devices=platform[0].get_devices(device_type=cl.device_type.GPU))
assert len(platform) >= 1
ctx = cl.Context(devices=platform[0].get_devices(device_type=cl.device_type.GPU)[0:1])
q = cl.CommandQueue(ctx, properties=cl.command_queue_properties.PROFILING_ENABLE)
mf = cl.mem_flags
image_fmt = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.FLOAT if float32 else cl.channel_type.HALF_FLOAT)
@ -110,7 +110,10 @@ def load_thneed_model(fn="model.thneed", float32=False, replace=None):
k['args_name'] = []
prg = prgs[k['name']]
for i,arg in enumerate(k['args']):
k['args_name'].append(prg.get_arg_info(i, cl.kernel_arg_info.NAME))
try:
k['args_name'].append(prg.get_arg_info(i, cl.kernel_arg_info.NAME))
except cl.RuntimeError:
k['args_name'].append("<UNKNOWN>")
vision = vision[0:1]
vnum = vnum[0] if len(vnum) >= 1 else None