Simple version of the new GPU backend (#458)

* newgpu

* more to delete

* hmm, tests pass with constant folding

* fix lint/type

* fix constant folding

* comment and rerun tests

* lazy touchups

* fix graph_batchnorm test

* smaller transformer to fix OOM

* Revert "smaller transformer to fix OOM"

This reverts commit a44ef8edc275a4b3c78ee711ba188e220b7a879f.

* no func cache

* introspect

* touchups

* CLASTKernel

* ugh, it was lru_cache

* codegen

* spacing

* old gpu still in opencl

* typing fix
This commit is contained in:
George Hotz 2023-01-10 19:16:02 -08:00 committed by GitHub
parent 66123c99b9
commit fff1f046b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 220 additions and 105 deletions

View File

@ -3,9 +3,9 @@
from __future__ import annotations from __future__ import annotations
import os import os
from tinygrad.llops.ops_gpu import GPUBuffer, CL, CLProgram, CLBuffer from tinygrad.llops.ops_gpu import GPUBuffer, CL, CLProgram, CLBuffer
from tinygrad.ops import ProcessingOps, ReduceOps, UnaryOps, BinaryOps, MovementOps from tinygrad.ops import ProcessingOps, ReduceOps, UnaryOps, BinaryOps, MovementOps, get_buffers, get_lazyops, get_lazyop_info, LazyOp, Op
from tinygrad.helpers import prod, ConvArgs from tinygrad.helpers import prod, ConvArgs, dedup
from typing import List, Tuple, Optional, Dict, Set from typing import List, Tuple, Optional, Dict, Set, Union
import numpy as np import numpy as np
import pyopencl as cl import pyopencl as cl
@ -228,13 +228,104 @@ class OpenCLBuffer(GPUBuffer):
float4 dat = read_imagef(x, smp, l_smp); float4 dat = read_imagef(x, smp, l_smp);
return valid ? (idx4 == 0 ? dat.x : (idx4 == 1 ? dat.y : (idx4 == 2 ? dat.z : dat.w))) : 0.0; return valid ? (idx4 == 0 ? dat.x : (idx4 == 1 ? dat.y : (idx4 == 2 ? dat.z : dat.w))) : 0.0;
}}""", f"read_only image2d_t {name}_g", f"get_{name}(smp, {name}_g, gid);" }}""", f"read_only image2d_t {name}_g", f"get_{name}(smp, {name}_g, gid);"
#ewtypes.append(f"read_only image2d_t {name}_g") else:
return super().contiguous_view_constant_fold(name, reduce) idx_getter = f"int valid = 1; {'long' if prod(x.shape) >= 2**31 else 'int'} idx = gid; {'idx *= '+str(reduce)+'; idx += subidx;' if reduce is not None else ''} {x.st.expr().replace('//', '/')};"
constant = x._backing[0] if x._base_shape == (1,) and x._backing is not None else None
args = (["__global const float *x"] if constant is None else []) + ["int gid"] + (["int subidx"] if reduce is not None else [])
return f"inline float get_{name}({','.join(args)}) {{ {idx_getter} return valid ? {constant if constant is not None else 'x[idx]'} : 0.0;}}", \
f"__global const float *{name}_g" if constant is None else None, \
f"get_{name}({name+'_g, ' if constant is None else ''}gid{', subidx' if reduce is not None else ''});"
@classmethod
def exec_ast(cls, ast:LazyOp):
# copied from llvm
bufs = dedup(get_buffers(ast))
reduceops = dedup([x for x in get_lazyops(ast) if isinstance(x.op, ReduceOps) or isinstance(x.op, ProcessingOps)])
assert len(reduceops) <= 1, f"max one reduce op in an ast, {reduceops}"
earlybufs = dedup(get_buffers(reduceops[0])) if len(reduceops) > 0 else []
reduce_shape = (earlybufs[0].shape, reduceops[0].arg) if len(reduceops) > 0 and isinstance(reduceops[0].op, ReduceOps) else None
info = get_lazyop_info(ast)
ret = cls(info.shape)
buf_names : Dict[GPUBuffer, str] = {x:f"arg_{i}" for i,x in enumerate(bufs)}
# special names for input and weight
if len(reduceops) > 0 and isinstance(reduceops[0].op, ProcessingOps):
buf_names[reduceops[0].src[0]] = "input"
buf_names[reduceops[0].src[1]] = "weight"
def _ast(x: Union[GPUBuffer, LazyOp], buf_names: Dict[GPUBuffer, str], code_for_op: Dict[Op, str], allow_reduce=False) -> str:
if isinstance(x, GPUBuffer):
return buf_names[x]
if not allow_reduce and type(x.op) in [ProcessingOps, ReduceOps]:
return "acc"
srcs_code = [_ast(src, buf_names, code_for_op) for src in x.src]
code = code_for_op[x.op]
if len(srcs_code) >= 1:
code = code.replace("A", srcs_code[0])
if len(srcs_code) >= 2:
code = code.replace("B", srcs_code[1])
return code
earlycode = _ast(reduceops[0], buf_names, cls.code_for_op, allow_reduce=True) if len(reduceops) > 0 and isinstance(reduceops[0].op, ReduceOps) else "acc"
code = _ast(ast, buf_names, cls.code_for_op)
C = reduceops[0].arg if len(reduceops) > 0 and isinstance(reduceops[0].op, ProcessingOps) else None
reduce_op = reduceops[0].op if len(reduceops) > 0 and isinstance(reduceops[0].op, ReduceOps) else ReduceOps.SUM
return ret._processing_op([(buf_names[x], x) for x in bufs], code, C, reduce_op, reduce_shape, set(buf_names[x] for x in earlybufs), earlycode, info.flops)
def _simple_processing_op(ret, bufs: List[Tuple[str, GPUBuffer]]=[], code:str="acc", C:Optional[ConvArgs]=None, op=ReduceOps.SUM, reduce_shape=None, earlybufs:Set[str]=set(), earlycode:str="acc", op_estimate=0) -> GPUBuffer:
assert C is None, f"conv isn't handled by GPU anymore {C}"
# get the input/output shape and the reduce amount
reduce_shape = (bufs[0][1].shape, ret.shape) if reduce_shape is None else reduce_shape
red = prod([s for s,n in zip(*reduce_shape) if n == 1])
assert red < 2**31, f"reduce must be under 2**31, {red} isn't"
# if it's a partial reduce, assert last non reduced axis is before the first reduced axis
if red > 1 and prod(ret.shape) != 1:
assert max([i for i,(s,n) in enumerate(zip(*reduce_shape)) if s == n and n != 1]) < min([i for i,(s,n) in enumerate(zip(*reduce_shape)) if s != 1 and n == 1])
kernel_name = "reduce" if red > 1 else "elementwise"
early_views = {name:buf.contiguous_view_constant_fold(name, red) for name, buf in bufs if name in earlybufs}
late_views = {name:buf.contiguous_view_constant_fold(name) for name, buf in bufs if name not in earlybufs}
views = {**early_views, **late_views}
buf_types : List[str] = [views[name][1] for name, _ in bufs if views[name][1] is not None] # type: ignore
buf_cl = [buf.cl if 'image2d_t' not in views[name][1] else buf.image for name, buf in bufs if views[name][1] is not None] # type: ignore
# use local memory if it's a multistage reduce
inter_red = 256 if (prod(ret.shape) < 8192 and red >= 256) else 1
if inter_red > 1:
buf_cl.append(cl.LocalMemory(inter_red*4))
reduce_loop = f"int mid = get_global_id(1); for (int subidx = {red//inter_red + 1} * mid; subidx < min({red}, {red//inter_red + 1} * (mid+1)); subidx++)" if inter_red > 1 else f"for (int subidx = 0; subidx < {red}; subidx++)"
conv_prg = CLProgram(kernel_name, f"""{chr(10).join([x[0] for x in views.values()])}
__kernel void {kernel_name}({','.join(["__global float* restrict output"] + buf_types + (["__local float *temp"] if inter_red > 1 else []))}) {{
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
float acc = {GPUBuffer.start_for_op[op]};
int gid = get_global_id(0);
{reduce_loop} {{
{chr(10).join([f' float {name} = ' + early_views[name][2] for name in early_views])}
acc = {earlycode};
}}"""+(f"""
temp[mid] = acc; barrier(CLK_LOCAL_MEM_FENCE);
if (mid == 0) {{ acc = {GPUBuffer.start_for_op[op]};
for (int rdx = 0; rdx < {inter_red}; rdx++) {{
acc = {GPUBuffer.code_for_op[op].replace('A', 'temp[rdx]')};
}}""" if inter_red != 1 else "{")+f"""
{chr(10).join([f' float {name} = ' + late_views[name][2] for name in late_views])}
output[gid] = {code};
}}
}}""")
conv_prg([prod(ret.shape), inter_red, 1], [1, inter_red, 1] if inter_red > 1 else None, ret.cl, *buf_cl, op_estimate=op_estimate)
return ret
def _processing_op(ret, bufs: List[Tuple[str, OpenCLBuffer]]=[], code:str="acc", C=None, op=ReduceOps.SUM, reduce_shape=None, earlybufs:Set[str]=set(), earlycode:str="acc", op_estimate=0): def _processing_op(ret, bufs: List[Tuple[str, OpenCLBuffer]]=[], code:str="acc", C=None, op=ReduceOps.SUM, reduce_shape=None, earlybufs:Set[str]=set(), earlycode:str="acc", op_estimate=0):
if C is None or earlycode != "acc": if C is None or earlycode != "acc":
# TODO: handle an opencl conv without the conv part # TODO: handle an opencl conv without the conv part
return super()._processing_op(bufs, code, C, op, reduce_shape, earlybufs, earlycode, op_estimate) return ret._simple_processing_op(bufs, code, C, op, reduce_shape, earlybufs, earlycode, op_estimate)
assert earlycode == "acc" assert earlycode == "acc"
x = [x for x in bufs if x[0] == "input"][0][1] x = [x for x in bufs if x[0] == "input"][0][1]

