mirror of https://github.com/commaai/tinygrad.git
_min_max of MUL of 2 non-positive inputs (#6454)
This commit is contained in:
parent
b7ce9a1530
commit
2105832b87
|
@ -664,5 +664,26 @@ class TestSymbolicRealWorld(unittest.TestCase):
|
||||||
# NOTE: this used to have 13,151,129,600 in the output which is out of int32 range.
|
# NOTE: this used to have 13,151,129,600 in the output which is out of int32 range.
|
||||||
assert idx.render() == "((((1+lidx5)%16)*49)+(((1+lidx5)//16)*802816)+(gidx0*3211264)+(gidx1*784)+(gidx2*8)+(lidx4*100352)+2207744+lidx3)"
|
assert idx.render() == "((((1+lidx5)%16)*49)+(((1+lidx5)//16)*802816)+(gidx0*3211264)+(gidx1*784)+(gidx2*8)+(lidx4*100352)+2207744+lidx3)"
|
||||||
|
|
||||||
|
class TestBounds(unittest.TestCase):
|
||||||
|
def test_unrolled_arange(self):
|
||||||
|
# #include <metal_stdlib>
|
||||||
|
# using namespace metal;
|
||||||
|
# kernel void r_2560_640_4(device int* data0, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
# int gidx0 = gid.x; /* 2560 */
|
||||||
|
# int alu0 = (gidx0*(-1));
|
||||||
|
# int alu1 = max((int)((-640)),((((alu0+2559)/(-4))*(-1))+(-640)));
|
||||||
|
# int alu2 = max((int)((-640)),((((alu0+2560)/(-4))*(-1))+(-640)));
|
||||||
|
# int alu3 = max((int)((-640)),((((alu0+2561)/(-4))*(-1))+(-640)));
|
||||||
|
# int alu4 = max((int)((-640)),((((alu0+2562)/(-4))*(-1))+(-640)));
|
||||||
|
# *(data0+gidx0) = ((alu1*(-1))+(alu2*(-1))+(alu4*(-1))+(alu3*(-1))+(-1));
|
||||||
|
# }
|
||||||
|
gidx0 = Variable("gidx0", 0, 2559)
|
||||||
|
assert gidx0.vmin == 0 and gidx0.vmax == 2559
|
||||||
|
alu0 = gidx0 * -1
|
||||||
|
assert alu0.vmin == -2559 and alu0.vmax == 0
|
||||||
|
assert (alu0+2559).vmin == 0 and (alu0+2559).vmax == 2559
|
||||||
|
assert ((alu0+2559)//-4).vmin == -639 and ((alu0+2559)//-4).vmax == 0
|
||||||
|
assert (((alu0+2559)//-4)*(-1)).vmin == 0 and (((alu0+2559)//-4)*(-1)).vmax == 639
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -440,12 +440,15 @@ class UOp(MathTrait):
|
||||||
if self.op is UOps.ALU and cast(DType, self.dtype).count == 1:
|
if self.op is UOps.ALU and cast(DType, self.dtype).count == 1:
|
||||||
s0,s1 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(2)]
|
s0,s1 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(2)]
|
||||||
if self.arg is BinaryOps.ADD: return s0.vmin+s1.vmin, s0.vmax+s1.vmax
|
if self.arg is BinaryOps.ADD: return s0.vmin+s1.vmin, s0.vmax+s1.vmax
|
||||||
if self.arg is BinaryOps.MUL and (s0.vmin >= 0 or s1.vmin >= 0):
|
if self.arg is BinaryOps.MUL:
|
||||||
# handle at lease one is non-negative
|
# both are non-positive
|
||||||
Lmin, Lmax = (s0.vmin, s0.vmax) if s1.vmin >= 0 else (s0.vmax, s0.vmin)
|
if (s0.vmax <= 0 and s1.vmax <= 0): return s0.vmax*s1.vmax, s0.vmin*s1.vmin
|
||||||
Rmin, Rmax = (s1.vmin, s1.vmax) if s0.vmin >= 0 else (s1.vmax, s1.vmin)
|
# at lease one is non-negative
|
||||||
assert math.isnan(Lmax*Rmax) or math.isnan(Lmin*Rmin) or Lmax*Rmax >= Lmin*Rmin, f"{Lmax=}, {Lmin=}, {Rmax=}, {Rmin=}"
|
if (s0.vmin >= 0 or s1.vmin >= 0):
|
||||||
return Lmin*Rmin, Lmax*Rmax
|
Lmin, Lmax = (s0.vmin, s0.vmax) if s1.vmin >= 0 else (s0.vmax, s0.vmin)
|
||||||
|
Rmin, Rmax = (s1.vmin, s1.vmax) if s0.vmin >= 0 else (s1.vmax, s1.vmin)
|
||||||
|
assert math.isnan(Lmax*Rmax) or math.isnan(Lmin*Rmin) or Lmax*Rmax >= Lmin*Rmin, f"{Lmax=}, {Lmin=}, {Rmax=}, {Rmin=}"
|
||||||
|
return Lmin*Rmin, Lmax*Rmax
|
||||||
if self.arg is BinaryOps.MOD and s1.vmin > 0: return 0, s1.vmax-1
|
if self.arg is BinaryOps.MOD and s1.vmin > 0: return 0, s1.vmax-1
|
||||||
if self.arg is BinaryOps.IDIV and s1.op is UOps.CONST:
|
if self.arg is BinaryOps.IDIV and s1.op is UOps.CONST:
|
||||||
if s1.arg > 0: return s0.vmin//s1.arg, s0.vmax//s1.arg
|
if s1.arg > 0: return s0.vmin//s1.arg, s0.vmax//s1.arg
|
||||||
|
|
Loading…
Reference in New Issue