mirror of https://github.com/commaai/tinygrad.git
fix Tensor.triu / Tensor.triu with boolean input (#4941)
`where(self, 0)` incorrectly upcasted the output. `where(self, False)` is correct but looks unnatural, so added a cast at the end. Pattern matcher can fold the cast into where branches
This commit is contained in:
parent
cc90b3ef9f
commit
fae08c4d48
|
@ -301,6 +301,7 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op([(5,3,3)], lambda x: x.tril())
|
||||
helper_test_op([(5,0,3)], lambda x: x.tril())
|
||||
helper_test_op([(5,3,3)], lambda x: x.tril(1))
|
||||
helper_test_op(None, lambda x: x.tril(), vals=[[[True] * 3] * 3], forward_only=True)
|
||||
def test_triu(self):
|
||||
helper_test_op([(3,3)], lambda x: x.triu())
|
||||
helper_test_op([(3,3)], lambda x: x.triu(1))
|
||||
|
@ -308,6 +309,7 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op([(5,3,3)], lambda x: x.triu())
|
||||
helper_test_op([(5,0,3)], lambda x: x.triu())
|
||||
helper_test_op([(5,3,3)], lambda x: x.triu(1))
|
||||
helper_test_op(None, lambda x: x.triu(), vals=[[[True] * 3] * 3], forward_only=True)
|
||||
|
||||
def test_maximum(self):
|
||||
helper_test_op([(45,65), (45,65)], torch.maximum, Tensor.maximum)
|
||||
|
|
|
@ -1777,7 +1777,7 @@ class Tensor:
|
|||
print(t.triu(k=1).numpy())
|
||||
```
|
||||
"""
|
||||
return Tensor._tri(self.shape[-2], self.shape[-1], k=k, device=self.device).where(self, 0)
|
||||
return Tensor._tri(self.shape[-2], self.shape[-1], k=k, device=self.device).where(self, 0).cast(self.dtype)
|
||||
def tril(self, k:int=0) -> Tensor:
|
||||
"""
|
||||
Returns the lower triangular part of the tensor, the other elements are set to 0.
|
||||
|
@ -1790,7 +1790,7 @@ class Tensor:
|
|||
print(t.tril().numpy())
|
||||
```
|
||||
"""
|
||||
return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, device=self.device).where(0, self)
|
||||
return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, device=self.device).where(0, self).cast(self.dtype)
|
||||
|
||||
# ***** unary ops *****
|
||||
|
||||
|
|
Loading…
Reference in New Issue