View File

@ -7,7 +7,7 @@ from tinygrad.llops.ops_gpu import CL, GPUBuffer
#from tinygrad.llops.ops_opencl import CLImage, OpenCLBuffer #from tinygrad.llops.ops_opencl import CLImage, OpenCLBuffer
def print_objects(): def print_objects():
gc.collect() #gc.collect()
tensors = [x for x in gc.get_objects() if isinstance(x, Tensor)] tensors = [x for x in gc.get_objects() if isinstance(x, Tensor)]
tensor_ram_used = sum([prod(x.shape)*4 for x in tensors]) tensor_ram_used = sum([prod(x.shape)*4 for x in tensors])
lazybuffers = [x for x in gc.get_objects() if isinstance(x, LazyBuffer)] lazybuffers = [x for x in gc.get_objects() if isinstance(x, LazyBuffer)]
@ -24,7 +24,7 @@ def print_objects():
bb = gc.get_referrers(tb) bb = gc.get_referrers(tb)
for b in bb: for b in bb:
if b is not gpubuffers and b is not gpubuffers_orphaned: if b is not gpubuffers and b is not gpubuffers_orphaned:
print(tb, "reference", type(b),len(b), str(b)[0:150]) print(tb, "\nreference", type(b), len(b), str(b)[0:150], "\n\n")
for x in gc.get_referrers(b): for x in gc.get_referrers(b):
print(str(x)[0:100]) print(str(x)[0:100])
if cnt == 10: if cnt == 10:
@ -35,6 +35,8 @@ def print_objects():
if getattr(x, '_buf', None): del x._buf if getattr(x, '_buf', None): del x._buf
if getattr(x, '_image', None): del x._image if getattr(x, '_image', None): del x._image
return len(gpubuffers_orphaned)
""" """
import gc import gc

