support symbols in shrink (#1611)

This commit is contained in:
chenyu 2023-08-22 09:08:21 -07:00 committed by GitHub
parent 718ced296c
commit 89e13f2f04
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 57 additions and 12 deletions

View File

@ -151,15 +151,18 @@ class Transformer:
def __call__(self, tokens:Tensor, start_pos:int): def __call__(self, tokens:Tensor, start_pos:int):
_bsz, seqlen = tokens.shape _bsz, seqlen = tokens.shape
# get only the part we are using.
# NOTE: if you remove contiguous here, it breaks because you can't put different ShapeTrackers into the compiled JIT
# NOTE: realize is not enough, since the realized buffer will have an offset that the kernel doesn't know about
# TODO: check that we didn't do this in the JIT and confirm the ShapeTrackers match the template
# TODO: support Variables in shrink
freqs_cis = self.freqs_cis[:, start_pos:start_pos+seqlen].contiguous()
mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=dtypes.float32).triu(start_pos+1).realize() if seqlen > 1 else None mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=dtypes.float32).triu(start_pos+1).realize() if seqlen > 1 else None
do_jit = getenv("JIT") and mask is None do_jit = getenv("JIT") and mask is None
# get only the part of freqs_cis that we are using.
if do_jit:
pos = Variable("pos", 1, 1024)
assert seqlen == 1, "seqlen > 1 not supported for JIT"
freqs_cis = self.freqs_cis.shrink(((0, self.freqs_cis.shape[0]), (pos, pos+seqlen),(0, self.freqs_cis.shape[2]),(0, self.freqs_cis.shape[3]),(0, self.freqs_cis.shape[4])))
freqs_cis.lazydata.st.var_vals[pos] = start_pos
else:
freqs_cis = self.freqs_cis.shrink(((0, self.freqs_cis.shape[0]), (start_pos, start_pos+seqlen),(0, self.freqs_cis.shape[2]),(0, self.freqs_cis.shape[3]),(0, self.freqs_cis.shape[4])))
h = self.jitted_tok_embeddings(tokens) if do_jit else self.tok_embeddings(tokens) h = self.jitted_tok_embeddings(tokens) if do_jit else self.tok_embeddings(tokens)
h = h.sequential([functools.partial(layer, start_pos=start_pos, freqs_cis=freqs_cis, mask=mask) for layer in self.layers]) h = h.sequential([functools.partial(layer, start_pos=start_pos, freqs_cis=freqs_cis, mask=mask) for layer in self.layers])
return self.jitted_norm_output(h) if do_jit else self.norm_output(h) return self.jitted_norm_output(h) if do_jit else self.norm_output(h)

View File

@ -52,6 +52,15 @@ class TestJit(unittest.TestCase):
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
add(a, bad) add(a, bad)
def test_jit_shape_views_mismatch(self):
@TinyJit
def add(a): return (a+1).realize()
with self.assertRaises(AssertionError):
for i in range(1,5):
# a has an offset that the kernel doesn't know about
a = Tensor.randn(10, 10).realize()[:, i:i+2]
add(a)
def test_jit_duplicate_fail(self): def test_jit_duplicate_fail(self):
# the jit doesn't support duplicate arguments # the jit doesn't support duplicate arguments
@TinyJit @TinyJit

View File

@ -151,5 +151,19 @@ class TestSymbolicJit(unittest.TestCase):
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
add(a, bad) add(a, bad)
def test_shrink(self):
# shrink is a movement, so we pair it with a simple function to test the JIT interaction
def f(a): return (a+1).realize()
jf = TinyJit(f)
vi = Variable("i", 1, 10)
for i in range(1, 5):
a = Tensor.rand(7, 11)
symbolic = a.shrink(((3,5),(vi,vi+2)))
symbolic.lazydata.st.var_vals[vi] = i
symbolic = jf(symbolic).numpy()
expected = f(a.shrink(((3,5),(i,i+2)))).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert len(jf.jit_cache) == 1
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -112,5 +112,15 @@ class TestSymbolicOps(unittest.TestCase):
expected = f(a, b).numpy() expected = f(a, b).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_shrink(self):
vi = Variable("i", 1, 10)
for i in range(1, 5):
a = Tensor.rand(7, 11)
symbolic = a.shrink(((3,5),(vi,vi+2)))
symbolic.lazydata.st.var_vals[vi] = i
symbolic = symbolic.numpy()
expected = a.shrink(((3,5),(i,i+2))).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -127,6 +127,12 @@ class TestSymbolicExpand(unittest.TestCase):
a = a + 1 a = a + 1
assert a.shape == (3, vi) assert a.shape == (3, vi)
class TestSymbolicShrink(unittest.TestCase):
def test_shrink_symbols(self):
vi = Variable("i", 1, 5)
t = Tensor.rand(3, 5).shrink(((0, 2), (vi, vi+1)))
assert t.shape == (2, 1)
class TestSymbolicShapeExpr(unittest.TestCase): class TestSymbolicShapeExpr(unittest.TestCase):
def test_symbolic_expr_idxs(self): def test_symbolic_expr_idxs(self):
# taken from symbolic shape llama # taken from symbolic shape llama

View File

