mirror of https://github.com/commaai/tinygrad.git
allow specify number of beam workers (#4292)
This commit is contained in:
parent
74a1be88f5
commit
ac9464f47a
|
@ -111,9 +111,11 @@ def beam_search(lin:Linearizer, rawbufs:List[Buffer], amt:int, allow_test_size=T
|
||||||
beam: List[Tuple[Linearizer, float]] = []
|
beam: List[Tuple[Linearizer, float]] = []
|
||||||
seen_libs = set()
|
seen_libs = set()
|
||||||
|
|
||||||
default_parallel, min_progress_micros = 1 if lin.opts.device in {"CUDA", "HSA", "AMD", "NV"} else 0, getenv("BEAM_MIN_PROGRESS",0.01)
|
default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "HSA", "AMD", "NV"} else 0
|
||||||
if beam_pool is None and getenv("PARALLEL", default_parallel):
|
if beam_pool is None and (workers := getenv("PARALLEL", default_parallel)):
|
||||||
beam_pool = multiprocessing.get_context("spawn").Pool(multiprocessing.cpu_count(), _init_worker, (), getenv("BEAM_MAX_TASKS_PER_CHILD", 16))
|
beam_pool = multiprocessing.get_context("spawn").Pool(workers, _init_worker, (), getenv("BEAM_MAX_TASKS_PER_CHILD", 16))
|
||||||
|
|
||||||
|
min_progress_micros = getenv("BEAM_MIN_PROGRESS", 0.01)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
rawbufs = _ensure_buffer_alloc(rawbufs)
|
rawbufs = _ensure_buffer_alloc(rawbufs)
|
||||||
|
|
Loading…
Reference in New Issue