mirror of https://github.com/commaai/tinygrad.git
good changes from openpilot_compile2 (#2000)
* good changed from openpilot_compile2 * float32 image type was wrong * cleaner way to write that + a test
This commit is contained in:
parent
05be57f57f
commit
ffa33d743a
|
@ -19,8 +19,7 @@ def safe_numpy(t) -> np.ndarray:
|
|||
if not isinstance(t, Tensor): return t
|
||||
global numpy_cache
|
||||
if t not in numpy_cache:
|
||||
if DEBUG >= 1:
|
||||
print("numpy cache miss", t)
|
||||
if DEBUG >= 3: print("numpy cache miss", t)
|
||||
tmp = t.numpy()
|
||||
numpy_cache[t] = tmp if len(tmp.shape) else tmp.reshape(1)
|
||||
assert len(numpy_cache[t].shape) > 0
|
||||
|
@ -95,9 +94,6 @@ def get_run_onnx(onnx_model: ModelProto):
|
|||
print(inp.name, inp.dims, inp.data_type, len(inp.raw_data))
|
||||
print(inp)
|
||||
raise Exception("no data")
|
||||
if DEBUG >= 1:
|
||||
print("realize", inp.name)
|
||||
tensors[inp.name].realize()
|
||||
|
||||
# preparse the attributes
|
||||
attribute_dict = {}
|
||||
|
@ -130,13 +126,6 @@ def get_run_onnx(onnx_model: ModelProto):
|
|||
if shape: # if only input_tensor is not variable type
|
||||
input_shape = input_tensors[inp.name].shape if isinstance(input_tensors[inp.name], Tensor) else (1, *[i.shape for i in input_tensors[inp.name]])
|
||||
assert input_shape == shape, f"wrong shape for input {inp.name}, {input_shape} isn't {shape}"
|
||||
for _,v in input_tensors.items():
|
||||
if isinstance(v, Tensor):
|
||||
v.realize()
|
||||
elif isinstance(v, list):
|
||||
for v_ in v: v_.realize()
|
||||
else:
|
||||
raise Exception(f"unknown input type: {type(v)}")
|
||||
else:
|
||||
raise Exception(f"no data for {inp.name} with shape {shape}")
|
||||
|
||||
|
|
|
@ -322,5 +322,10 @@ class TestSchedule(unittest.TestCase):
|
|||
out = x.permute(0,2,3,1).contiguous()
|
||||
check_schedule(out, 2, filter_loadops=False)
|
||||
|
||||
def test_double_from(self):
|
||||
x = Tensor([1,2,3,4])
|
||||
out = x.to('cpu')
|
||||
check_schedule(out, 0, filter_loadops=False)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
|
|
@ -86,6 +86,25 @@ class TestSafetensors(unittest.TestCase):
|
|||
for k in f.keys():
|
||||
np.testing.assert_array_equal(f.get_tensor(k).numpy(), state_dict[k].numpy())
|
||||
|
||||
def test_huggingface_enet_safetensors(self):
|
||||
# test a real file
|
||||
fn = fetch_as_file("https://huggingface.co/timm/mobilenetv3_small_075.lamb_in1k/resolve/main/model.safetensors")
|
||||
state_dict = safe_load(fn)
|
||||
assert len(state_dict.keys()) == 244
|
||||
assert 'blocks.2.2.se.conv_reduce.weight' in state_dict
|
||||
assert state_dict['blocks.0.0.bn1.num_batches_tracked'].numpy() == 276570
|
||||
assert state_dict['blocks.2.0.bn2.num_batches_tracked'].numpy() == 276570
|
||||
|
||||
def test_metadata(self):
|
||||
metadata = {"hello": "world"}
|
||||
safe_save({}, temp('metadata.safetensors'), metadata)
|
||||
import struct
|
||||
with open(temp('metadata.safetensors'), 'rb') as f:
|
||||
dat = f.read()
|
||||
sz = struct.unpack(">Q", dat[0:8])[0]
|
||||
import json
|
||||
assert json.loads(dat[8:8+sz])['__metadata__']['hello'] == 'world'
|
||||
|
||||
class TestDiskTensor(unittest.TestCase):
|
||||
def test_empty(self):
|
||||
pathlib.Path(temp("dt1")).unlink(missing_ok=True)
|
||||
|
|
|
@ -3,8 +3,8 @@ import itertools, math, os
|
|||
from tinygrad.helpers import DEBUG, prod, getenv, ImageDType, dtypes
|
||||
from tinygrad.ops import ReduceOps, BinaryOps, UnaryOps, LazyOp, BufferOps
|
||||
from tinygrad.codegen.kernel import Kernel, LocalBuffer, LinearizerOptions
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
|
||||
from tinygrad.shape.view import View, strides_for_shape
|
||||
|
||||
class OptimizedKernel(Kernel):
|
||||
def __init__(self, ast:LazyOp, opts:Optional[LinearizerOptions]=None, var_vals=None):
|
||||
|
@ -62,6 +62,19 @@ class OptimizedKernel(Kernel):
|
|||
if self.shape_len == 0: return
|
||||
shapes, strides = [x.shape for x in self.sts], [x.real_strides() for x in self.sts]
|
||||
|
||||
# if it's an image, insert fake strides such that this fusion doesn't happen across image axes
|
||||
if self.bufs[0].dtype.name.startswith('image'):
|
||||
base_shape = self.bufs[0].dtype.shape
|
||||
if shape_idx_groups := get_contraction(self.output_shape, base_shape):
|
||||
special_strides: Tuple[int, ...] = tuple()
|
||||
for i,g in enumerate(shape_idx_groups):
|
||||
shape_piece = tuple(self.output_shape[x] for x in g)
|
||||
assert prod(shape_piece) == base_shape[i], "get_contraction was wrong?"
|
||||
special_strides += strides_for_shape(shape_piece)
|
||||
# adding the fake image shape
|
||||
shapes.append(self.output_shape)
|
||||
strides.append(special_strides)
|
||||
|
||||
# merge dimensions if we can, multi get_shape_strides
|
||||
# TODO: does this always preserve the reduce dimension, NO
|
||||
# TODO: move this into shapetracker, with tests!
|
||||
|
@ -78,7 +91,7 @@ class OptimizedKernel(Kernel):
|
|||
else: rets[j].append((shapes[j][i], strides[j][i]))
|
||||
|
||||
# do the reshapes
|
||||
for i,x in enumerate(rets): self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x]))
|
||||
for i,x in enumerate(rets[:len(self.sts)]): self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x]))
|
||||
|
||||
# ******************** GPU simplifiers ********************
|
||||
def _limit_size(self, x: Tuple[int], max_size: List) -> Tuple[int, ...]:
|
||||
|
@ -354,14 +367,6 @@ class OptimizedKernel(Kernel):
|
|||
# simplify (sets first_reduce)
|
||||
self.simplify_ones()
|
||||
|
||||
# use more opencl indexing if the output buffer is an image and we have room
|
||||
if self.bufs[0].dtype.name.startswith('image') and self.first_reduce+len(self.group_for_reduce) < 3:
|
||||
base_shape = self.bufs[0].dtype.shape
|
||||
if (base_shape[0]*base_shape[1]) % self.sts[0].shape[0] == 0 and self.sts[0].shape[0]//base_shape[0] != 0:
|
||||
if DEBUG >= 4: print("split opencl", base_shape, self.sts[0].shape)
|
||||
self.reshape_and_permute(lambda x: [base_shape[0], x[0]//base_shape[0]]+list(x[1:]), None)
|
||||
self.simplify_ones()
|
||||
|
||||
# no more opt if we are grouping
|
||||
if self.group_for_reduce: return
|
||||
|
||||
|
|
|
@ -129,6 +129,12 @@ class dtypes:
|
|||
_float4: Final[DType] = DType(4, 4*4, "float4", None, 4)
|
||||
_arg_int32: Final[DType] = DType(2, 4, "_arg_int32", None)
|
||||
|
||||
# NOTE: these are image dtypes
|
||||
@staticmethod
|
||||
def imageh(shp): return ImageDType(100, 2, "imageh", np.float16, shp)
|
||||
@staticmethod
|
||||
def imagef(shp): return ImageDType(100, 4, "imagef", np.float32, shp)
|
||||
|
||||
# HACK: staticmethods are not callable in 3.8 so we have to compare the class
|
||||
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not k.startswith('__') and not callable(v) and not v.__class__ == staticmethod}
|
||||
|
||||
|
|
|
@ -191,8 +191,13 @@ class LazyBuffer:
|
|||
# NOTE: dtypes.from_np(self.dtype.np) to deal with image types
|
||||
return self.loadop(LoadOps.CONST, tuple(), dtypes.from_np(self.dtype.np), self.device, arg=val).reshape((1,)*len(self.shape)).expand(self.shape)
|
||||
|
||||
def copy_to_device(self, device:str) -> LazyBuffer:
|
||||
# back off a FROM if it's a double FROM
|
||||
if not self.realized and self.op.op == LoadOps.FROM and cast(LazyBuffer, self.op.src[0]).device == device: return cast(LazyBuffer, self.op.src[0])
|
||||
return LazyBuffer.loadop(LoadOps.FROM, self.shape, self.dtype, device, src=self.contiguous())
|
||||
|
||||
def contiguous(self:LazyBuffer) -> LazyBuffer:
|
||||
if not self.realized and self.op.op == LoadOps.CONTIGUOUS: return self # two CONTIGUOUS in a row is one
|
||||
if not self.realized and self.op.op in LoadOps and self.op.op != LoadOps.CONST: return self # all LoadOps are already contiguous (except CONST)
|
||||
if self.st.contiguous and self.st.size() == self.base.st.size() and not self.is_unrealized_const():
|
||||
# this will turn into nothing, it's based and a copy
|
||||
# TODO: based lazybuffers shouldn't take dtype or var_vals, same issue in movementops
|
||||
|
|
|
@ -1,10 +1,6 @@
|
|||
import numpy as np
|
||||
from tinygrad.helpers import prod, IMAGE, ImageDType, getenv, dtypes
|
||||
from tinygrad.helpers import prod, IMAGE, getenv, dtypes
|
||||
from tinygrad.lazy import get_single_root
|
||||
|
||||
FLOAT16 = getenv("FLOAT16", 0)
|
||||
base_image_type = (100, 2, "imageh", np.float16) if FLOAT16 else (100, 4, "imagef", np.float32)
|
||||
|
||||
def image_dot(self, w):
|
||||
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
|
||||
n1, n2 = len(self.shape), len(w.shape)
|
||||
|
@ -27,6 +23,8 @@ def image_dot(self, w):
|
|||
return cx.conv2d(cw, groups=groups).reshape(shape=out_shape_t).permute(order=order)
|
||||
|
||||
def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, padding=0):
|
||||
base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef
|
||||
|
||||
(bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape
|
||||
rcout = cout//groups
|
||||
x, w = self, weight.reshape(groups, rcout, cin, H, W)
|
||||
|
@ -56,7 +54,7 @@ def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, paddin
|
|||
else: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,3,1).reshape(cout//4, H*cin//4*W*4, 4)
|
||||
|
||||
# contiguous creates the image, and early realize static weights (TODO: test for the static weight)
|
||||
if IMAGE >= 2: x,w = x.cast(ImageDType(*base_image_type, shape=x.shape)), w.cast(ImageDType(*base_image_type, shape=w.shape))
|
||||
if IMAGE >= 2: x,w = x.cast(base_image_type(x.shape)), w.cast(base_image_type(w.shape))
|
||||
x, w = x.contiguous(), w.contiguous()
|
||||
if get_single_root(w.lazydata).realized: w.realize()
|
||||
|
||||
|
@ -86,7 +84,7 @@ def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, paddin
|
|||
|
||||
# reshape to image and cast back to image
|
||||
ret = ret.reshape(bs*oy, ox*cout//4, 4)
|
||||
if IMAGE >= 2: ret = ret.cast(ImageDType(*base_image_type, shape=ret.shape))
|
||||
if IMAGE >= 2: ret = ret.cast(base_image_type(ret.shape))
|
||||
if IMAGE >= 3: ret = ret.contiguous()
|
||||
|
||||
# undo hack for non multiples of 4 on C.rcout
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import os, json, pathlib, zipfile, pickle
|
||||
from tqdm import tqdm
|
||||
from typing import Dict, Union, List
|
||||
from typing import Dict, Union, List, Optional, Any
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import dtypes, prod, argsort, DEBUG, Timing, GlobalCounters, CI
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
|
@ -12,15 +12,16 @@ inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
|
|||
def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
|
||||
t = fn if isinstance(fn, Tensor) else Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
|
||||
json_len = t[0:1].cast(dtypes.int64).numpy()[0]
|
||||
metadata = json.loads(t[8:8+json_len].numpy().tobytes())
|
||||
return {k:t[8+json_len+v['data_offsets'][0]:].cast(safe_dtypes[v['dtype']])[:prod(v['shape'])].reshape(v['shape']) for k,v in metadata.items() if k != "__metadata__"}
|
||||
headers = json.loads(t[8:8+json_len].numpy().tobytes())
|
||||
return {k:t[8+json_len+v['data_offsets'][0]:].cast(safe_dtypes[v['dtype']])[:prod(v['shape'])].reshape(v['shape']) for k,v in headers.items() if k != "__metadata__"}
|
||||
|
||||
def safe_save(tensors:Dict[str, Tensor], fn:str):
|
||||
metadata, offset = {}, 0
|
||||
def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any]]=None):
|
||||
headers, offset = {}, 0
|
||||
if metadata: headers['__metadata__'] = metadata
|
||||
for k,v in tensors.items():
|
||||
metadata[k] = {'dtype': inverse_safe_dtypes[v.dtype], 'shape': list(v.shape), 'data_offsets':[offset, offset+v.nbytes()]}
|
||||
headers[k] = {'dtype': inverse_safe_dtypes[v.dtype], 'shape': list(v.shape), 'data_offsets':[offset, offset+v.nbytes()]}
|
||||
offset += v.nbytes()
|
||||
j = json.dumps(metadata, separators=(',', ':'))
|
||||
j = json.dumps(headers, separators=(',', ':'))
|
||||
j += "\x20"*((8-len(j)%8)%8)
|
||||
pathlib.Path(fn).unlink(missing_ok=True)
|
||||
t = Tensor.empty(8+len(j)+offset, dtype=dtypes.uint8, device=f"disk:{fn}")
|
||||
|
|
|
@ -50,6 +50,7 @@ def run_schedule(schedule:List[Tuple[LazyOp, LazyBuffer, Tuple[LazyBuffer, ...]]
|
|||
while len(schedule):
|
||||
op,out,buffers = schedule.pop(0)
|
||||
log_schedule_item(op, out, buffers)
|
||||
assert all(x.realized for x in buffers), "can't run schedule, some buffers aren't realized"
|
||||
if DEBUG >= 3:
|
||||
from extra.utils import print_tree # type: ignore
|
||||
print_tree(op)
|
||||
|
@ -68,10 +69,12 @@ def run_schedule(schedule:List[Tuple[LazyOp, LazyBuffer, Tuple[LazyBuffer, ...]]
|
|||
|
||||
def _realize_empty(buffer: LazyBuffer) -> None:
|
||||
assert all_int(buffer.shape), "does not support symbolic shape"
|
||||
if DEBUG >= 2: print(f"*** empty {buffer.device} shape {str(buffer.shape):23s} dtype {buffer.dtype}")
|
||||
buffer.realized = Device[buffer.device].buffer(prod(buffer.shape), buffer.dtype, **buffer._device_extra_args())
|
||||
|
||||
def _realize_rand(buffer: LazyBuffer) -> None:
|
||||
assert all_int(buffer.shape), "does not support symbolic shape"
|
||||
if DEBUG >= 2: print(f"*** rand {buffer.device} shape {str(buffer.shape):23s} dtype {buffer.dtype}")
|
||||
rng = np.random.default_rng(buffer.op.arg)
|
||||
buffer.realized = Device[buffer.device].buffer.fromCPU(rng.random(size=prod(buffer.shape), dtype=np.float32).astype(dtype=buffer.dtype.np, copy=False), **buffer._device_extra_args())
|
||||
|
||||
|
@ -80,7 +83,7 @@ def _realize_rand(buffer: LazyBuffer) -> None:
|
|||
def _realize_from(buffer: LazyBuffer, src: LazyBuffer) -> None:
|
||||
assert src.realized.size == buffer.st.size(), f"size mismatch on FROM {src.realized.size} != {buffer.st.size()}"
|
||||
assert src.st.contiguous and buffer.st.contiguous, "all must be contiguous for from"
|
||||
if DEBUG >= 3: print(f"*** copy {buffer.device} <- {src.device} size {src.realized.size} dtype {src.realized.dtype}")
|
||||
if DEBUG >= 2: print(f"*** copy {buffer.device} <- {src.device} size {src.realized.size:16d} shape {str(buffer.shape):23s} dtype {src.realized.dtype}")
|
||||
# TODO: make this generic
|
||||
if isinstance(src.realized, RawDiskBuffer) and issubclass(Device[buffer.device].buffer, RawBufferMapped):
|
||||
assert all_int(buffer.shape), "does not support symbolic shape"
|
||||
|
@ -95,6 +98,7 @@ def _realize_from(buffer: LazyBuffer, src: LazyBuffer) -> None:
|
|||
# *** n op LoadOps ***
|
||||
|
||||
def _realize_custom(buffer: LazyBuffer, *inputs: LazyBuffer) -> None:
|
||||
if DEBUG >= 2: print(f"*** custom {buffer.device} shape {str(buffer.shape):23s} dtype {buffer.dtype}")
|
||||
buffer.realized = buffer.op.arg(buffer, *inputs)
|
||||
|
||||
LOAD_OPS_DISPATCHER: Dict[LoadOps, Callable] = {
|
||||
|
|
|
@ -61,17 +61,20 @@ class Tensor:
|
|||
self._ctx: Optional[Function] = None
|
||||
if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
|
||||
elif isinstance(data, (int, float)):
|
||||
self.lazydata = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or Tensor.default_type, device, data)
|
||||
return
|
||||
data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or Tensor.default_type, device, data)
|
||||
elif data.__class__ is list:
|
||||
assert dtype is None or dtype.np is not None, f"{dtype} doesn't have a numpy dtype"
|
||||
data = LazyBuffer.fromCPU(np.array(data, dtype=(dtype or Tensor.default_type).np))
|
||||
elif isinstance(data, np.ndarray):
|
||||
assert dtype is None or dtype.np is not None, f"{dtype} doesn't have a numpy dtype"
|
||||
data = LazyBuffer.fromCPU(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data)
|
||||
if data.shape == ():
|
||||
data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_np(data.dtype), device, data.item())
|
||||
else:
|
||||
data = LazyBuffer.fromCPU(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data)
|
||||
else: raise RuntimeError(f"can't create Tensor from {data}")
|
||||
|
||||
self.lazydata = data if data.device == device else LazyBuffer.loadop(LoadOps.FROM, data.shape, data.dtype, device, src=data.contiguous())
|
||||
# data is a LazyBuffer, but it might be on the wrong device
|
||||
self.lazydata = data if data.device == device else data.copy_to_device(device)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Tensor {self.lazydata!r} on {self.device} with grad {(self.grad.lazydata if self.grad else None)!r}>"
|
||||
|
|
Loading…
Reference in New Issue