diff --git a/extra/models/llama.py b/extra/models/llama.py index 08630112..e1a0a136 100644 --- a/extra/models/llama.py +++ b/extra/models/llama.py @@ -101,7 +101,7 @@ class TransformerBlock: def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]): h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask) - return (h + self.feed_forward(self.ffn_norm(h).half())).realize() + return h + self.feed_forward(self.ffn_norm(h).half()) class Transformer: def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, linear=nn.Linear, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward): diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index e4030f62..9e1a00ee 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -94,11 +94,14 @@ class TinyJit(Generic[ReturnType]): def __call__(self, *args, **kwargs) -> ReturnType: # all inputs (except const) are realized - input_tensors: Dict[Union[int, str], Union[LazyBuffer, MultiLazyBuffer]] = { cast(Union[int, str], k):v.realize().lazydata for k,v in itertools.chain(enumerate(args), sorted(kwargs.items())) if v.__class__ is Tensor } # noqa: E501 - expected_name_sts_dtype_device = tuple([(k, v.st.unbind()[0] if isinstance(v, LazyBuffer) else ShapeTracker.from_shape(v.shape), v.dtype, v.device) for k,v in input_tensors.items()]) #noqa: E501 + input_tensors: Dict[Union[int, str], Tensor] = { cast(Union[int, str], k):v for k,v in itertools.chain(enumerate(args), sorted(kwargs.items())) if v.__class__ is Tensor } # noqa: E501 + Tensor.corealize(input_tensors.values()) + input_lbs: Dict[Union[int, str], Union[LazyBuffer, MultiLazyBuffer]] = {k:v.lazydata for k,v in input_tensors.items()} + expected_name_sts_dtype_device = tuple([(k, v.st.unbind()[0] if isinstance(v, LazyBuffer) else ShapeTracker.from_shape(v.shape), v.dtype, v.device) for k,v in input_lbs.items()]) #noqa: E501 # get rawbuffers - lbs: List[LazyBuffer] = [v for v in input_tensors.values() if isinstance(v, LazyBuffer)] + flatten([mlb.lbs for mlb in input_tensors.values() if isinstance(mlb, MultiLazyBuffer)]) #noqa: E501 + lbs: List[LazyBuffer] = [v for v in input_lbs.values() if isinstance(v, LazyBuffer)] + \ + flatten([mlb.lbs for mlb in input_lbs.values() if isinstance(mlb, MultiLazyBuffer)]) input_rawbuffers: List[Buffer] = [v.base.realized for v in lbs if v.base.realized is not None] assert len(set(input_rawbuffers)) == len(input_rawbuffers), "duplicate inputs to JIT" @@ -113,6 +116,7 @@ class TinyJit(Generic[ReturnType]): for x,y in zip(self.expected_name_sts_dtype_device, expected_name_sts_dtype_device)), \ f"mismatch of input tensors, expected {self.expected_name_sts_dtype_device} got {expected_name_sts_dtype_device}" for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_idx] + if DEBUG >= 1: print(f"jit execs {len(self.jit_cache)} kernels") for ji in self.jit_cache: ji.prg(cast(List[Buffer], ji.rawbufs), var_vals, wait=DEBUG>=2, jit=True) elif self.cnt == 1: # jit capture diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 355cfb19..22411db0 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -238,6 +238,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffe # confirm everything was scheduled correctly if not all(degree == 0 for degree in in_degree.values()) or len(prescheduled) != len(schedule): raise RuntimeError(f"cycle detected in graph, prescheduled {len(prescheduled)} but only scheduled {len(schedule)}") + if DEBUG >= 1 and len(schedule) > 0: print(f"scheduled {len(schedule)} kernels") return schedule, var_vals def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]: