Optimize ptx loops (#4263)

* Optimize PTX loops

* Update assembly.py
This commit is contained in:
Szymon Ożóg 2024-04-23 10:20:14 +02:00 committed by GitHub
parent 967638f0d5
commit 6c25f1abf7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 5 additions and 2 deletions

View File

@ -338,7 +338,7 @@ class UOpGraph:
self.replace_op(u, new)
return True
def uoptimize(self):
def optimize_loops(self):
# get PHI node loop scope, link anything using a DEFINE_ACC to the loop as a "parent"
acc_scope: DefaultDict[UOp, List[UOp]] = defaultdict(list)
for u in self.uops:
@ -356,6 +356,9 @@ class UOpGraph:
while self.uops_optimization(get_recursive_parents): pass
self.simplify_phi_loops(get_recursive_parents)
def uoptimize(self):
self.optimize_loops()
# (recursively) remove childless uops
# TODO: remove DEFINE_GLOBAL from here
self.remove_childless(set(x for x in self.uops if x.uop in {UOps.DEFINE_GLOBAL, UOps.STORE}))

View File

@ -83,7 +83,6 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,), None),), root.arg)),
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool})},
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,), None),), root.arg)),
])
# here we do a pretransform on UOps to fix some shortcomings of PTX
@ -92,6 +91,7 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
for pointer_op in list(filter(lambda uop: uop.uop in [UOps.LOAD, UOps.STORE], uops.uops)): ptr_ar(pointer_op, uops)
uops.remove_childless(set(x for x in uops if x.uop in {UOps.DEFINE_GLOBAL, UOps.PHI, UOps.ENDIF, UOps.ENDLOOP, UOps.STORE}))
uops.optimize_loops()
def kk(*s: str): kernel.append("\n".join(s))