View File

@ -6,8 +6,8 @@ import unittest
def model_step(lm): def model_step(lm):
Tensor.training = True Tensor.training = True
x = Tensor.ones(8,12,128,256, requires_grad=False) x = Tensor.ones(8,12,128,256, requires_grad=False)
loss = lm.forward(x).sum()
optimizer = optim.SGD(get_parameters(lm), lr=0.001) optimizer = optim.SGD(get_parameters(lm), lr=0.001)
loss = lm.forward(x).sum()
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
del x,loss del x,loss

View File

@ -46,6 +46,10 @@ class TestTrain(unittest.TestCase):
Y = np.zeros((BS,6), dtype=np.int32) Y = np.zeros((BS,6), dtype=np.int32)
train_one_step(model,X,Y) train_one_step(model,X,Y)
if Device.DEFAULT == "GPU":
from extra.introspection import print_objects
assert print_objects() == 0
def test_resnet(self): def test_resnet(self):
X = np.zeros((BS, 3, 224, 224), dtype=np.float32) X = np.zeros((BS, 3, 224, 224), dtype=np.float32)
Y = np.zeros((BS), dtype=np.int32) Y = np.zeros((BS), dtype=np.int32)

View File

@ -4,11 +4,12 @@ from copy import copy
import os, sys, weakref import os, sys, weakref
from tinygrad.helpers import ConvArgs, get_available_llops, prod from tinygrad.helpers import ConvArgs, get_available_llops, prod
from tinygrad.shapetracker import ShapeTracker from tinygrad.shapetracker import ShapeTracker
from tinygrad.ops import DeviceBuffer, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps, LoadOps, OpType, LazyOp, get_buffers, get_lazyops from tinygrad.ops import DeviceBuffer, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps, LoadOps, OpType, LazyOp, get_buffers, get_lazyops, DEBUG
from tinygrad.graph import log_op from tinygrad.graph import log_op
# lazy can recurse a lot # lazy can recurse a lot
sys.setrecursionlimit(10000) sys.setrecursionlimit(10000)
sys.tracebacklimit = 20
OPT = int(os.getenv("OPT", "1")) OPT = int(os.getenv("OPT", "1"))
NOCONV = int(os.getenv("NOCONV", "0")) NOCONV = int(os.getenv("NOCONV", "0"))
@ -70,7 +71,14 @@ def _realize_reduceops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer
def _realize_binaryops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer], OpType]: def _realize_binaryops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer], OpType]:
real_srcs : Dict[LazyBuffer, Union[None, LazyOp, DeviceBuffer]] = {x:None for x in get_buffers(self.op)} real_srcs : Dict[LazyBuffer, Union[None, LazyOp, DeviceBuffer]] = {x:None for x in get_buffers(self.op)}
op_type : OpType = BinaryOps op_type : OpType = BinaryOps
psrcs : List[Tuple[LazyBuffer, LazyBuffer]] = [(k,x) for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())) if x.optype in [ProcessingOps,ReduceOps] and x.realized is None and len(x.children) <= 1 and len(k.children) <= 1] if DEBUG >= 3:
for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())):
if x.optype in [ProcessingOps,ReduceOps] and x.realized is None:
print("\nHIT", k,x)
for tk in k.children: print("k", tk)
for tx in x.children: print("x", tx)
# NOTE: contiguous does not always mean the same size with SHRINK. this is still mergable but requires more thought how
psrcs : List[Tuple[LazyBuffer, LazyBuffer]] = [(k,x) for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())) if x.optype in [ProcessingOps,ReduceOps] and x.realized is None and prod(k.shape) == prod(x.shape) and len(x.children) <= 1 and len(k.children) <= 1]
intermediate_shape = self.shape intermediate_shape = self.shape
if len(psrcs) == 1 and MERGE_ONE_REDUCE_INTO_ELEMENTWISE and (self.device != "OPENCL" or self.shape[-1] == 4): if len(psrcs) == 1 and MERGE_ONE_REDUCE_INTO_ELEMENTWISE and (self.device != "OPENCL" or self.shape[-1] == 4):
if psrcs[0][1].optype == ProcessingOps: if psrcs[0][1].optype == ProcessingOps:
@ -236,8 +244,9 @@ class LazyBuffer:
if NOCONV or not getattr(x.dbuffer, "processing_op", False): if NOCONV or not getattr(x.dbuffer, "processing_op", False):
# universal conv, just mul and reduce # universal conv, just mul and reduce
# TODO: is there any way to replace strided with other movement ops? answer: not really # TODO: is there any way to replace strided with other movement ops? answer: not really
if C.sy == 1 and C.sx == 1 and C.H == 1 and C.W == 1: if C.sy == 1 and C.sx == 1 and C.H == 1 and C.W == 1 and False:
# TODO: this doesn't belong here, ShapeTracker or lazy should be able to infer this from STRIDED # TODO: this doesn't belong here, ShapeTracker or lazy should be able to infer this from STRIDED
# TODO: this is disabled. it breaks fusion of ops without pushing PERMUTES. this is also a depthwise conv
x = x.movement_op(MovementOps.RESHAPE, (C.bs, C.groups, C.cin, C.oy, C.ox, 1, C.H, C.W)) x = x.movement_op(MovementOps.RESHAPE, (C.bs, C.groups, C.cin, C.oy, C.ox, 1, C.H, C.W))
x = x.movement_op(MovementOps.PERMUTE, (0,1,5,3,4,2,6,7)) x = x.movement_op(MovementOps.PERMUTE, (0,1,5,3,4,2,6,7))
else: else:

