From ac9464f47a6c6091ee00584dc331675b393756df Mon Sep 17 00:00:00 2001 From: David Hou Date: Thu, 25 Apr 2024 07:44:43 -0700 Subject: [PATCH] allow specify number of beam workers (#4292) --- tinygrad/features/search.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tinygrad/features/search.py b/tinygrad/features/search.py index b7bd6d9f..9892b0b5 100644 --- a/tinygrad/features/search.py +++ b/tinygrad/features/search.py @@ -111,9 +111,11 @@ def beam_search(lin:Linearizer, rawbufs:List[Buffer], amt:int, allow_test_size=T beam: List[Tuple[Linearizer, float]] = [] seen_libs = set() - default_parallel, min_progress_micros = 1 if lin.opts.device in {"CUDA", "HSA", "AMD", "NV"} else 0, getenv("BEAM_MIN_PROGRESS",0.01) - if beam_pool is None and getenv("PARALLEL", default_parallel): - beam_pool = multiprocessing.get_context("spawn").Pool(multiprocessing.cpu_count(), _init_worker, (), getenv("BEAM_MAX_TASKS_PER_CHILD", 16)) + default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "HSA", "AMD", "NV"} else 0 + if beam_pool is None and (workers := getenv("PARALLEL", default_parallel)): + beam_pool = multiprocessing.get_context("spawn").Pool(workers, _init_worker, (), getenv("BEAM_MAX_TASKS_PER_CHILD", 16)) + + min_progress_micros = getenv("BEAM_MIN_PROGRESS", 0.01) try: rawbufs = _ensure_buffer_alloc(rawbufs)