fix nonstatic weights

This commit is contained in:
George Hotz 2022-10-20 17:04:14 -07:00
parent 59143bbb3b
commit 1bec4651b3
4 changed files with 10 additions and 6 deletions

View File

@ -62,7 +62,12 @@ def preprocessing_op(x,w,C):
w = w.contiguous_op()
# early realize on the weights
w.realize().image
bw = w
while getattr(bw, 'op', None) and len(bw.op.src) == 1:
bw = bw.op.src[0]
if bw.realized:
# weights are static
w.realize().image
return x,w,C
def postprocessing_op(ret, C, C_initial):

View File

@ -87,10 +87,7 @@ def get_run_onnx(onnx_model):
elif n.op_type == "Sigmoid": ret = inp[0].sigmoid()
elif n.op_type == "Tanh": ret = inp[0].tanh()
elif n.op_type == "Softmax": ret = inp[0].softmax()
elif n.op_type == "MatMul":
if inp[1].lazydata.realized is not None:
print("WARNING: matmul weights are not static")
ret = inp[0].matmul(inp[1])
elif n.op_type == "MatMul": ret = inp[0].matmul(inp[1])
# one liners
elif n.op_type == "Elu": ret = inp[0].elu(alpha=opt['alpha'])
elif n.op_type == "Clip": ret = inp[0].clip(*(inp[1:] if len(inp) > 1 else (opt.get('min', -3.4e38), opt.get('max', 3.4e38))))

View File

@ -281,7 +281,8 @@ class Thneed:
local_cl_cache = []
for prg, args in self.cl_cache:
args = list(args)
if args[1] is None and len(args[0]) == 2:
# TODO: WTF is wrong with to_image?
if args[1] is None and len(args[0]) == 2 and 'to_image' not in prg.name:
args[1] = [min(MAX_WORKGROUP, args[0][0]), 1]
try:
e = prg.clprg(CL().cl_queue, *args)

View File

@ -93,6 +93,7 @@ def compile(dat, output_fn):
# real run
inputs, np_inputs = get_random_input_tensors(input_shapes)
print("***** REAL RUN *****")
tinygrad_out = run_onnx(inputs)['outputs']
# note, since CL.CACHE is enabled, it doesn't actually run the kernels