mirror of https://github.com/commaai/tinygrad.git
Simple version of the new GPU backend (#458)
* newgpu * more to delete * hmm, tests pass with constant folding * fix lint/type * fix constant folding * comment and rerun tests * lazy touchups * fix graph_batchnorm test * smaller transformer to fix OOM * Revert "smaller transformer to fix OOM" This reverts commit a44ef8edc275a4b3c78ee711ba188e220b7a879f. * no func cache * introspect * touchups * CLASTKernel * ugh, it was lru_cache * codegen * spacing * old gpu still in opencl * typing fix
This commit is contained in:
parent
66123c99b9
commit
fff1f046b0
|
@ -3,9 +3,9 @@
|
|||
from __future__ import annotations
|
||||
import os
|
||||
from tinygrad.llops.ops_gpu import GPUBuffer, CL, CLProgram, CLBuffer
|
||||
from tinygrad.ops import ProcessingOps, ReduceOps, UnaryOps, BinaryOps, MovementOps
|
||||
from tinygrad.helpers import prod, ConvArgs
|
||||
from typing import List, Tuple, Optional, Dict, Set
|
||||
from tinygrad.ops import ProcessingOps, ReduceOps, UnaryOps, BinaryOps, MovementOps, get_buffers, get_lazyops, get_lazyop_info, LazyOp, Op
|
||||
from tinygrad.helpers import prod, ConvArgs, dedup
|
||||
from typing import List, Tuple, Optional, Dict, Set, Union
|
||||
import numpy as np
|
||||
import pyopencl as cl
|
||||
|
||||
|
@ -228,13 +228,104 @@ class OpenCLBuffer(GPUBuffer):
|
|||
float4 dat = read_imagef(x, smp, l_smp);
|
||||
return valid ? (idx4 == 0 ? dat.x : (idx4 == 1 ? dat.y : (idx4 == 2 ? dat.z : dat.w))) : 0.0;
|
||||
}}""", f"read_only image2d_t {name}_g", f"get_{name}(smp, {name}_g, gid);"
|
||||
#ewtypes.append(f"read_only image2d_t {name}_g")
|
||||
return super().contiguous_view_constant_fold(name, reduce)
|
||||
else:
|
||||
idx_getter = f"int valid = 1; {'long' if prod(x.shape) >= 2**31 else 'int'} idx = gid; {'idx *= '+str(reduce)+'; idx += subidx;' if reduce is not None else ''} {x.st.expr().replace('//', '/')};"
|
||||
constant = x._backing[0] if x._base_shape == (1,) and x._backing is not None else None
|
||||
args = (["__global const float *x"] if constant is None else []) + ["int gid"] + (["int subidx"] if reduce is not None else [])
|
||||
return f"inline float get_{name}({','.join(args)}) {{ {idx_getter} return valid ? {constant if constant is not None else 'x[idx]'} : 0.0;}}", \
|
||||
f"__global const float *{name}_g" if constant is None else None, \
|
||||
f"get_{name}({name+'_g, ' if constant is None else ''}gid{', subidx' if reduce is not None else ''});"
|
||||
|
||||
@classmethod
|
||||
def exec_ast(cls, ast:LazyOp):
|
||||
# copied from llvm
|
||||
bufs = dedup(get_buffers(ast))
|
||||
reduceops = dedup([x for x in get_lazyops(ast) if isinstance(x.op, ReduceOps) or isinstance(x.op, ProcessingOps)])
|
||||
assert len(reduceops) <= 1, f"max one reduce op in an ast, {reduceops}"
|
||||
earlybufs = dedup(get_buffers(reduceops[0])) if len(reduceops) > 0 else []
|
||||
reduce_shape = (earlybufs[0].shape, reduceops[0].arg) if len(reduceops) > 0 and isinstance(reduceops[0].op, ReduceOps) else None
|
||||
info = get_lazyop_info(ast)
|
||||
ret = cls(info.shape)
|
||||
|
||||
buf_names : Dict[GPUBuffer, str] = {x:f"arg_{i}" for i,x in enumerate(bufs)}
|
||||
|
||||
# special names for input and weight
|
||||
if len(reduceops) > 0 and isinstance(reduceops[0].op, ProcessingOps):
|
||||
buf_names[reduceops[0].src[0]] = "input"
|
||||
buf_names[reduceops[0].src[1]] = "weight"
|
||||
|
||||
def _ast(x: Union[GPUBuffer, LazyOp], buf_names: Dict[GPUBuffer, str], code_for_op: Dict[Op, str], allow_reduce=False) -> str:
|
||||
if isinstance(x, GPUBuffer):
|
||||
return buf_names[x]
|
||||
if not allow_reduce and type(x.op) in [ProcessingOps, ReduceOps]:
|
||||
return "acc"
|
||||
srcs_code = [_ast(src, buf_names, code_for_op) for src in x.src]
|
||||
code = code_for_op[x.op]
|
||||
if len(srcs_code) >= 1:
|
||||
code = code.replace("A", srcs_code[0])
|
||||
if len(srcs_code) >= 2:
|
||||
code = code.replace("B", srcs_code[1])
|
||||
return code
|
||||
|
||||
earlycode = _ast(reduceops[0], buf_names, cls.code_for_op, allow_reduce=True) if len(reduceops) > 0 and isinstance(reduceops[0].op, ReduceOps) else "acc"
|
||||
code = _ast(ast, buf_names, cls.code_for_op)
|
||||
|
||||
C = reduceops[0].arg if len(reduceops) > 0 and isinstance(reduceops[0].op, ProcessingOps) else None
|
||||
reduce_op = reduceops[0].op if len(reduceops) > 0 and isinstance(reduceops[0].op, ReduceOps) else ReduceOps.SUM
|
||||
return ret._processing_op([(buf_names[x], x) for x in bufs], code, C, reduce_op, reduce_shape, set(buf_names[x] for x in earlybufs), earlycode, info.flops)
|
||||
|
||||
def _simple_processing_op(ret, bufs: List[Tuple[str, GPUBuffer]]=[], code:str="acc", C:Optional[ConvArgs]=None, op=ReduceOps.SUM, reduce_shape=None, earlybufs:Set[str]=set(), earlycode:str="acc", op_estimate=0) -> GPUBuffer:
|
||||
assert C is None, f"conv isn't handled by GPU anymore {C}"
|
||||
|
||||
# get the input/output shape and the reduce amount
|
||||
reduce_shape = (bufs[0][1].shape, ret.shape) if reduce_shape is None else reduce_shape
|
||||
red = prod([s for s,n in zip(*reduce_shape) if n == 1])
|
||||
assert red < 2**31, f"reduce must be under 2**31, {red} isn't"
|
||||
|
||||
# if it's a partial reduce, assert last non reduced axis is before the first reduced axis
|
||||
if red > 1 and prod(ret.shape) != 1:
|
||||
assert max([i for i,(s,n) in enumerate(zip(*reduce_shape)) if s == n and n != 1]) < min([i for i,(s,n) in enumerate(zip(*reduce_shape)) if s != 1 and n == 1])
|
||||
|
||||
kernel_name = "reduce" if red > 1 else "elementwise"
|
||||
early_views = {name:buf.contiguous_view_constant_fold(name, red) for name, buf in bufs if name in earlybufs}
|
||||
late_views = {name:buf.contiguous_view_constant_fold(name) for name, buf in bufs if name not in earlybufs}
|
||||
views = {**early_views, **late_views}
|
||||
|
||||
buf_types : List[str] = [views[name][1] for name, _ in bufs if views[name][1] is not None] # type: ignore
|
||||
buf_cl = [buf.cl if 'image2d_t' not in views[name][1] else buf.image for name, buf in bufs if views[name][1] is not None] # type: ignore
|
||||
|
||||
# use local memory if it's a multistage reduce
|
||||
inter_red = 256 if (prod(ret.shape) < 8192 and red >= 256) else 1
|
||||
if inter_red > 1:
|
||||
buf_cl.append(cl.LocalMemory(inter_red*4))
|
||||
|
||||
reduce_loop = f"int mid = get_global_id(1); for (int subidx = {red//inter_red + 1} * mid; subidx < min({red}, {red//inter_red + 1} * (mid+1)); subidx++)" if inter_red > 1 else f"for (int subidx = 0; subidx < {red}; subidx++)"
|
||||
conv_prg = CLProgram(kernel_name, f"""{chr(10).join([x[0] for x in views.values()])}
|
||||
__kernel void {kernel_name}({','.join(["__global float* restrict output"] + buf_types + (["__local float *temp"] if inter_red > 1 else []))}) {{
|
||||
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
|
||||
float acc = {GPUBuffer.start_for_op[op]};
|
||||
int gid = get_global_id(0);
|
||||
{reduce_loop} {{
|
||||
{chr(10).join([f' float {name} = ' + early_views[name][2] for name in early_views])}
|
||||
acc = {earlycode};
|
||||
}}"""+(f"""
|
||||
temp[mid] = acc; barrier(CLK_LOCAL_MEM_FENCE);
|
||||
if (mid == 0) {{ acc = {GPUBuffer.start_for_op[op]};
|
||||
for (int rdx = 0; rdx < {inter_red}; rdx++) {{
|
||||
acc = {GPUBuffer.code_for_op[op].replace('A', 'temp[rdx]')};
|
||||
}}""" if inter_red != 1 else "{")+f"""
|
||||
{chr(10).join([f' float {name} = ' + late_views[name][2] for name in late_views])}
|
||||
output[gid] = {code};
|
||||
}}
|
||||
}}""")
|
||||
|
||||
conv_prg([prod(ret.shape), inter_red, 1], [1, inter_red, 1] if inter_red > 1 else None, ret.cl, *buf_cl, op_estimate=op_estimate)
|
||||
return ret
|
||||
|
||||
def _processing_op(ret, bufs: List[Tuple[str, OpenCLBuffer]]=[], code:str="acc", C=None, op=ReduceOps.SUM, reduce_shape=None, earlybufs:Set[str]=set(), earlycode:str="acc", op_estimate=0):
|
||||
if C is None or earlycode != "acc":
|
||||
# TODO: handle an opencl conv without the conv part
|
||||
return super()._processing_op(bufs, code, C, op, reduce_shape, earlybufs, earlycode, op_estimate)
|
||||
return ret._simple_processing_op(bufs, code, C, op, reduce_shape, earlybufs, earlycode, op_estimate)
|
||||
assert earlycode == "acc"
|
||||
|
||||
x = [x for x in bufs if x[0] == "input"][0][1]
|
||||
|
|
|
@ -7,7 +7,7 @@ from tinygrad.llops.ops_gpu import CL, GPUBuffer
|
|||
#from tinygrad.llops.ops_opencl import CLImage, OpenCLBuffer
|
||||
|
||||
def print_objects():
|
||||
gc.collect()
|
||||
#gc.collect()
|
||||
tensors = [x for x in gc.get_objects() if isinstance(x, Tensor)]
|
||||
tensor_ram_used = sum([prod(x.shape)*4 for x in tensors])
|
||||
lazybuffers = [x for x in gc.get_objects() if isinstance(x, LazyBuffer)]
|
||||
|
@ -24,7 +24,7 @@ def print_objects():
|
|||
bb = gc.get_referrers(tb)
|
||||
for b in bb:
|
||||
if b is not gpubuffers and b is not gpubuffers_orphaned:
|
||||
print(tb, "reference", type(b),len(b), str(b)[0:150])
|
||||
print(tb, "\nreference", type(b), len(b), str(b)[0:150], "\n\n")
|
||||
for x in gc.get_referrers(b):
|
||||
print(str(x)[0:100])
|
||||
if cnt == 10:
|
||||
|
@ -35,6 +35,8 @@ def print_objects():
|
|||
if getattr(x, '_buf', None): del x._buf
|
||||
if getattr(x, '_image', None): del x._image
|
||||
|
||||
return len(gpubuffers_orphaned)
|
||||
|
||||
"""
|
||||
import gc
|
||||
|
||||
|
|
|
@ -6,8 +6,8 @@ import unittest
|
|||
def model_step(lm):
|
||||
Tensor.training = True
|
||||
x = Tensor.ones(8,12,128,256, requires_grad=False)
|
||||
loss = lm.forward(x).sum()
|
||||
optimizer = optim.SGD(get_parameters(lm), lr=0.001)
|
||||
loss = lm.forward(x).sum()
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
del x,loss
|
||||
|
|
|
@ -46,6 +46,10 @@ class TestTrain(unittest.TestCase):
|
|||
Y = np.zeros((BS,6), dtype=np.int32)
|
||||
train_one_step(model,X,Y)
|
||||
|
||||
if Device.DEFAULT == "GPU":
|
||||
from extra.introspection import print_objects
|
||||
assert print_objects() == 0
|
||||
|
||||
def test_resnet(self):
|
||||
X = np.zeros((BS, 3, 224, 224), dtype=np.float32)
|
||||
Y = np.zeros((BS), dtype=np.int32)
|
||||
|
|
|
@ -4,11 +4,12 @@ from copy import copy
|
|||
import os, sys, weakref
|
||||
from tinygrad.helpers import ConvArgs, get_available_llops, prod
|
||||
from tinygrad.shapetracker import ShapeTracker
|
||||
from tinygrad.ops import DeviceBuffer, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps, LoadOps, OpType, LazyOp, get_buffers, get_lazyops
|
||||
from tinygrad.ops import DeviceBuffer, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps, LoadOps, OpType, LazyOp, get_buffers, get_lazyops, DEBUG
|
||||
from tinygrad.graph import log_op
|
||||
|
||||
# lazy can recurse a lot
|
||||
sys.setrecursionlimit(10000)
|
||||
sys.tracebacklimit = 20
|
||||
|
||||
OPT = int(os.getenv("OPT", "1"))
|
||||
NOCONV = int(os.getenv("NOCONV", "0"))
|
||||
|
@ -70,7 +71,14 @@ def _realize_reduceops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer
|
|||
def _realize_binaryops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer], OpType]:
|
||||
real_srcs : Dict[LazyBuffer, Union[None, LazyOp, DeviceBuffer]] = {x:None for x in get_buffers(self.op)}
|
||||
op_type : OpType = BinaryOps
|
||||
psrcs : List[Tuple[LazyBuffer, LazyBuffer]] = [(k,x) for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())) if x.optype in [ProcessingOps,ReduceOps] and x.realized is None and len(x.children) <= 1 and len(k.children) <= 1]
|
||||
if DEBUG >= 3:
|
||||
for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())):
|
||||
if x.optype in [ProcessingOps,ReduceOps] and x.realized is None:
|
||||
print("\nHIT", k,x)
|
||||
for tk in k.children: print("k", tk)
|
||||
for tx in x.children: print("x", tx)
|
||||
# NOTE: contiguous does not always mean the same size with SHRINK. this is still mergable but requires more thought how
|
||||
psrcs : List[Tuple[LazyBuffer, LazyBuffer]] = [(k,x) for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())) if x.optype in [ProcessingOps,ReduceOps] and x.realized is None and prod(k.shape) == prod(x.shape) and len(x.children) <= 1 and len(k.children) <= 1]
|
||||
intermediate_shape = self.shape
|
||||
if len(psrcs) == 1 and MERGE_ONE_REDUCE_INTO_ELEMENTWISE and (self.device != "OPENCL" or self.shape[-1] == 4):
|
||||
if psrcs[0][1].optype == ProcessingOps:
|
||||
|
@ -236,8 +244,9 @@ class LazyBuffer:
|
|||
if NOCONV or not getattr(x.dbuffer, "processing_op", False):
|
||||
# universal conv, just mul and reduce
|
||||
# TODO: is there any way to replace strided with other movement ops? answer: not really
|
||||
if C.sy == 1 and C.sx == 1 and C.H == 1 and C.W == 1:
|
||||
if C.sy == 1 and C.sx == 1 and C.H == 1 and C.W == 1 and False:
|
||||
# TODO: this doesn't belong here, ShapeTracker or lazy should be able to infer this from STRIDED
|
||||
# TODO: this is disabled. it breaks fusion of ops without pushing PERMUTES. this is also a depthwise conv
|
||||
x = x.movement_op(MovementOps.RESHAPE, (C.bs, C.groups, C.cin, C.oy, C.ox, 1, C.H, C.W))
|
||||
x = x.movement_op(MovementOps.PERMUTE, (0,1,5,3,4,2,6,7))
|
||||
else:
|
||||
|
|
|
@ -4,8 +4,8 @@ import numpy as np
|
|||
import pyopencl as cl # type: ignore
|
||||
from collections import defaultdict
|
||||
from typing import List, Tuple, Optional, Dict, Union, Set
|
||||
from tinygrad.helpers import prod, ConvArgs, dedup
|
||||
from tinygrad.ops import DEBUG, ProcessingOps, UnaryOps, BinaryOps, ReduceOps, LazyOp, get_buffers, get_lazyops, Op, get_lazyop_info, ExplicitExecAST, GlobalCounters
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.ops import DEBUG, ASTKernel, UnaryOps, BinaryOps, ReduceOps, LazyOp, Op, ExplicitExecAST, GlobalCounters
|
||||
from tinygrad.shapetracker import ShapeTracker
|
||||
|
||||
CLCACHE = int(os.getenv("CLCACHE", "1"))
|
||||
|
@ -79,6 +79,102 @@ class CLProgram:
|
|||
|
||||
# **** end CL wrappers ****
|
||||
|
||||
class CLASTKernel(ASTKernel):
|
||||
def __init__(self, ast:LazyOp):
|
||||
super().__init__(ast)
|
||||
self.ast = ast
|
||||
|
||||
def compute_buf_index(self, st, buf_index, offset=0):
|
||||
key = f"{buf_index}_{offset}"
|
||||
# add the index if we don't have it
|
||||
if key not in self.seen_idx:
|
||||
idx_pieces = [str(st.offset + offset)] + [(f"idx{i}*{st}" if st != 1 else f"idx{i}") for i,(sh,st) in enumerate(zip(self.shapes[buf_index][0:self.last_reduce], self.strides[buf_index][0:self.last_reduce])) if sh != 1 and st != 0]
|
||||
if st.needs_valid(): self.kernel.append(f"bool bufvalid{key} = true;")
|
||||
self.kernel.append(f"int bufi{key} = " + '('+' + '.join(idx_pieces)+');\n')
|
||||
if len(st.views) > 1:
|
||||
extra_idx = ';'.join([v.expr for v in st.views[0:-1][::-1] if v.expr not in ['', 'idx=idx', 'valid=valid']])
|
||||
self.kernel.append(extra_idx.replace("//", "/").replace("idx", f"bufi{key}").replace("valid", f"bufvalid{key}") + ";\n")
|
||||
self.seen_idx.add(key)
|
||||
return key
|
||||
|
||||
def store(self, buf_index, value, offset=0):
|
||||
st = self.bufs[buf_index].st
|
||||
if offset > 0: assert len(st.views) == 1
|
||||
key = self.compute_buf_index(st, buf_index, offset)
|
||||
self.kernel.append(f"data{buf_index}[bufi{key}] = {value};\n")
|
||||
|
||||
def load(self, buf_index, offset=0):
|
||||
if buf_index not in self.loaded_keys:
|
||||
st = self.bufs[buf_index].st
|
||||
if offset > 0: assert len(st.views) == 1
|
||||
key = self.compute_buf_index(st, buf_index, offset)
|
||||
|
||||
# constant folding
|
||||
constant_fold = None
|
||||
if self.bufs[buf_index]._base_shape == (1,) and self.bufs[buf_index]._backing:
|
||||
self.bufs_to_delete.add(buf_index)
|
||||
constant_fold = f"({self.bufs[buf_index]._backing[0]})"
|
||||
|
||||
ldr = f"data{buf_index}[bufi{key}]" if not constant_fold else constant_fold
|
||||
ldr = f"(bufvalid{key} ? {ldr} : 0.0)" if st.needs_valid() else ldr
|
||||
self.kernel.append(f"float val{key} = {ldr};\n")
|
||||
self.loaded_keys[buf_index] = f"val{key}"
|
||||
return self.loaded_keys[buf_index]
|
||||
|
||||
def ast_parse(self, x, reduce=False) -> str:
|
||||
if not isinstance(x, LazyOp): return self.load(self.bufs.index(x))
|
||||
if isinstance(x.op, ReduceOps) and not reduce: return "acc"
|
||||
values = [self.ast_parse(v, reduce) for v in x.src]
|
||||
code = GPUBuffer.code_for_op[x.op] # TODO: replace this with a function
|
||||
if isinstance(x.op, ReduceOps): return code.replace("A", values[0])
|
||||
if len(values) >= 1: code = code.replace("A", values[0])
|
||||
if len(values) >= 2: code = code.replace("B", values[1])
|
||||
return code
|
||||
|
||||
def codegen(self):
|
||||
# TODO: fetch from quick cache before processing
|
||||
self.process()
|
||||
|
||||
self.bufs_to_delete : Set[int] = set()
|
||||
self.seen_idx : Set[str] = set()
|
||||
self.loaded_keys : Dict[int, str] = {}
|
||||
|
||||
self.output_shape = self.shapes[0][:self.first_reduce]
|
||||
self.kernel : List[str] = [f"int idx{i} = get_global_id({min(3, len(self.output_shape))-1-i});\n" for i in range(min(3, len(self.output_shape)))]
|
||||
if len(self.output_shape) > 3:
|
||||
# compact all the dimensions into the final one
|
||||
for i in range(len(self.output_shape)-1, 2, -1):
|
||||
self.kernel += [f"int idx{i} = idx2 % {self.output_shape[i]};", f"idx2 = idx2 / {self.output_shape[i]};\n"]
|
||||
self.output_shape = list(self.output_shape[0:2]) + [prod(self.output_shape[2:])]
|
||||
|
||||
# early ast
|
||||
if self.reduceop:
|
||||
full_shape = [x for x in self.shapes if x != self.shapes[0]]
|
||||
full_shape = self.shapes[0] if len(full_shape) == 0 else full_shape[0]
|
||||
|
||||
self.kernel.append(f"float acc = {GPUBuffer.start_for_op[self.reduceop.op]};\n")
|
||||
for i in range(self.first_reduce, self.last_reduce):
|
||||
self.kernel.append(f"for (int idx{i} = 0; idx{i} < {full_shape[i]}; idx{i}++) {{\n")
|
||||
self.kernel.append(" acc = " + self.ast_parse(self.reduceop, reduce=True) + ";\n")
|
||||
self.kernel += ["}\n"] * (self.last_reduce - self.first_reduce)
|
||||
|
||||
# late ast
|
||||
process_ast = self.ast_parse(self.ast)
|
||||
self.store(0, process_ast)
|
||||
self.kernel.append("}")
|
||||
|
||||
# kernel function definition
|
||||
function_name = ("re_S" if self.reduceop else "ew_S") + '_'.join([str(x) for x in self.bufs[0].shape if x != 1])
|
||||
self.kernel = [f"__kernel void {function_name}(",] + [', '.join(f'__global float *data{i}' for i in range(len(self.bufs)) if i not in self.bufs_to_delete)] + [") {\n"] + self.kernel
|
||||
|
||||
# compile kernel
|
||||
fxn = CLProgram(function_name, ' '.join(self.kernel))
|
||||
|
||||
def runner(*bufs):
|
||||
clbufs = [x.cl for i,x in enumerate(bufs) if i not in self.bufs_to_delete]
|
||||
return fxn(self.output_shape[::-1] if len(self.output_shape) > 0 else [1], None, *clbufs, op_estimate=self.info.flops)
|
||||
return runner
|
||||
|
||||
class GPUBuffer(ExplicitExecAST):
|
||||
code_for_op : Dict[Op, str] = {
|
||||
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "max(A, (float)0.)",
|
||||
|
@ -117,96 +213,8 @@ class GPUBuffer(ExplicitExecAST):
|
|||
CL.enqueue_copy(data, self.contiguous().cl, is_blocking=True)
|
||||
return data
|
||||
|
||||
def contiguous_view_constant_fold(x, name:str, reduce:Optional[int]=None) -> Tuple[str, Optional[str], str]:
|
||||
idx_getter = f"int valid = 1; {'long' if prod(x.shape) >= 2**31 else 'int'} idx = gid; {'idx *= '+str(reduce)+'; idx += subidx;' if reduce is not None else ''} {x.st.expr().replace('//', '/')};"
|
||||
constant = x._backing[0] if x._base_shape == (1,) and x._backing is not None else None
|
||||
args = (["__global const float *x"] if constant is None else []) + ["int gid"] + (["int subidx"] if reduce is not None else [])
|
||||
return f"inline float get_{name}({','.join(args)}) {{ {idx_getter} return valid ? {constant if constant is not None else 'x[idx]'} : 0.0;}}", \
|
||||
f"__global const float *{name}_g" if constant is None else None, \
|
||||
f"get_{name}({name+'_g, ' if constant is None else ''}gid{', subidx' if reduce is not None else ''});"
|
||||
|
||||
@classmethod
|
||||
def exec_ast(cls, ast:LazyOp):
|
||||
# copied from llvm
|
||||
bufs = dedup(get_buffers(ast))
|
||||
reduceops = dedup([x for x in get_lazyops(ast) if isinstance(x.op, ReduceOps) or isinstance(x.op, ProcessingOps)])
|
||||
assert len(reduceops) <= 1, f"max one reduce op in an ast, {reduceops}"
|
||||
earlybufs = dedup(get_buffers(reduceops[0])) if len(reduceops) > 0 else []
|
||||
reduce_shape = (earlybufs[0].shape, reduceops[0].arg) if len(reduceops) > 0 and isinstance(reduceops[0].op, ReduceOps) else None
|
||||
info = get_lazyop_info(ast)
|
||||
ret = cls(info.shape)
|
||||
|
||||
buf_names : Dict[GPUBuffer, str] = {x:f"arg_{i}" for i,x in enumerate(bufs)}
|
||||
|
||||
# special names for input and weight
|
||||
if len(reduceops) > 0 and isinstance(reduceops[0].op, ProcessingOps):
|
||||
buf_names[reduceops[0].src[0]] = "input"
|
||||
buf_names[reduceops[0].src[1]] = "weight"
|
||||
|
||||
def _ast(x: Union[GPUBuffer, LazyOp], buf_names: Dict[GPUBuffer, str], code_for_op: Dict[Op, str], allow_reduce=False) -> str:
|
||||
if isinstance(x, GPUBuffer):
|
||||
return buf_names[x]
|
||||
if not allow_reduce and type(x.op) in [ProcessingOps, ReduceOps]:
|
||||
return "acc"
|
||||
srcs_code = [_ast(src, buf_names, code_for_op) for src in x.src]
|
||||
code = code_for_op[x.op]
|
||||
if len(srcs_code) >= 1:
|
||||
code = code.replace("A", srcs_code[0])
|
||||
if len(srcs_code) >= 2:
|
||||
code = code.replace("B", srcs_code[1])
|
||||
return code
|
||||
|
||||
earlycode = _ast(reduceops[0], buf_names, cls.code_for_op, allow_reduce=True) if len(reduceops) > 0 and isinstance(reduceops[0].op, ReduceOps) else "acc"
|
||||
code = _ast(ast, buf_names, cls.code_for_op)
|
||||
|
||||
C = reduceops[0].arg if len(reduceops) > 0 and isinstance(reduceops[0].op, ProcessingOps) else None
|
||||
reduce_op = reduceops[0].op if len(reduceops) > 0 and isinstance(reduceops[0].op, ReduceOps) else ReduceOps.SUM
|
||||
return ret._processing_op([(buf_names[x], x) for x in bufs], code, C, reduce_op, reduce_shape, set(buf_names[x] for x in earlybufs), earlycode, info.flops)
|
||||
|
||||
def _processing_op(ret, bufs: List[Tuple[str, GPUBuffer]]=[], code:str="acc", C:Optional[ConvArgs]=None, op=ReduceOps.SUM, reduce_shape=None, earlybufs:Set[str]=set(), earlycode:str="acc", op_estimate=0) -> GPUBuffer:
|
||||
assert C is None, f"conv isn't handled by GPU anymore {C}"
|
||||
|
||||
# get the input/output shape and the reduce amount
|
||||
reduce_shape = (bufs[0][1].shape, ret.shape) if reduce_shape is None else reduce_shape
|
||||
red = prod([s for s,n in zip(*reduce_shape) if n == 1])
|
||||
assert red < 2**31, f"reduce must be under 2**31, {red} isn't"
|
||||
|
||||
# if it's a partial reduce, assert last non reduced axis is before the first reduced axis
|
||||
if red > 1 and prod(ret.shape) != 1:
|
||||
assert max([i for i,(s,n) in enumerate(zip(*reduce_shape)) if s == n and n != 1]) < min([i for i,(s,n) in enumerate(zip(*reduce_shape)) if s != 1 and n == 1])
|
||||
|
||||
kernel_name = "reduce" if red > 1 else "elementwise"
|
||||
early_views = {name:buf.contiguous_view_constant_fold(name, red) for name, buf in bufs if name in earlybufs}
|
||||
late_views = {name:buf.contiguous_view_constant_fold(name) for name, buf in bufs if name not in earlybufs}
|
||||
views = {**early_views, **late_views}
|
||||
|
||||
buf_types : List[str] = [views[name][1] for name, _ in bufs if views[name][1] is not None] # type: ignore
|
||||
buf_cl = [buf.cl if 'image2d_t' not in views[name][1] else buf.image for name, buf in bufs if views[name][1] is not None] # type: ignore
|
||||
|
||||
# use local memory if it's a multistage reduce
|
||||
inter_red = 256 if (prod(ret.shape) < 8192 and red >= 256) else 1
|
||||
if inter_red > 1:
|
||||
buf_cl.append(cl.LocalMemory(inter_red*4))
|
||||
|
||||
reduce_loop = f"int mid = get_global_id(1); for (int subidx = {red//inter_red + 1} * mid; subidx < min({red}, {red//inter_red + 1} * (mid+1)); subidx++)" if inter_red > 1 else f"for (int subidx = 0; subidx < {red}; subidx++)"
|
||||
conv_prg = CLProgram(kernel_name, f"""{chr(10).join([x[0] for x in views.values()])}
|
||||
__kernel void {kernel_name}({','.join(["__global float* restrict output"] + buf_types + (["__local float *temp"] if inter_red > 1 else []))}) {{
|
||||
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
|
||||
float acc = {GPUBuffer.start_for_op[op]};
|
||||
int gid = get_global_id(0);
|
||||
{reduce_loop} {{
|
||||
{chr(10).join([f' float {name} = ' + early_views[name][2] for name in early_views])}
|
||||
acc = {earlycode};
|
||||
}}"""+(f"""
|
||||
temp[mid] = acc; barrier(CLK_LOCAL_MEM_FENCE);
|
||||
if (mid == 0) {{ acc = {GPUBuffer.start_for_op[op]};
|
||||
for (int rdx = 0; rdx < {inter_red}; rdx++) {{
|
||||
acc = {GPUBuffer.code_for_op[op].replace('A', 'temp[rdx]')};
|
||||
}}""" if inter_red != 1 else "{")+f"""
|
||||
{chr(10).join([f' float {name} = ' + late_views[name][2] for name in late_views])}
|
||||
output[gid] = {code};
|
||||
}}
|
||||
}}""")
|
||||
|
||||
conv_prg([prod(ret.shape), inter_red, 1], [1, inter_red, 1] if inter_red > 1 else None, ret.cl, *buf_cl, op_estimate=op_estimate)
|
||||
return ret
|
||||
k = CLASTKernel(ast)
|
||||
k.codegen()(*k.bufs)
|
||||
return k.ret
|
||||
|
|
|
@ -153,6 +153,7 @@ class ASTKernel:
|
|||
rets[j].append((shapes[j][i], strides[j][i]))
|
||||
self.shapes, self.strides = [[y[0] for y in x] for x in rets], [[y[1] for y in x] for x in rets]
|
||||
self.first_reduce = get_first_reduce(self.shapes) # update this if axis merged
|
||||
self.last_reduce = len(self.shapes[0])
|
||||
|
||||
# include the offsets (as is)
|
||||
self.offsets = [x.st.views[-1].offset for x in self.bufs]
|
||||
|
|
Loading…
Reference in New Issue