pad None means (0,0) (#2273)

This commit is contained in:
chenyu 2023-11-11 12:50:26 -05:00 committed by GitHub
parent c5d70c1871
commit 453f48ce02
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 7 additions and 7 deletions

View File

@ -40,8 +40,8 @@ class Attention:
values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1)
# update the cache
self.cache_k.assign(keys.pad(((0,0),(0,MAX_CONTEXT-start_pos-seqlen),(0,0),(0,0))).contiguous()).realize()
self.cache_v.assign(values.pad(((0,0),(0,MAX_CONTEXT-start_pos-seqlen),(0,0),(0,0))).contiguous()).realize()
self.cache_k.assign(keys.pad((None,(0,MAX_CONTEXT-start_pos-seqlen),None,None)).contiguous()).realize()
self.cache_v.assign(values.pad((None,(0,MAX_CONTEXT-start_pos-seqlen),None,None)).contiguous()).realize()
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
return self.c_proj(xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, -1))

View File

@ -90,8 +90,8 @@ class Attention:
values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1)
# update the cache
self.cache_k.assign(keys.pad(((0,0),(0,MAX_CONTEXT-start_pos-seqlen),(0,0),(0,0))).contiguous()).realize()
self.cache_v.assign(values.pad(((0,0),(0,MAX_CONTEXT-start_pos-seqlen),(0,0),(0,0))).contiguous()).realize()
self.cache_k.assign(keys.pad((None,(0,MAX_CONTEXT-start_pos-seqlen),None,None)).contiguous()).realize()
self.cache_v.assign(values.pad((None,(0,MAX_CONTEXT-start_pos-seqlen),None,None)).contiguous()).realize()
keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)

View File

@ -253,9 +253,9 @@ class Tensor:
def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple([x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))]))
def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args))
def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)])
def shrink(self, arg:Tuple[Optional[Tuple[sint, sint]], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=tuple(x if x else (0,s) for x,s in zip(arg, self.shape))) if any(x != (0,s) for x,s in zip(arg, self.shape)) else self
def pad(self, arg: Tuple[Tuple[int, int], ...], value:float=0) -> Tensor:
ret = mlops.Pad.apply(self, arg=arg) if any(x != (0, 0) for x in arg) else self
def shrink(self, arg:Tuple[Optional[Tuple[sint, sint]], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=tuple(x if x is not None else (0,s) for x,s in zip(arg, self.shape))) if any(x is not None and x != (0,s) for x,s in zip(arg, self.shape)) else self
def pad(self, arg:Tuple[Optional[Tuple[int, int]], ...], value:float=0.0) -> Tensor:
ret = mlops.Pad.apply(self, arg=tuple(x if x is not None else (0,0) for x in arg)) if any(x is not None and x != (0,0) for x in arg) else self
return ret if 0 == value else ret + mlops.Pad.apply(Tensor.ones_like(self), arg=arg).where(0, value)
# ***** movement hlops *****