allow specify number of beam workers (#4292)

This commit is contained in:
David Hou 2024-04-25 07:44:43 -07:00 committed by GitHub
parent 74a1be88f5
commit ac9464f47a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 5 additions and 3 deletions

View File

@ -111,9 +111,11 @@ def beam_search(lin:Linearizer, rawbufs:List[Buffer], amt:int, allow_test_size=T
beam: List[Tuple[Linearizer, float]] = []
seen_libs = set()
default_parallel, min_progress_micros = 1 if lin.opts.device in {"CUDA", "HSA", "AMD", "NV"} else 0, getenv("BEAM_MIN_PROGRESS",0.01)
if beam_pool is None and getenv("PARALLEL", default_parallel):
beam_pool = multiprocessing.get_context("spawn").Pool(multiprocessing.cpu_count(), _init_worker, (), getenv("BEAM_MAX_TASKS_PER_CHILD", 16))
default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "HSA", "AMD", "NV"} else 0
if beam_pool is None and (workers := getenv("PARALLEL", default_parallel)):
beam_pool = multiprocessing.get_context("spawn").Pool(workers, _init_worker, (), getenv("BEAM_MAX_TASKS_PER_CHILD", 16))
min_progress_micros = getenv("BEAM_MIN_PROGRESS", 0.01)
try:
rawbufs = _ensure_buffer_alloc(rawbufs)