mirror of https://github.com/commaai/tinygrad.git
pad None means (0,0) (#2273)
This commit is contained in:
parent
c5d70c1871
commit
453f48ce02
|
@ -40,8 +40,8 @@ class Attention:
|
|||
values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1)
|
||||
|
||||
# update the cache
|
||||
self.cache_k.assign(keys.pad(((0,0),(0,MAX_CONTEXT-start_pos-seqlen),(0,0),(0,0))).contiguous()).realize()
|
||||
self.cache_v.assign(values.pad(((0,0),(0,MAX_CONTEXT-start_pos-seqlen),(0,0),(0,0))).contiguous()).realize()
|
||||
self.cache_k.assign(keys.pad((None,(0,MAX_CONTEXT-start_pos-seqlen),None,None)).contiguous()).realize()
|
||||
self.cache_v.assign(values.pad((None,(0,MAX_CONTEXT-start_pos-seqlen),None,None)).contiguous()).realize()
|
||||
|
||||
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
|
||||
return self.c_proj(xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, -1))
|
||||
|
|
|
@ -90,8 +90,8 @@ class Attention:
|
|||
values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1)
|
||||
|
||||
# update the cache
|
||||
self.cache_k.assign(keys.pad(((0,0),(0,MAX_CONTEXT-start_pos-seqlen),(0,0),(0,0))).contiguous()).realize()
|
||||
self.cache_v.assign(values.pad(((0,0),(0,MAX_CONTEXT-start_pos-seqlen),(0,0),(0,0))).contiguous()).realize()
|
||||
self.cache_k.assign(keys.pad((None,(0,MAX_CONTEXT-start_pos-seqlen),None,None)).contiguous()).realize()
|
||||
self.cache_v.assign(values.pad((None,(0,MAX_CONTEXT-start_pos-seqlen),None,None)).contiguous()).realize()
|
||||
|
||||
keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
|
||||
|
||||
|
|
|
@ -253,9 +253,9 @@ class Tensor:
|
|||
def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple([x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))]))
|
||||
def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args))
|
||||
def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)])
|
||||
def shrink(self, arg:Tuple[Optional[Tuple[sint, sint]], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=tuple(x if x else (0,s) for x,s in zip(arg, self.shape))) if any(x != (0,s) for x,s in zip(arg, self.shape)) else self
|
||||
def pad(self, arg: Tuple[Tuple[int, int], ...], value:float=0) -> Tensor:
|
||||
ret = mlops.Pad.apply(self, arg=arg) if any(x != (0, 0) for x in arg) else self
|
||||
def shrink(self, arg:Tuple[Optional[Tuple[sint, sint]], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=tuple(x if x is not None else (0,s) for x,s in zip(arg, self.shape))) if any(x is not None and x != (0,s) for x,s in zip(arg, self.shape)) else self
|
||||
def pad(self, arg:Tuple[Optional[Tuple[int, int]], ...], value:float=0.0) -> Tensor:
|
||||
ret = mlops.Pad.apply(self, arg=tuple(x if x is not None else (0,0) for x in arg)) if any(x is not None and x != (0,0) for x in arg) else self
|
||||
return ret if 0 == value else ret + mlops.Pad.apply(Tensor.ones_like(self), arg=arg).where(0, value)
|
||||
|
||||
# ***** movement hlops *****
|
||||
|
|
Loading…
Reference in New Issue