ptx cleanup (#3893)

This commit is contained in:
Szymon Ożóg 2024-03-24 22:54:45 +01:00 committed by GitHub
parent 2e39f57594
commit 2d0bfdf01c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 3 additions and 4 deletions

View File

@ -66,10 +66,8 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
def ptr_ar(root):
root.arg = '.shared' if root.vin[0].uop == UOps.DEFINE_LOCAL else '.global' # move this to the argL
if root.vin[0].dtype.itemsize > 1:
val = uops.add(UOps.CONST, dtypes.int, tuple(), arg=root.vin[0].dtype.itemsize, insert_before=uops.uops.index(root))
ptr = uops.add(UOps.ALU, dtypes.int, (root.vin[1], val), arg=BinaryOps.MUL, insert_before=uops.uops.index(root))
else: ptr = root.vin[1]
val = uops.add(UOps.CONST, dtypes.int, tuple(), arg=root.vin[0].dtype.itemsize, insert_before=uops.uops.index(root))
ptr = uops.add(UOps.ALU, dtypes.int, (root.vin[1], val), arg=BinaryOps.MUL, insert_before=uops.uops.index(root))
if ptr.uop == UOps.CONST: root.vin = (root.vin[0], ptr) + root.vin[2:]
else:
zero = uops.add(UOps.CONST, dtypes.int, tuple(), arg=0, cachable=False, insert_before=uops.uops.index(root))
@ -98,6 +96,7 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
if o in u.vin and u is not n:
u.vin = tuple(n if x == o else x for x in u.vin)
if rew := matcher.rewrite(u): replace[u] = rew
uops.remove_childless(set(x for x in uops if x.uop in {UOps.DEFINE_GLOBAL, UOps.PHI, UOps.ENDIF, UOps.ENDLOOP, UOps.STORE}))
def kk(*s: str): kernel.append("\n".join(s))