fix Tensor.triu / Tensor.triu with boolean input (#4941)

`where(self, 0)` incorrectly upcasted the output. `where(self, False)` is correct but looks unnatural, so added a cast at the end. Pattern matcher can fold the cast into where branches
This commit is contained in:
chenyu 2024-06-12 20:16:13 -04:00 committed by GitHub
parent cc90b3ef9f
commit fae08c4d48
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 4 additions and 2 deletions

View File

@ -301,6 +301,7 @@ class TestOps(unittest.TestCase):
helper_test_op([(5,3,3)], lambda x: x.tril())
helper_test_op([(5,0,3)], lambda x: x.tril())
helper_test_op([(5,3,3)], lambda x: x.tril(1))
helper_test_op(None, lambda x: x.tril(), vals=[[[True] * 3] * 3], forward_only=True)
def test_triu(self):
helper_test_op([(3,3)], lambda x: x.triu())
helper_test_op([(3,3)], lambda x: x.triu(1))
@ -308,6 +309,7 @@ class TestOps(unittest.TestCase):
helper_test_op([(5,3,3)], lambda x: x.triu())
helper_test_op([(5,0,3)], lambda x: x.triu())
helper_test_op([(5,3,3)], lambda x: x.triu(1))
helper_test_op(None, lambda x: x.triu(), vals=[[[True] * 3] * 3], forward_only=True)
def test_maximum(self):
helper_test_op([(45,65), (45,65)], torch.maximum, Tensor.maximum)

View File

@ -1777,7 +1777,7 @@ class Tensor:
print(t.triu(k=1).numpy())
```
"""
return Tensor._tri(self.shape[-2], self.shape[-1], k=k, device=self.device).where(self, 0)
return Tensor._tri(self.shape[-2], self.shape[-1], k=k, device=self.device).where(self, 0).cast(self.dtype)
def tril(self, k:int=0) -> Tensor:
"""
Returns the lower triangular part of the tensor, the other elements are set to 0.
@ -1790,7 +1790,7 @@ class Tensor:
print(t.tril().numpy())
```
"""
return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, device=self.device).where(0, self)
return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, device=self.device).where(0, self).cast(self.dtype)
# ***** unary ops *****