mirror of https://github.com/commaai/tinygrad.git
Const pad support to pad2d and slice (#1392)
* slice to pad2d migrate * Gain line * Mypy happy * Mypy happy * Revert * whitespace
This commit is contained in:
parent
ab9e4a2e93
commit
8889821547
|
@ -608,7 +608,9 @@ class TestOps(unittest.TestCase):
|
|||
|
||||
def test_pad2d(self):
|
||||
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4)))
|
||||
|
||||
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (-1,2,-3,4)), lambda x: x.pad2d(padding=(-1,2,-3,4)))
|
||||
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=5), lambda x: x.pad2d(padding=(1,2,3,4),value=5))
|
||||
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (-1,2,-3,4), value=5), lambda x: x.pad2d(padding=(-1,2,-3,4),value=5))
|
||||
def test_pad(self):
|
||||
helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)),lambda x: x.pad(((3,4),(1,2))))
|
||||
helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=5), lambda x: x.pad(((3,4), (1,2)), value=5))
|
||||
|
|
|
@ -251,10 +251,10 @@ class Tensor:
|
|||
# ***** movement hlops *****
|
||||
|
||||
# NOTE: using slice is discouraged and things should migrate to pad and shrink
|
||||
def slice(self, arg:Sequence[Optional[Tuple[int, int]]]) -> Tensor:
|
||||
def slice(self, arg:Sequence[Optional[Tuple[int, int]]], value:float=0) -> Tensor:
|
||||
arg_ = tuple([a if a is not None else (0,s) for s,a in zip(self.shape, arg)])
|
||||
padding = tuple([(max(0, -p[0]), max(0, p[1]-self.shape[i])) for i,p in enumerate(arg_)])
|
||||
return self.pad(padding).shrink(tuple([(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg_)]))
|
||||
return self.pad(padding, value=value).shrink(tuple([(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg_)]))
|
||||
|
||||
# - Negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element
|
||||
# - A slice i:j returns the elements with indices in [i, j)
|
||||
|
@ -377,9 +377,9 @@ class Tensor:
|
|||
return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])
|
||||
|
||||
# (padding_left, padding_right, padding_top, padding_bottom)
|
||||
def pad2d(self, padding:Union[List[int], Tuple[int, ...]]):
|
||||
def pad2d(self, padding:Union[List[int], Tuple[int, ...]], value:float=0):
|
||||
slc = [(-p0, s+p1) for p0,p1,s in zip(padding[::2], padding[1::2], self.shape[::-1])][::-1]
|
||||
return self.slice([(0,s) for s in self.shape[:-(len(padding)//2)]] + slc)
|
||||
return self.slice([(0,s) for s in self.shape[:-(len(padding)//2)]] + slc, value=value)
|
||||
|
||||
@property
|
||||
def T(self) -> Tensor: return self.transpose()
|
||||
|
|
Loading…
Reference in New Issue