View File

@ -4,8 +4,8 @@ import numpy as np
import pyopencl as cl # type: ignore import pyopencl as cl # type: ignore
from collections import defaultdict from collections import defaultdict
from typing import List, Tuple, Optional, Dict, Union, Set from typing import List, Tuple, Optional, Dict, Union, Set
from tinygrad.helpers import prod, ConvArgs, dedup from tinygrad.helpers import prod
from tinygrad.ops import DEBUG, ProcessingOps, UnaryOps, BinaryOps, ReduceOps, LazyOp, get_buffers, get_lazyops, Op, get_lazyop_info, ExplicitExecAST, GlobalCounters from tinygrad.ops import DEBUG, ASTKernel, UnaryOps, BinaryOps, ReduceOps, LazyOp, Op, ExplicitExecAST, GlobalCounters
from tinygrad.shapetracker import ShapeTracker from tinygrad.shapetracker import ShapeTracker
CLCACHE = int(os.getenv("CLCACHE", "1")) CLCACHE = int(os.getenv("CLCACHE", "1"))
@ -79,6 +79,102 @@ class CLProgram:
# **** end CL wrappers **** # **** end CL wrappers ****
class CLASTKernel(ASTKernel):
def __init__(self, ast:LazyOp):
super().__init__(ast)
self.ast = ast
def compute_buf_index(self, st, buf_index, offset=0):
key = f"{buf_index}_{offset}"
# add the index if we don't have it
if key not in self.seen_idx:
idx_pieces = [str(st.offset + offset)] + [(f"idx{i}*{st}" if st != 1 else f"idx{i}") for i,(sh,st) in enumerate(zip(self.shapes[buf_index][0:self.last_reduce], self.strides[buf_index][0:self.last_reduce])) if sh != 1 and st != 0]
if st.needs_valid(): self.kernel.append(f"bool bufvalid{key} = true;")
self.kernel.append(f"int bufi{key} = " + '('+' + '.join(idx_pieces)+');\n')
if len(st.views) > 1:
extra_idx = ';'.join([v.expr for v in st.views[0:-1][::-1] if v.expr not in ['', 'idx=idx', 'valid=valid']])
self.kernel.append(extra_idx.replace("//", "/").replace("idx", f"bufi{key}").replace("valid", f"bufvalid{key}") + ";\n")
self.seen_idx.add(key)
return key
def store(self, buf_index, value, offset=0):
st = self.bufs[buf_index].st
if offset > 0: assert len(st.views) == 1
key = self.compute_buf_index(st, buf_index, offset)
self.kernel.append(f"data{buf_index}[bufi{key}] = {value};\n")
def load(self, buf_index, offset=0):
if buf_index not in self.loaded_keys:
st = self.bufs[buf_index].st
if offset > 0: assert len(st.views) == 1
key = self.compute_buf_index(st, buf_index, offset)
# constant folding
constant_fold = None
if self.bufs[buf_index]._base_shape == (1,) and self.bufs[buf_index]._backing:
self.bufs_to_delete.add(buf_index)
constant_fold = f"({self.bufs[buf_index]._backing[0]})"
ldr = f"data{buf_index}[bufi{key}]" if not constant_fold else constant_fold
ldr = f"(bufvalid{key} ? {ldr} : 0.0)" if st.needs_valid() else ldr
self.kernel.append(f"float val{key} = {ldr};\n")
self.loaded_keys[buf_index] = f"val{key}"
return self.loaded_keys[buf_index]
def ast_parse(self, x, reduce=False) -> str:
if not isinstance(x, LazyOp): return self.load(self.bufs.index(x))
if isinstance(x.op, ReduceOps) and not reduce: return "acc"
values = [self.ast_parse(v, reduce) for v in x.src]
code = GPUBuffer.code_for_op[x.op] # TODO: replace this with a function
if isinstance(x.op, ReduceOps): return code.replace("A", values[0])
if len(values) >= 1: code = code.replace("A", values[0])
if len(values) >= 2: code = code.replace("B", values[1])
return code
def codegen(self):
# TODO: fetch from quick cache before processing
self.process()
self.bufs_to_delete : Set[int] = set()
self.seen_idx : Set[str] = set()
self.loaded_keys : Dict[int, str] = {}
self.output_shape = self.shapes[0][:self.first_reduce]
self.kernel : List[str] = [f"int idx{i} = get_global_id({min(3, len(self.output_shape))-1-i});\n" for i in range(min(3, len(self.output_shape)))]
if len(self.output_shape) > 3:
# compact all the dimensions into the final one
for i in range(len(self.output_shape)-1, 2, -1):
self.kernel += [f"int idx{i} = idx2 % {self.output_shape[i]};", f"idx2 = idx2 / {self.output_shape[i]};\n"]
self.output_shape = list(self.output_shape[0:2]) + [prod(self.output_shape[2:])]
# early ast
if self.reduceop:
full_shape = [x for x in self.shapes if x != self.shapes[0]]
full_shape = self.shapes[0] if len(full_shape) == 0 else full_shape[0]
self.kernel.append(f"float acc = {GPUBuffer.start_for_op[self.reduceop.op]};\n")
for i in range(self.first_reduce, self.last_reduce):
self.kernel.append(f"for (int idx{i} = 0; idx{i} < {full_shape[i]}; idx{i}++) {{\n")
self.kernel.append(" acc = " + self.ast_parse(self.reduceop, reduce=True) + ";\n")
self.kernel += ["}\n"] * (self.last_reduce - self.first_reduce)
# late ast
process_ast = self.ast_parse(self.ast)
self.store(0, process_ast)
self.kernel.append("}")
# kernel function definition
function_name = ("re_S" if self.reduceop else "ew_S") + '_'.join([str(x) for x in self.bufs[0].shape if x != 1])
self.kernel = [f"__kernel void {function_name}(",] + [', '.join(f'__global float *data{i}' for i in range(len(self.bufs)) if i not in self.bufs_to_delete)] + [") {\n"] + self.kernel
# compile kernel
fxn = CLProgram(function_name, ' '.join(self.kernel))
def runner(*bufs):
clbufs = [x.cl for i,x in enumerate(bufs) if i not in self.bufs_to_delete]
return fxn(self.output_shape[::-1] if len(self.output_shape) > 0 else [1], None, *clbufs, op_estimate=self.info.flops)
return runner
class GPUBuffer(ExplicitExecAST): class GPUBuffer(ExplicitExecAST):
code_for_op : Dict[Op, str] = { code_for_op : Dict[Op, str] = {
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "max(A, (float)0.)", UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "max(A, (float)0.)",
@ -117,96 +213,8 @@ class GPUBuffer(ExplicitExecAST):
CL.enqueue_copy(data, self.contiguous().cl, is_blocking=True) CL.enqueue_copy(data, self.contiguous().cl, is_blocking=True)
return data return data
def contiguous_view_constant_fold(x, name:str, reduce:Optional[int]=None) -> Tuple[str, Optional[str], str]:
idx_getter = f"int valid = 1; {'long' if prod(x.shape) >= 2**31 else 'int'} idx = gid; {'idx *= '+str(reduce)+'; idx += subidx;' if reduce is not None else ''} {x.st.expr().replace('//', '/')};"
constant = x._backing[0] if x._base_shape == (1,) and x._backing is not None else None
args = (["__global const float *x"] if constant is None else []) + ["int gid"] + (["int subidx"] if reduce is not None else [])
return f"inline float get_{name}({','.join(args)}) {{ {idx_getter} return valid ? {constant if constant is not None else 'x[idx]'} : 0.0;}}", \
f"__global const float *{name}_g" if constant is None else None, \
f"get_{name}({name+'_g, ' if constant is None else ''}gid{', subidx' if reduce is not None else ''});"
@classmethod @classmethod
def exec_ast(cls, ast:LazyOp): def exec_ast(cls, ast:LazyOp):
# copied from llvm k = CLASTKernel(ast)
bufs = dedup(get_buffers(ast)) k.codegen()(*k.bufs)
reduceops = dedup([x for x in get_lazyops(ast) if isinstance(x.op, ReduceOps) or isinstance(x.op, ProcessingOps)]) return k.ret
assert len(reduceops) <= 1, f"max one reduce op in an ast, {reduceops}"
earlybufs = dedup(get_buffers(reduceops[0])) if len(reduceops) > 0 else []
reduce_shape = (earlybufs[0].shape, reduceops[0].arg) if len(reduceops) > 0 and isinstance(reduceops[0].op, ReduceOps) else None
info = get_lazyop_info(ast)
ret = cls(info.shape)
buf_names : Dict[GPUBuffer, str] = {x:f"arg_{i}" for i,x in enumerate(bufs)}
# special names for input and weight
if len(reduceops) > 0 and isinstance(reduceops[0].op, ProcessingOps):
buf_names[reduceops[0].src[0]] = "input"
buf_names[reduceops[0].src[1]] = "weight"
def _ast(x: Union[GPUBuffer, LazyOp], buf_names: Dict[GPUBuffer, str], code_for_op: Dict[Op, str], allow_reduce=False) -> str:
if isinstance(x, GPUBuffer):
return buf_names[x]
if not allow_reduce and type(x.op) in [ProcessingOps, ReduceOps]:
return "acc"
srcs_code = [_ast(src, buf_names, code_for_op) for src in x.src]
code = code_for_op[x.op]
if len(srcs_code) >= 1:
code = code.replace("A", srcs_code[0])
if len(srcs_code) >= 2:
code = code.replace("B", srcs_code[1])
return code
earlycode = _ast(reduceops[0], buf_names, cls.code_for_op, allow_reduce=True) if len(reduceops) > 0 and isinstance(reduceops[0].op, ReduceOps) else "acc"
code = _ast(ast, buf_names, cls.code_for_op)
C = reduceops[0].arg if len(reduceops) > 0 and isinstance(reduceops[0].op, ProcessingOps) else None
reduce_op = reduceops[0].op if len(reduceops) > 0 and isinstance(reduceops[0].op, ReduceOps) else ReduceOps.SUM
return ret._processing_op([(buf_names[x], x) for x in bufs], code, C, reduce_op, reduce_shape, set(buf_names[x] for x in earlybufs), earlycode, info.flops)
def _processing_op(ret, bufs: List[Tuple[str, GPUBuffer]]=[], code:str="acc", C:Optional[ConvArgs]=None, op=ReduceOps.SUM, reduce_shape=None, earlybufs:Set[str]=set(), earlycode:str="acc", op_estimate=0) -> GPUBuffer:
assert C is None, f"conv isn't handled by GPU anymore {C}"
# get the input/output shape and the reduce amount
reduce_shape = (bufs[0][1].shape, ret.shape) if reduce_shape is None else reduce_shape
red = prod([s for s,n in zip(*reduce_shape) if n == 1])
assert red < 2**31, f"reduce must be under 2**31, {red} isn't"
# if it's a partial reduce, assert last non reduced axis is before the first reduced axis
if red > 1 and prod(ret.shape) != 1:
assert max([i for i,(s,n) in enumerate(zip(*reduce_shape)) if s == n and n != 1]) < min([i for i,(s,n) in enumerate(zip(*reduce_shape)) if s != 1 and n == 1])
kernel_name = "reduce" if red > 1 else "elementwise"
early_views = {name:buf.contiguous_view_constant_fold(name, red) for name, buf in bufs if name in earlybufs}
late_views = {name:buf.contiguous_view_constant_fold(name) for name, buf in bufs if name not in earlybufs}
views = {**early_views, **late_views}
buf_types : List[str] = [views[name][1] for name, _ in bufs if views[name][1] is not None] # type: ignore
buf_cl = [buf.cl if 'image2d_t' not in views[name][1] else buf.image for name, buf in bufs if views[name][1] is not None] # type: ignore
# use local memory if it's a multistage reduce
inter_red = 256 if (prod(ret.shape) < 8192 and red >= 256) else 1
if inter_red > 1:
buf_cl.append(cl.LocalMemory(inter_red*4))
reduce_loop = f"int mid = get_global_id(1); for (int subidx = {red//inter_red + 1} * mid; subidx < min({red}, {red//inter_red + 1} * (mid+1)); subidx++)" if inter_red > 1 else f"for (int subidx = 0; subidx < {red}; subidx++)"
conv_prg = CLProgram(kernel_name, f"""{chr(10).join([x[0] for x in views.values()])}
__kernel void {kernel_name}({','.join(["__global float* restrict output"] + buf_types + (["__local float *temp"] if inter_red > 1 else []))}) {{
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
float acc = {GPUBuffer.start_for_op[op]};
int gid = get_global_id(0);
{reduce_loop} {{
{chr(10).join([f' float {name} = ' + early_views[name][2] for name in early_views])}
acc = {earlycode};
}}"""+(f"""
temp[mid] = acc; barrier(CLK_LOCAL_MEM_FENCE);
if (mid == 0) {{ acc = {GPUBuffer.start_for_op[op]};
for (int rdx = 0; rdx < {inter_red}; rdx++) {{
acc = {GPUBuffer.code_for_op[op].replace('A', 'temp[rdx]')};
}}""" if inter_red != 1 else "{")+f"""
{chr(10).join([f' float {name} = ' + late_views[name][2] for name in late_views])}
output[gid] = {code};
}}
}}""")
conv_prg([prod(ret.shape), inter_red, 1], [1, inter_red, 1] if inter_red > 1 else None, ret.cl, *buf_cl, op_estimate=op_estimate)
return ret

View File

@ -153,6 +153,7 @@ class ASTKernel:
rets[j].append((shapes[j][i], strides[j][i])) rets[j].append((shapes[j][i], strides[j][i]))
self.shapes, self.strides = [[y[0] for y in x] for x in rets], [[y[1] for y in x] for x in rets] self.shapes, self.strides = [[y[0] for y in x] for x in rets], [[y[1] for y in x] for x in rets]
self.first_reduce = get_first_reduce(self.shapes) # update this if axis merged self.first_reduce = get_first_reduce(self.shapes) # update this if axis merged
self.last_reduce = len(self.shapes[0])
# include the offsets (as is) # include the offsets (as is)
self.offsets = [x.st.views[-1].offset for x in self.bufs] self.offsets = [x.st.views[-1].offset for x in self.bufs]