implement MAX in other dtypes (#2572)

This commit is contained in:
qazal 2023-12-02 18:21:59 -05:00 committed by GitHub
parent 065495e0c9
commit fa1d4dd14b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -18,7 +18,7 @@ code_for_op: Final[Dict[Op, Callable]] = {
BinaryOps.DIV: lambda builder,x,y: builder.sdiv(x,y) if isinstance(x.type, ir.IntType) else builder.fdiv(x,y, flags=LLVM_FAST_MATH_FLAGS),
# TODO: this should be casted
BinaryOps.CMPLT: lambda builder,x,y: builder.zext(builder.icmp_signed("<", x, y),ir.IntType(32)) if isinstance(x.type, ir.IntType) else builder.uitofp(builder.fcmp_ordered("<", x, y, flags=LLVM_FAST_MATH_FLAGS), ir.FloatType()),
BinaryOps.MAX: lambda builder,x,y: builder.select(builder.fcmp_unordered(">", x, y, flags=LLVM_FAST_MATH_FLAGS), x, y, flags=LLVM_FAST_MATH_FLAGS),
BinaryOps.MAX: lambda builder,x,y: builder.select(builder.fcmp_unordered(">", x, y, flags=LLVM_FAST_MATH_FLAGS), x, y, flags=LLVM_FAST_MATH_FLAGS) if isinstance(x.type, ir.FloatType) else builder.select(builder.icmp_signed(">", x, y), x, y),
BinaryOps.MOD: lambda builder,x,y: builder.srem(x,y) if isinstance(x.type, ir.IntType) else builder.frem(x,y) if isinstance(x.type, ir.FloatType) else builder.urem(x,y),
TernaryOps.MULACC: lambda builder,x,y,z: builder.fadd(builder.fmul(x,y, flags=LLVM_FAST_MATH_FLAGS), z, flags=LLVM_FAST_MATH_FLAGS),
TernaryOps.WHERE: lambda builder,x,y,z: builder.select(builder.fcmp_unordered("!=", x, ir.Constant(ir.FloatType(), 0), flags=LLVM_FAST_MATH_FLAGS) if isinstance(x.type, ir.FloatType) else builder.trunc(x, ir.IntType(1)), y, z),