mirror of https://github.com/commaai/tinygrad.git
add assert to catch issue in attention
This commit is contained in:
parent
26c78ccf7d
commit
50c95c7d9a
|
@ -4,8 +4,7 @@ from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
|
|||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.nn import batch_normalize
|
||||
|
||||
MAX_CONVS = int(os.getenv("MAX_CONVS", -1))
|
||||
from tinygrad.ops import DEBUG
|
||||
|
||||
def get_run_onnx(onnx_model):
|
||||
def shape_to_tuple(s): return tuple(x.dim_value for x in s.dim)
|
||||
|
@ -38,6 +37,8 @@ def get_run_onnx(onnx_model):
|
|||
print(inp.name, inp.dims, inp.data_type, len(inp.raw_data))
|
||||
print(inp)
|
||||
raise Exception("no data")
|
||||
if DEBUG >= 1:
|
||||
print("realize", inp.name)
|
||||
tensors[inp.name].realize()
|
||||
|
||||
def run_onnx(inputs={}, debug=False):
|
||||
|
@ -60,7 +61,6 @@ def get_run_onnx(onnx_model):
|
|||
else:
|
||||
raise Exception(f"no data for {inp.name} with shape {shape}")
|
||||
|
||||
conv_count = 0
|
||||
for num,n in enumerate(onnx_model.graph.node):
|
||||
if debug: print(f"{num}: op {n.op_type}")
|
||||
inp = [tensors[x] if x in tensors else (intermediate_tensors[x] if x in intermediate_tensors else input_tensors[x]) for x in n.input]
|
||||
|
@ -71,7 +71,9 @@ def get_run_onnx(onnx_model):
|
|||
elif n.op_type == "Sigmoid": ret = inp[0].sigmoid()
|
||||
elif n.op_type == "Tanh": ret = inp[0].tanh()
|
||||
elif n.op_type == "Softmax": ret = inp[0].softmax()
|
||||
elif n.op_type == "MatMul": ret = inp[0].matmul(inp[1])
|
||||
elif n.op_type == "MatMul":
|
||||
assert inp[1].lazydata.realized is not None
|
||||
ret = inp[0].matmul(inp[1])
|
||||
# one liners
|
||||
elif n.op_type == "Elu": ret = inp[0].elu(alpha=opt['alpha'])
|
||||
elif n.op_type == "Clip": ret = inp[0].clip(*(inp[1:] if len(inp) > 1 else (opt.get('min', -3.4e38), opt.get('max', 3.4e38))))
|
||||
|
@ -110,10 +112,6 @@ def get_run_onnx(onnx_model):
|
|||
else:
|
||||
x = x.pad2d((opt['pads'][0], opt['pads'][2], opt['pads'][1], opt['pads'][3]))
|
||||
ret = x.conv2d(w, b, stride=opt['strides'], groups=opt.get('group', 1))
|
||||
conv_count += 1
|
||||
if conv_count == MAX_CONVS:
|
||||
ret.numpy()
|
||||
break
|
||||
elif n.op_type in ["Add", "Sub", "Mul"]:
|
||||
# TODO: add this to tinygrad? i don't think it's in torch
|
||||
if len(inp[0].shape) != len(inp[1].shape) and prod(inp[0].shape) == prod(inp[1].shape):
|
||||
|
|
|
@ -5,7 +5,7 @@ import struct
|
|||
import json
|
||||
import traceback
|
||||
import numpy as np
|
||||
from tinygrad.llops.ops_gpu import CL, CLProgram, CLBuffer
|
||||
from tinygrad.llops.ops_gpu import CL, CLProgram
|
||||
from tinygrad.helpers import prod
|
||||
import pyopencl as cl
|
||||
import networkx as nx
|
||||
|
@ -185,7 +185,7 @@ class Thneed:
|
|||
})
|
||||
if needs_load:
|
||||
data = np.empty(a.size//4, dtype=np.float32)
|
||||
CL.enqueue_copy(data, a, is_blocking=True)
|
||||
cl.enqueue_copy(CL().cl_queue, data, a, is_blocking=True)
|
||||
weights.append(data.tobytes())
|
||||
elif isinstance(a, cl.Image):
|
||||
needs_load = a in self.buffers_to_save
|
||||
|
@ -195,7 +195,7 @@ class Thneed:
|
|||
buf = cl.Buffer(CL().cl_ctx, cl.mem_flags.READ_WRITE, size=size * (2 if FLOAT16 else 1))
|
||||
|
||||
# zero out the buffer
|
||||
CL.enqueue_copy(buf, b'\x00'*buf.size, is_blocking=True)
|
||||
cl.enqueue_copy(CL().cl_queue, buf, b'\x00'*buf.size, is_blocking=True)
|
||||
|
||||
CLProgram("from_image_strided", """
|
||||
__kernel void from_image_strided(read_only image2d_t in, __global float4 *out, int row_pitch) {
|
||||
|
@ -215,7 +215,7 @@ class Thneed:
|
|||
|
||||
if needs_load:
|
||||
data = np.empty(size//(2 if FLOAT16 else 4), dtype=np.float32)
|
||||
CL.enqueue_copy(data, buf, is_blocking=True)
|
||||
cl.enqueue_copy(CL().cl_queue, data, buf, is_blocking=True)
|
||||
if FLOAT16: data = data.astype(np.float16)
|
||||
weights.append(data.tobytes())
|
||||
else:
|
||||
|
|
|
@ -85,11 +85,11 @@ def compile(dat, output_fn):
|
|||
et = time.monotonic()
|
||||
print(f"ran openpilot model in {(et-st)*1000.0:.2f} ms, waited {(mt2-mt)*1000.0:.2f} ms for realize, {(et-mt2)*1000.0:.2f} ms for GPU queue")
|
||||
|
||||
# realize all non GCed tensors (fix for batchnorm folding)
|
||||
import gc
|
||||
gc.collect()
|
||||
for x in [x for x in gc.get_objects() if isinstance(x, Tensor)]:
|
||||
x.realize()
|
||||
# realize all non GCed tensors (fix for batchnorm folding)
|
||||
import gc
|
||||
gc.collect()
|
||||
for x in [x for x in gc.get_objects() if isinstance(x, Tensor)]:
|
||||
x.realize()
|
||||
|
||||
# real run
|
||||
inputs, np_inputs = get_random_input_tensors(input_shapes)
|
||||
|
@ -108,6 +108,9 @@ def compile(dat, output_fn):
|
|||
CL.CACHE = None
|
||||
t.optimize_local_workgroup()
|
||||
|
||||
# save thneed (before run)
|
||||
t.save(output_fn)
|
||||
|
||||
print(f"buffers to save: {len(t.buffers_to_save)}, outputs: {t.outputs}")
|
||||
t.run()
|
||||
|
||||
|
@ -116,9 +119,6 @@ def compile(dat, output_fn):
|
|||
CL.enqueue_copy(thneed_out, t.outputs[0], is_blocking=True)
|
||||
np.testing.assert_allclose(thneed_out, tinygrad_out.numpy())
|
||||
|
||||
# save thneed
|
||||
t.save(output_fn)
|
||||
|
||||
# float32 only (fix this)
|
||||
FLOAT16 = int(os.getenv("FLOAT16", 0))
|
||||
if FLOAT16 == 0:
|
||||
|
@ -132,6 +132,20 @@ def compile(dat, output_fn):
|
|||
_, new_np_inputs = get_random_input_tensors(input_shapes)
|
||||
new_torch_out = run_onnx_torch(onnx_model, new_np_inputs).numpy()
|
||||
|
||||
# try old thneed with a different input
|
||||
for k,v in t.inputs.items():
|
||||
CL.enqueue_copy(v, new_np_inputs[k], is_blocking=True)
|
||||
|
||||
t.run()
|
||||
old_thneed_out = np.empty((t.outputs[0].size//4,), dtype=np.float32).reshape(tinygrad_out.shape)
|
||||
CL.enqueue_copy(old_thneed_out, t.outputs[0], is_blocking=True)
|
||||
|
||||
# compare thneed (rerun) with torch
|
||||
np.testing.assert_allclose(new_torch_out, old_thneed_out, atol=1e-4, rtol=1e-2)
|
||||
|
||||
# load thneed and try that
|
||||
_, new_np_inputs = get_random_input_tensors(input_shapes)
|
||||
new_torch_out = run_onnx_torch(onnx_model, new_np_inputs).numpy()
|
||||
nt = Thneed()
|
||||
nt.load(output_fn)
|
||||
|
||||
|
@ -142,6 +156,8 @@ def compile(dat, output_fn):
|
|||
nt.run()
|
||||
new_thneed_out = np.empty((nt.outputs[0].size//4,), dtype=np.float32).reshape(tinygrad_out.shape)
|
||||
CL.enqueue_copy(new_thneed_out, nt.outputs[0], is_blocking=True)
|
||||
|
||||
# compare torch to thneed
|
||||
np.testing.assert_allclose(new_torch_out, new_thneed_out, atol=1e-4, rtol=1e-2)
|
||||
print("thneed self-test passed!")
|
||||
except ModuleNotFoundError:
|
||||
|
|
Loading…
Reference in New Issue