openpilot fixups

This commit is contained in:
George Hotz 2023-03-06 14:14:44 -08:00
parent 4b9bc1615b
commit d8dda2af3a
4 changed files with 5 additions and 55 deletions

View File

@ -290,51 +290,3 @@ class Thneed:
print(f"total runtime: {total_runtime/1e6:.2f} ms wall time: {et*1000.0:.2f} ms")
return total_runtime/1e9
return et
def optimize_local_workgroup(self):
MAX_WORKGROUP = CL.cl_ctx.devices[0].max_work_group_size
local_cl_cache = []
for prg, args in self.cl_cache:
potential_locals = [tuple(args[1])] if args[1] is not None else []
runtimes = []
args = list(args)
# NOTE: if args[1] is not None, it may use local variables and you shouldn't change this
if args[1] is None and len(args[0]) == 1:
for l1 in [args[0][0], 1, 4, 16, MAX_WORKGROUP//4, MAX_WORKGROUP]:
potential_locals.append((l1,))
if args[1] is None and len(args[0]) == 2:
for l2 in [1, 4, 16, MAX_WORKGROUP//4, MAX_WORKGROUP]:
potential_locals.append((min(MAX_WORKGROUP, args[0][0]), l2))
if args[1] is None and len(args[0]) == 3:
for l2 in [16,args[0][1],MAX_WORKGROUP]:
for l3 in [4,16,args[0][2],MAX_WORKGROUP]:
for l1 in [max(1, MAX_WORKGROUP//(l2*l3)), args[0][0], 4, 16, MAX_WORKGROUP]:
if l1 > args[0][0] or l2 > args[0][1] or l3 > args[0][2]: continue
potential_locals.append((l1, l2, l3))
for local_args in potential_locals:
if prod(local_args) > MAX_WORKGROUP: continue
args[1] = local_args
# 3 runs just in case
for i in range(3):
try:
e = prg.clprg(CL.cl_queue, *args)
except (cl.LogicError, cl.RuntimeError):
# INVALID_WORK_GROUP_SIZE
continue
CL.cl_queue.finish()
runtime = e.profile.end - e.profile.start
#print(runtime, args[0], args[1])
runtimes.append((runtime, local_args))
if len(runtimes) > 0:
args[1] = sorted(runtimes)[0][1]
else:
args[1] = None
print("couldn't optimize", args[0])
local_cl_cache.append((prg, args))
self.cl_cache = local_cl_cache

View File

@ -79,9 +79,6 @@ def compile(dat, output_fn):
from extra.thneed import Thneed
t = Thneed(cl_cache, {k:v._cl for k,v in input_rawbuffers.items()})
if getenv("OPTWG", 0):
t.optimize_local_workgroup()
# save thneed (before run)
t.save(output_fn)

View File

@ -1,3 +1,2 @@
#!/bin/bash
FLOAT16=1 DEBUGCL=1 NATIVE_EXPLOG=1 VALIDHACKS=1 OPTWG=1 IMAGE=2 GPU=1 CLCACHE=0 python3 openpilot/compile.py
FLOAT16=1 DEBUGCL=1 NATIVE_EXPLOG=1 VALIDHACKS=1 OPTLOCAL=1 IMAGE=2 GPU=1 ENABLE_METHOD_CACHE=1 python3 openpilot/compile.py

View File

@ -1,4 +1,5 @@
from tinygrad.helpers import IMAGE
from tinygrad.lazy import get_single_root
def image_conv2d_decorator(normal_conv):
if IMAGE == 0: return normal_conv
@ -32,8 +33,9 @@ def image_conv2d_decorator(normal_conv):
elif cin_last: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,1,3).reshape(cout//4, H*cin//4*W*4, 4)
else: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,3,1).reshape(cout//4, H*cin//4*W*4, 4)
# contiguous creates the image, and early realize static weights (TODO: don't always realize)
x, w = x.contiguous(), w.contiguous().realize()
# contiguous creates the image, and early realize static weights (TODO: test for the static weight)
x, w = x.contiguous(), w.contiguous()
if get_single_root(w.lazydata).realized: w.realize()
# expand out
rcin_hi, rcin_lo = cin//4 if cin >= 4 else 1, 4 if cin >= 4 else 1