mirror of https://github.com/commaai/tinygrad.git
parent
ace834ef7b
commit
13575f080a
|
@ -2144,6 +2144,9 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op([(3, 3)], lambda x: x.int(), forward_only=True)
|
||||
helper_test_op([(3, 3)], lambda x: x.bool(), forward_only=True)
|
||||
|
||||
def test_bitcast(self):
|
||||
helper_test_op([(3, 3)], lambda x: x.view(torch.int32), lambda x: x.bitcast(dtypes.int32), forward_only=True)
|
||||
|
||||
class TestOpsUint8(unittest.TestCase):
|
||||
@unittest.skip('this is broken for negative numbers')
|
||||
def test_cast(self):
|
||||
|
|
|
@ -21,7 +21,8 @@ class Cast(Function):
|
|||
return x.bitcast(dtype) if self.bitcast else x.cast(dtype)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.bitcast(self.input_dtype) if self.bitcast else grad_output.cast(self.input_dtype)
|
||||
if self.bitcast: raise RuntimeError("bitcast cannot backward")
|
||||
return grad_output.cast(self.input_dtype)
|
||||
|
||||
# ************* unary ops *************
|
||||
|
||||
|
|
Loading…
Reference in New Issue