@ -4,7 +4,8 @@ from tinygrad.helpers import DEBUG, DType, merge_dicts
from tinygrad.lazy import Device from tinygrad.lazy import Device
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.ops import GlobalCounters, RawBuffer from tinygrad.ops import GlobalCounters, RawBuffer
from tinygrad.shape.symbolic import Variable, Node from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable
JIT_SUPPORTED_DEVICE = ["GPU", "CLANG", "METAL", "CUDA", "HIP", "WEBGPU"] JIT_SUPPORTED_DEVICE = ["GPU", "CLANG", "METAL", "CUDA", "HIP", "WEBGPU"]
@ -14,7 +15,7 @@ class TinyJit:
self.cnt: int = 0 self.cnt: int = 0
self.jit_cache: List[Tuple[Callable, List[Optional[RawBuffer]], Dict[Variable, int]]] = [] self.jit_cache: List[Tuple[Callable, List[Optional[RawBuffer]], Dict[Variable, int]]] = []
self.ret: Any = None self.ret: Any = None
self.input_replace: Dict[Tuple[int, int], Tuple[Union[int, str], Tuple[Union[Node, int],...], DType]]= {} # (kernel_number, buffer_number) -> (input_name, expected_shape, expected_type) self.input_replace: Dict[Tuple[int, int], Tuple[Union[int, str], ShapeTracker, DType]]= {} # (kernel_number, buffer_number) -> (input_name, expected_shapetracker, expected_type)
# add support for instance methods # add support for instance methods
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) def __get__(self, obj, objtype): return functools.partial(self.__call__, obj)
@ -22,13 +23,13 @@ class TinyJit:
def __call__(self, *args, **kwargs) -> Any: def __call__(self, *args, **kwargs) -> Any:
if Device.DEFAULT not in JIT_SUPPORTED_DEVICE: return self.fxn(*args, **kwargs) # only jit on supported device if Device.DEFAULT not in JIT_SUPPORTED_DEVICE: return self.fxn(*args, **kwargs) # only jit on supported device
# NOTE: this cast is needed since although we know realize will create a ".realized" RawBuffer, the type checker doesn't # NOTE: this cast is needed since although we know realize will create a ".realized" RawBuffer, the type checker doesn't
input_rawbuffers: Dict[Union[int, str], Tuple[RawBuffer, Tuple[Union[Node, int],...]]] = {cast(Union[int, str], k):(cast(RawBuffer, v.realize().lazydata.realized), v.shape) for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)} input_rawbuffers: Dict[Union[int, str], Tuple[RawBuffer, ShapeTracker]] = {cast(Union[int, str], k):(cast(RawBuffer, v.realize().lazydata.realized), v.lazydata.st) for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)}
assert len(input_rawbuffers) != 0, "no inputs to JIT" assert len(input_rawbuffers) != 0, "no inputs to JIT"
assert len(set(input_rawbuffers.values())) == len(input_rawbuffers), "duplicate inputs to JIT" assert len(set(input_rawbuffers.values())) == len(input_rawbuffers), "duplicate inputs to JIT"
if self.cnt >= 2: if self.cnt >= 2:
var_vals = dict(sorted(merge_dicts([arg.lazydata.st.var_vals for arg in args if isinstance(arg, Tensor)]).items(), key=lambda kv: kv[0].key)) var_vals = dict(sorted(merge_dicts([arg.lazydata.st.var_vals for arg in args if isinstance(arg, Tensor)]).items(), key=lambda kv: kv[0].key))
for (j,i),(input_name, expected_shape, expected_type) in self.input_replace.items(): for (j,i),(input_name, expected_st, expected_type) in self.input_replace.items():
assert input_rawbuffers[input_name][1] == expected_shape and input_rawbuffers[input_name][0].dtype == expected_type, f"shape or type mismatch in JIT, <{input_rawbuffers[input_name][1]}, {input_rawbuffers[input_name][0].dtype}> != <{expected_shape}, {expected_type}>" assert input_rawbuffers[input_name][1].views == expected_st.views and input_rawbuffers[input_name][0].dtype == expected_type, f"ShapeTracker.views or type mismatch in JIT, <{input_rawbuffers[input_name][1].views}, {input_rawbuffers[input_name][0].dtype}> != <{expected_st.views}, {expected_type}>"
self.jit_cache[j][1][i] = input_rawbuffers[input_name][0] self.jit_cache[j][1][i] = input_rawbuffers[input_name][0]
for prg, pargs, variables in self.jit_cache: # type: Callable, List[Optional[RawBuffer]], Dict[Variable, int] for prg, pargs, variables in self.jit_cache: # type: Callable, List[Optional[RawBuffer]], Dict[Variable, int]
for v in (var_vals.keys() & variables.keys()): variables[v] = var_vals[v] for v in (var_vals.keys() & variables.keys()): variables[v] = var_vals[v]

View File

@ -35,6 +35,7 @@ class Node:
def __add__(self, b:Union[Node,int]): return Variable.sum([self, b if isinstance(b, Node) else Variable.num(b)]) def __add__(self, b:Union[Node,int]): return Variable.sum([self, b if isinstance(b, Node) else Variable.num(b)])
def __radd__(self, b:int): return self+b def __radd__(self, b:int): return self+b
def __sub__(self, b:Union[Node,int]): return self+-b def __sub__(self, b:Union[Node,int]): return self+-b
def __rsub__(self, b:int): return -self+b
def __le__(self, b:Union[Node,int]): return self < (b+1) def __le__(self, b:Union[Node,int]): return self < (b+1)
def __gt__(self, b:Union[Node,int]): return (-self) < (-b) def __gt__(self, b:Union[Node,int]): return (-self) < (-b)
def __ge__(self, b:Union[Node,int]): return (-self) < (-b+1) def __ge__(self, b:Union[Node,int]): return (-self) < (-b+1)
@ -154,6 +155,7 @@ class NumNode(Node):
self.b:int = num self.b:int = num
self.min, self.max = num, num self.min, self.max = num, num
def __int__(self): return self.b def __int__(self): return self.b
def __index__(self): return self.b
def __eq__(self, other): return self.b == other def __eq__(self, other): return self.b == other
def __hash__(self): return self.hash # needed with __eq__ override def __hash__(self): return self.hash # needed with __eq__ override