_min_max of MUL of 2 non-positive inputs (#6454)

This commit is contained in:
chenyu 2024-09-10 07:13:01 -04:00 committed by GitHub
parent b7ce9a1530
commit 2105832b87
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 30 additions and 6 deletions

View File

@ -664,5 +664,26 @@ class TestSymbolicRealWorld(unittest.TestCase):
# NOTE: this used to have 13,151,129,600 in the output which is out of int32 range. # NOTE: this used to have 13,151,129,600 in the output which is out of int32 range.
assert idx.render() == "((((1+lidx5)%16)*49)+(((1+lidx5)//16)*802816)+(gidx0*3211264)+(gidx1*784)+(gidx2*8)+(lidx4*100352)+2207744+lidx3)" assert idx.render() == "((((1+lidx5)%16)*49)+(((1+lidx5)//16)*802816)+(gidx0*3211264)+(gidx1*784)+(gidx2*8)+(lidx4*100352)+2207744+lidx3)"
class TestBounds(unittest.TestCase):
def test_unrolled_arange(self):
# #include <metal_stdlib>
# using namespace metal;
# kernel void r_2560_640_4(device int* data0, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
# int gidx0 = gid.x; /* 2560 */
# int alu0 = (gidx0*(-1));
# int alu1 = max((int)((-640)),((((alu0+2559)/(-4))*(-1))+(-640)));
# int alu2 = max((int)((-640)),((((alu0+2560)/(-4))*(-1))+(-640)));
# int alu3 = max((int)((-640)),((((alu0+2561)/(-4))*(-1))+(-640)));
# int alu4 = max((int)((-640)),((((alu0+2562)/(-4))*(-1))+(-640)));
# *(data0+gidx0) = ((alu1*(-1))+(alu2*(-1))+(alu4*(-1))+(alu3*(-1))+(-1));
# }
gidx0 = Variable("gidx0", 0, 2559)
assert gidx0.vmin == 0 and gidx0.vmax == 2559
alu0 = gidx0 * -1
assert alu0.vmin == -2559 and alu0.vmax == 0
assert (alu0+2559).vmin == 0 and (alu0+2559).vmax == 2559
assert ((alu0+2559)//-4).vmin == -639 and ((alu0+2559)//-4).vmax == 0
assert (((alu0+2559)//-4)*(-1)).vmin == 0 and (((alu0+2559)//-4)*(-1)).vmax == 639
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -440,12 +440,15 @@ class UOp(MathTrait):
if self.op is UOps.ALU and cast(DType, self.dtype).count == 1: if self.op is UOps.ALU and cast(DType, self.dtype).count == 1:
s0,s1 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(2)] s0,s1 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(2)]
if self.arg is BinaryOps.ADD: return s0.vmin+s1.vmin, s0.vmax+s1.vmax if self.arg is BinaryOps.ADD: return s0.vmin+s1.vmin, s0.vmax+s1.vmax
if self.arg is BinaryOps.MUL and (s0.vmin >= 0 or s1.vmin >= 0): if self.arg is BinaryOps.MUL:
# handle at lease one is non-negative # both are non-positive
Lmin, Lmax = (s0.vmin, s0.vmax) if s1.vmin >= 0 else (s0.vmax, s0.vmin) if (s0.vmax <= 0 and s1.vmax <= 0): return s0.vmax*s1.vmax, s0.vmin*s1.vmin
Rmin, Rmax = (s1.vmin, s1.vmax) if s0.vmin >= 0 else (s1.vmax, s1.vmin) # at lease one is non-negative
assert math.isnan(Lmax*Rmax) or math.isnan(Lmin*Rmin) or Lmax*Rmax >= Lmin*Rmin, f"{Lmax=}, {Lmin=}, {Rmax=}, {Rmin=}" if (s0.vmin >= 0 or s1.vmin >= 0):
return Lmin*Rmin, Lmax*Rmax Lmin, Lmax = (s0.vmin, s0.vmax) if s1.vmin >= 0 else (s0.vmax, s0.vmin)
Rmin, Rmax = (s1.vmin, s1.vmax) if s0.vmin >= 0 else (s1.vmax, s1.vmin)
assert math.isnan(Lmax*Rmax) or math.isnan(Lmin*Rmin) or Lmax*Rmax >= Lmin*Rmin, f"{Lmax=}, {Lmin=}, {Rmax=}, {Rmin=}"
return Lmin*Rmin, Lmax*Rmax
if self.arg is BinaryOps.MOD and s1.vmin > 0: return 0, s1.vmax-1 if self.arg is BinaryOps.MOD and s1.vmin > 0: return 0, s1.vmax-1
if self.arg is BinaryOps.IDIV and s1.op is UOps.CONST: if self.arg is BinaryOps.IDIV and s1.op is UOps.CONST:
if s1.arg > 0: return s0.vmin//s1.arg, s0.vmax//s1.arg if s1.arg > 0: return s0.vmin//s1.arg, s0.vmax//s1.arg