mirror of https://github.com/commaai/tinygrad.git
delete opencl <celebration>
This commit is contained in:
parent
e313c8af20
commit
6d7658db12
|
@ -174,9 +174,7 @@ PROPROTIP: Set "DEBUG=1" environment variable if you want to see why it's slow.
|
|||
|
||||
You might need to download the [weight](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt) of Stable Diffusion and put it into weights/
|
||||
|
||||
Run `TORCH=1 python3 examples/stable_diffusion.py`
|
||||
|
||||
(or without torch: `OPT=2 OPENCL=1 python3 examples/stable_diffusion.py`)
|
||||
Run `GPU=1 python3 examples/stable_diffusion.py`
|
||||
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/geohot/tinygrad/master/docs/stable_diffusion_by_tinygrad.jpg">
|
||||
|
|
|
@ -1,154 +0,0 @@
|
|||
//PREFIX
|
||||
|
||||
__kernel void image_conv(
|
||||
write_only image2d_t output,
|
||||
read_only image2d_t input,
|
||||
read_only image2d_t weights
|
||||
#ifndef NOARGS
|
||||
,short numPackedInputChannelsForGroup,
|
||||
short totalNumPackedInputChannels,
|
||||
short numPackedOutputChannelsForGroup,
|
||||
short totalNumPackedOutputChannels,
|
||||
short numOutputColumns,
|
||||
short numOutputRows, short numInputRows
|
||||
#endif
|
||||
/*short filterSizeX, short filterSizeY,
|
||||
short paddingX, short paddingY,
|
||||
short strideX, short strideY,
|
||||
short dilationX, short dilationY*/
|
||||
//ARGS
|
||||
) {
|
||||
|
||||
//SHORTS
|
||||
|
||||
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
|
||||
|
||||
float4 outputValues[NUM_OUTPUTS];
|
||||
for (short i = 0; i < NUM_OUTPUTS; ++i) {
|
||||
outputValues[i] = (float4)(0, 0, 0, 0);
|
||||
}
|
||||
|
||||
short packedOutputChannel = get_global_id(0);
|
||||
int2 weightLocation;
|
||||
weightLocation.x = 0;
|
||||
weightLocation.y = packedOutputChannel;
|
||||
|
||||
short groupNum = (packedOutputChannel / numPackedOutputChannelsForGroup);
|
||||
short startPackedInputChannel = mul24(groupNum, numPackedInputChannelsForGroup);
|
||||
short startOutputColumn = mul24((short)get_global_id(1), NUM_OUTPUTS);
|
||||
short startX = mad24(mad24(startOutputColumn, strideX, -paddingX), totalNumPackedInputChannels, startPackedInputChannel);
|
||||
short strideWithChannels = mul24(strideX, totalNumPackedInputChannels);
|
||||
|
||||
int outputRow = get_global_id(2);
|
||||
int2 inputLocation;
|
||||
|
||||
#ifdef BATCH
|
||||
// TODO: this doesn't work with y padding
|
||||
inputLocation.y = mad24(outputRow % numOutputRows, strideY, -paddingY);
|
||||
int batchOffset = (outputRow / numOutputRows) * numInputRows;
|
||||
inputLocation.y += batchOffset;
|
||||
#else
|
||||
inputLocation.y = mad24(outputRow, strideY, -paddingY);
|
||||
#endif
|
||||
|
||||
#ifdef DEPTHWISE_UNSTRIDED
|
||||
for (short rfRow = 0; rfRow < filterSizeY; ++rfRow) {
|
||||
float4 inputValues[4];
|
||||
inputLocation.x = startX;
|
||||
for (short i = 1; i < 4; ++i) {
|
||||
inputValues[i] = read_imagef(input, smp, INPUT_LOCATION);
|
||||
inputLocation.x += totalNumPackedOutputChannels;
|
||||
}
|
||||
for (short rfColumn = 0; rfColumn < filterSizeX; ++rfColumn) {
|
||||
inputValues[0] = inputValues[1];
|
||||
inputValues[1] = inputValues[2];
|
||||
inputValues[2] = inputValues[3];
|
||||
inputValues[3] = read_imagef(input, smp, INPUT_LOCATION);
|
||||
inputLocation.x += totalNumPackedInputChannels;
|
||||
float4 weightValues = read_imagef(weights, smp, WEIGHT_LOCATION);
|
||||
++weightLocation.x;
|
||||
outputValues[0] += inputValues[0] * weightValues;
|
||||
outputValues[1] += inputValues[1] * weightValues;
|
||||
outputValues[2] += inputValues[2] * weightValues;
|
||||
outputValues[3] += inputValues[3] * weightValues;
|
||||
}
|
||||
++inputLocation.y;
|
||||
}
|
||||
#else
|
||||
|
||||
for (short rfRow = 0; rfRow < filterSizeY; ++rfRow) {
|
||||
// numPackedInputChannelsForGroup is 1 in depthwise
|
||||
for (short packedInputChannel = 0; packedInputChannel < numPackedInputChannelsForGroup; ++packedInputChannel) {
|
||||
short startXForChannel = startX + packedInputChannel;
|
||||
for (short rfColumn = 0; rfColumn < filterSizeX; ++rfColumn) {
|
||||
|
||||
short dilatedStepX = mul24(totalNumPackedInputChannels, dilationX);
|
||||
inputLocation.x = mad24(rfColumn, dilatedStepX, startXForChannel);
|
||||
float4 inputValues[NUM_OUTPUTS];
|
||||
for (short i = 0; i < NUM_OUTPUTS; ++i) {
|
||||
inputValues[i] = read_imagef(input, smp, INPUT_LOCATION);
|
||||
inputLocation.x += strideWithChannels;
|
||||
}
|
||||
|
||||
#ifdef DEPTHWISE
|
||||
float4 weightValues = read_imagef(weights, smp, WEIGHT_LOCATION);
|
||||
++weightLocation.x;
|
||||
for (short i = 0; i < NUM_OUTPUTS; ++i) {
|
||||
outputValues[i] += inputValues[i] * weightValues;
|
||||
}
|
||||
#else
|
||||
float4 weightValues[4];
|
||||
for (short outChIdx = 0; outChIdx < 4; ++outChIdx) {
|
||||
weightValues[outChIdx] = read_imagef(weights, smp, WEIGHT_LOCATION);
|
||||
++weightLocation.x;
|
||||
}
|
||||
|
||||
for (short i = 0; i < NUM_OUTPUTS; ++i) {
|
||||
// this is marginally faster than using dot
|
||||
float4 curOutputValues = outputValues[i];
|
||||
curOutputValues.x += inputValues[i].x * weightValues[0].x;
|
||||
curOutputValues.x += inputValues[i].y * weightValues[0].y;
|
||||
curOutputValues.x += inputValues[i].z * weightValues[0].z;
|
||||
curOutputValues.x += inputValues[i].w * weightValues[0].w;
|
||||
curOutputValues.y += inputValues[i].x * weightValues[1].x;
|
||||
curOutputValues.y += inputValues[i].y * weightValues[1].y;
|
||||
curOutputValues.y += inputValues[i].z * weightValues[1].z;
|
||||
curOutputValues.y += inputValues[i].w * weightValues[1].w;
|
||||
curOutputValues.z += inputValues[i].x * weightValues[2].x;
|
||||
curOutputValues.z += inputValues[i].y * weightValues[2].y;
|
||||
curOutputValues.z += inputValues[i].z * weightValues[2].z;
|
||||
curOutputValues.z += inputValues[i].w * weightValues[2].w;
|
||||
curOutputValues.w += inputValues[i].x * weightValues[3].x;
|
||||
curOutputValues.w += inputValues[i].y * weightValues[3].y;
|
||||
curOutputValues.w += inputValues[i].z * weightValues[3].z;
|
||||
curOutputValues.w += inputValues[i].w * weightValues[3].w;
|
||||
outputValues[i] = curOutputValues;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
inputLocation.y += dilationY;
|
||||
}
|
||||
#endif
|
||||
|
||||
int2 outputLocation;
|
||||
outputLocation.y = outputRow;
|
||||
|
||||
// do binops
|
||||
short outputColumn = startOutputColumn;
|
||||
for (short i = 0; i < NUM_OUTPUTS; ++i) {
|
||||
outputLocation.x = mad24(outputColumn, totalNumPackedOutputChannels, packedOutputChannel);
|
||||
//BINOP
|
||||
++outputColumn;
|
||||
}
|
||||
|
||||
// output to memory
|
||||
outputColumn = startOutputColumn;
|
||||
for (short i = 0; i < NUM_OUTPUTS; ++i) {
|
||||
outputLocation.x = mad24(outputColumn, totalNumPackedOutputChannels, packedOutputChannel);
|
||||
if (outputColumn < numOutputColumns) {
|
||||
write_imagef(output, OUTPUT_LOCATION, outputValues[i]);
|
||||
}
|
||||
++outputColumn;
|
||||
}
|
||||
}
|
|
@ -1,49 +0,0 @@
|
|||
//PREFIX
|
||||
|
||||
__kernel void matmul(
|
||||
write_only image2d_t output,
|
||||
__local float *outputScratch,
|
||||
read_only image2d_t input,
|
||||
read_only image2d_t weights
|
||||
//ARGS
|
||||
) {
|
||||
|
||||
//SHORTS
|
||||
|
||||
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
|
||||
short packedOutputChannel = get_global_id(2);
|
||||
short scratchOffset = mad24((short)get_local_id(1), 4, (short)get_local_id(0));
|
||||
short weightIndex = (short)get_global_id(0);
|
||||
|
||||
// fast path precompute (32x speedup)
|
||||
float outputValue = 0.0f;
|
||||
for (short inputSet = (short)get_global_id(1); inputSet < numPackedInputChannelsForGroup; inputSet += get_global_size(1)) {
|
||||
int2 inputLocation = (int2)(inputSet, 0);
|
||||
float4 inputValues = read_imagef(input, smp, INPUT_LOCATION);
|
||||
int2 weightLocation = (int2)(mad24(inputSet, 4, weightIndex), packedOutputChannel);
|
||||
float4 weightValues = read_imagef(weights, smp, WEIGHT_LOCATION);
|
||||
outputValue += dot(inputValues, weightValues);
|
||||
}
|
||||
|
||||
short scratchIndex = mad24((short)get_local_id(2), mul24((short)get_local_size(1), 4), scratchOffset);
|
||||
outputScratch[scratchIndex] = outputValue;
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
if (scratchOffset == 0) {
|
||||
float4 outputValues = (float4)(0, 0, 0, 0);
|
||||
|
||||
// fast path
|
||||
for (short i = 0; i < (short)get_global_size(1); ++i) {
|
||||
outputValues += vload4(0, &outputScratch[scratchIndex]);
|
||||
scratchIndex += 4;
|
||||
}
|
||||
|
||||
// insert unary and binary ops here
|
||||
int2 outputLocation = (int2)(packedOutputChannel, 0);
|
||||
//BINOP
|
||||
|
||||
// output to memory
|
||||
write_imagef(output, OUTPUT_LOCATION, outputValues);
|
||||
}
|
||||
}
|
|
@ -1,431 +0,0 @@
|
|||
# type: ignore
|
||||
|
||||
from __future__ import annotations
|
||||
import os
|
||||
from tinygrad.llops.ops_gpu import GPUBuffer, CL, CLProgram, CLBuffer
|
||||
from tinygrad.ops import ProcessingOps, ReduceOps, UnaryOps, BinaryOps, MovementOps, get_buffers, get_lazyops, get_lazyop_info, LazyOp, Op
|
||||
from tinygrad.helpers import prod, ConvArgs, dedup
|
||||
from typing import List, Tuple, Optional, Dict, Set, Union
|
||||
import numpy as np
|
||||
import pyopencl as cl
|
||||
|
||||
UNSAFE_FLOAT4 = int(os.getenv("UNSAFE_FLOAT4", 0))
|
||||
NATIVE_EXPLOG = int(os.getenv("NATIVE_EXPLOG", 0)) # this is needed as a switch for the tests to pass
|
||||
FLOAT16 = int(os.getenv("FLOAT16", 0))
|
||||
|
||||
import pathlib
|
||||
def load(x):
|
||||
with open(x) as f:
|
||||
ret = f.read()
|
||||
return ret
|
||||
CONV_SRC = load(pathlib.Path(__file__).resolve().parent.parent.parent / 'accel/opencl/conv.cl')
|
||||
MATMUL_SRC = load(pathlib.Path(__file__).resolve().parent.parent.parent / 'accel/opencl/matmul.cl')
|
||||
|
||||
class CLImage:
|
||||
fmt = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.HALF_FLOAT if FLOAT16 else cl.channel_type.FLOAT)
|
||||
|
||||
def __init__(self, shape):
|
||||
self.max_hw = min(CL().cl_ctx.devices[0].image2d_max_width, CL.cl_ctx.devices[0].image2d_max_height)
|
||||
self.shape = shape
|
||||
self.n_tile = int(np.ceil(max(shape) / self.max_hw).item())
|
||||
# if n_tile > 1, we can't fit the image into a CL image at native size,
|
||||
# and need to internally store it as a set of disjoint tiles
|
||||
if self.n_tile * min(shape) > self.max_hw:
|
||||
raise Exception(f"shape {shape} exceeds Metal image limits, even after tiling")
|
||||
if shape[0] >= shape[1]:
|
||||
# wider than it is tall; extra tiles overflow on y
|
||||
self.tile_axis, tiled_width, tiled_height = 1, min(shape[0], self.max_hw), self.n_tile * shape[1]
|
||||
else:
|
||||
# taller than it is wide; extra tiles overflow on x
|
||||
self.tile_axis, tiled_width, tiled_height = 0, self.n_tile * shape[0], min(shape[1], self.max_hw)
|
||||
self.cl = cl.Image(CL.cl_ctx, cl.mem_flags.READ_WRITE, CLImage.fmt, shape=(tiled_width, tiled_height))
|
||||
CL.mem_used += self.cl.row_pitch * self.cl.height
|
||||
|
||||
def pos_to_sample_pos(self, l="l", check_bounds=True):
|
||||
if self.n_tile == 1:
|
||||
# happy path where no indexing ops are needed
|
||||
return l
|
||||
# sad tiled path; need to adjust indices, and manually check bounds for the tiled axis
|
||||
if self.tile_axis == 1:
|
||||
sample_pos = f"((int2)({l}.x % {self.max_hw}, ({l}.x / {self.max_hw}) * {self.shape[1]} + {l}.y))"
|
||||
in_bounds = f"((0 <= {l}.y) && ({l}.y < {self.shape[1]}))"
|
||||
else:
|
||||
sample_pos = f"((int2)(({l}.y / {self.max_hw}) * {self.shape[0]} + {l}.x, {l}.y % {self.max_hw}))"
|
||||
in_bounds = f"((0 <= {l}.x) && ({l}.x < {self.shape[0]}))"
|
||||
if check_bounds:
|
||||
return f"({in_bounds} ? {sample_pos} : (int2)(-1, -1))"
|
||||
return sample_pos
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, "cl"):
|
||||
CL.mem_used -= self.cl.row_pitch * self.cl.height
|
||||
|
||||
def get_replacements(prg_src:str, opencl_type:List[str]) -> Dict[str, str]:
|
||||
middle_code = []
|
||||
|
||||
"""
|
||||
vv = "xyzw"
|
||||
for i in range(4):
|
||||
acc = f"outputValues[i].{vv[i%4]}"
|
||||
args = [x.split(" ")[-1].replace("*", "") for x in opencl_type]
|
||||
args = [f"(outputRow * get_image_width(output) + outputLocation.x)*4+{i}", acc]+args
|
||||
middle_code.append(f"{acc} = _ewop("+', '.join(args)+");\n")
|
||||
"""
|
||||
acc = "outputValues[i]"
|
||||
args = [x.split(" ")[-1].replace("*", "") for x in opencl_type]
|
||||
args = ["smp", "outputLocation", "(outputLocation.y * get_image_width(output) + outputLocation.x)*4", acc]+args
|
||||
middle_code.append(f"{acc} = _ewop("+', '.join(args)+");\n")
|
||||
|
||||
replacements = {}
|
||||
replacements["//PREFIX"] = prg_src
|
||||
replacements["//BINOP"] = ''.join(middle_code)
|
||||
if len(opencl_type) != 0:
|
||||
replacements["//ARGS"] = ","+','.join(opencl_type)
|
||||
return replacements
|
||||
|
||||
def get_getters(ewbufs, ret):
|
||||
fakebufs = []
|
||||
ewtypes = []
|
||||
getters = []
|
||||
for name, buf in ewbufs:
|
||||
view, unfolded, _ = buf.contiguous_view_constant_fold(name)
|
||||
if not unfolded:
|
||||
getters.append(view)
|
||||
fakebufs.append(name)
|
||||
getters.append(f"inline float4 get4_{name}(int gid) {{"+
|
||||
f"return (float4)(get_{name}(gid+0), get_{name}(gid+1), get_{name}(gid+2), get_{name}(gid+3)); }}")
|
||||
elif buf.is_image() and buf.shape == ret.shape and buf.st.contiguous:
|
||||
# use an image here
|
||||
ewtypes.append(f"read_only image2d_t {name}_g")
|
||||
getters.append(f"inline float4 get4_{name}(read_only image2d_t x, const sampler_t smp, int2 loc, int gid) {{ return read_imagef(x, smp, {buf._image.pos_to_sample_pos('loc')}); }}")
|
||||
elif buf.st.contiguous:
|
||||
# use float4
|
||||
ewtypes.append(f"__global const float4 *{name}_g")
|
||||
getters.append(f"inline float4 get4_{name}(__global const float4 *x, const sampler_t smp, int2 loc, int gid) {{ return x[gid/4]; }}")
|
||||
elif UNSAFE_FLOAT4:
|
||||
# aggressive constant folding
|
||||
fakebufs.append(name)
|
||||
prt = buf._backing.reshape((-1, 4))
|
||||
cc = []
|
||||
for ii in range(prt.shape[0]):
|
||||
cc.append("(float4)(%ff, %ff, %ff, %ff)" % (prt[ii][0], prt[ii][1], prt[ii][2], prt[ii][3]))
|
||||
getters.append(f"const __constant float4 const_{name}[] = {{"+', '.join(cc)+"};")
|
||||
getters.append(f"inline float4 get4_{name}(int gid) {{"+
|
||||
"int idx = gid;"+buf.st.expr()+";"+
|
||||
f"return const_{name}[idx/4]; }}")
|
||||
"""
|
||||
# use float4 indexed (HACK!)
|
||||
# TODO: work out when this is okay
|
||||
ewtypes.append(f"__global const float4 *{name}_g")
|
||||
getters.append(f"inline float4 get4_{name}(__global const float4 *x, const sampler_t smp, int2 loc, int gid) {{"+
|
||||
"int valid = 1; int idx = gid;"+buf.st.expr()+";"+
|
||||
f"return x[idx/4]; }}")
|
||||
"""
|
||||
else:
|
||||
# fallback to float
|
||||
getters.append(view)
|
||||
ewtypes.append(f"__global const float *{name}_g")
|
||||
getters.append(f"inline float4 get4_{name}(__global const float *x, const sampler_t smp, int2 loc, int gid) {{"+
|
||||
f"return (float4)(get_{name}(x,gid+0), get_{name}(x,gid+1), get_{name}(x,gid+2), get_{name}(x,gid+3)); }}")
|
||||
return fakebufs, ewtypes, getters
|
||||
|
||||
def roundup(x, n=4): return (x+(n-1))//n * n
|
||||
class OpenCLBuffer(GPUBuffer):
|
||||
code_for_op = {
|
||||
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "max(A, (float)0.)", UnaryOps.SIGN: "sign(A)",
|
||||
UnaryOps.EXP: "native_exp(A)" if NATIVE_EXPLOG else "exp(A)",
|
||||
UnaryOps.LOG: "native_log(A)" if NATIVE_EXPLOG else "log(A)",
|
||||
UnaryOps.RECIPROCAL: "native_recip(A)" if NATIVE_EXPLOG else "((float)1.0/A)",
|
||||
BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)", BinaryOps.DIV: "(A/B)", BinaryOps.POW: "pow(A,B)", BinaryOps.CMPEQ: "(A==B)",
|
||||
ReduceOps.SUM: "(acc + A)", ReduceOps.MAX: "max(A, acc)", MovementOps.RESHAPE: "(A)"
|
||||
}
|
||||
start_for_op = {ReduceOps.SUM: "0.0", ReduceOps.MAX: "-INFINITY"}
|
||||
def __init__(self, shape, hostbuf:Optional[OpenCLBuffer]=None, backing:Optional[np.ndarray]=None):
|
||||
self._image = hostbuf._image if hostbuf is not None else None
|
||||
self.copied_backing = False
|
||||
super().__init__(shape, hostbuf, backing)
|
||||
assert not (self._image and self._buf)
|
||||
|
||||
@staticmethod
|
||||
def fromCPU(x): return OpenCLBuffer(x.shape, backing=x.view(np.ndarray).astype(np.float32).ravel())
|
||||
|
||||
def __repr__(self): return f"<OpenCLBuffer with shape {self.shape!r}>"
|
||||
|
||||
@property
|
||||
def cl(self):
|
||||
if self._buf is None:
|
||||
if self._backing is not None and not self.copied_backing:
|
||||
self._buf = CLBuffer(4*roundup(prod(self._backing.shape)))
|
||||
CL.enqueue_copy(self._buf.cl, self._backing, is_blocking=False)
|
||||
self.copied_backing = True
|
||||
elif self.st.contiguous:
|
||||
self._buf = CLBuffer(4*roundup(prod(self.shape)))
|
||||
|
||||
if self._image is not None:
|
||||
self._buf = CLBuffer(4*roundup(prod(self._image.shape)*4))
|
||||
if self._backing is not None and not self.copied_backing:
|
||||
CL.enqueue_copy(self._buf.cl, self._backing, is_blocking=False)
|
||||
self.copied_backing = True
|
||||
#print(f"converting {self.shape} back to buffer, image shape is {self._image.shape}")
|
||||
CLProgram("from_image", f"""
|
||||
__kernel void from_image(
|
||||
__global float4 *out,
|
||||
read_only image2d_t in) {{
|
||||
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
|
||||
int2 l;
|
||||
l.y = get_global_id(1);
|
||||
l.x = get_global_id(0);
|
||||
int2 l_smp = {self._image.pos_to_sample_pos('l')};
|
||||
int W = {str(self._image.shape[0])};
|
||||
out[l.y*W + l.x] = read_imagef(in, smp, l_smp);
|
||||
}}
|
||||
""")(self._image.shape, None, self._buf.cl, self._image.cl)
|
||||
self._image = None
|
||||
return self._buf.cl
|
||||
|
||||
def is_image(self): return self._image is not None
|
||||
|
||||
@property
|
||||
def image(self):
|
||||
if self._image is None:
|
||||
assert len(self.shape) == 3 and self.shape[2] == 4, f"bad shape for image {self.shape}"
|
||||
assert self.st.contiguous, f"{self} is not contiguous"
|
||||
self._image = CLImage(shape=(self.shape[1], self.shape[0]))
|
||||
if self._buf is not None:
|
||||
assert prod(self.shape) <= prod(self._image.cl.shape)*4
|
||||
#print(f"converting {self.shape} to image with shape {self._image.shape}")
|
||||
CLProgram("to_image", f"""
|
||||
__kernel void to_image(
|
||||
write_only image2d_t out,
|
||||
__global const float4 *in) {{
|
||||
int2 l;
|
||||
l.y = get_global_id(1);
|
||||
l.x = get_global_id(0);
|
||||
int2 l_out = {self._image.pos_to_sample_pos('l', check_bounds=False)};
|
||||
int W = {str(self._image.shape[0])};
|
||||
write_imagef(out, l_out, in[l.y*W + l.x]);
|
||||
}}
|
||||
""")(self._image.shape, None, self._image.cl, self._buf.cl)
|
||||
self._buf = None
|
||||
return self._image.cl
|
||||
|
||||
SUPPORTS_PADDING = True
|
||||
def processing_op(x, op:ProcessingOps, w:GPUBuffer, C:ConvArgs):
|
||||
assert op == ProcessingOps.CONV, f"{op} isn't supported"
|
||||
return type(x)(C.out_shape)._processing_op([("input", x.contiguous()), ("weight", w.contiguous())], "acc", C)
|
||||
|
||||
def contiguous_view_constant_fold(x, name:str, reduce:Optional[int]=None) -> Tuple[str, Optional[str], str]:
|
||||
# this will only be for convs, for reduce we have to fall back to cl
|
||||
if x.is_image() and reduce is None:
|
||||
#print("is image")
|
||||
return f"""inline float get_{name}(const sampler_t smp, read_only image2d_t x, int gid) {{
|
||||
int valid = 1; int idx = gid; {x.st.expr().replace('//', '/')};
|
||||
int2 l;
|
||||
int W = {str(x._image.shape[0])};
|
||||
l.y = idx / (W*4);
|
||||
l.x = (idx/4) % W;
|
||||
int idx4 = idx % 4;
|
||||
int2 l_smp = {x._image.pos_to_sample_pos('l')};
|
||||
float4 dat = read_imagef(x, smp, l_smp);
|
||||
return valid ? (idx4 == 0 ? dat.x : (idx4 == 1 ? dat.y : (idx4 == 2 ? dat.z : dat.w))) : 0.0;
|
||||
}}""", f"read_only image2d_t {name}_g", f"get_{name}(smp, {name}_g, gid);"
|
||||
else:
|
||||
idx_getter = f"int valid = 1; {'long' if prod(x.shape) >= 2**31 else 'int'} idx = gid; {'idx *= '+str(reduce)+'; idx += subidx;' if reduce is not None else ''} {x.st.expr().replace('//', '/')};"
|
||||
constant = x._backing[0] if x._base_shape == (1,) and x._backing is not None else None
|
||||
args = (["__global const float *x"] if constant is None else []) + ["int gid"] + (["int subidx"] if reduce is not None else [])
|
||||
return f"inline float get_{name}({','.join(args)}) {{ {idx_getter} return valid ? {constant if constant is not None else 'x[idx]'} : 0.0;}}", \
|
||||
f"__global const float *{name}_g" if constant is None else None, \
|
||||
f"get_{name}({name+'_g, ' if constant is None else ''}gid{', subidx' if reduce is not None else ''});"
|
||||
|
||||
@classmethod
|
||||
def exec_ast(cls, ast:LazyOp):
|
||||
# copied from llvm
|
||||
bufs = dedup(get_buffers(ast))
|
||||
reduceops = dedup([x for x in get_lazyops(ast) if isinstance(x.op, ReduceOps) or isinstance(x.op, ProcessingOps)])
|
||||
assert len(reduceops) <= 1, f"max one reduce op in an ast, {reduceops}"
|
||||
earlybufs = dedup(get_buffers(reduceops[0])) if len(reduceops) > 0 else []
|
||||
reduce_shape = (earlybufs[0].shape, reduceops[0].arg) if len(reduceops) > 0 and isinstance(reduceops[0].op, ReduceOps) else None
|
||||
info = get_lazyop_info(ast)
|
||||
ret = cls(info.shape)
|
||||
|
||||
buf_names : Dict[GPUBuffer, str] = {x:f"arg_{i}" for i,x in enumerate(bufs)}
|
||||
|
||||
# special names for input and weight
|
||||
if len(reduceops) > 0 and isinstance(reduceops[0].op, ProcessingOps):
|
||||
buf_names[reduceops[0].src[0]] = "input"
|
||||
buf_names[reduceops[0].src[1]] = "weight"
|
||||
|
||||
def _ast(x: Union[GPUBuffer, LazyOp], buf_names: Dict[GPUBuffer, str], code_for_op: Dict[Op, str], allow_reduce=False) -> str:
|
||||
if isinstance(x, GPUBuffer):
|
||||
return buf_names[x]
|
||||
if not allow_reduce and type(x.op) in [ProcessingOps, ReduceOps]:
|
||||
return "acc"
|
||||
srcs_code = [_ast(src, buf_names, code_for_op) for src in x.src]
|
||||
code = code_for_op[x.op]
|
||||
if len(srcs_code) >= 1:
|
||||
code = code.replace("A", srcs_code[0])
|
||||
if len(srcs_code) >= 2:
|
||||
code = code.replace("B", srcs_code[1])
|
||||
return code
|
||||
|
||||
earlycode = _ast(reduceops[0], buf_names, cls.code_for_op, allow_reduce=True) if len(reduceops) > 0 and isinstance(reduceops[0].op, ReduceOps) else "acc"
|
||||
code = _ast(ast, buf_names, cls.code_for_op)
|
||||
|
||||
C = reduceops[0].arg if len(reduceops) > 0 and isinstance(reduceops[0].op, ProcessingOps) else None
|
||||
reduce_op = reduceops[0].op if len(reduceops) > 0 and isinstance(reduceops[0].op, ReduceOps) else ReduceOps.SUM
|
||||
return ret._processing_op([(buf_names[x], x) for x in bufs], code, C, reduce_op, reduce_shape, set(buf_names[x] for x in earlybufs), earlycode, info.flops)
|
||||
|
||||
def _simple_processing_op(ret, bufs: List[Tuple[str, GPUBuffer]]=[], code:str="acc", C:Optional[ConvArgs]=None, op=ReduceOps.SUM, reduce_shape=None, earlybufs:Set[str]=set(), earlycode:str="acc", op_estimate=0) -> GPUBuffer:
|
||||
assert C is None, f"conv isn't handled by GPU anymore {C}"
|
||||
|
||||
# get the input/output shape and the reduce amount
|
||||
reduce_shape = (bufs[0][1].shape, ret.shape) if reduce_shape is None else reduce_shape
|
||||
red = prod([s for s,n in zip(*reduce_shape) if n == 1])
|
||||
assert red < 2**31, f"reduce must be under 2**31, {red} isn't"
|
||||
|
||||
# if it's a partial reduce, assert last non reduced axis is before the first reduced axis
|
||||
if red > 1 and prod(ret.shape) != 1:
|
||||
assert max([i for i,(s,n) in enumerate(zip(*reduce_shape)) if s == n and n != 1]) < min([i for i,(s,n) in enumerate(zip(*reduce_shape)) if s != 1 and n == 1])
|
||||
|
||||
kernel_name = "reduce" if red > 1 else "elementwise"
|
||||
early_views = {name:buf.contiguous_view_constant_fold(name, red) for name, buf in bufs if name in earlybufs}
|
||||
late_views = {name:buf.contiguous_view_constant_fold(name) for name, buf in bufs if name not in earlybufs}
|
||||
views = {**early_views, **late_views}
|
||||
|
||||
buf_types : List[str] = [views[name][1] for name, _ in bufs if views[name][1] is not None] # type: ignore
|
||||
buf_cl = [buf.cl if 'image2d_t' not in views[name][1] else buf.image for name, buf in bufs if views[name][1] is not None] # type: ignore
|
||||
|
||||
# use local memory if it's a multistage reduce
|
||||
inter_red = 256 if (prod(ret.shape) < 8192 and red >= 256) else 1
|
||||
if inter_red > 1:
|
||||
buf_cl.append(cl.LocalMemory(inter_red*4))
|
||||
|
||||
reduce_loop = f"int mid = get_global_id(1); for (int subidx = {red//inter_red + 1} * mid; subidx < min({red}, {red//inter_red + 1} * (mid+1)); subidx++)" if inter_red > 1 else f"for (int subidx = 0; subidx < {red}; subidx++)"
|
||||
conv_prg = CLProgram(kernel_name, f"""{chr(10).join([x[0] for x in views.values()])}
|
||||
__kernel void {kernel_name}({','.join(["__global float* restrict output"] + buf_types + (["__local float *temp"] if inter_red > 1 else []))}) {{
|
||||
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
|
||||
float acc = {GPUBuffer.start_for_op[op]};
|
||||
int gid = get_global_id(0);
|
||||
{reduce_loop} {{
|
||||
{chr(10).join([f' float {name} = ' + early_views[name][2] for name in early_views])}
|
||||
acc = {earlycode};
|
||||
}}"""+(f"""
|
||||
temp[mid] = acc; barrier(CLK_LOCAL_MEM_FENCE);
|
||||
if (mid == 0) {{ acc = {GPUBuffer.start_for_op[op]};
|
||||
for (int rdx = 0; rdx < {inter_red}; rdx++) {{
|
||||
acc = {GPUBuffer.code_for_op[op].replace('A', 'temp[rdx]')};
|
||||
}}""" if inter_red != 1 else "{")+f"""
|
||||
{chr(10).join([f' float {name} = ' + late_views[name][2] for name in late_views])}
|
||||
output[gid] = {code};
|
||||
}}
|
||||
}}""", op_estimate=op_estimate)
|
||||
|
||||
conv_prg([prod(ret.shape), inter_red, 1], [1, inter_red, 1] if inter_red > 1 else None, ret.cl, *buf_cl)
|
||||
return ret
|
||||
|
||||
def _processing_op(ret, bufs: List[Tuple[str, OpenCLBuffer]]=[], code:str="acc", C=None, op=ReduceOps.SUM, reduce_shape=None, earlybufs:Set[str]=set(), earlycode:str="acc", op_estimate=0):
|
||||
if C is None or earlycode != "acc":
|
||||
# TODO: handle an opencl conv without the conv part
|
||||
return ret._simple_processing_op(bufs, code, C, op, reduce_shape, earlybufs, earlycode, op_estimate)
|
||||
assert earlycode == "acc"
|
||||
|
||||
x = [x for x in bufs if x[0] == "input"][0][1]
|
||||
w = [x for x in bufs if x[0] == "weight"][0][1]
|
||||
ewbufs = [x for x in bufs if x[0] not in ["input", "weight"]]
|
||||
|
||||
# remove fakebufs
|
||||
fakebufs, ewtypes, getters = get_getters(ewbufs, ret)
|
||||
ewbufs = [x for x in ewbufs if x[0] not in fakebufs]
|
||||
|
||||
elementwise_prefix = '\n'.join(getters)+ \
|
||||
"\n\ninline float4 _ewop("+','.join(["const sampler_t smp", "int2 loc", "int gid", "float4 acc"]+ewtypes)+") {\n"+ \
|
||||
''.join([f"float4 {name} = get4_{name}(gid);\n" for name in fakebufs])+ \
|
||||
''.join([f"float4 {name} = get4_{name}({name}_g, smp, loc, gid);\n" for name, _ in ewbufs])+ \
|
||||
f"return {code}; }}"
|
||||
|
||||
replacements = get_replacements(elementwise_prefix, ewtypes)
|
||||
|
||||
(x.image, w.image, ret.image)
|
||||
# fix sampling
|
||||
replacements["INPUT_LOCATION"] = x._image.pos_to_sample_pos("inputLocation")
|
||||
replacements["WEIGHT_LOCATION"] = w._image.pos_to_sample_pos("weightLocation")
|
||||
replacements["OUTPUT_LOCATION"] = ret._image.pos_to_sample_pos("outputLocation", check_bounds=False)
|
||||
# fix widths
|
||||
replacements["get_image_width(output)"] = f"({ret._image.shape[0]})"
|
||||
|
||||
x, w = x.contiguous(), w.contiguous()
|
||||
options = []
|
||||
if C.bs > 1:
|
||||
options.append("-DBATCH")
|
||||
assert C.py == 0, "batched conv doesn't work with y-padding"
|
||||
if C.sx == 1 and C.sy == 1 and C.dx == 1 and C.dy == 1 and C.cin == 1:
|
||||
options.append("-DDEPTHWISE_UNSTRIDED")
|
||||
elif C.cin == 1:
|
||||
options.append("-DDEPTHWISE")
|
||||
if C.groups == 1 and C.H == 1 and C.W == 1 and C.iy == 1 and C.ix == 1 and C.oy == 1 and C.ox == 1 and C.sx == 1 and C.sy == 1 and C.dx == 1 and C.dy == 1 and C.bs == 1:
|
||||
options.append("-DMATMUL")
|
||||
# NOTE: this is not actually a matmul, it's a vector * matrix
|
||||
|
||||
conv_args = []
|
||||
conv_short_names = ["numPackedInputChannelsForGroup", "totalNumPackedInputChannels", "numPackedOutputChannelsForGroup", "totalNumPackedOutputChannels", "numOutputColumns", "numOutputRows", "numInputRows"]
|
||||
conv_shorts = [max(1, C.cin//4), C.groups*C.cin//4, max(1, C.rcout//4), C.cout//4, C.ox, C.oy, C.iy]
|
||||
|
||||
conv_src = MATMUL_SRC
|
||||
replacements["//SHORTS"] = ''.join([f"short {name} = {val};" for name,val in zip(conv_short_names, conv_shorts)])
|
||||
if "//BINOP" in replacements:
|
||||
replacements["//BINOP"] = replacements["//BINOP"].replace("outputValues[i]", "outputValues")
|
||||
for k,v in replacements.items():
|
||||
conv_src = conv_src.replace(k, v)
|
||||
|
||||
#print(conv_src)
|
||||
conv_prg = CLProgram("matmul", conv_src,
|
||||
options=tuple(options),
|
||||
argdtypes=tuple([None, None, None, None] + [np.int16]*len(conv_args) + [None]*len(ewbufs)),
|
||||
op_estimate=op_estimate
|
||||
)
|
||||
global_work_size = [4, 16, C.cout//4]
|
||||
|
||||
# must be even
|
||||
lw = CL.cl_ctx.devices[0].max_work_group_size // (global_work_size[0] * global_work_size[1])
|
||||
while global_work_size[2] % lw != 0:
|
||||
lw -= 1
|
||||
local_work_size = [4, global_work_size[1], lw]
|
||||
|
||||
#print(global_work_size, local_work_size)
|
||||
conv_prg(global_work_size, local_work_size, ret.image, cl.LocalMemory(4 * local_work_size[0] * local_work_size[1] * lw), x.image, w.image, *conv_args, *[buf.image if 'image2d_t' in typ else buf.cl for typ, (_, buf) in zip(ewtypes, ewbufs)])
|
||||
return ret
|
||||
|
||||
# this option is unused
|
||||
if C.H == 1 and C.W == 1:
|
||||
options.append("-DONLY_1X1_CONV")
|
||||
|
||||
assert C.cout%4 == 0
|
||||
conv_src = CONV_SRC
|
||||
conv_short_names = ["filterSizeX", "filterSizeY", "paddingX", "paddingY", "strideX", "strideY", "dilationX", "dilationY"]
|
||||
conv_shorts = [C.W, C.H, C.px, C.py, C.sx, C.sy, C.dx, C.dy]
|
||||
conv_arg_names = ["numPackedInputChannelsForGroup", "totalNumPackedInputChannels", "numPackedOutputChannelsForGroup", "totalNumPackedOutputChannels", "numOutputColumns", "numOutputRows", "numInputRows"]
|
||||
conv_args = [max(1, C.cin//4), C.groups*C.cin//4, max(1, C.rcout//4), C.cout//4, C.ox, C.oy, C.iy]
|
||||
|
||||
NUM_OUTPUTS = 4
|
||||
options.append(f"-DNUM_OUTPUTS={NUM_OUTPUTS}")
|
||||
|
||||
# comment out for args
|
||||
conv_short_names += conv_arg_names
|
||||
conv_shorts += conv_args
|
||||
conv_args = []
|
||||
options.append("-DNOARGS")
|
||||
|
||||
replacements["//SHORTS"] = ''.join([f"short {name} = {val};" for name,val in zip(conv_short_names, conv_shorts)])
|
||||
for k,v in replacements.items():
|
||||
conv_src = conv_src.replace(k, v)
|
||||
#print(conv_src)
|
||||
conv_prg = CLProgram("image_conv", conv_src,
|
||||
options=tuple(options),
|
||||
argdtypes=tuple([None, None, None] + [np.int16]*len(conv_args) + [None]*len(ewbufs)),
|
||||
op_estimate=op_estimate
|
||||
)
|
||||
global_work_size = [C.cout//4, (C.ox+NUM_OUTPUTS-1)//NUM_OUTPUTS, C.bs*C.oy]
|
||||
conv_prg(global_work_size, None, ret.image, x.image, w.image, *conv_args, *[buf.image if 'image2d_t' in typ else buf.cl for typ, (_, buf) in zip(ewtypes, ewbufs)])
|
||||
return ret
|
||||
|
||||
GPUBuffer = OpenCLBuffer
|
|
@ -1,88 +0,0 @@
|
|||
from tinygrad.ops import MovementOps, ProcessingOps
|
||||
|
||||
# input format is N, H x W, C//4 x 4
|
||||
# dweight format is oc//4 x ch, cw x 4(oc)
|
||||
# weight format is oc//4 x ch, ic//4, cw, 4(oc) x 4(ic)
|
||||
def preprocessing_op(x,w,C,make_image=True):
|
||||
w = w.movement_op(MovementOps.RESHAPE, (C.groups, C.rcout, C.cin, C.H, C.W))
|
||||
#print(x.shape, w.shape)
|
||||
|
||||
if C.bs > 1 and C.py > 0:
|
||||
# explicitly add y-padding for batched inputs
|
||||
# N C H W
|
||||
xs = [(0, 0) for _ in x.shape]
|
||||
xs[2] = (C.py, C.py)
|
||||
x = x.movement_op(MovementOps.PAD, xs)
|
||||
C = C._replace(iy=C.iy + C.py*2, py=0)
|
||||
|
||||
# hack for non multiples of 4 on C.cin
|
||||
if C.cin % 4 != 0 and not (C.cin == 1 and C.groups%4 == 0):
|
||||
to_add = 4 - (C.cin % 4)
|
||||
ws = [(0, 0) for _ in w.shape]
|
||||
ws[2] = (0, to_add)
|
||||
w = w.movement_op(MovementOps.PAD, ws)
|
||||
|
||||
x = x.movement_op(MovementOps.RESHAPE, (C.bs, C.groups, C.cin, C.iy, C.ix))
|
||||
xs = [(0, 0) for _ in x.shape]
|
||||
xs[2] = (0, to_add)
|
||||
x = x.movement_op(MovementOps.PAD, xs)
|
||||
C = C._replace(cin = C.cin + to_add)
|
||||
x = x.movement_op(MovementOps.RESHAPE, (C.bs, C.groups*C.cin, C.iy, C.ix))
|
||||
|
||||
# hack for non multiples of 4 on C.rcout
|
||||
if C.rcout % 4 != 0 and not (C.rcout == 1 and C.groups%4 == 0):
|
||||
added_output_channels = 4 - (C.rcout % 4)
|
||||
ws = [(0, 0) for _ in w.shape]
|
||||
ws[1] = (0, added_output_channels)
|
||||
w = w.movement_op(MovementOps.PAD, ws)
|
||||
C = C._replace(rcout = C.rcout + added_output_channels, cout = C.groups * (C.rcout + added_output_channels))
|
||||
|
||||
# packed
|
||||
assert (C.groups*C.cin) % 4 == 0
|
||||
#print(x.shape)
|
||||
x = x.movement_op(MovementOps.PERMUTE, (0,2,3,1))
|
||||
x = x.movement_op(MovementOps.RESHAPE, (C.bs*C.iy, C.ix*C.groups*C.cin//4, 4))
|
||||
|
||||
assert C.cout % 4 == 0
|
||||
if C.cin == 1:
|
||||
# depthwise
|
||||
w = w.movement_op(MovementOps.RESHAPE, (C.cout//4,4,C.H*C.W))
|
||||
w = w.movement_op(MovementOps.PERMUTE, (0,2,1))
|
||||
else:
|
||||
w = w.movement_op(MovementOps.RESHAPE, (C.cout//4,4,C.cin//4,4,C.H,C.W))
|
||||
w = w.movement_op(MovementOps.PERMUTE, (0,4,2,5,1,3))
|
||||
w = w.movement_op(MovementOps.RESHAPE, (C.cout//4, C.H * C.cin//4 * C.W * 4, 4))
|
||||
|
||||
C = C._replace(out_shape = (C.bs*C.oy, C.ox*C.cout//4, 4))
|
||||
#x = contiguous(ctx, x, x.shapetracker) if not x.shapetracker.contiguous else x
|
||||
#w = contiguous(ctx, w, w.shapetracker) if not w.shapetracker.contiguous else w
|
||||
|
||||
# contiguous before image, always
|
||||
x = x.contiguous()
|
||||
w = w.contiguous()
|
||||
|
||||
# early realize on the weights
|
||||
bw = w
|
||||
while getattr(bw, 'op', None) and len(bw.op.src) == 1:
|
||||
bw = bw.op.src[0]
|
||||
if bw.realized:
|
||||
# weights are static
|
||||
wr = w.realize() #.image
|
||||
if make_image:
|
||||
wr.image
|
||||
return x,w,C
|
||||
|
||||
def postprocessing_op(ret, C, C_initial):
|
||||
added_output_channels = C.rcout - C_initial.rcout
|
||||
|
||||
# undo hack for non multiples of 4 on C.rcout
|
||||
if added_output_channels != 0:
|
||||
ret = ret.movement_op(MovementOps.RESHAPE, (C.bs, C.oy, C.ox, C.groups, C.rcout))
|
||||
xs = [(0, s) for s in ret.shape]
|
||||
xs[4] = (0, ret.shape[4]-added_output_channels)
|
||||
ret = ret.movement_op(MovementOps.SHRINK, xs)
|
||||
C = C._replace(rcout = C.rcout - added_output_channels, cout = C.groups * (C.rcout - added_output_channels))
|
||||
|
||||
ret = ret.movement_op(MovementOps.RESHAPE, (C.bs, C.oy, C.ox, C.cout))
|
||||
ret = ret.movement_op(MovementOps.PERMUTE, (0,3,1,2))
|
||||
return ret
|
|
@ -633,7 +633,6 @@ if __name__ == "__main__":
|
|||
parser.add_argument('--out', type=str, default="/tmp/rendered.png", help="Output filename")
|
||||
args = parser.parse_args()
|
||||
|
||||
# WTF!! no_grad breaks it (only with OPENCL, now fixed)
|
||||
Tensor.no_grad = True
|
||||
model = StableDiffusion()
|
||||
|
||||
|
|
|
@ -1,2 +0,0 @@
|
|||
#!/bin/bash -e
|
||||
OPENCL=1 DEBUGCL=1 python3 openpilot/compile.py ../selfdrive/modeld/models/supercombo.onnx ../selfdrive/modeld/models/supercombo.thneed
|
|
@ -1,39 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
import unittest
|
||||
import torch
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Device
|
||||
|
||||
def helper_test_tiler(x_shape, pts):
|
||||
from tinygrad.llops.ops_gpu import CLProgram
|
||||
from tinygrad.llops.ops_opencl import OpenCLBuffer
|
||||
torch.manual_seed(0)
|
||||
x = torch.randn(*x_shape).numpy()
|
||||
targets = np.array([(x[i, j] if (0 <= i < x.shape[0] and 0 <= j < x.shape[1]) else [0]*4) for (i, j) in pts])
|
||||
x_buffer, pts_buffer, out_buffer = OpenCLBuffer.fromCPU(x), OpenCLBuffer.fromCPU(np.flip(pts, -1)), OpenCLBuffer(targets.shape)
|
||||
x_buffer.image
|
||||
CLProgram("test_tiler", f"""
|
||||
__kernel void test_tiler(__global float4 *out, read_only image2d_t in, __global const float2 *pts) {{
|
||||
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
|
||||
int2 l_smp = convert_int2(pts[get_global_id(0)]);
|
||||
out[get_global_id(0)] = read_imagef(in, smp, {x_buffer._image.pos_to_sample_pos('l_smp')});
|
||||
}}
|
||||
""")((len(targets), 1), None, out_buffer.cl, x_buffer.image, pts_buffer.cl)
|
||||
out = out_buffer.toCPU()
|
||||
np.testing.assert_allclose(out, targets)
|
||||
|
||||
def get_pts(*boundary_coords):
|
||||
from tinygrad.llops.ops_gpu import CL
|
||||
c = sum(([i - 1, i, i + 1] for i in boundary_coords), start=[CL().cl_ctx.devices[0].image2d_max_width])
|
||||
return [[i, j] for i in c for j in c]
|
||||
|
||||
@unittest.skipUnless(hasattr(Device, "OPENCL"), "Test requires OpenCL")
|
||||
class TestCLTiler(unittest.TestCase):
|
||||
"""Test for CLImage tiling logic, which allows large tensors to fit within limited-size OpenCL images."""
|
||||
|
||||
def test_small(self):
|
||||
helper_test_tiler((5, 6, 4), get_pts(0, 5, 6))
|
||||
def test_wide(self):
|
||||
helper_test_tiler((3, 40_000, 4), get_pts(0, 3, 40_000))
|
||||
def test_tall(self):
|
||||
helper_test_tiler((40_000, 3, 4), get_pts(0, 3, 40_000))
|
|
@ -41,7 +41,7 @@ def helper_test_speed(f1, *args):
|
|||
GlobalCounters.global_mem = 0
|
||||
st = time.monotonic()
|
||||
ret = f1(*args)
|
||||
if CL is not None and ret.device in ["GPU", "OPENCL"]:
|
||||
if CL is not None and ret.device in ["GPU"]:
|
||||
CL.cl_queue.finish()
|
||||
if torch_device != "cpu":
|
||||
# TODO: better way to sync?
|
||||
|
|
|
@ -88,7 +88,7 @@ def _realize_binaryops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer
|
|||
# NOTE: contiguous does not always mean the same size with SHRINK. this is still mergeable but requires more thought how
|
||||
psrcs : List[Tuple[LazyBuffer, LazyBuffer]] = [(k,x) for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())) if x.optype in [ProcessingOps,ReduceOps] and x.realized is None and prod(k.shape) == prod(x.shape) and len(x.children) <= 1 and len(k.children) <= 1]
|
||||
intermediate_shape = self.shape
|
||||
if len(psrcs) == 1 and MERGE_ONE_REDUCE_INTO_ELEMENTWISE and (self.device != "OPENCL" or self.shape[-1] == 4):
|
||||
if len(psrcs) == 1 and MERGE_ONE_REDUCE_INTO_ELEMENTWISE:
|
||||
if psrcs[0][1].optype == ProcessingOps:
|
||||
real_srcs[psrcs[0][0]] = psrcs[0][1].op
|
||||
for x in psrcs[0][1].op.src:
|
||||
|
@ -250,9 +250,67 @@ class LazyBuffer:
|
|||
x = self
|
||||
|
||||
if IMAGE >= 1:
|
||||
from accel.opencl.preprocessing import preprocessing_op, postprocessing_op # type: ignore
|
||||
Cold = C
|
||||
x,w,C = preprocessing_op(x, w, Cold, False)
|
||||
w = w.movement_op(MovementOps.RESHAPE, (C.groups, C.rcout, C.cin, C.H, C.W))
|
||||
|
||||
if C.bs > 1 and C.py > 0:
|
||||
# explicitly add y-padding for batched inputs
|
||||
# N C H W
|
||||
xs = [(0, 0) for _ in x.shape]
|
||||
xs[2] = (C.py, C.py)
|
||||
x = x.movement_op(MovementOps.PAD, xs)
|
||||
C = C._replace(iy=C.iy + C.py*2, py=0)
|
||||
|
||||
# hack for non multiples of 4 on C.cin
|
||||
if C.cin % 4 != 0 and not (C.cin == 1 and C.groups%4 == 0):
|
||||
to_add = 4 - (C.cin % 4)
|
||||
ws = [(0, 0) for _ in w.shape]
|
||||
ws[2] = (0, to_add)
|
||||
w = w.movement_op(MovementOps.PAD, ws)
|
||||
|
||||
x = x.movement_op(MovementOps.RESHAPE, (C.bs, C.groups, C.cin, C.iy, C.ix))
|
||||
xs = [(0, 0) for _ in x.shape]
|
||||
xs[2] = (0, to_add)
|
||||
x = x.movement_op(MovementOps.PAD, xs)
|
||||
C = C._replace(cin = C.cin + to_add)
|
||||
x = x.movement_op(MovementOps.RESHAPE, (C.bs, C.groups*C.cin, C.iy, C.ix))
|
||||
|
||||
# hack for non multiples of 4 on C.rcout
|
||||
if C.rcout % 4 != 0 and not (C.rcout == 1 and C.groups%4 == 0):
|
||||
added_output_channels = 4 - (C.rcout % 4)
|
||||
ws = [(0, 0) for _ in w.shape]
|
||||
ws[1] = (0, added_output_channels)
|
||||
w = w.movement_op(MovementOps.PAD, ws)
|
||||
C = C._replace(rcout = C.rcout + added_output_channels, cout = C.groups * (C.rcout + added_output_channels))
|
||||
else:
|
||||
added_output_channels = 0
|
||||
|
||||
# packed
|
||||
assert (C.groups*C.cin) % 4 == 0
|
||||
x = x.movement_op(MovementOps.PERMUTE, (0,2,3,1))
|
||||
x = x.movement_op(MovementOps.RESHAPE, (C.bs*C.iy, C.ix*C.groups*C.cin//4, 4))
|
||||
|
||||
assert C.cout % 4 == 0
|
||||
if C.cin == 1:
|
||||
# depthwise
|
||||
w = w.movement_op(MovementOps.RESHAPE, (C.cout//4,4,C.H*C.W))
|
||||
w = w.movement_op(MovementOps.PERMUTE, (0,2,1))
|
||||
else:
|
||||
w = w.movement_op(MovementOps.RESHAPE, (C.cout//4,4,C.cin//4,4,C.H,C.W))
|
||||
w = w.movement_op(MovementOps.PERMUTE, (0,4,2,5,1,3))
|
||||
w = w.movement_op(MovementOps.RESHAPE, (C.cout//4, C.H * C.cin//4 * C.W * 4, 4))
|
||||
|
||||
C = C._replace(out_shape = (C.bs*C.oy, C.ox*C.cout//4, 4))
|
||||
|
||||
# contiguous before image, always
|
||||
x = x.contiguous()
|
||||
w = w.contiguous()
|
||||
|
||||
# early realize on the weights
|
||||
bw = w
|
||||
while getattr(bw, 'op', None) and len(bw.op.src) == 1:
|
||||
bw = bw.op.src[0]
|
||||
if bw.realized:
|
||||
w.realize()
|
||||
|
||||
# set up the conv
|
||||
# (C.bs*C.iy, C.ix*C.groups*C.cin//4, 4)
|
||||
|
@ -284,8 +342,19 @@ class LazyBuffer:
|
|||
|
||||
# now do the conv in this space
|
||||
ret = x.binary_op(BinaryOps.MUL, w).reduce_op(ReduceOps.SUM, (C.bs, C.oy, C.ox, C.cout//4, 4, 1, 1, 1, 1))
|
||||
ret = ret.movement_op(MovementOps.RESHAPE, (C.bs*C.oy, C.ox*C.cout//4, 4)).contiguous() #True)
|
||||
return postprocessing_op(ret, C, Cold)
|
||||
ret = ret.movement_op(MovementOps.RESHAPE, (C.bs*C.oy, C.ox*C.cout//4, 4)).contiguous()
|
||||
|
||||
# undo hack for non multiples of 4 on C.rcout
|
||||
if added_output_channels != 0:
|
||||
ret = ret.movement_op(MovementOps.RESHAPE, (C.bs, C.oy, C.ox, C.groups, C.rcout))
|
||||
xs = [(0, s) for s in ret.shape]
|
||||
xs[4] = (0, ret.shape[4]-added_output_channels)
|
||||
ret = ret.movement_op(MovementOps.SHRINK, xs)
|
||||
C = C._replace(rcout = C.rcout - added_output_channels, cout = C.groups * (C.rcout - added_output_channels))
|
||||
|
||||
ret = ret.movement_op(MovementOps.RESHAPE, (C.bs, C.oy, C.ox, C.cout))
|
||||
ret = ret.movement_op(MovementOps.PERMUTE, (0,3,1,2))
|
||||
return ret
|
||||
|
||||
# TODO: fixup C?
|
||||
if NOCONV or not getattr(x.dbuffer, "SUPPORTS_PADDING", False):
|
||||
|
@ -311,12 +380,6 @@ class LazyBuffer:
|
|||
.movement_op(MovementOps.EXPAND, (C.bs, C.groups, C.rcout, C.oy, C.ox, C.cin, C.H, C.W))
|
||||
return x.binary_op(BinaryOps.MUL, w).reduce_op(ReduceOps.SUM, (C.bs, C.groups, C.rcout, C.oy, C.ox, 1, 1, 1)) \
|
||||
.movement_op(MovementOps.RESHAPE, (C.bs, C.cout, C.oy, C.ox))
|
||||
elif x.device == "OPENCL":
|
||||
# TODO: these can be properties on the device buffer
|
||||
from accel.opencl.preprocessing import preprocessing_op, postprocessing_op # type: ignore
|
||||
x,w,Cn = preprocessing_op(x, w, C)
|
||||
ret = LazyBuffer(x.device, Cn.out_shape, ProcessingOps, LazyOp(op, (x, w), Cn))
|
||||
return postprocessing_op(ret, Cn, C)
|
||||
else:
|
||||
return LazyBuffer(x.device, C.out_shape, ProcessingOps, LazyOp(op, (x, w), C))
|
||||
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
../../accel/opencl/ops_opencl.py
|
|
@ -69,7 +69,7 @@ class GenericShape(GenericExecAST):
|
|||
def get_lazyop_info(ast:LazyOp): return GenericShape.exec_ast(ast, lambda x: GenericShape(x.shape))
|
||||
|
||||
# assumes you are using ShapeTracker
|
||||
# used in GPUBuffer, OpenCLBuffer, and LLVMBuffer
|
||||
# used in GPUBuffer and LLVMBuffer
|
||||
class ExplicitExecAST(DeviceBuffer):
|
||||
def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], hostbuf=None):
|
||||
self.st = shape if isinstance(shape, ShapeTracker) else ShapeTracker(tuple(shape))
|
||||
|
|
Loading…
Reference in New Issue