remove bitcast backward in function.py (#7031)

bitcast cannot backward
This commit is contained in:
chenyu 2024-10-13 10:08:27 -04:00 committed by GitHub
parent ace834ef7b
commit 13575f080a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 5 additions and 1 deletions

View File

@ -2144,6 +2144,9 @@ class TestOps(unittest.TestCase):
helper_test_op([(3, 3)], lambda x: x.int(), forward_only=True)
helper_test_op([(3, 3)], lambda x: x.bool(), forward_only=True)
def test_bitcast(self):
helper_test_op([(3, 3)], lambda x: x.view(torch.int32), lambda x: x.bitcast(dtypes.int32), forward_only=True)
class TestOpsUint8(unittest.TestCase):
@unittest.skip('this is broken for negative numbers')
def test_cast(self):

View File

@ -21,7 +21,8 @@ class Cast(Function):
return x.bitcast(dtype) if self.bitcast else x.cast(dtype)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.bitcast(self.input_dtype) if self.bitcast else grad_output.cast(self.input_dtype)
if self.bitcast: raise RuntimeError("bitcast cannot backward")
return grad_output.cast(self.input_dtype)
# ************* unary ops *************