mirror of https://github.com/commaai/tinygrad.git
support symbols in shrink (#1611)
This commit is contained in:
parent
718ced296c
commit
89e13f2f04
|
@ -151,15 +151,18 @@ class Transformer:
|
||||||
|
|
||||||
def __call__(self, tokens:Tensor, start_pos:int):
|
def __call__(self, tokens:Tensor, start_pos:int):
|
||||||
_bsz, seqlen = tokens.shape
|
_bsz, seqlen = tokens.shape
|
||||||
# get only the part we are using.
|
|
||||||
# NOTE: if you remove contiguous here, it breaks because you can't put different ShapeTrackers into the compiled JIT
|
|
||||||
# NOTE: realize is not enough, since the realized buffer will have an offset that the kernel doesn't know about
|
|
||||||
# TODO: check that we didn't do this in the JIT and confirm the ShapeTrackers match the template
|
|
||||||
# TODO: support Variables in shrink
|
|
||||||
freqs_cis = self.freqs_cis[:, start_pos:start_pos+seqlen].contiguous()
|
|
||||||
mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=dtypes.float32).triu(start_pos+1).realize() if seqlen > 1 else None
|
mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=dtypes.float32).triu(start_pos+1).realize() if seqlen > 1 else None
|
||||||
|
|
||||||
do_jit = getenv("JIT") and mask is None
|
do_jit = getenv("JIT") and mask is None
|
||||||
|
|
||||||
|
# get only the part of freqs_cis that we are using.
|
||||||
|
if do_jit:
|
||||||
|
pos = Variable("pos", 1, 1024)
|
||||||
|
assert seqlen == 1, "seqlen > 1 not supported for JIT"
|
||||||
|
freqs_cis = self.freqs_cis.shrink(((0, self.freqs_cis.shape[0]), (pos, pos+seqlen),(0, self.freqs_cis.shape[2]),(0, self.freqs_cis.shape[3]),(0, self.freqs_cis.shape[4])))
|
||||||
|
freqs_cis.lazydata.st.var_vals[pos] = start_pos
|
||||||
|
else:
|
||||||
|
freqs_cis = self.freqs_cis.shrink(((0, self.freqs_cis.shape[0]), (start_pos, start_pos+seqlen),(0, self.freqs_cis.shape[2]),(0, self.freqs_cis.shape[3]),(0, self.freqs_cis.shape[4])))
|
||||||
|
|
||||||
h = self.jitted_tok_embeddings(tokens) if do_jit else self.tok_embeddings(tokens)
|
h = self.jitted_tok_embeddings(tokens) if do_jit else self.tok_embeddings(tokens)
|
||||||
h = h.sequential([functools.partial(layer, start_pos=start_pos, freqs_cis=freqs_cis, mask=mask) for layer in self.layers])
|
h = h.sequential([functools.partial(layer, start_pos=start_pos, freqs_cis=freqs_cis, mask=mask) for layer in self.layers])
|
||||||
return self.jitted_norm_output(h) if do_jit else self.norm_output(h)
|
return self.jitted_norm_output(h) if do_jit else self.norm_output(h)
|
||||||
|
|
|
@ -52,6 +52,15 @@ class TestJit(unittest.TestCase):
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
add(a, bad)
|
add(a, bad)
|
||||||
|
|
||||||
|
def test_jit_shape_views_mismatch(self):
|
||||||
|
@TinyJit
|
||||||
|
def add(a): return (a+1).realize()
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
for i in range(1,5):
|
||||||
|
# a has an offset that the kernel doesn't know about
|
||||||
|
a = Tensor.randn(10, 10).realize()[:, i:i+2]
|
||||||
|
add(a)
|
||||||
|
|
||||||
def test_jit_duplicate_fail(self):
|
def test_jit_duplicate_fail(self):
|
||||||
# the jit doesn't support duplicate arguments
|
# the jit doesn't support duplicate arguments
|
||||||
@TinyJit
|
@TinyJit
|
||||||
|
|
|
@ -151,5 +151,19 @@ class TestSymbolicJit(unittest.TestCase):
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
add(a, bad)
|
add(a, bad)
|
||||||
|
|
||||||
|
def test_shrink(self):
|
||||||
|
# shrink is a movement, so we pair it with a simple function to test the JIT interaction
|
||||||
|
def f(a): return (a+1).realize()
|
||||||
|
jf = TinyJit(f)
|
||||||
|
vi = Variable("i", 1, 10)
|
||||||
|
for i in range(1, 5):
|
||||||
|
a = Tensor.rand(7, 11)
|
||||||
|
symbolic = a.shrink(((3,5),(vi,vi+2)))
|
||||||
|
symbolic.lazydata.st.var_vals[vi] = i
|
||||||
|
symbolic = jf(symbolic).numpy()
|
||||||
|
expected = f(a.shrink(((3,5),(i,i+2)))).numpy()
|
||||||
|
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||||
|
assert len(jf.jit_cache) == 1
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
|
@ -112,5 +112,15 @@ class TestSymbolicOps(unittest.TestCase):
|
||||||
expected = f(a, b).numpy()
|
expected = f(a, b).numpy()
|
||||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||||
|
|
||||||
|
def test_shrink(self):
|
||||||
|
vi = Variable("i", 1, 10)
|
||||||
|
for i in range(1, 5):
|
||||||
|
a = Tensor.rand(7, 11)
|
||||||
|
symbolic = a.shrink(((3,5),(vi,vi+2)))
|
||||||
|
symbolic.lazydata.st.var_vals[vi] = i
|
||||||
|
symbolic = symbolic.numpy()
|
||||||
|
expected = a.shrink(((3,5),(i,i+2))).numpy()
|
||||||
|
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
|
@ -127,6 +127,12 @@ class TestSymbolicExpand(unittest.TestCase):
|
||||||
a = a + 1
|
a = a + 1
|
||||||
assert a.shape == (3, vi)
|
assert a.shape == (3, vi)
|
||||||
|
|
||||||
|
class TestSymbolicShrink(unittest.TestCase):
|
||||||
|
def test_shrink_symbols(self):
|
||||||
|
vi = Variable("i", 1, 5)
|
||||||
|
t = Tensor.rand(3, 5).shrink(((0, 2), (vi, vi+1)))
|
||||||
|
assert t.shape == (2, 1)
|
||||||
|
|
||||||
class TestSymbolicShapeExpr(unittest.TestCase):
|
class TestSymbolicShapeExpr(unittest.TestCase):
|
||||||
def test_symbolic_expr_idxs(self):
|
def test_symbolic_expr_idxs(self):
|
||||||
# taken from symbolic shape llama
|
# taken from symbolic shape llama
|
||||||
|
|
|
@ -4,7 +4,8 @@ from tinygrad.helpers import DEBUG, DType, merge_dicts
|
||||||
from tinygrad.lazy import Device
|
from tinygrad.lazy import Device
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.ops import GlobalCounters, RawBuffer
|
from tinygrad.ops import GlobalCounters, RawBuffer
|
||||||
from tinygrad.shape.symbolic import Variable, Node
|
from tinygrad.shape.shapetracker import ShapeTracker
|
||||||
|
from tinygrad.shape.symbolic import Variable
|
||||||
|
|
||||||
JIT_SUPPORTED_DEVICE = ["GPU", "CLANG", "METAL", "CUDA", "HIP", "WEBGPU"]
|
JIT_SUPPORTED_DEVICE = ["GPU", "CLANG", "METAL", "CUDA", "HIP", "WEBGPU"]
|
||||||
|
|
||||||
|
@ -14,7 +15,7 @@ class TinyJit:
|
||||||
self.cnt: int = 0
|
self.cnt: int = 0
|
||||||
self.jit_cache: List[Tuple[Callable, List[Optional[RawBuffer]], Dict[Variable, int]]] = []
|
self.jit_cache: List[Tuple[Callable, List[Optional[RawBuffer]], Dict[Variable, int]]] = []
|
||||||
self.ret: Any = None
|
self.ret: Any = None
|
||||||
self.input_replace: Dict[Tuple[int, int], Tuple[Union[int, str], Tuple[Union[Node, int],...], DType]]= {} # (kernel_number, buffer_number) -> (input_name, expected_shape, expected_type)
|
self.input_replace: Dict[Tuple[int, int], Tuple[Union[int, str], ShapeTracker, DType]]= {} # (kernel_number, buffer_number) -> (input_name, expected_shapetracker, expected_type)
|
||||||
|
|
||||||
# add support for instance methods
|
# add support for instance methods
|
||||||
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj)
|
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj)
|
||||||
|
@ -22,13 +23,13 @@ class TinyJit:
|
||||||
def __call__(self, *args, **kwargs) -> Any:
|
def __call__(self, *args, **kwargs) -> Any:
|
||||||
if Device.DEFAULT not in JIT_SUPPORTED_DEVICE: return self.fxn(*args, **kwargs) # only jit on supported device
|
if Device.DEFAULT not in JIT_SUPPORTED_DEVICE: return self.fxn(*args, **kwargs) # only jit on supported device
|
||||||
# NOTE: this cast is needed since although we know realize will create a ".realized" RawBuffer, the type checker doesn't
|
# NOTE: this cast is needed since although we know realize will create a ".realized" RawBuffer, the type checker doesn't
|
||||||
input_rawbuffers: Dict[Union[int, str], Tuple[RawBuffer, Tuple[Union[Node, int],...]]] = {cast(Union[int, str], k):(cast(RawBuffer, v.realize().lazydata.realized), v.shape) for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)}
|
input_rawbuffers: Dict[Union[int, str], Tuple[RawBuffer, ShapeTracker]] = {cast(Union[int, str], k):(cast(RawBuffer, v.realize().lazydata.realized), v.lazydata.st) for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)}
|
||||||
assert len(input_rawbuffers) != 0, "no inputs to JIT"
|
assert len(input_rawbuffers) != 0, "no inputs to JIT"
|
||||||
assert len(set(input_rawbuffers.values())) == len(input_rawbuffers), "duplicate inputs to JIT"
|
assert len(set(input_rawbuffers.values())) == len(input_rawbuffers), "duplicate inputs to JIT"
|
||||||
if self.cnt >= 2:
|
if self.cnt >= 2:
|
||||||
var_vals = dict(sorted(merge_dicts([arg.lazydata.st.var_vals for arg in args if isinstance(arg, Tensor)]).items(), key=lambda kv: kv[0].key))
|
var_vals = dict(sorted(merge_dicts([arg.lazydata.st.var_vals for arg in args if isinstance(arg, Tensor)]).items(), key=lambda kv: kv[0].key))
|
||||||
for (j,i),(input_name, expected_shape, expected_type) in self.input_replace.items():
|
for (j,i),(input_name, expected_st, expected_type) in self.input_replace.items():
|
||||||
assert input_rawbuffers[input_name][1] == expected_shape and input_rawbuffers[input_name][0].dtype == expected_type, f"shape or type mismatch in JIT, <{input_rawbuffers[input_name][1]}, {input_rawbuffers[input_name][0].dtype}> != <{expected_shape}, {expected_type}>"
|
assert input_rawbuffers[input_name][1].views == expected_st.views and input_rawbuffers[input_name][0].dtype == expected_type, f"ShapeTracker.views or type mismatch in JIT, <{input_rawbuffers[input_name][1].views}, {input_rawbuffers[input_name][0].dtype}> != <{expected_st.views}, {expected_type}>"
|
||||||
self.jit_cache[j][1][i] = input_rawbuffers[input_name][0]
|
self.jit_cache[j][1][i] = input_rawbuffers[input_name][0]
|
||||||
for prg, pargs, variables in self.jit_cache: # type: Callable, List[Optional[RawBuffer]], Dict[Variable, int]
|
for prg, pargs, variables in self.jit_cache: # type: Callable, List[Optional[RawBuffer]], Dict[Variable, int]
|
||||||
for v in (var_vals.keys() & variables.keys()): variables[v] = var_vals[v]
|
for v in (var_vals.keys() & variables.keys()): variables[v] = var_vals[v]
|
||||||
|
|
|
@ -35,6 +35,7 @@ class Node:
|
||||||
def __add__(self, b:Union[Node,int]): return Variable.sum([self, b if isinstance(b, Node) else Variable.num(b)])
|
def __add__(self, b:Union[Node,int]): return Variable.sum([self, b if isinstance(b, Node) else Variable.num(b)])
|
||||||
def __radd__(self, b:int): return self+b
|
def __radd__(self, b:int): return self+b
|
||||||
def __sub__(self, b:Union[Node,int]): return self+-b
|
def __sub__(self, b:Union[Node,int]): return self+-b
|
||||||
|
def __rsub__(self, b:int): return -self+b
|
||||||
def __le__(self, b:Union[Node,int]): return self < (b+1)
|
def __le__(self, b:Union[Node,int]): return self < (b+1)
|
||||||
def __gt__(self, b:Union[Node,int]): return (-self) < (-b)
|
def __gt__(self, b:Union[Node,int]): return (-self) < (-b)
|
||||||
def __ge__(self, b:Union[Node,int]): return (-self) < (-b+1)
|
def __ge__(self, b:Union[Node,int]): return (-self) < (-b+1)
|
||||||
|
@ -154,6 +155,7 @@ class NumNode(Node):
|
||||||
self.b:int = num
|
self.b:int = num
|
||||||
self.min, self.max = num, num
|
self.min, self.max = num, num
|
||||||
def __int__(self): return self.b
|
def __int__(self): return self.b
|
||||||
|
def __index__(self): return self.b
|
||||||
def __eq__(self, other): return self.b == other
|
def __eq__(self, other): return self.b == other
|
||||||
def __hash__(self): return self.hash # needed with __eq__ override
|
def __hash__(self): return self.hash # needed with __eq__ override
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue