diff --git a/RELEASES.md b/RELEASES.md index 3acc674e..8ebc06cc 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,3 +1,11 @@ +Carrot2-v9 (2025-12-03) +======================== +* ST model +* fix CasperEV FCA11 +* fix DriverMonitoring alert (for USA) +* apply livePose +* update sensor code + Carrot2-v9 (2025-10-17) ======================== * Nuggets In Dijon model diff --git a/SConstruct b/SConstruct index a4d1739f..54f04cd9 100644 --- a/SConstruct +++ b/SConstruct @@ -330,12 +330,9 @@ Export('env', 'qt_env', 'arch', 'real_arch') # Build common module SConscript(['common/SConscript']) -Import('_common', '_gpucommon') - +Import('_common') common = [_common, 'json11', 'zmq'] -gpucommon = [_gpucommon] - -Export('common', 'gpucommon') +Export('common') # Build messaging (cereal + msgq + socketmaster + their dependencies) # Enable swaglog include in submodules @@ -365,7 +362,6 @@ SConscript([ ]) if arch != "Darwin": SConscript([ - 'system/sensord/SConscript', 'system/logcatd/SConscript', ]) diff --git a/cereal/services.py b/cereal/services.py index 2c3c80de..76c9faab 100644 --- a/cereal/services.py +++ b/cereal/services.py @@ -52,7 +52,7 @@ _services: dict[str, tuple] = { "clocks": (True, 0.1, 1), "ubloxRaw": (True, 20.), "livePose": (True, 20., 4), - "liveLocationKalman": (True, 20., 5), + #"liveLocationKalman": (True, 20., 5), "liveParameters": (True, 20., 5), "cameraOdometry": (True, 20., 10), "thumbnail": (True, 1 / 60., 1), diff --git a/common/SConscript b/common/SConscript index 829db6ee..0891b790 100644 --- a/common/SConscript +++ b/common/SConscript @@ -4,22 +4,13 @@ common_libs = [ 'params.cc', 'swaglog.cc', 'util.cc', - 'i2c.cc', 'watchdog.cc', - 'ratekeeper.cc' -] - -if arch != "Darwin": - common_libs.append('gpio.cc') - -_common = env.Library('common', common_libs, LIBS="json11") - -files = [ + 'ratekeeper.cc', 'clutil.cc', ] -_gpucommon = env.Library('gpucommon', files) -Export('_common', '_gpucommon') +_common = env.Library('common', common_libs, LIBS="json11") +Export('_common') if GetOption('extras'): env.Program('tests/test_common', diff --git a/common/constants.py b/common/constants.py new file mode 100644 index 00000000..7ca425c4 --- /dev/null +++ b/common/constants.py @@ -0,0 +1,23 @@ +import numpy as np + +# conversions +class CV: + # Speed + MPH_TO_KPH = 1.609344 + KPH_TO_MPH = 1. / MPH_TO_KPH + MS_TO_KPH = 3.6 + KPH_TO_MS = 1. / MS_TO_KPH + MS_TO_MPH = MS_TO_KPH * KPH_TO_MPH + MPH_TO_MS = MPH_TO_KPH * KPH_TO_MS + MS_TO_KNOTS = 1.9438 + KNOTS_TO_MS = 1. / MS_TO_KNOTS + + # Angle + DEG_TO_RAD = np.pi / 180. + RAD_TO_DEG = 1. / DEG_TO_RAD + + # Mass + LB_TO_KG = 0.453592 + + +ACCELERATION_DUE_TO_GRAVITY = 9.81 # m/s^2 diff --git a/common/filter_simple.py b/common/filter_simple.py index b690b8b2..7918768b 100644 --- a/common/filter_simple.py +++ b/common/filter_simple.py @@ -20,6 +20,23 @@ class FirstOrderFilter: self.x = x return self.x + +class BounceFilter(FirstOrderFilter): + def __init__(self, x0, rc, dt, initialized=True, bounce=2): + self.velocity = FirstOrderFilter(0.0, 0.15, dt) + self.bounce = bounce + super().__init__(x0, rc, dt, initialized) + + def update(self, x): + super().update(x) + scale = self.dt / (1.0 / 60.0) # tuned at 60 fps + self.velocity.x += (x - self.x) * self.bounce * scale * self.dt + self.velocity.update(0.0) + if abs(self.velocity.x) < 1e-5: + self.velocity.x = 0.0 + self.x += self.velocity.x + return self.x + class MyMovingAverage: def __init__(self, window_size, value=None): self.window_size = window_size diff --git a/common/gpio.cc b/common/gpio.cc deleted file mode 100644 index dd7ba34b..00000000 --- a/common/gpio.cc +++ /dev/null @@ -1,84 +0,0 @@ -#include "common/gpio.h" - -#include - -#ifdef __APPLE__ -int gpio_init(int pin_nr, bool output) { - return 0; -} - -int gpio_set(int pin_nr, bool high) { - return 0; -} - -int gpiochip_get_ro_value_fd(const char* consumer_label, int gpiochiop_id, int pin_nr) { - return 0; -} - -#else - -#include -#include - -#include -#include -#include - -#include "common/util.h" -#include "common/swaglog.h" - -int gpio_init(int pin_nr, bool output) { - char pin_dir_path[50]; - int pin_dir_path_len = snprintf(pin_dir_path, sizeof(pin_dir_path), - "/sys/class/gpio/gpio%d/direction", pin_nr); - if (pin_dir_path_len <= 0) { - return -1; - } - const char *value = output ? "out" : "in"; - return util::write_file(pin_dir_path, (void*)value, strlen(value)); -} - -int gpio_set(int pin_nr, bool high) { - char pin_val_path[50]; - int pin_val_path_len = snprintf(pin_val_path, sizeof(pin_val_path), - "/sys/class/gpio/gpio%d/value", pin_nr); - if (pin_val_path_len <= 0) { - return -1; - } - return util::write_file(pin_val_path, (void*)(high ? "1" : "0"), 1); -} - -int gpiochip_get_ro_value_fd(const char* consumer_label, int gpiochiop_id, int pin_nr) { - - // Assumed that all interrupt pins are unexported and rights are given to - // read from gpiochip0. - std::string gpiochip_path = "/dev/gpiochip" + std::to_string(gpiochiop_id); - int fd = open(gpiochip_path.c_str(), O_RDONLY); - if (fd < 0) { - LOGE("Error opening gpiochip0 fd"); - return -1; - } - - // Setup event - struct gpioevent_request rq; - rq.lineoffset = pin_nr; - rq.handleflags = GPIOHANDLE_REQUEST_INPUT; - - /* Requesting both edges as the data ready pulse from the lsm6ds sensor is - very short(75us) and is mostly detected as falling edge instead of rising. - So if it is detected as rising the following falling edge is skipped. */ - rq.eventflags = GPIOEVENT_REQUEST_BOTH_EDGES; - - strncpy(rq.consumer_label, consumer_label, std::size(rq.consumer_label) - 1); - int ret = util::safe_ioctl(fd, GPIO_GET_LINEEVENT_IOCTL, &rq); - if (ret == -1) { - LOGE("Unable to get line event from ioctl : %s", strerror(errno)); - close(fd); - return -1; - } - - close(fd); - return rq.fd; -} - -#endif diff --git a/common/gpio.h b/common/gpio.h deleted file mode 100644 index 89cdedd6..00000000 --- a/common/gpio.h +++ /dev/null @@ -1,33 +0,0 @@ -#pragma once - -// Pin definitions -#ifdef QCOM2 - #define GPIO_HUB_RST_N 30 - #define GPIO_UBLOX_RST_N 32 - #define GPIO_UBLOX_SAFEBOOT_N 33 - #define GPIO_GNSS_PWR_EN 34 /* SCHEMATIC LABEL: GPIO_UBLOX_PWR_EN */ - #define GPIO_STM_RST_N 124 - #define GPIO_STM_BOOT0 134 - #define GPIO_BMX_ACCEL_INT 21 - #define GPIO_BMX_GYRO_INT 23 - #define GPIO_BMX_MAGN_INT 87 - #define GPIO_LSM_INT 84 - #define GPIOCHIP_INT 0 -#else - #define GPIO_HUB_RST_N 0 - #define GPIO_UBLOX_RST_N 0 - #define GPIO_UBLOX_SAFEBOOT_N 0 - #define GPIO_GNSS_PWR_EN 0 /* SCHEMATIC LABEL: GPIO_UBLOX_PWR_EN */ - #define GPIO_STM_RST_N 0 - #define GPIO_STM_BOOT0 0 - #define GPIO_BMX_ACCEL_INT 0 - #define GPIO_BMX_GYRO_INT 0 - #define GPIO_BMX_MAGN_INT 0 - #define GPIO_LSM_INT 0 - #define GPIOCHIP_INT 0 -#endif - -int gpio_init(int pin_nr, bool output); -int gpio_set(int pin_nr, bool high); - -int gpiochip_get_ro_value_fd(const char* consumer_label, int gpiochiop_id, int pin_nr); diff --git a/common/i2c.cc b/common/i2c.cc deleted file mode 100644 index 3d6c79ef..00000000 --- a/common/i2c.cc +++ /dev/null @@ -1,92 +0,0 @@ -#include "common/i2c.h" - -#include -#include -#include - -#include -#include -#include - -#include "common/swaglog.h" -#include "common/util.h" - -#define UNUSED(x) (void)(x) - -#ifdef QCOM2 -// TODO: decide if we want to install libi2c-dev everywhere -extern "C" { - #include - #include -} - -I2CBus::I2CBus(uint8_t bus_id) { - char bus_name[20]; - snprintf(bus_name, 20, "/dev/i2c-%d", bus_id); - - i2c_fd = HANDLE_EINTR(open(bus_name, O_RDWR)); - if (i2c_fd < 0) { - throw std::runtime_error("Failed to open I2C bus"); - } -} - -I2CBus::~I2CBus() { - if (i2c_fd >= 0) { - close(i2c_fd); - } -} - -int I2CBus::read_register(uint8_t device_address, uint register_address, uint8_t *buffer, uint8_t len) { - std::lock_guard lk(m); - - int ret = 0; - - ret = HANDLE_EINTR(ioctl(i2c_fd, I2C_SLAVE, device_address)); - if (ret < 0) { goto fail; } - - ret = i2c_smbus_read_i2c_block_data(i2c_fd, register_address, len, buffer); - if ((ret < 0) || (ret != len)) { goto fail; } - -fail: - return ret; -} - -int I2CBus::set_register(uint8_t device_address, uint register_address, uint8_t data) { - std::lock_guard lk(m); - - int ret = 0; - - ret = HANDLE_EINTR(ioctl(i2c_fd, I2C_SLAVE, device_address)); - if (ret < 0) { goto fail; } - - ret = i2c_smbus_write_byte_data(i2c_fd, register_address, data); - if (ret < 0) { goto fail; } - -fail: - return ret; -} - -#else - -I2CBus::I2CBus(uint8_t bus_id) { - UNUSED(bus_id); - i2c_fd = -1; -} - -I2CBus::~I2CBus() {} - -int I2CBus::read_register(uint8_t device_address, uint register_address, uint8_t *buffer, uint8_t len) { - UNUSED(device_address); - UNUSED(register_address); - UNUSED(buffer); - UNUSED(len); - return -1; -} - -int I2CBus::set_register(uint8_t device_address, uint register_address, uint8_t data) { - UNUSED(device_address); - UNUSED(register_address); - UNUSED(data); - return -1; -} -#endif diff --git a/common/i2c.h b/common/i2c.h deleted file mode 100644 index ca0d4635..00000000 --- a/common/i2c.h +++ /dev/null @@ -1,19 +0,0 @@ -#pragma once - -#include -#include - -#include - -class I2CBus { - private: - int i2c_fd; - std::mutex m; - - public: - I2CBus(uint8_t bus_id); - ~I2CBus(); - - int read_register(uint8_t device_address, uint register_address, uint8_t *buffer, uint8_t len); - int set_register(uint8_t device_address, uint register_address, uint8_t data); -}; diff --git a/common/mock/__init__.py b/common/mock/__init__.py index 86d44d1b..9890d73c 100644 --- a/common/mock/__init__.py +++ b/common/mock/__init__.py @@ -13,7 +13,7 @@ from openpilot.common.realtime import Ratekeeper MOCK_GENERATOR = { - "livePose": generate_livePose + "livePose": generate_livePose, "liveLocationKalman": generate_liveLocationKalman } diff --git a/common/params_keys.h b/common/params_keys.h index 846c095e..0a15db5c 100644 --- a/common/params_keys.h +++ b/common/params_keys.h @@ -183,6 +183,7 @@ inline static std::unordered_map keys = { {"AutoNaviSpeedSafetyFactor", PERSISTENT}, {"AutoNaviCountDownMode", PERSISTENT}, {"TurnSpeedControlMode", PERSISTENT}, + {"CarrotSmartSpeedControl", PERSISTENT}, {"MapTurnSpeedFactor", PERSISTENT}, {"ModelTurnSpeedFactor", PERSISTENT}, {"StoppingAccel", PERSISTENT}, @@ -281,6 +282,9 @@ inline static std::unordered_map keys = { {"MuteDoor", PERSISTENT}, {"MuteSeatbelt", PERSISTENT}, {"CarrotException", CLEAR_ON_MANAGER_START}, + {"CarrotSpeed", PERSISTENT}, + {"CarrotSpeedViz", PERSISTENT}, + {"CarrotSpeedTable", PERSISTENT}, {"CarName", PERSISTENT}, {"EVTable", PERSISTENT}, {"LongPitch", PERSISTENT}, diff --git a/common/utils.py b/common/utils.py new file mode 100644 index 00000000..89c0601f --- /dev/null +++ b/common/utils.py @@ -0,0 +1,118 @@ +import io +import os +import tempfile +import contextlib +import subprocess +import time +import functools +from subprocess import Popen, PIPE, TimeoutExpired +import zstandard as zstd +from openpilot.common.swaglog import cloudlog + +LOG_COMPRESSION_LEVEL = 10 # little benefit up to level 15. level ~17 is a small step change + + +class CallbackReader: + """Wraps a file, but overrides the read method to also + call a callback function with the number of bytes read so far.""" + def __init__(self, f, callback, *args): + self.f = f + self.callback = callback + self.cb_args = args + self.total_read = 0 + + def __getattr__(self, attr): + return getattr(self.f, attr) + + def read(self, *args, **kwargs): + chunk = self.f.read(*args, **kwargs) + self.total_read += len(chunk) + self.callback(*self.cb_args, self.total_read) + return chunk + + +@contextlib.contextmanager +def atomic_write_in_dir(path: str, mode: str = 'w', buffering: int = -1, encoding: str | None = None, newline: str | None = None, + overwrite: bool = False): + """Write to a file atomically using a temporary file in the same directory as the destination file.""" + dir_name = os.path.dirname(path) + + if not overwrite and os.path.exists(path): + raise FileExistsError(f"File '{path}' already exists. To overwrite it, set 'overwrite' to True.") + + with tempfile.NamedTemporaryFile(mode=mode, buffering=buffering, encoding=encoding, newline=newline, dir=dir_name, delete=False) as tmp_file: + yield tmp_file + tmp_file_name = tmp_file.name + os.replace(tmp_file_name, path) + + +def get_upload_stream(filepath: str, should_compress: bool) -> tuple[io.BufferedIOBase, int]: + if not should_compress: + file_size = os.path.getsize(filepath) + file_stream = open(filepath, "rb") + return file_stream, file_size + + # Compress the file on the fly + compressed_stream = io.BytesIO() + compressor = zstd.ZstdCompressor(level=LOG_COMPRESSION_LEVEL) + + with open(filepath, "rb") as f: + compressor.copy_stream(f, compressed_stream) + compressed_size = compressed_stream.tell() + compressed_stream.seek(0) + return compressed_stream, compressed_size + + +# remove all keys that end in DEPRECATED +def strip_deprecated_keys(d): + for k in list(d.keys()): + if isinstance(k, str): + if k.endswith('DEPRECATED'): + d.pop(k) + elif isinstance(d[k], dict): + strip_deprecated_keys(d[k]) + return d + + +def run_cmd(cmd: list[str], cwd=None, env=None) -> str: + return subprocess.check_output(cmd, encoding='utf8', cwd=cwd, env=env).strip() + + +def run_cmd_default(cmd: list[str], default: str = "", cwd=None, env=None) -> str: + try: + return run_cmd(cmd, cwd=cwd, env=env) + except subprocess.CalledProcessError: + return default + + +@contextlib.contextmanager +def managed_proc(cmd: list[str], env: dict[str, str]): + proc = Popen(cmd, env=env, stdout=PIPE, stderr=PIPE) + try: + yield proc + finally: + if proc.poll() is None: + proc.terminate() + try: + proc.wait(timeout=5) + except TimeoutExpired: + proc.kill() + + +def retry(attempts=3, delay=1.0, ignore_failure=False): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + for _ in range(attempts): + try: + return func(*args, **kwargs) + except Exception: + cloudlog.exception(f"{func.__name__} failed, trying again") + time.sleep(delay) + + if ignore_failure: + cloudlog.error(f"{func.__name__} failed after retry") + else: + raise Exception(f"{func.__name__} failed after retry") + return wrapper + return decorator diff --git a/launch_chffrplus.sh b/launch_chffrplus.sh index e67eec5a..ffd2fe28 100755 --- a/launch_chffrplus.sh +++ b/launch_chffrplus.sh @@ -88,6 +88,12 @@ function launch { echo "shapely installing." pip install shapely fi + if python -c "import kaitaistruct" > /dev/null 2>&1; then + echo "kaitaistruct already installed." + else + echo "kaitaistruct installing." + pip install kaitaistruct + fi # events language init #LANG=$(cat ${PARAMS_ROOT}/d/LanguageSetting) diff --git a/opendbc_repo/opendbc/car/hyundai/carcontroller.py b/opendbc_repo/opendbc/car/hyundai/carcontroller.py index 6489aaf9..b5a44cbb 100644 --- a/opendbc_repo/opendbc/car/hyundai/carcontroller.py +++ b/opendbc_repo/opendbc/car/hyundai/carcontroller.py @@ -333,15 +333,17 @@ class CarController(CarControllerBase): if self.CP.carFingerprint in CAN_GEARS["send_mdps12"]: # send mdps12 to LKAS to prevent LKAS error can_sends.append(hyundaican.create_mdps12(self.packer, self.frame, CS.mdps12)) + casper_opt = self.CP.carFingerprint in (CAR.HYUNDAI_CASPER_EV) if self.frame % 2 == 0 and self.CP.openpilotLongitudinalControl: self.hyundai_jerk.make_jerk(self.CP, CS, accel, actuators, hud_control) self.hyundai_jerk.check_carrot_cruise(CC, CS, hud_control, stopping, accel, actuators.aTarget) #jerk = 3.0 if actuators.longControlState == LongCtrlState.pid else 1.0 use_fca = self.CP.flags & HyundaiFlags.USE_FCA.value if camera_scc: + can_sends.extend(hyundaican.create_acc_commands_scc(self.packer, CC.enabled, accel, self.hyundai_jerk, int(self.frame / 2), hud_control, set_speed_in_units, stopping, - CC.cruiseControl.override, use_fca, CS, self.soft_hold_mode)) + CC.cruiseControl.override, casper_opt, CS, self.soft_hold_mode)) else: can_sends.extend(hyundaican.create_acc_commands(self.packer, CC.enabled, accel, self.hyundai_jerk, int(self.frame / 2), hud_control, set_speed_in_units, stopping, @@ -355,8 +357,10 @@ class CarController(CarControllerBase): # 5 Hz ACC options if self.frame % 20 == 0 and self.CP.openpilotLongitudinalControl: if camera_scc: - #if CS.scc13 is not None: - # can_sends.append(hyundaican.create_acc_opt_copy(CS, self.packer)) + if CS.scc13 is not None: + if casper_opt: + #can_sends.append(hyundaican.create_acc_opt_copy(CS, self.packer)) + pass pass else: can_sends.extend(hyundaican.create_acc_opt(self.packer, self.CP)) diff --git a/opendbc_repo/opendbc/car/hyundai/carstate.py b/opendbc_repo/opendbc/car/hyundai/carstate.py index 0704dd81..cae39f49 100644 --- a/opendbc_repo/opendbc/car/hyundai/carstate.py +++ b/opendbc_repo/opendbc/car/hyundai/carstate.py @@ -129,13 +129,15 @@ class CarState(CarStateBase): ecu_disabled = True if ecu_disabled: - self.SCC11 = self.SCC12 = self.SCC13 = self.SCC14 = False + self.SCC11 = self.SCC12 = self.SCC13 = self.SCC14 = self.FCA11 = False else: bus_cruise = 2 if self.CP.flags & HyundaiFlags.CAMERA_SCC else 0 self.SCC11 = True if 1056 in fingerprints[bus_cruise] else False self.SCC12 = True if 1057 in fingerprints[bus_cruise] else False self.SCC13 = True if 1290 in fingerprints[bus_cruise] else False self.SCC14 = True if 905 in fingerprints[bus_cruise] else False + self.FCA11 = False + self.FCA11_bus = Bus.cam self.HAS_LFA_BUTTON = True if 913 in fingerprints[0] else False self.CRUISE_BUTTON_ALT = True if 1007 in fingerprints[0] else False @@ -179,6 +181,12 @@ class CarState(CarStateBase): elif self.controls_ready_count == 100: print("cp_cam.seen_addresses =", cp_cam.seen_addresses) print("cp.seen_addresses =", cp.seen_addresses) + if 909 in cp_cam.seen_addresses: + self.FCA11 = True + self.FCA11_bus = Bus.cam + elif 909 in cp.seen_addresses: + self.FCA11 = True + self.FCA11_bus = Bus.pt if cp_alt is not None: print("cp_alt.seen_addresses =", cp_alt.seen_addresses) @@ -360,6 +368,7 @@ class CarState(CarStateBase): self.scc12 = cp_cruise.vl["SCC12"] if self.SCC12 else None self.scc13 = cp_cruise.vl["SCC13"] if self.SCC13 else None self.scc14 = cp_cruise.vl["SCC14"] if self.SCC14 else None + self.fca11 = can_parsers[self.FCA11_bus].vl["FCA11"] if self.FCA11 else None cluSpeed = cp.vl["CLU11"]["CF_Clu_Vanz"] decimal = cp.vl["CLU11"]["CF_Clu_VanzDecimal"] if 0. < decimal < 0.5: diff --git a/opendbc_repo/opendbc/car/hyundai/hyundaican.py b/opendbc_repo/opendbc/car/hyundai/hyundaican.py index dad31979..eaaccf95 100644 --- a/opendbc_repo/opendbc/car/hyundai/hyundaican.py +++ b/opendbc_repo/opendbc/car/hyundai/hyundaican.py @@ -210,6 +210,14 @@ def create_acc_commands_scc(packer, enabled, accel, jerk, idx, hud_control, set_ values["ObjDistStat"] = objGap2 commands.append(packer.make_can_msg("SCC14", 0, values)) + if CS.fca11 is not None and use_fca: # CASPER_EV의 경우 FCA11에서 fail이 간헐적 발생함.. 그냥막자.. 원인불명.. + values = copy.copy(CS.fca11) + if values["FCA_Failinfo"] != 0: + values["FCA_Status"] = 2 + values["FCA_Failinfo"] = 0 + fca11_dat = packer.make_can_msg("FCA11", 0, values)[1] + values["CR_FCA_ChkSum"] = hyundai_checksum(fca11_dat[:7]) + commands.append(packer.make_can_msg("FCA11", 0, values)) # Only send FCA11 on cars where it exists on the bus if False: #use_fca: # note that some vehicles most likely have an alternate checksum/counter definition @@ -227,6 +235,10 @@ def create_acc_commands_scc(packer, enabled, accel, jerk, idx, hud_control, set_ return commands def create_acc_opt_copy(CS, packer): + values = copy.copy(CS.scc13) + if values["NEW_SIGNAL_1"] == 255: + values["NEW_SIGNAL_1"] = 218 + values["NEW_SIGNAL_2"] = 0 return packer.make_can_msg("SCC13", 0, CS.scc13) def create_acc_commands(packer, enabled, accel, jerk, idx, hud_control, set_speed, stopping, long_override, use_fca, CP, CS, soft_hold_mode): diff --git a/opendbc_repo/opendbc/car/hyundai/hyundaicanfd.py b/opendbc_repo/opendbc/car/hyundai/hyundaicanfd.py index bb142b30..e080226b 100644 --- a/opendbc_repo/opendbc/car/hyundai/hyundaicanfd.py +++ b/opendbc_repo/opendbc/car/hyundai/hyundaicanfd.py @@ -517,10 +517,11 @@ def create_ccnc_messages(CP, packer, CAN, frame, CC, CS, hud_control, disp_angle values["LKA_ICON"] = 4 if lat_active else 3 if lat_enabled else 0 values["FCA_ALT_ICON"] = 0 - if values["ALERTS_2"] in [1, 2, 5]: + if values["ALERTS_2"] in [1, 2, 5, 10, 22]: # 10,22: 운전자모니터 알람/경고 values["ALERTS_2"] = 0 values["DAW_ICON"] = 0 + values["SOUNDS_1"] = 0 # 운전자모니터경고음. values["SOUNDS_2"] = 0 # 2: STEER중지 경고후에도 사운드가 나옴. values["SOUNDS_4"] = 0 # 차선변경알림? 에이 그냥0으로.. diff --git a/opendbc_repo/opendbc/safety/safety/safety_hyundai.h b/opendbc_repo/opendbc/safety/safety/safety_hyundai.h index cbaa854b..cb16ed9f 100644 --- a/opendbc_repo/opendbc/safety/safety/safety_hyundai.h +++ b/opendbc_repo/opendbc/safety/safety/safety_hyundai.h @@ -200,6 +200,7 @@ uint32_t last_ts_scc12_from_op = 0; uint32_t last_ts_scc13_from_op = 0; uint32_t last_ts_mdps12_from_op = 0; uint32_t last_ts_fca11_from_op = 0; +uint32_t last_ts_fca12_from_op = 0; static bool hyundai_tx_hook(const CANPacket_t *to_send) { const TorqueSteeringLimits HYUNDAI_STEERING_LIMITS = HYUNDAI_LIMITS(512, 10, 10); @@ -210,7 +211,7 @@ static bool hyundai_tx_hook(const CANPacket_t *to_send) { int addr = GET_ADDR(to_send); // FCA11: Block any potential actuation - if (addr == 0x38D) { + if (false && addr == 0x38D) { int CR_VSM_DecCmd = GET_BYTE(to_send, 1); bool FCA_CmdAct = GET_BIT(to_send, 20U); bool CF_VSM_DecCmdAct = GET_BIT(to_send, 31U); @@ -277,16 +278,19 @@ static bool hyundai_tx_hook(const CANPacket_t *to_send) { tx = false; } } + uint32_t now = microsecond_timer_get(); if(addr == 832) - last_ts_lkas11_from_op = (tx == 0 ? 0 : microsecond_timer_get()); + last_ts_lkas11_from_op = (tx == 0 ? 0 : now); else if(addr == 1057) - last_ts_scc12_from_op = (tx == 0 ? 0 : microsecond_timer_get()); + last_ts_scc12_from_op = (tx == 0 ? 0 : now); else if(addr == 593) - last_ts_mdps12_from_op = (tx == 0 ? 0 : microsecond_timer_get()); - else if(addr == 909) - last_ts_fca11_from_op = (tx == 0 ? 0 : microsecond_timer_get()); + last_ts_mdps12_from_op = (tx == 0 ? 0 : now); + else if (addr == 909) + last_ts_fca11_from_op = (tx == 0 ? 0 : now); + else if (addr == 1155) + last_ts_fca12_from_op = (tx == 0 ? 0 : now); else if(addr == 1290) - last_ts_scc13_from_op = (tx == 0 ? 0 : microsecond_timer_get()); + last_ts_scc13_from_op = (tx == 0 ? 0 : now); return tx; } @@ -313,9 +317,10 @@ static int hyundai_fwd_hook(int bus_num, int addr) { bool is_lfahda_msg = addr == 1157; bool is_scc_msg = addr == 1056 || addr == 1057 || addr == 905; bool is_scc13_msg = addr == 1290; - bool is_fca_msg = addr == 909 || addr == 1155; + bool is_fca11_msg = addr == 909; + bool is_fca12_msg = addr == 1155; - bool block_msg = is_lkas_msg || is_lfahda_msg || is_scc_msg || is_scc13_msg; //|| is_fca_msg; + bool block_msg = is_lkas_msg || is_lfahda_msg || is_scc_msg || is_scc13_msg || is_fca11_msg || is_fca12_msg; if (!block_msg) { bus_fwd = 0; } @@ -330,11 +335,15 @@ static int hyundai_fwd_hook(int bus_num, int addr) { bus_fwd = 0; } else if (is_scc13_msg) { - if (now - last_ts_scc13_from_op >= 400000) + if (now - last_ts_scc13_from_op >= 800000) bus_fwd = 0; } - else if(is_fca_msg) { - if(now - last_ts_fca11_from_op >= 400000) + else if (is_fca11_msg) { + if (now - last_ts_fca11_from_op >= 400000) + bus_fwd = 0; + } + else if (is_fca12_msg) { + if (now - last_ts_fca12_from_op >= 400000) bus_fwd = 0; } } diff --git a/restart.sh b/restart.sh index 09263207..d198531b 100755 --- a/restart.sh +++ b/restart.sh @@ -1,2 +1,2 @@ git pull -tmux kill-session -t comma; rm -f /tmp/safe_staging_overlay.lock; sleep 1;tmux new -s comma -d "/data/openpilot/launch_openpilot.sh" +tmux kill-session -t comma; rm -f /tmp/safe_staging_overlay.lock; sleep 1;tmux new -s comma -d "bash -lc '/data/openpilot/launch_openpilot.sh'" diff --git a/selfdrive/car/cruise.py b/selfdrive/car/cruise.py index fbb5ecf8..c029677c 100644 --- a/selfdrive/car/cruise.py +++ b/selfdrive/car/cruise.py @@ -247,6 +247,7 @@ class VCruiseCarrot: self.autoGasSyncSpeed = self.params.get_bool("AutoGasSyncSpeed") * unit_factor self.autoSpeedUptoRoadSpeedLimit = self.params.get_float("AutoSpeedUptoRoadSpeedLimit") * 0.01 self.autoRoadSpeedAdjust = self.params.get_float("AutoRoadSpeedAdjust") * 0.01 + self.smartSpeedControl = self.params.get_int("CarrotSmartSpeedControl") useLaneLineSpeed = self.params.get_int("UseLaneLineSpeed") * unit_factor if self.useLaneLineSpeed != useLaneLineSpeed: @@ -441,6 +442,20 @@ class VCruiseCarrot: return button_kph, button_type, self.long_pressed def _carrot_command(self, v_cruise_kph, button_type, long_pressed): + carrot_speed = self.params_memory.get_int("CarrotSpeed") + if carrot_speed != 0: + if carrot_speed > 0: + if self.smartSpeedControl in [1,3]: + v_cruise_kph = max(carrot_speed, v_cruise_kph) + else: + if self.smartSpeedControl == 3: + v_cruise_kph = -carrot_speed + #elif self.smartSpeedControl == 1: + # v_cruise_kph = max(-carrot_speed, v_cruise_kph) + elif self.smartSpeedControl == 2: + v_cruise_kph = min(-carrot_speed, v_cruise_kph) + self.params_memory.put_int_nonblocking("CarrotSpeed", 0) + self._add_log(f"Carrot speed set to {v_cruise_kph}") if self.carrot_cmd_index_last != self.carrot_cmd_index: self.carrot_cmd_index_last = self.carrot_cmd_index print(f"Carrot command(cruise.py): {self.carrot_cmd} {self.carrot_arg}") diff --git a/selfdrive/carrot/carrot_man.py b/selfdrive/carrot/carrot_man.py index b423149d..30d08b4e 100644 --- a/selfdrive/carrot/carrot_man.py +++ b/selfdrive/carrot/carrot_man.py @@ -22,6 +22,9 @@ from openpilot.selfdrive.navd.helpers import Coordinate from opendbc.car.common.conversions import Conversions as CV from openpilot.selfdrive.carrot.carrot_serv import CarrotServ +from openpilot.selfdrive.carrot.carrot_speed import CarrotSpeed + +from openpilot.common.gps import get_gps_location_service try: from shapely.geometry import LineString @@ -185,7 +188,8 @@ class CarrotMan: print("************************************************CarrotMan init************************************************") self.params = Params() self.params_memory = Params("/dev/shm/params") - self.sm = messaging.SubMaster(['deviceState', 'carState', 'controlsState', 'longitudinalPlan', 'modelV2', 'selfdriveState', 'carControl', 'navRouteNavd', 'liveLocationKalman', 'navInstruction']) + self.gps_location_service = get_gps_location_service(self.params) + self.sm = messaging.SubMaster(['deviceState', 'carState', 'controlsState', 'radarState', 'longitudinalPlan', 'modelV2', 'selfdriveState', 'carControl', 'navRouteNavd', self.gps_location_service, 'navInstruction']) self.pm = messaging.PubMaster(['carrotMan', "navRoute", "navInstructionCarrot"]) self.carrot_serv = CarrotServ() @@ -258,9 +262,22 @@ class CarrotMan: sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) frame = 0 self.save_toggle_values() - rk = Ratekeeper(10, print_delay_threshold=None) + + carrot_speed = CarrotSpeed(neighbor_ring=2) + self.params_memory.put_int_nonblocking("CarrotSpeed", 0) + + rk = Ratekeeper(20, print_delay_threshold=None) carrotIndex_last = self.carrot_serv.carrotIndex + phone_gps_frame = self.carrot_serv.phone_gps_frame + carrot_speed_active_count = 0 + self.v_cruise_last = 0 + self.long_active = False + self.v_cruise_change = 0 + self._last_vt = 0.0 + self.gas_pressed_count = 0 + self._last_viz_t = 0.0 + while self.is_running: try: self.sm.update(0) @@ -273,7 +290,16 @@ class CarrotMan: #print("coords=", coords) #print("curvatures=", curvatures) - self.carrot_serv.update_navi(remote_ip, self.sm, self.pm, vturn_speed, coords, distances, route_speed) + self.carrot_serv.update_navi(remote_ip, self.sm, self.pm, vturn_speed, coords, distances, route_speed, self.gps_location_service) + + if phone_gps_frame != self.carrot_serv.phone_gps_frame: + phone_gps_frame = self.carrot_serv.phone_gps_frame + carrot_speed_active_count = 10 + else: + carrot_speed_active_count -= 1 + + if carrot_speed_active_count > 0: + self.carrot_speed_serv(carrot_speed, frame) if frame % 20 == 0 or remote_addr is not None: try: @@ -318,6 +344,71 @@ class CarrotMan: traceback.print_exc() time.sleep(1) + def carrot_speed_serv(self, carrot_speed, frame): + v_ego = a_ego = 0.0 + gas_pressed = False + if self.sm.alive['carState'] and self.sm.alive['carControl']: + CS = self.sm['carState'] + CC = self.sm['carControl'] + v_ego = CS.vEgo + a_ego = CS.aEgo + gas_pressed = CS.gasPressed + v_ego_kph = v_ego * 3.6 + if gas_pressed: + self.gas_pressed_count = 200 + self.v_cruise_change = 0 + elif self._last_vt == CS.vCruise: + self.v_cruise_last = CS.vCruise + elif self.long_active and CC.longActive and self.gas_pressed_count == 0: + if self.v_cruise_last < CS.vCruise: # 속도가 증가하면 + self.v_cruise_change = 100 + elif self.v_cruise_last > CS.vCruise: # 속도가 감소하면 + if v_ego_kph < CS.vCruise: # 주행속도가 느리면 + self.v_cruise_change = 100 + else: # 주행속도가 빠르면 + self.v_cruise_change = -100 + + if self.v_cruise_change != 0: + self.gas_pressed_count = 0 + else: + self.v_cruise_change = 0 + self.long_active = CC.longActive + self.v_cruise_last = CS.vCruise + else: + self.v_cruise_change = 0 + + now = time.monotonic() + heading = self.carrot_serv.bearing #nPosAnglePhone + lat, lon = self.carrot_serv.vpPosPointLat, self.carrot_serv.vpPosPointLon #self.carrot_serv.estimate_position(self.carrot_serv.phone_latitude, self.carrot_serv.phone_longitude, heading, v_ego, now - self.carrot_serv.last_update_gps_time_phone) + vt = carrot_speed.query_target_dist(lat, lon, heading, 0.0) + if self.v_cruise_change != 0: + carrot_speed.add_sample(lat, lon, heading, self.v_cruise_last if self.v_cruise_change > 0 else (- self.v_cruise_last)) + if self.v_cruise_change > 0: + self.v_cruise_change -= 1 + if self.v_cruise_change < 0: + self.v_cruise_change += 1 + else: + if self.gas_pressed_count > 0: + vt = max(vt, self.v_cruise_last) + carrot_speed.add_sample(lat, lon, heading, vt) + + self.params_memory.put_int_nonblocking("CarrotSpeed", int(vt)) + + self._last_vt = vt + if gas_pressed and a_ego < -0.5: #self._last_vt < 0.0: + carrot_speed.invalidate_last_hit(window_s=2.0, action="clear") + self.gas_pressed_count = max(0, self.gas_pressed_count - 1) + + if now - self._last_viz_t > 0.5: # 2Hz + self._last_viz_t = now + viz_json = carrot_speed.export_cells_around(lat, lon, heading, ring=2, max_points=64) + # 메모리 Params에 쓰는 게 좋음 (디스크 말고) + self.params_memory.put_nonblocking("CarrotSpeedViz", viz_json) + + carrot_speed.maybe_save() + + + def carrot_navi_route(self): if self.carrot_serv.active_carrot > 1: diff --git a/selfdrive/carrot/carrot_serv.py b/selfdrive/carrot/carrot_serv.py index 85d1d45b..2402aec1 100644 --- a/selfdrive/carrot/carrot_serv.py +++ b/selfdrive/carrot/carrot_serv.py @@ -19,6 +19,7 @@ from openpilot.common.filter_simple import MyMovingAverage from openpilot.system.hardware import PC, TICI from openpilot.selfdrive.navd.helpers import Coordinate from opendbc.car.common.conversions import Conversions as CV +from openpilot.common.gps import get_gps_location_service nav_type_mapping = { 12: ("turn", "left", 1), @@ -145,8 +146,11 @@ class CarrotServ: self.bearing = 0.0 self.gps_valid = False - self.gps_accuracy_phone = 0.0 + self.phone_gps_accuracy = 0.0 self.gps_accuracy_device = 0.0 + self.phone_latitude = 0.0 + self.phone_longitude = 0.0 + self.phone_gps_frame = 0 self.totalDistance = 0 self.xSpdLimit = 0 @@ -641,15 +645,14 @@ class CarrotServ: self.xSpdType = -1 self.xSpdDist = 0 - def _update_gps(self, v_ego, sm): - llk = 'liveLocationKalman' - location = sm[llk] + def _update_gps(self, v_ego, sm, gps_service): + gps = sm[gps_service] #print(f"location = {sm.valid[llk]}, {sm.updated[llk]}, {sm.recv_frame[llk]}, {sm.recv_time[llk]}") if not sm.updated['carState'] or not sm.updated['carControl']: # or not sm.updated[llk]: return self.nPosAngle CS = sm['carState'] CC = sm['carControl'] - self.gps_valid = (location.status == log.LiveLocationKalman.Status.valid) and location.positionGeodetic.valid + self.gps_valid = sm.updated[gps_service] and gps.hasFix now = time.monotonic() gps_updated_phone = (now - self.last_update_gps_time_phone) < 3 @@ -658,8 +661,8 @@ class CarrotServ: bearing = self.nPosAngle if gps_updated_phone: self.bearing_offset = 0.0 - elif sm.valid[llk]: - bearing = math.degrees(location.calibratedOrientationNED.value[2]) + elif self.gps_valid: + bearing = self.nPosAngle = gps.bearingDeg if self.gps_valid: self.bearing_offset = 0.0 elif self.active_carrot > 0: @@ -669,13 +672,13 @@ class CarrotServ: #print(f"bearing = {bearing:.1f}, posA=={self.nPosAngle:.1f}, posP=={self.nPosAnglePhone:.1f}, offset={self.bearing_offset:.1f}, {gps_updated_phone}, {gps_updated_navi}") gpsDelayTimeAdjust = 0.0 if gps_updated_navi: - gpsDelayTimeAdjust = 1.0 + gpsDelayTimeAdjust = 0 #1.0 external_gps_update_timedout = not (gps_updated_phone or gps_updated_navi) #print(f"gps_valid = {self.gps_valid}, bearing = {bearing:.1f}, pos = {location.positionGeodetic.value[0]:.6f}, {location.positionGeodetic.value[1]:.6f}") if self.gps_valid and external_gps_update_timedout: # 내부GPS가 자동하고 carrotman으로부터 gps신호가 없는경우 - self.vpPosPointLatNavi = location.positionGeodetic.value[0] - self.vpPosPointLonNavi = location.positionGeodetic.value[1] + self.vpPosPointLatNavi = gps.latitude + self.vpPosPointLonNavi = gps.longitude self.last_calculate_gps_time = now #sm.recv_time[llk] elif gps_updated_navi: # carrot navi로부터 gps신호가 수신되는 경우.. if abs(self.bearing_measured - bearing) < 0.1: @@ -853,7 +856,7 @@ class CarrotServ: self.xSpdDist = distance self.xSpdType =xSpdType - def update_navi(self, remote_ip, sm, pm, vturn_speed, coords, distances, route_speed): + def update_navi(self, remote_ip, sm, pm, vturn_speed, coords, distances, route_speed, gps_service): self.debugText = "" self.update_params() @@ -874,7 +877,7 @@ class CarrotServ: road_speed_limit_changed = True if self.nRoadLimitSpeed != self.nRoadLimitSpeed_last else False self.nRoadLimitSpeed_last = self.nRoadLimitSpeed #self.bearing = self.nPosAngle #self._update_gps(v_ego, sm) - self.bearing = self._update_gps(v_ego, sm) + self.bearing = self._update_gps(v_ego, sm, gps_service) self.xSpdDist = max(self.xSpdDist - delta_dist, -1000) self.xDistToTurn = self.xDistToTurn - delta_dist @@ -1281,15 +1284,20 @@ class CarrotServ: # 3초간 navi 데이터가 없으면, phone gps로 업데이트 if "latitude" in json: self.nPosAnglePhone = float(json.get("heading", self.nPosAngle)) + self.phone_latitude = float(json.get("latitude", self.vpPosPointLatNavi)) + self.phone_longitude = float(json.get("longitude", self.vpPosPointLonNavi)) + self.phone_gps_accuracy = float(json.get("accuracy", 0)) + if self.phone_gps_accuracy < 15.0: + self.phone_gps_frame += 1 if (now - self.last_update_gps_time_navi) > 3.0: - self.vpPosPointLatNavi = float(json.get("latitude", self.vpPosPointLatNavi)) - self.vpPosPointLonNavi = float(json.get("longitude", self.vpPosPointLonNavi)) + self.vpPosPointLatNavi = self.phone_latitude + self.vpPosPointLonNavi = self.phone_longitude + self.nPosAngle = self.nPosAnglePhone # self.nPosSpeed = self.ve # TODO speed from v_ego - self.last_update_gps_time_phone = self.last_calculate_gps_time = now - self.gps_accuracy_phone = float(json.get("accuracy", 0)) + self.last_update_gps_time_phone = self.last_calculate_gps_time = now self.nPosSpeed = float(json.get("gps_speed", 0)) - print(f"phone gps: {self.vpPosPointLatNavi}, {self.vpPosPointLonNavi}, {self.gps_accuracy_phone}, {self.nPosSpeed}") + print(f"phone gps: {self.vpPosPointLatNavi}, {self.vpPosPointLonNavi}, {self.phone_gps_accuracy}, {self.nPosSpeed}") import traceback diff --git a/selfdrive/carrot/carrot_speed.py b/selfdrive/carrot/carrot_speed.py new file mode 100644 index 00000000..da6c1425 --- /dev/null +++ b/selfdrive/carrot/carrot_speed.py @@ -0,0 +1,401 @@ +# -*- coding: utf-8 -*- +""" +CarrotSpeedTable v2.1 (Params backend, JSON+gzip, 1e-4° grid, 8 buckets) +- 저장 키: "CarrotSpeedTable" +- 포맷(JSON): {"format":"v5","dir_buckets":8,"cells":{"gy,gx":[[v,ts],...]} } +- gzip 저장/로드 지원 (기본 on). 기존 비압축 v2도 로드 가능. +- 격자: 위/경도 각 1e-4° 스냅(한국 위도에서 약 9~11m) +- 저장: 단일 speed(부호 포함)만 해당 셀 1곳에 기록 + * 입력 > 0: 기존 None/음수/더 작은 양수면 갱신(더 큰 +) + * 입력 < 0: 기존 None/양수/덜 음수면 갱신(더 작은 -) +- 조회: 전방 lookahead 셀 → 없으면 이웃 탐색(ring=1) + * 본셀: 시간 필터 없음 + * 이웃: 오래된 데이터만 사용(age ≥ 120s) +- 정리(청소) 없음: 오래된 데이터도 유지 +""" + +import json, math, threading, time, gzip +from typing import Optional, Tuple, Dict, List +from openpilot.common.params import Params + + +# ---------- 지오/도우미 ---------- + +def quantize_1e4(lat: float, lon: float) -> Tuple[int, int]: + gy = int(math.floor(lat * 1e4 + 0.5)) + gx = int(math.floor(lon * 1e4 + 0.5)) + return gy, gx + +def heading_to_bucket(heading_deg: float) -> int: + # 8 버킷 고정 + step = 45.0 # 360/8 + i = int((heading_deg % 360.0) // step) + if i < 0: return 0 + if i > 7: return 7 + return i + +DIR_8 = { + 0: ( 1, 0), # 북 + 1: ( 1, 1), # 북동 + 2: ( 0, 1), # 동 + 3: (-1, 1), # 남동 + 4: (-1, 0), # 남 + 5: (-1, -1), # 남서 + 6: ( 0, -1), # 서 + 7: ( 1, -1), # 북서 +} + +def project_point(lat: float, lon: float, heading_deg: float, distance_m: float) -> Tuple[float, float]: + if distance_m <= 0.0: + return lat, lon + R = 6_371_000.0 + h = math.radians(heading_deg) + dlat = (distance_m * math.cos(h)) / R + dlon = (distance_m * math.sin(h)) / (R * math.cos(math.radians(lat))) + return lat + math.degrees(dlat), lon + math.degrees(dlon) + +def _is_gzip(data: bytes) -> bool: + return len(data) >= 2 and data[0] == 0x1F and data[1] == 0x8B + + +# ---------- 메인 클래스 ---------- + +class CarrotSpeed: + KEY = "CarrotSpeedTable" + + def __init__(self, + neighbor_ring: int = 1, + neighbor_old_threshold_s: int = 120, + use_gzip: bool = True, + gzip_level: int = 5): + # 고정 사양 + self.buckets = 8 + + # 파라미터 + self.neighbor_ring = max(0, int(neighbor_ring)) + self.neighbor_old_threshold_s = int(neighbor_old_threshold_s) + self.use_gzip = bool(use_gzip) + self.gzip_level = int(gzip_level) + + # 내부 상태 + self._lock = threading.RLock() + # _cells[(gy,gx)] = [[value or None, ts(int seconds) or None] * 8] + self._cells: Dict[Tuple[int, int], List[List[Optional[float]]]] = {} + self._dirty = False + self._last_save = 0 + self._params = Params() + + self._load_from_params_if_exists() + + self._last_hit = None # (gy, gx, b, ts_when_read) + self._last_hit_read_ms = 0 # 밀리초 + + # ----- 내부 유틸 ----- + + def _ensure_cell(self, gy: int, gx: int) -> List[List[Optional[float]]]: + arr = self._cells.get((gy, gx)) + if arr is None: + arr = [[None, None] for _ in range(self.buckets)] # [v, ts] + self._cells[(gy, gx)] = arr + return arr + + def _now(self) -> int: + # int 초 + return int(time.time()) + + def _age(self, ts: Optional[float]) -> Optional[int]: + if ts is None: + return None + return self._now() - int(ts) + + def _neighbor_indices(self, gy: int, gx: int) -> List[Tuple[int, int]]: + r = self.neighbor_ring + if r <= 0: + return [] + out = [] + for dy in range(-r, r + 1): + for dx in range(-r, r + 1): + if dy == 0 and dx == 0: + continue + out.append((gy + dy, gx + dx)) + return out + + def _neighbors_8(self, gy, gx): + for dy in (-1, 0, 1): + for dx in (-1, 0, 1): + if dy == 0 and dx == 0: + continue + yield gy + dy, gx + dx + + def _try_cell_bucket_old(self, arr, b): + v, ts = arr[b] + if v is None or ts is None: + return None, None + if self._now() - int(ts) < self.neighbor_old_threshold_s: + return None, None + return float(v), b + # ----- 공용 API ----- + def export_cells_around(self, lat: float, lon: float, + heading_deg: float, + ring: int = 1, max_points: int = 64) -> str: + """ + 현재 lat, lon 기준 주변 그리드(ring 범위)에서 + 값이 있는 셀들을 (lat, lon, speed) 리스트로 JSON으로 반환. + Params("CarrotSpeedViz")에 그대로 넣을 용도. + """ + gy0, gx0 = quantize_1e4(lat, lon) + b0 = heading_to_bucket(heading_deg) + pts = [] + + with self._lock: + for dy in range(-ring, ring + 1): + for dx in range(-ring, ring + 1): + gy = gy0 + dy + gx = gx0 + dx + arr = self._cells.get((gy, gx)) + if not arr: + continue + + # 먼저 exact bucket(b0) + v, ts = arr[b0] + if v is not None: + cell_lat = (gy + 0.5) * 1e-4 + cell_lon = (gx + 0.5) * 1e-4 + pts.append([cell_lat, cell_lon, float(v)]) + if len(pts) >= max_points: + return json.dumps({"pts": pts}, separators=(",",":")) + + # 없다면 좌/우 + for b in ((b0 - 1) % self.buckets, (b0 + 1) % self.buckets): + v, ts = arr[b] + if v is None: + continue + cell_lat = (gy + 0.5) * 1e-4 + cell_lon = (gx + 0.5) * 1e-4 + pts.append([cell_lat, cell_lon, float(v)]) + if len(pts) >= max_points: + return json.dumps({"pts": pts}, separators=(",",":")) + + return json.dumps({"pts": pts}, separators=(",",":")) + + def add_sample(self, lat: float, lon: float, heading_deg: float, speed_signed: float): + """ + 단일 speed(부호 포함) 저장. + - 기준 셀(현재 위치) + heading 기준 좌/우 1셀, 2셀까지 동일 speed 기록 + - 각 셀 안에서는 heading 버킷 b와 b±1 세 개 버킷 모두 같은 값으로 갱신. + - >0: 기존 음수/None도 교체, 기존 양수면 평균으로 완만하게 갱신. + - <0: 항상 새 음수로 덮어쓰기(돌발 감속 우선). + ==0: 무시 + """ + v_in = round(float(speed_signed), 1) + if v_in == 0.0: + return + + # 현재 위치를 그리드로 + gy0, gx0 = quantize_1e4(lat, lon) + b = heading_to_bucket(heading_deg) + now = self._now() + + # bucket에 해당하는 전진 방향 그리드 벡터 + dy_f, dx_f = DIR_8[b] + + # heading 기준 좌/우 1셀, 2셀 (project_point 사용 X) + # 좌 = 전진벡터를 90° 회전 (dy,dx) -> (dx,-dy) + # 우 = 전진벡터를 -90° 회전 (dy,dx) -> (-dx,dy) + dy_l1, dx_l1 = dx_f, -dy_f + dy_r1, dx_r1 = -dx_f, dy_f + + dy_l2, dx_l2 = 2 * dy_l1, 2 * dx_l1 + dy_r2, dx_r2 = 2 * dy_r1, 2 * dx_r1 + + # 기록할 셀들: 중앙 + 좌/우 1칸 + 좌/우 2칸 + target_cells = { + (gy0, gx0), + (gy0 + dy_l1, gx0 + dx_l1), + (gy0 + dy_r1, gx0 + dx_r1), + (gy0 + dy_l2, gx0 + dx_l2), + (gy0 + dy_r2, gx0 + dx_r2), + } + + with self._lock: + for gy, gx in target_cells: + arr = self._ensure_cell(gy, gx) + + # b, b-1, b+1 세 버킷 모두 같은 정책으로 업데이트 + for off in (0, -1, +1): + bi = (b + off) % self.buckets + v_old, ts_old = arr[bi] + + if v_old is None: + # 처음 쓰는 버킷 + arr[bi] = [v_in, now] + else: + if v_in > 0.0: + # 가속 정보: 기존 양수면 평균, 음수면 교체 + if v_old < 0.0: + # 음수 -> 양수로 바뀌면 새 양수로 교체 (ts는 기존 유지) + arr[bi] = [v_in, ts_old] + else: + new_val = round((v_old + v_in) / 2.0, 1) + arr[bi] = [new_val, ts_old] + else: + # 감속 정보: 항상 새 음수로 덮어쓰기, ts는 기존 유지 + arr[bi] = [v_in, ts_old] + + self._dirty = True + + + def query_target(self, lat: float, lon: float, heading_deg: float, v_ego: float, + lookahead_s: float = 2.0) -> float: + dist = max(0.0, float(v_ego) * float(lookahead_s)) + return self.query_target_dist(lat, lon, heading_deg, dist) + + def query_target_dist(self, lat: float, lon: float, heading_deg: float, dist: float) -> float: + b = heading_to_bucket(heading_deg) + + cand_ds = [dist] + for off in (3.0, -3.0): + d2 = dist + off + if d2 >= 0.0: + cand_ds.append(d2) + + with self._lock: + for d in cand_ds: + y, x = project_point(lat, lon, heading_deg, d) + gy, gx = quantize_1e4(y, x) + + arr = self._cells.get((gy, gx)) + if not arr: + continue + + v, b_sel = self._try_cell_bucket_old(arr, b) + if v is not None: + now_sec = int(time.time()) + self._last_hit = (gy, gx, b_sel, now_sec) + self._last_hit_read_ms = int(time.time() * 1000) + return v + + return 0.0 + + def invalidate_last_hit(self, window_s: float = 2.0, action: str = "clear") -> bool: + if self._last_hit is None: + return False + gy, gx, b, read_ts = self._last_hit + now = int(time.time()) + if (now - int(read_ts)) > window_s: + return False + + with self._lock: + arr = self._cells.get((gy, gx)) + if not arr: + return False + + # b, b-1, b+1 모두 invalidate + for off in (0, -1, +1): + bi = (b + off) % self.buckets + v, ts = arr[bi] + + if action == "clear": + if v is not None and v < 0.0: + arr[bi] = [None, None] + else: # "age_bump" + if v is not None: + arr[bi] = [v, now] + else: + # 값이 없으면 넘어가기만 (그 버킷만 skip) + pass + + self._dirty = True + + return True + + def maybe_save(self, interval_s: int = 60) -> None: + now = self._now() + if (not self._dirty) or (now - self._last_save < interval_s): + return + self.save() + + def save(self) -> None: + payload = self._encode_payload() + self._params.put_nonblocking(self.KEY, payload) + self._last_save = self._now() + self._dirty = False + + def close(self) -> None: + try: + if self._dirty: + self.save() + except Exception: + pass + + # ----- 직렬화 ----- + + def _encode_payload(self) -> bytes: + with self._lock: + cells = {} + for (gy, gx), arr in self._cells.items(): + key = f"{gy},{gx}" + # arr: [[v, ts], ...] (ts는 int 또는 None) + cells[key] = [[None if v is None else float(v), + None if ts is None else int(ts)] for (v, ts) in arr] + obj = {"format": "v5", "dir_buckets": self.buckets, "cells": cells} + raw = json.dumps(obj, separators=(",", ":")).encode("utf-8") + if self.use_gzip: + return gzip.compress(raw, compresslevel=self.gzip_level) + return raw + + def _load_from_params_if_exists(self) -> None: + raw = self._params.get(self.KEY) + if not raw: + return + try: + data_bytes = raw + if _is_gzip(data_bytes): + data_bytes = gzip.decompress(data_bytes) + data = json.loads(data_bytes.decode("utf-8")) + + # v3 아니면 삭제/초기화 + if data.get("format") != "v5": + self._params.remove(self.KEY) + with self._lock: + self._cells = {} + self._dirty = False + return + + buckets = int(data.get("dir_buckets", 8)) + if buckets != 8: + # 버킷 불일치도 삭제/초기화 + self._params.remove(self.KEY) + with self._lock: + self._cells = {} + self._dirty = False + return + + restored: Dict[Tuple[int, int], List[List[Optional[float]]]] = {} + for key, arr in data.get("cells", {}).items(): + gy, gx = map(int, key.split(",")) + fixed: List[List[Optional[float]]] = [] + if isinstance(arr, list) and len(arr) == 8: + for pair in arr: + if isinstance(pair, list) and len(pair) == 2: + v, ts = pair + v2 = None if v is None else float(v) + # ts는 int로 강제 + ts2 = None if ts is None else int(ts) + fixed.append([v2, ts2]) + else: + fixed.append([None, None]) + else: + fixed = [[None, None] for _ in range(8)] + restored[(gy, gx)] = fixed + + with self._lock: + self._cells = restored + self._dirty = False + + except Exception: + # 파싱 실패 시 안전 초기화 + self._params.delete(self.KEY) + with self._lock: + self._cells = {} + self._dirty = False diff --git a/selfdrive/carrot_settings.json b/selfdrive/carrot_settings.json index fc2f1a2f..b9818b8a 100644 --- a/selfdrive/carrot_settings.json +++ b/selfdrive/carrot_settings.json @@ -1744,6 +1744,19 @@ "default": 1, "unit": 1 }, + { + "group": "감속제어", + "name": "CarrotSmartSpeedControl", + "title": "스마트속도제어(속도재생)", + "descr": "0: 속도제어안함\n 1: 가속만, 2: 감속만, 3: 모두", + "egroup": "SPEED", + "etitle": "Smart Speed Control(Replay)", + "edescr": "0:not use, 1:accel, 2:decel, 3:all", + "min": 0, + "max": 3, + "default": 0, + "unit": 1 + }, { "group": "감속제어", "name": "MapTurnSpeedFactor", diff --git a/selfdrive/controls/controlsd.py b/selfdrive/controls/controlsd.py index 5ed5fc95..2a9f8df1 100644 --- a/selfdrive/controls/controlsd.py +++ b/selfdrive/controls/controlsd.py @@ -48,7 +48,7 @@ class Controls: self.sm = messaging.SubMaster(['liveDelay', 'liveParameters', 'liveTorqueParameters', 'modelV2', 'selfdriveState', 'liveCalibration', 'livePose', 'longitudinalPlan', 'carState', 'carOutput', - 'carrotMan', 'lateralPlan', 'radarState', 'liveLocationKalman', + 'carrotMan', 'lateralPlan', 'radarState', 'driverMonitoringState', 'onroadEvents', 'driverAssistance'], poll='selfdriveState') self.pm = messaging.PubMaster(['carControl', 'controlsState']) @@ -180,7 +180,7 @@ class Controls: actuators.curvature = float(self.desired_curvature) steer, steeringAngleDeg, lac_log = self.LaC.update(CC.latActive, CS, self.VM, lp, self.steer_limited_by_controls, self.desired_curvature, - self.sm['liveLocationKalman'], curvature_limited, + CC, curvature_limited, model_data=self.sm['modelV2']) actuators.torque = float(steer) actuators.steeringAngleDeg = float(steeringAngleDeg) @@ -237,23 +237,16 @@ class Controls: # Orientation and angle rates can be useful for carcontroller # Only calibrated (car) frame is relevant for the carcontroller - #if self.calibrated_pose is not None: - # CC.orientationNED = self.calibrated_pose.orientation.xyz.tolist() - # CC.angularVelocity = self.calibrated_pose.angular_velocity.xyz.tolist() + if self.calibrated_pose is not None: + CC.orientationNED = self.calibrated_pose.orientation.xyz.tolist() + CC.angularVelocity = self.calibrated_pose.angular_velocity.xyz.tolist() - orientation_value = list(self.sm['liveLocationKalman'].calibratedOrientationNED.value) - if len(orientation_value) > 2: - CC.orientationNED = orientation_value - angular_rate_value = list(self.sm['liveLocationKalman'].angularVelocityCalibrated.value) - if len(angular_rate_value) > 2: - CC.angularVelocity = angular_rate_value - - acceleration_value = list(self.sm['liveLocationKalman'].accelerationCalibrated.value) - if len(acceleration_value) > 2: - if abs(acceleration_value[0]) > 16.0: - print("Collision detected. disable openpilot, restart") - self.params.put_bool("OpenpilotEnabledToggle", False) - self.params.put_int("SoftRestartTriggered", 1) + #acceleration_value = list(self.sm['liveLocationKalman'].accelerationCalibrated.value) + #if len(acceleration_value) > 2: + # if abs(acceleration_value[0]) > 16.0: + # print("Collision detected. disable openpilot, restart") + # self.params.put_bool("OpenpilotEnabledToggle", False) + # self.params.put_int("SoftRestartTriggered", 1) CC.cruiseControl.override = CC.enabled and not CC.longActive and self.CP.openpilotLongitudinalControl CC.cruiseControl.cancel = CS.cruiseState.enabled and (not CC.enabled or not self.CP.pcmCruise) diff --git a/selfdrive/controls/lib/latcontrol.py b/selfdrive/controls/lib/latcontrol.py index ce0292bd..2f48cbf7 100644 --- a/selfdrive/controls/lib/latcontrol.py +++ b/selfdrive/controls/lib/latcontrol.py @@ -17,7 +17,7 @@ class LatControl(ABC): self.steer_max = 1.0 @abstractmethod - def update(self, active, CS, VM, params, steer_limited_by_controls, desired_curvature, llk, curvature_limited, model_data=None): + def update(self, active, CS, VM, params, steer_limited_by_controls, desired_curvature, CC, curvature_limited, model_data=None): pass def reset(self): diff --git a/selfdrive/controls/lib/latcontrol_angle.py b/selfdrive/controls/lib/latcontrol_angle.py index 7fe69913..a68eb1b3 100644 --- a/selfdrive/controls/lib/latcontrol_angle.py +++ b/selfdrive/controls/lib/latcontrol_angle.py @@ -17,7 +17,7 @@ class LatControlAngle(LatControl): #self.factor = 0.5 #print("Angle factor", self.factor) - def update(self, active, CS, VM, params, steer_limited_by_controls, desired_curvature, llk, curvature_limited, model_data=None): + def update(self, active, CS, VM, params, steer_limited_by_controls, desired_curvature, CC, curvature_limited, model_data=None): angle_log = log.ControlsState.LateralAngleState.new_message() if not active: diff --git a/selfdrive/controls/lib/latcontrol_pid.py b/selfdrive/controls/lib/latcontrol_pid.py index 9c1724ab..1a03afbe 100644 --- a/selfdrive/controls/lib/latcontrol_pid.py +++ b/selfdrive/controls/lib/latcontrol_pid.py @@ -17,7 +17,7 @@ class LatControlPID(LatControl): super().reset() self.pid.reset() - def update(self, active, CS, VM, params, steer_limited_by_controls, desired_curvature, llk, curvature_limited, model_data=None): + def update(self, active, CS, VM, params, steer_limited_by_controls, desired_curvature, CC, curvature_limited, model_data=None): pid_log = log.ControlsState.LateralPIDState.new_message() pid_log.steeringAngleDeg = float(CS.steeringAngleDeg) pid_log.steeringRateDeg = float(CS.steeringRateDeg) diff --git a/selfdrive/controls/lib/latcontrol_torque.py b/selfdrive/controls/lib/latcontrol_torque.py index ba26c7ed..a920369b 100644 --- a/selfdrive/controls/lib/latcontrol_torque.py +++ b/selfdrive/controls/lib/latcontrol_torque.py @@ -139,7 +139,7 @@ class LatControlTorque(LatControl): self.torque_params.latAccelOffset = latAccelOffset self.torque_params.friction = friction - def update(self, active, CS, VM, params, steer_limited_by_controls, desired_curvature, llk, curvature_limited, model_data=None): + def update(self, active, CS, VM, params, steer_limited_by_controls, desired_curvature, CC, curvature_limited, model_data=None): self.frame += 1 if self.frame % 10 == 0: lateralTorqueCustom = self.params.get_int("LateralTorqueCustom") @@ -181,7 +181,7 @@ class LatControlTorque(LatControl): actual_curvature_rate = -VM.calc_curvature(math.radians(CS.steeringRateDeg), CS.vEgo, 0.0) actual_lateral_jerk = actual_curvature_rate * CS.vEgo ** 2 else: - actual_curvature_llk = llk.angularVelocityCalibrated.value[2] / CS.vEgo + actual_curvature_llk = CC.angularVelocity[2] / CS.vEgo #llk.angularVelocityCalibrated.value[2] / CS.vEgo actual_curvature = np.interp(CS.vEgo, [2.0, 5.0], [actual_curvature_vm, actual_curvature_llk]) curvature_deadzone = 0.0 desired_lateral_accel = desired_curvature * CS.vEgo ** 2 @@ -219,8 +219,8 @@ class LatControlTorque(LatControl): # update past data pitch = 0 roll = params.roll - if len(llk.calibratedOrientationNED.value) > 1: - pitch = self.pitch.update(llk.calibratedOrientationNED.value[1]) + if len(CC.orientionNED) > 1: + pitch = self.pitch.update(CC.orientationNED[1]) roll = roll_pitch_adjust(roll, pitch) self.roll_deque.append(roll) self.lateral_accel_desired_deque.append(desired_lateral_accel) diff --git a/selfdrive/controls/lib/lateral_planner.py b/selfdrive/controls/lib/lateral_planner.py index b7f4772e..3fbccad5 100644 --- a/selfdrive/controls/lib/lateral_planner.py +++ b/selfdrive/controls/lib/lateral_planner.py @@ -95,12 +95,16 @@ class LateralPlanner: # clip speed , lateral planning is not possible at 0 speed measured_curvature = sm['controlsState'].curvature - v_ego_car = sm['carState'].vEgo + v_ego_car = max(sm['carState'].vEgo, MIN_SPEED) + speed_kph = v_ego_car * 3.6 + self.v_ego = v_ego_car self.curve_speed = sm['carrotMan'].vTurnSpeed # Parse model predictions md = sm['modelV2'] + model_active = False if len(md.position.x) == TRAJECTORY_SIZE and len(md.orientation.x) == TRAJECTORY_SIZE: + model_active = True self.path_xyz = np.column_stack([md.position.x, md.position.y, md.position.z]) self.t_idxs = np.array(md.position.t) self.plan_yaw = np.array(md.orientation.z) @@ -125,9 +129,9 @@ class LateralPlanner: if self.useLaneLineSpeedApply == 0 or self.laneless_only: self.useLaneLineMode = False - elif self.v_ego*3.6 >= self.useLaneLineSpeedApply + 2: + elif speed_kph >= self.useLaneLineSpeedApply + 2: self.useLaneLineMode = True - elif self.v_ego*3.6 < self.useLaneLineSpeedApply - 2: + elif speed_kph < self.useLaneLineSpeedApply - 2: self.useLaneLineMode = False # Turn off lanes during lane change @@ -143,10 +147,15 @@ class LateralPlanner: self.LP.lane_width_left = md.meta.laneWidthLeft self.LP.lane_width_right = md.meta.laneWidthRight self.LP.curvature = measured_curvature - self.path_xyz, self.lanelines_active = self.LP.get_d_path(sm['carState'], self.v_ego, self.t_idxs, self.path_xyz, self.curve_speed) - - #if self.LP.lanefull_mode: - # self.plan_yaw, self.plan_yaw_rate = self.LP.calculate_plan_yaw_and_yaw_rate(self.path_xyz) + self.path_xyz, self.lanelines_active = self.LP.get_d_path(sm['carState'], v_ego_car, self.t_idxs, self.path_xyz, self.curve_speed) + + if self.lanelines_active: + self.plan_yaw, self.plan_yaw_rate = yaw_from_path_no_scipy( + self.path_xyz, self.v_plan, + smooth_window=5, + clip_rate=2.0, + align_first_yaw=None #md.orientation.z[0] # 초기 정렬 + ) self.latDebugText = self.LP.debugText #self.lanelines_active = True if self.LP.d_prob > 0.3 and self.LP.lanefull_mode else False @@ -212,12 +221,14 @@ class LateralPlanner: lateralPlan.psis = self.lat_mpc.x_sol[0:CONTROL_N, 2].tolist() lateralPlan.distances = self.lat_mpc.x_sol[0:CONTROL_N, 0].tolist() + v_div = np.maximum(self.v_plan[:CONTROL_N], 6.0) if len(self.v_plan) == TRAJECTORY_SIZE: - lateralPlan.curvatures = (self.lat_mpc.x_sol[0:CONTROL_N, 3] / self.v_plan[0:CONTROL_N]).tolist() + lateralPlan.curvatures = (self.lat_mpc.x_sol[0:CONTROL_N, 3] / v_div).tolist() else: lateralPlan.curvatures = (self.lat_mpc.x_sol[0:CONTROL_N, 3] / self.v_ego).tolist() - lateralPlan.curvatureRates = [float(x.item() / self.v_ego) for x in self.lat_mpc.u_sol[0:CONTROL_N - 1]] + [0.0] + v_div2 = max(self.v_ego, 6.0) + lateralPlan.curvatureRates = [float(x.item() / v_div2) for x in self.lat_mpc.u_sol[0:CONTROL_N - 1]] + [0.0] lateralPlan.mpcSolutionValid = bool(plan_solution_valid) lateralPlan.solverExecutionTime = self.lat_mpc.solve_time @@ -263,3 +274,74 @@ class LateralPlanner: pm.send('lateralPlan', plan_send) +def smooth_moving_avg(arr, window=5): + if window < 2: + return arr + if window % 2 == 0: + window += 1 + pad = window // 2 + arr_pad = np.pad(arr, (pad, pad), mode='edge') + kernel = np.ones(window) / window + return np.convolve(arr_pad, kernel, mode='same')[pad:-pad] + +def yaw_from_path_no_scipy(path_xyz, v_plan, smooth_window=5, + clip_rate=2.0, align_first_yaw=None): + + v0 = float(np.asarray(v_plan)[0]) if len(v_plan) else 0.0 + # 저속(≤6 m/s)에서는 창을 크게 + if v0 <= 6.0: + smooth_window = max(smooth_window, 9) # 9~11 권장 + + N = path_xyz.shape[0] + x = path_xyz[:, 0].astype(float) + y = path_xyz[:, 1].astype(float) + + if N < 5: + return np.zeros(N, np.float32), np.zeros(N, np.float32) + + # 1) s(호길이) 계산 + dx = np.diff(x) + dy = np.diff(y) + ds_seg = np.sqrt(dx*dx + dy*dy) + ds_seg[ds_seg < 0.05] = 0.05 + s = np.zeros(N, float) + s[1:] = np.cumsum(ds_seg) + if s[-1] < 0.5: # 총 호길이 < 0.5m면 미분 결과 의미가 약함 + return np.zeros(N, np.float32), np.zeros(N, np.float32) + + # 2) smoothing (이동평균) + x_smooth = smooth_moving_avg(x, smooth_window) + y_smooth = smooth_moving_avg(y, smooth_window) + + # 3) 1·2차 도함수(s축 미분) + dx_ds = np.gradient(x_smooth, s) + dy_ds = np.gradient(y_smooth, s) + d2x_ds2 = np.gradient(dx_ds, s) + d2y_ds2 = np.gradient(dy_ds, s) + + # 4) yaw = atan2(dy/ds, dx/ds) + yaw = np.unwrap(np.arctan2(dy_ds, dx_ds)) + + # 5) 곡률 kappa = ... + denom = (dx_ds*dx_ds + dy_ds*dy_ds)**1.5 + denom[denom < 1e-9] = 1e-9 + kappa = (dx_ds * d2y_ds2 - dy_ds * d2x_ds2) / denom + + # 6) yaw_rate = kappa * v + v = np.asarray(v_plan, float) + yaw_rate = kappa * v + if v0 <= 6.0: + # 이동평균으로 미세 요동 감쇄(창 5~7) + yaw_rate = smooth_moving_avg(yaw_rate, window=7) + + # 7) 초기 yaw 정렬 (선택) + if align_first_yaw is not None: + bias = yaw[0] - float(align_first_yaw) + yaw = yaw - bias + + # 8) 안정화 + yaw = np.where(np.isfinite(yaw), yaw, 0.0) + yaw_rate = np.where(np.isfinite(yaw_rate), yaw_rate, 0.0) + yaw_rate = np.clip(yaw_rate, -abs(clip_rate), abs(clip_rate)) + + return yaw.astype(np.float32), yaw_rate.astype(np.float32) diff --git a/selfdrive/locationd/.gitignore b/selfdrive/locationd/.gitignore index 11b9f127..1a8c7238 100644 --- a/selfdrive/locationd/.gitignore +++ b/selfdrive/locationd/.gitignore @@ -1,3 +1,2 @@ params_learner paramsd -locationd diff --git a/selfdrive/locationd/SConscript b/selfdrive/locationd/SConscript index 2bd90783..e8eeff7e 100644 --- a/selfdrive/locationd/SConscript +++ b/selfdrive/locationd/SConscript @@ -1,6 +1,4 @@ -Import('env', 'arch', 'common', 'messaging', 'rednose', 'transformations') - -loc_libs = [messaging, common, 'pthread', 'dl'] +Import('env', 'rednose') # build ekf models rednose_gen_dir = 'models/generated' @@ -14,13 +12,6 @@ pose_ekf = env.RednoseCompileFilter( extra_gen_artifacts=[], gen_script_deps=rednose_gen_deps, ) -live_ekf = env.RednoseCompileFilter( - target='live', - filter_gen_script='models/live_kf.py', - output_dir=rednose_gen_dir, - extra_gen_artifacts=['live_kf_constants.h'], - gen_script_deps=rednose_gen_deps, -) car_ekf = env.RednoseCompileFilter( target='car', filter_gen_script='models/car_kf.py', @@ -28,17 +19,3 @@ car_ekf = env.RednoseCompileFilter( extra_gen_artifacts=[], gen_script_deps=rednose_gen_deps, ) - -# locationd build -locationd_sources = ["locationd.cc", "models/live_kf.cc"] - -lenv = env.Clone() -# ekf filter libraries need to be linked, even if no symbols are used -if arch != "Darwin": - lenv["LINKFLAGS"] += ["-Wl,--no-as-needed"] - -lenv["LIBPATH"].append(Dir(rednose_gen_dir).abspath) -lenv["RPATH"].append(Dir(rednose_gen_dir).abspath) -locationd = lenv.Program("locationd", locationd_sources, LIBS=["live", "ekf_sym"] + loc_libs + transformations) -lenv.Depends(locationd, rednose) -lenv.Depends(locationd, live_ekf) diff --git a/selfdrive/locationd/locationd.cc b/selfdrive/locationd/locationd.cc deleted file mode 100644 index 5cfcfa50..00000000 --- a/selfdrive/locationd/locationd.cc +++ /dev/null @@ -1,774 +0,0 @@ -#include "selfdrive/locationd/locationd.h" - -#include -#include - -#include -#include -#include - -using namespace EKFS; -using namespace Eigen; - -ExitHandler do_exit; -const double ACCEL_SANITY_CHECK = 100.0; // m/s^2 -const double ROTATION_SANITY_CHECK = 10.0; // rad/s -const double TRANS_SANITY_CHECK = 200.0; // m/s -const double CALIB_RPY_SANITY_CHECK = 0.5; // rad (+- 30 deg) -const double ALTITUDE_SANITY_CHECK = 10000; // m -const double MIN_STD_SANITY_CHECK = 1e-5; // m or rad -const double VALID_TIME_SINCE_RESET = 1.0; // s -const double VALID_POS_STD = 50.0; // m -const double MAX_RESET_TRACKER = 5.0; -const double SANE_GPS_UNCERTAINTY = 1500.0; // m -const double INPUT_INVALID_THRESHOLD = 0.5; // same as reset tracker -const double RESET_TRACKER_DECAY = 0.99995; -const double DECAY = 0.9993; // ~10 secs to resume after a bad input -const double MAX_FILTER_REWIND_TIME = 0.8; // s -const double YAWRATE_CROSS_ERR_CHECK_FACTOR = 30; - -// TODO: GPS sensor time offsets are empirically calculated -// They should be replaced with synced time from a real clock -const double GPS_QUECTEL_SENSOR_TIME_OFFSET = 0.630; // s -const double GPS_UBLOX_SENSOR_TIME_OFFSET = 0.095; // s -const float GPS_POS_STD_THRESHOLD = 50.0; -const float GPS_VEL_STD_THRESHOLD = 5.0; -const float GPS_POS_ERROR_RESET_THRESHOLD = 300.0; -const float GPS_POS_STD_RESET_THRESHOLD = 2.0; -const float GPS_VEL_STD_RESET_THRESHOLD = 0.5; -const float GPS_ORIENTATION_ERROR_RESET_THRESHOLD = 1.0; -const int GPS_ORIENTATION_ERROR_RESET_CNT = 3; - -const bool DEBUG = getenv("DEBUG") != nullptr && std::string(getenv("DEBUG")) != "0"; - -static VectorXd floatlist2vector(const capnp::List::Reader& floatlist) { - VectorXd res(floatlist.size()); - for (int i = 0; i < floatlist.size(); i++) { - res[i] = floatlist[i]; - } - return res; -} - -static Vector4d quat2vector(const Quaterniond& quat) { - return Vector4d(quat.w(), quat.x(), quat.y(), quat.z()); -} - -static Quaterniond vector2quat(const VectorXd& vec) { - return Quaterniond(vec(0), vec(1), vec(2), vec(3)); -} - -static void init_measurement(cereal::LiveLocationKalman::Measurement::Builder meas, const VectorXd& val, const VectorXd& std, bool valid) { - meas.setValue(kj::arrayPtr(val.data(), val.size())); - meas.setStd(kj::arrayPtr(std.data(), std.size())); - meas.setValid(valid); -} - - -static MatrixXdr rotate_cov(const MatrixXdr& rot_matrix, const MatrixXdr& cov_in) { - // To rotate a covariance matrix, the cov matrix needs to multiplied left and right by the transform matrix - return ((rot_matrix * cov_in) * rot_matrix.transpose()); -} - -static VectorXd rotate_std(const MatrixXdr& rot_matrix, const VectorXd& std_in) { - // Stds cannot be rotated like values, only covariances can be rotated - return rotate_cov(rot_matrix, std_in.array().square().matrix().asDiagonal()).diagonal().array().sqrt(); -} - -Localizer::Localizer(LocalizerGnssSource gnss_source) { - this->kf = std::make_unique(); - this->reset_kalman(); - - this->calib = Vector3d(0.0, 0.0, 0.0); - this->device_from_calib = MatrixXdr::Identity(3, 3); - this->calib_from_device = MatrixXdr::Identity(3, 3); - - for (int i = 0; i < POSENET_STD_HIST_HALF * 2; i++) { - this->posenet_stds.push_back(10.0); - } - - VectorXd ecef_pos = this->kf->get_x().segment(STATE_ECEF_POS_START); - this->converter = std::make_unique((ECEF) { .x = ecef_pos[0], .y = ecef_pos[1], .z = ecef_pos[2] }); - this->configure_gnss_source(gnss_source); -} - -void Localizer::build_live_location(cereal::LiveLocationKalman::Builder& fix) { - VectorXd predicted_state = this->kf->get_x(); - MatrixXdr predicted_cov = this->kf->get_P(); - VectorXd predicted_std = predicted_cov.diagonal().array().sqrt(); - - VectorXd fix_ecef = predicted_state.segment(STATE_ECEF_POS_START); - ECEF fix_ecef_ecef = { .x = fix_ecef(0), .y = fix_ecef(1), .z = fix_ecef(2) }; - VectorXd fix_ecef_std = predicted_std.segment(STATE_ECEF_POS_ERR_START); - VectorXd vel_ecef = predicted_state.segment(STATE_ECEF_VELOCITY_START); - VectorXd vel_ecef_std = predicted_std.segment(STATE_ECEF_VELOCITY_ERR_START); - VectorXd fix_pos_geo_vec = this->get_position_geodetic(); - VectorXd orientation_ecef = quat2euler(vector2quat(predicted_state.segment(STATE_ECEF_ORIENTATION_START))); - VectorXd orientation_ecef_std = predicted_std.segment(STATE_ECEF_ORIENTATION_ERR_START); - MatrixXdr orientation_ecef_cov = predicted_cov.block(STATE_ECEF_ORIENTATION_ERR_START, STATE_ECEF_ORIENTATION_ERR_START); - MatrixXdr device_from_ecef = euler2rot(orientation_ecef).transpose(); - VectorXd calibrated_orientation_ecef = rot2euler((this->calib_from_device * device_from_ecef).transpose()); - - VectorXd acc_calib = this->calib_from_device * predicted_state.segment(STATE_ACCELERATION_START); - MatrixXdr acc_calib_cov = predicted_cov.block(STATE_ACCELERATION_ERR_START, STATE_ACCELERATION_ERR_START); - VectorXd acc_calib_std = rotate_cov(this->calib_from_device, acc_calib_cov).diagonal().array().sqrt(); - VectorXd ang_vel_calib = this->calib_from_device * predicted_state.segment(STATE_ANGULAR_VELOCITY_START); - - MatrixXdr vel_angular_cov = predicted_cov.block(STATE_ANGULAR_VELOCITY_ERR_START, STATE_ANGULAR_VELOCITY_ERR_START); - VectorXd ang_vel_calib_std = rotate_cov(this->calib_from_device, vel_angular_cov).diagonal().array().sqrt(); - - VectorXd vel_device = device_from_ecef * vel_ecef; - VectorXd device_from_ecef_eul = quat2euler(vector2quat(predicted_state.segment(STATE_ECEF_ORIENTATION_START))).transpose(); - MatrixXdr condensed_cov(STATE_ECEF_ORIENTATION_ERR_LEN + STATE_ECEF_VELOCITY_ERR_LEN, STATE_ECEF_ORIENTATION_ERR_LEN + STATE_ECEF_VELOCITY_ERR_LEN); - condensed_cov.topLeftCorner() = - predicted_cov.block(STATE_ECEF_ORIENTATION_ERR_START, STATE_ECEF_ORIENTATION_ERR_START); - condensed_cov.topRightCorner() = - predicted_cov.block(STATE_ECEF_ORIENTATION_ERR_START, STATE_ECEF_VELOCITY_ERR_START); - condensed_cov.bottomRightCorner() = - predicted_cov.block(STATE_ECEF_VELOCITY_ERR_START, STATE_ECEF_VELOCITY_ERR_START); - condensed_cov.bottomLeftCorner() = - predicted_cov.block(STATE_ECEF_VELOCITY_ERR_START, STATE_ECEF_ORIENTATION_ERR_START); - VectorXd H_input(device_from_ecef_eul.size() + vel_ecef.size()); - H_input << device_from_ecef_eul, vel_ecef; - MatrixXdr HH = this->kf->H(H_input); - MatrixXdr vel_device_cov = (HH * condensed_cov) * HH.transpose(); - VectorXd vel_device_std = vel_device_cov.diagonal().array().sqrt(); - - VectorXd vel_calib = this->calib_from_device * vel_device; - VectorXd vel_calib_std = rotate_cov(this->calib_from_device, vel_device_cov).diagonal().array().sqrt(); - - VectorXd orientation_ned = ned_euler_from_ecef(fix_ecef_ecef, orientation_ecef); - VectorXd orientation_ned_std = rotate_cov(this->converter->ecef2ned_matrix, orientation_ecef_cov).diagonal().array().sqrt(); - VectorXd calibrated_orientation_ned = ned_euler_from_ecef(fix_ecef_ecef, calibrated_orientation_ecef); - VectorXd nextfix_ecef = fix_ecef + vel_ecef; - VectorXd ned_vel = this->converter->ecef2ned((ECEF) { .x = nextfix_ecef(0), .y = nextfix_ecef(1), .z = nextfix_ecef(2) }).to_vector() - converter->ecef2ned(fix_ecef_ecef).to_vector(); - - VectorXd accDevice = predicted_state.segment(STATE_ACCELERATION_START); - VectorXd accDeviceErr = predicted_std.segment(STATE_ACCELERATION_ERR_START); - - VectorXd angVelocityDevice = predicted_state.segment(STATE_ANGULAR_VELOCITY_START); - VectorXd angVelocityDeviceErr = predicted_std.segment(STATE_ANGULAR_VELOCITY_ERR_START); - - Vector3d nans = Vector3d(NAN, NAN, NAN); - - // TODO fill in NED and Calibrated stds - // write measurements to msg - init_measurement(fix.initPositionGeodetic(), fix_pos_geo_vec, nans, this->gps_mode); - init_measurement(fix.initPositionECEF(), fix_ecef, fix_ecef_std, this->gps_mode); - init_measurement(fix.initVelocityECEF(), vel_ecef, vel_ecef_std, this->gps_mode); - init_measurement(fix.initVelocityNED(), ned_vel, nans, this->gps_mode); - init_measurement(fix.initVelocityDevice(), vel_device, vel_device_std, true); - init_measurement(fix.initAccelerationDevice(), accDevice, accDeviceErr, true); - init_measurement(fix.initOrientationECEF(), orientation_ecef, orientation_ecef_std, this->gps_mode); - init_measurement(fix.initCalibratedOrientationECEF(), calibrated_orientation_ecef, nans, this->calibrated && this->gps_mode); - init_measurement(fix.initOrientationNED(), orientation_ned, orientation_ned_std, this->gps_mode); - init_measurement(fix.initCalibratedOrientationNED(), calibrated_orientation_ned, nans, this->calibrated && this->gps_mode); - init_measurement(fix.initAngularVelocityDevice(), angVelocityDevice, angVelocityDeviceErr, true); - init_measurement(fix.initVelocityCalibrated(), vel_calib, vel_calib_std, this->calibrated); - init_measurement(fix.initAngularVelocityCalibrated(), ang_vel_calib, ang_vel_calib_std, this->calibrated); - init_measurement(fix.initAccelerationCalibrated(), acc_calib, acc_calib_std, this->calibrated); - if (DEBUG) { - init_measurement(fix.initFilterState(), predicted_state, predicted_std, true); - } - - double old_mean = 0.0, new_mean = 0.0; - int i = 0; - for (double x : this->posenet_stds) { - if (i < POSENET_STD_HIST_HALF) { - old_mean += x; - } else { - new_mean += x; - } - i++; - } - old_mean /= POSENET_STD_HIST_HALF; - new_mean /= POSENET_STD_HIST_HALF; - // experimentally found these values, no false positives in 20k minutes of driving - bool std_spike = (new_mean / old_mean > 4.0 && new_mean > 7.0); - - fix.setPosenetOK(!(std_spike && this->car_speed > 5.0)); - fix.setDeviceStable(!this->device_fell); - fix.setExcessiveResets(this->reset_tracker > MAX_RESET_TRACKER); - fix.setTimeToFirstFix(std::isnan(this->ttff) ? -1. : this->ttff); - this->device_fell = false; - - //fix.setGpsWeek(this->time.week); - //fix.setGpsTimeOfWeek(this->time.tow); - fix.setUnixTimestampMillis(this->unix_timestamp_millis); - - double time_since_reset = this->kf->get_filter_time() - this->last_reset_time; - fix.setTimeSinceReset(time_since_reset); - if (fix_ecef_std.norm() < VALID_POS_STD && this->calibrated && time_since_reset > VALID_TIME_SINCE_RESET) { - fix.setStatus(cereal::LiveLocationKalman::Status::VALID); - } else if (fix_ecef_std.norm() < VALID_POS_STD && time_since_reset > VALID_TIME_SINCE_RESET) { - fix.setStatus(cereal::LiveLocationKalman::Status::UNCALIBRATED); - } else { - fix.setStatus(cereal::LiveLocationKalman::Status::UNINITIALIZED); - } -} - -VectorXd Localizer::get_position_geodetic() { - VectorXd fix_ecef = this->kf->get_x().segment(STATE_ECEF_POS_START); - ECEF fix_ecef_ecef = { .x = fix_ecef(0), .y = fix_ecef(1), .z = fix_ecef(2) }; - Geodetic fix_pos_geo = ecef2geodetic(fix_ecef_ecef); - return Vector3d(fix_pos_geo.lat, fix_pos_geo.lon, fix_pos_geo.alt); -} - -VectorXd Localizer::get_state() { - return this->kf->get_x(); -} - -VectorXd Localizer::get_stdev() { - return this->kf->get_P().diagonal().array().sqrt(); -} - -bool Localizer::are_inputs_ok() { - return this->critical_services_valid(this->observation_values_invalid) && !this->observation_timings_invalid; -} - -void Localizer::observation_timings_invalid_reset(){ - this->observation_timings_invalid = false; -} - -void Localizer::handle_sensor(double current_time, const cereal::SensorEventData::Reader& log) { - // TODO does not yet account for double sensor readings in the log - - // Ignore empty readings (e.g. in case the magnetometer had no data ready) - if (log.getTimestamp() == 0) { - return; - } - - double sensor_time = 1e-9 * log.getTimestamp(); - - // sensor time and log time should be close - if (std::abs(current_time - sensor_time) > 0.1) { - LOGE("Sensor reading ignored, sensor timestamp more than 100ms off from log time"); - this->observation_timings_invalid = true; - return; - } else if (!this->is_timestamp_valid(sensor_time)) { - this->observation_timings_invalid = true; - return; - } - - // TODO: handle messages from two IMUs at the same time - if (log.getSource() == cereal::SensorEventData::SensorSource::BMX055) { - return; - } - - // Gyro Uncalibrated - if (log.getSensor() == SENSOR_GYRO_UNCALIBRATED && log.getType() == SENSOR_TYPE_GYROSCOPE_UNCALIBRATED) { - auto v = log.getGyroUncalibrated().getV(); - auto meas = Vector3d(-v[2], -v[1], -v[0]); - - VectorXd gyro_bias = this->kf->get_x().segment(STATE_GYRO_BIAS_START); - float gyro_camodo_yawrate_err = std::abs((meas[2] - gyro_bias[2]) - this->camodo_yawrate_distribution[0]); - float gyro_camodo_yawrate_err_threshold = YAWRATE_CROSS_ERR_CHECK_FACTOR * this->camodo_yawrate_distribution[1]; - bool gyro_valid = gyro_camodo_yawrate_err < gyro_camodo_yawrate_err_threshold; - - if ((meas.norm() < ROTATION_SANITY_CHECK) && gyro_valid) { - this->kf->predict_and_observe(sensor_time, OBSERVATION_PHONE_GYRO, { meas }); - this->observation_values_invalid["gyroscope"] *= DECAY; - } else { - this->observation_values_invalid["gyroscope"] += 1.0; - } - } - - // Accelerometer - if (log.getSensor() == SENSOR_ACCELEROMETER && log.getType() == SENSOR_TYPE_ACCELEROMETER) { - auto v = log.getAcceleration().getV(); - - // TODO: reduce false positives and re-enable this check - // check if device fell, estimate 10 for g - // 40m/s**2 is a good filter for falling detection, no false positives in 20k minutes of driving - // this->device_fell |= (floatlist2vector(v) - Vector3d(10.0, 0.0, 0.0)).norm() > 40.0; - - auto meas = Vector3d(-v[2], -v[1], -v[0]); - if (meas.norm() < ACCEL_SANITY_CHECK) { - this->kf->predict_and_observe(sensor_time, OBSERVATION_PHONE_ACCEL, { meas }); - this->observation_values_invalid["accelerometer"] *= DECAY; - } else { - this->observation_values_invalid["accelerometer"] += 1.0; - } - } -} - -void Localizer::input_fake_gps_observations(double current_time) { - // This is done to make sure that the error estimate of the position does not blow up - // when the filter is in no-gps mode - // Steps : first predict -> observe current obs with reasonable STD - this->kf->predict(current_time); - - VectorXd current_x = this->kf->get_x(); - VectorXd ecef_pos = current_x.segment(STATE_ECEF_POS_START); - VectorXd ecef_vel = current_x.segment(STATE_ECEF_VELOCITY_START); - const MatrixXdr &ecef_pos_R = this->kf->get_fake_gps_pos_cov(); - const MatrixXdr &ecef_vel_R = this->kf->get_fake_gps_vel_cov(); - - this->kf->predict_and_observe(current_time, OBSERVATION_ECEF_POS, { ecef_pos }, { ecef_pos_R }); - this->kf->predict_and_observe(current_time, OBSERVATION_ECEF_VEL, { ecef_vel }, { ecef_vel_R }); -} - -void Localizer::handle_gps(double current_time, const cereal::GpsLocationData::Reader& log, const double sensor_time_offset) { - bool gps_unreasonable = (Vector2d(log.getHorizontalAccuracy(), log.getVerticalAccuracy()).norm() >= SANE_GPS_UNCERTAINTY); - bool gps_accuracy_insane = ((log.getVerticalAccuracy() <= 0) || (log.getSpeedAccuracy() <= 0) || (log.getBearingAccuracyDeg() <= 0)); - bool gps_lat_lng_alt_insane = ((std::abs(log.getLatitude()) > 90) || (std::abs(log.getLongitude()) > 180) || (std::abs(log.getAltitude()) > ALTITUDE_SANITY_CHECK)); - bool gps_vel_insane = (floatlist2vector(log.getVNED()).norm() > TRANS_SANITY_CHECK); - - if (!log.getHasFix() || gps_unreasonable || gps_accuracy_insane || gps_lat_lng_alt_insane || gps_vel_insane) { - //this->gps_valid = false; - this->determine_gps_mode(current_time); - return; - } - - double sensor_time = current_time - sensor_time_offset; - - // Process message - //this->gps_valid = true; - this->gps_mode = true; - Geodetic geodetic = { log.getLatitude(), log.getLongitude(), log.getAltitude() }; - this->converter = std::make_unique(geodetic); - - VectorXd ecef_pos = this->converter->ned2ecef({ 0.0, 0.0, 0.0 }).to_vector(); - VectorXd ecef_vel = this->converter->ned2ecef({ log.getVNED()[0], log.getVNED()[1], log.getVNED()[2] }).to_vector() - ecef_pos; - float ecef_pos_std = std::sqrt(this->gps_variance_factor * std::pow(log.getHorizontalAccuracy(), 2) + this->gps_vertical_variance_factor * std::pow(log.getVerticalAccuracy(), 2)); - MatrixXdr ecef_pos_R = Vector3d::Constant(std::pow(this->gps_std_factor * ecef_pos_std, 2)).asDiagonal(); - MatrixXdr ecef_vel_R = Vector3d::Constant(std::pow(this->gps_std_factor * log.getSpeedAccuracy(), 2)).asDiagonal(); - - this->unix_timestamp_millis = log.getUnixTimestampMillis(); - double gps_est_error = (this->kf->get_x().segment(STATE_ECEF_POS_START) - ecef_pos).norm(); - - VectorXd orientation_ecef = quat2euler(vector2quat(this->kf->get_x().segment(STATE_ECEF_ORIENTATION_START))); - VectorXd orientation_ned = ned_euler_from_ecef({ ecef_pos(0), ecef_pos(1), ecef_pos(2) }, orientation_ecef); - VectorXd orientation_ned_gps = Vector3d(0.0, 0.0, DEG2RAD(log.getBearingDeg())); - VectorXd orientation_error = (orientation_ned - orientation_ned_gps).array() - M_PI; - for (int i = 0; i < orientation_error.size(); i++) { - orientation_error(i) = std::fmod(orientation_error(i), 2.0 * M_PI); - if (orientation_error(i) < 0.0) { - orientation_error(i) += 2.0 * M_PI; - } - orientation_error(i) -= M_PI; - } - VectorXd initial_pose_ecef_quat = quat2vector(euler2quat(ecef_euler_from_ned({ ecef_pos(0), ecef_pos(1), ecef_pos(2) }, orientation_ned_gps))); - - if (ecef_vel.norm() > 5.0 && orientation_error.norm() > 1.0) { - LOGE("Locationd vs ubloxLocation orientation difference too large, kalman reset"); - this->reset_kalman(NAN, initial_pose_ecef_quat, ecef_pos, ecef_vel, ecef_pos_R, ecef_vel_R); - this->kf->predict_and_observe(sensor_time, OBSERVATION_ECEF_ORIENTATION_FROM_GPS, { initial_pose_ecef_quat }); - } else if (gps_est_error > 100.0) { - LOGE("Locationd vs ubloxLocation position difference too large, kalman reset"); - this->reset_kalman(NAN, initial_pose_ecef_quat, ecef_pos, ecef_vel, ecef_pos_R, ecef_vel_R); - } - - this->last_gps_msg = sensor_time; - this->kf->predict_and_observe(sensor_time, OBSERVATION_ECEF_POS, { ecef_pos }, { ecef_pos_R }); - this->kf->predict_and_observe(sensor_time, OBSERVATION_ECEF_VEL, { ecef_vel }, { ecef_vel_R }); -} - -void Localizer::handle_gnss(double current_time, const cereal::GnssMeasurements::Reader& log) { - - if (!log.getPositionECEF().getValid() || !log.getVelocityECEF().getValid()) { - this->determine_gps_mode(current_time); - return; - } - - double sensor_time = log.getMeasTime() * 1e-9; - sensor_time -= this->gps_time_offset; - - auto ecef_pos_v = log.getPositionECEF().getValue(); - VectorXd ecef_pos = Vector3d(ecef_pos_v[0], ecef_pos_v[1], ecef_pos_v[2]); - - // indexed at 0 cause all std values are the same MAE - auto ecef_pos_std = log.getPositionECEF().getStd()[0]; - MatrixXdr ecef_pos_R = Vector3d::Constant(pow(this->gps_std_factor*ecef_pos_std, 2)).asDiagonal(); - - auto ecef_vel_v = log.getVelocityECEF().getValue(); - VectorXd ecef_vel = Vector3d(ecef_vel_v[0], ecef_vel_v[1], ecef_vel_v[2]); - - // indexed at 0 cause all std values are the same MAE - auto ecef_vel_std = log.getVelocityECEF().getStd()[0]; - MatrixXdr ecef_vel_R = Vector3d::Constant(pow(this->gps_std_factor*ecef_vel_std, 2)).asDiagonal(); - - double gps_est_error = (this->kf->get_x().segment(STATE_ECEF_POS_START) - ecef_pos).norm(); - - VectorXd orientation_ecef = quat2euler(vector2quat(this->kf->get_x().segment(STATE_ECEF_ORIENTATION_START))); - VectorXd orientation_ned = ned_euler_from_ecef({ ecef_pos[0], ecef_pos[1], ecef_pos[2] }, orientation_ecef); - - LocalCoord convs((ECEF){ .x = ecef_pos[0], .y = ecef_pos[1], .z = ecef_pos[2] }); - ECEF next_ecef = {.x = ecef_pos[0] + ecef_vel[0], .y = ecef_pos[1] + ecef_vel[1], .z = ecef_pos[2] + ecef_vel[2]}; - VectorXd ned_vel = convs.ecef2ned(next_ecef).to_vector(); - double bearing_rad = atan2(ned_vel[1], ned_vel[0]); - - VectorXd orientation_ned_gps = Vector3d(0.0, 0.0, bearing_rad); - VectorXd orientation_error = (orientation_ned - orientation_ned_gps).array() - M_PI; - for (int i = 0; i < orientation_error.size(); i++) { - orientation_error(i) = std::fmod(orientation_error(i), 2.0 * M_PI); - if (orientation_error(i) < 0.0) { - orientation_error(i) += 2.0 * M_PI; - } - orientation_error(i) -= M_PI; - } - VectorXd initial_pose_ecef_quat = quat2vector(euler2quat(ecef_euler_from_ned({ ecef_pos(0), ecef_pos(1), ecef_pos(2) }, orientation_ned_gps))); - - if (ecef_pos_std > GPS_POS_STD_THRESHOLD || ecef_vel_std > GPS_VEL_STD_THRESHOLD) { - this->determine_gps_mode(current_time); - return; - } - - // prevent jumping gnss measurements (covered lots, standstill...) - bool orientation_reset = ecef_vel_std < GPS_VEL_STD_RESET_THRESHOLD; - orientation_reset &= orientation_error.norm() > GPS_ORIENTATION_ERROR_RESET_THRESHOLD; - orientation_reset &= !this->standstill; - if (orientation_reset) { - this->orientation_reset_count++; - } else { - this->orientation_reset_count = 0; - } - - if ((gps_est_error > GPS_POS_ERROR_RESET_THRESHOLD && ecef_pos_std < GPS_POS_STD_RESET_THRESHOLD) || this->last_gps_msg == 0) { - // always reset on first gps message and if the location is off but the accuracy is high - LOGE("Locationd vs gnssMeasurement position difference too large, kalman reset"); - this->reset_kalman(NAN, initial_pose_ecef_quat, ecef_pos, ecef_vel, ecef_pos_R, ecef_vel_R); - } else if (orientation_reset_count > GPS_ORIENTATION_ERROR_RESET_CNT) { - LOGE("Locationd vs gnssMeasurement orientation difference too large, kalman reset"); - this->reset_kalman(NAN, initial_pose_ecef_quat, ecef_pos, ecef_vel, ecef_pos_R, ecef_vel_R); - this->kf->predict_and_observe(sensor_time, OBSERVATION_ECEF_ORIENTATION_FROM_GPS, { initial_pose_ecef_quat }); - this->orientation_reset_count = 0; - } - - this->gps_mode = true; - this->last_gps_msg = sensor_time; - this->kf->predict_and_observe(sensor_time, OBSERVATION_ECEF_POS, { ecef_pos }, { ecef_pos_R }); - this->kf->predict_and_observe(sensor_time, OBSERVATION_ECEF_VEL, { ecef_vel }, { ecef_vel_R }); -} - -void Localizer::handle_car_state(double current_time, const cereal::CarState::Reader& log) { - this->car_speed = std::abs(log.getVEgo()); - this->standstill = log.getStandstill(); - if (this->standstill) { - this->kf->predict_and_observe(current_time, OBSERVATION_NO_ROT, { Vector3d(0.0, 0.0, 0.0) }); - this->kf->predict_and_observe(current_time, OBSERVATION_NO_ACCEL, { Vector3d(0.0, 0.0, 0.0) }); - } -} - -void Localizer::handle_cam_odo(double current_time, const cereal::CameraOdometry::Reader& log) { - VectorXd rot_device = this->device_from_calib * floatlist2vector(log.getRot()); - VectorXd trans_device = this->device_from_calib * floatlist2vector(log.getTrans()); - - if (!this->is_timestamp_valid(current_time)) { - this->observation_timings_invalid = true; - return; - } - - if ((rot_device.norm() > ROTATION_SANITY_CHECK) || (trans_device.norm() > TRANS_SANITY_CHECK)) { - this->observation_values_invalid["cameraOdometry"] += 1.0; - return; - } - - VectorXd rot_calib_std = floatlist2vector(log.getRotStd()); - VectorXd trans_calib_std = floatlist2vector(log.getTransStd()); - - if ((rot_calib_std.minCoeff() <= MIN_STD_SANITY_CHECK) || (trans_calib_std.minCoeff() <= MIN_STD_SANITY_CHECK)) { - this->observation_values_invalid["cameraOdometry"] += 1.0; - return; - } - - if ((rot_calib_std.norm() > 10 * ROTATION_SANITY_CHECK) || (trans_calib_std.norm() > 10 * TRANS_SANITY_CHECK)) { - this->observation_values_invalid["cameraOdometry"] += 1.0; - return; - } - - this->posenet_stds.pop_front(); - this->posenet_stds.push_back(trans_calib_std[0]); - - // Multiply by 10 to avoid to high certainty in kalman filter because of temporally correlated noise - trans_calib_std *= 10.0; - rot_calib_std *= 10.0; - MatrixXdr rot_device_cov = rotate_std(this->device_from_calib, rot_calib_std).array().square().matrix().asDiagonal(); - MatrixXdr trans_device_cov = rotate_std(this->device_from_calib, trans_calib_std).array().square().matrix().asDiagonal(); - this->kf->predict_and_observe(current_time, OBSERVATION_CAMERA_ODO_ROTATION, - { rot_device }, { rot_device_cov }); - this->kf->predict_and_observe(current_time, OBSERVATION_CAMERA_ODO_TRANSLATION, - { trans_device }, { trans_device_cov }); - this->observation_values_invalid["cameraOdometry"] *= DECAY; - this->camodo_yawrate_distribution = Vector2d(rot_device[2], rotate_std(this->device_from_calib, rot_calib_std)[2]); -} - -void Localizer::handle_live_calib(double current_time, const cereal::LiveCalibrationData::Reader& log) { - if (!this->is_timestamp_valid(current_time)) { - this->observation_timings_invalid = true; - return; - } - - if (log.getRpyCalib().size() > 0) { - auto live_calib = floatlist2vector(log.getRpyCalib()); - if ((live_calib.minCoeff() < -CALIB_RPY_SANITY_CHECK) || (live_calib.maxCoeff() > CALIB_RPY_SANITY_CHECK)) { - this->observation_values_invalid["liveCalibration"] += 1.0; - return; - } - - this->calib = live_calib; - this->device_from_calib = euler2rot(this->calib); - this->calib_from_device = this->device_from_calib.transpose(); - this->calibrated = log.getCalStatus() == cereal::LiveCalibrationData::Status::CALIBRATED; - this->observation_values_invalid["liveCalibration"] *= DECAY; - } -} - -void Localizer::reset_kalman(double current_time) { - const VectorXd &init_x = this->kf->get_initial_x(); - const MatrixXdr &init_P = this->kf->get_initial_P(); - this->reset_kalman(current_time, init_x, init_P); -} - -void Localizer::finite_check(double current_time) { - bool all_finite = this->kf->get_x().array().isFinite().all() or this->kf->get_P().array().isFinite().all(); - if (!all_finite) { - LOGE("Non-finite values detected, kalman reset"); - this->reset_kalman(current_time); - } -} - -void Localizer::time_check(double current_time) { - if (std::isnan(this->last_reset_time)) { - this->last_reset_time = current_time; - } - if (std::isnan(this->first_valid_log_time)) { - this->first_valid_log_time = current_time; - } - double filter_time = this->kf->get_filter_time(); - bool big_time_gap = !std::isnan(filter_time) && (current_time - filter_time > 10); - if (big_time_gap) { - LOGE("Time gap of over 10s detected, kalman reset"); - this->reset_kalman(current_time); - } -} - -void Localizer::update_reset_tracker() { - // reset tracker is tuned to trigger when over 1reset/10s over 2min period - if (this->is_gps_ok()) { - this->reset_tracker *= RESET_TRACKER_DECAY; - } else { - this->reset_tracker = 0.0; - } -} - -void Localizer::reset_kalman(double current_time, const VectorXd &init_orient, const VectorXd &init_pos, const VectorXd &init_vel, const MatrixXdr &init_pos_R, const MatrixXdr &init_vel_R) { - // too nonlinear to init on completely wrong - VectorXd current_x = this->kf->get_x(); - MatrixXdr current_P = this->kf->get_P(); - MatrixXdr init_P = this->kf->get_initial_P(); - const MatrixXdr &reset_orientation_P = this->kf->get_reset_orientation_P(); - int non_ecef_state_err_len = init_P.rows() - (STATE_ECEF_POS_ERR_LEN + STATE_ECEF_ORIENTATION_ERR_LEN + STATE_ECEF_VELOCITY_ERR_LEN); - - current_x.segment(STATE_ECEF_ORIENTATION_START) = init_orient; - current_x.segment(STATE_ECEF_VELOCITY_START) = init_vel; - current_x.segment(STATE_ECEF_POS_START) = init_pos; - - init_P.block(STATE_ECEF_POS_ERR_START, STATE_ECEF_POS_ERR_START).diagonal() = init_pos_R.diagonal(); - init_P.block(STATE_ECEF_ORIENTATION_ERR_START, STATE_ECEF_ORIENTATION_ERR_START).diagonal() = reset_orientation_P.diagonal(); - init_P.block(STATE_ECEF_VELOCITY_ERR_START, STATE_ECEF_VELOCITY_ERR_START).diagonal() = init_vel_R.diagonal(); - init_P.block(STATE_ANGULAR_VELOCITY_ERR_START, STATE_ANGULAR_VELOCITY_ERR_START, non_ecef_state_err_len, non_ecef_state_err_len).diagonal() = current_P.block(STATE_ANGULAR_VELOCITY_ERR_START, - STATE_ANGULAR_VELOCITY_ERR_START, non_ecef_state_err_len, non_ecef_state_err_len).diagonal(); - - this->reset_kalman(current_time, current_x, init_P); -} - -void Localizer::reset_kalman(double current_time, const VectorXd &init_x, const MatrixXdr &init_P) { - this->kf->init_state(init_x, init_P, current_time); - this->last_reset_time = current_time; - this->reset_tracker += 1.0; -} - -void Localizer::handle_msg_bytes(const char *data, const size_t size) { - AlignedBuffer aligned_buf; - - capnp::FlatArrayMessageReader cmsg(aligned_buf.align(data, size)); - cereal::Event::Reader event = cmsg.getRoot(); - - this->handle_msg(event); -} - -void Localizer::handle_msg(const cereal::Event::Reader& log) { - double t = log.getLogMonoTime() * 1e-9; - this->time_check(t); - if (log.isAccelerometer()) { - this->handle_sensor(t, log.getAccelerometer()); - } else if (log.isGyroscope()) { - this->handle_sensor(t, log.getGyroscope()); - } else if (log.isGpsLocation()) { - this->handle_gps(t, log.getGpsLocation(), GPS_QUECTEL_SENSOR_TIME_OFFSET); - } else if (log.isGpsLocationExternal()) { - this->handle_gps(t, log.getGpsLocationExternal(), GPS_UBLOX_SENSOR_TIME_OFFSET); - //} else if (log.isGnssMeasurements()) { - // this->handle_gnss(t, log.getGnssMeasurements()); - } else if (log.isCarState()) { - this->handle_car_state(t, log.getCarState()); - } else if (log.isCameraOdometry()) { - this->handle_cam_odo(t, log.getCameraOdometry()); - } else if (log.isLiveCalibration()) { - this->handle_live_calib(t, log.getLiveCalibration()); - } - this->finite_check(); - this->update_reset_tracker(); -} - -kj::ArrayPtr Localizer::get_message_bytes(MessageBuilder& msg_builder, bool inputsOK, - bool sensorsOK, bool gpsOK, bool msgValid) { - cereal::Event::Builder evt = msg_builder.initEvent(); - evt.setValid(msgValid); - cereal::LiveLocationKalman::Builder liveLoc = evt.initLiveLocationKalman(); - this->build_live_location(liveLoc); - liveLoc.setSensorsOK(sensorsOK); - liveLoc.setGpsOK(gpsOK); - liveLoc.setInputsOK(inputsOK); - return msg_builder.toBytes(); -} - -bool Localizer::is_gps_ok() { - return (this->kf->get_filter_time() - this->last_gps_msg) < 2.0; -} - -bool Localizer::critical_services_valid(const std::map &critical_services) { - for (auto &kv : critical_services){ - if (kv.second >= INPUT_INVALID_THRESHOLD){ - return false; - } - } - return true; -} - -bool Localizer::is_timestamp_valid(double current_time) { - double filter_time = this->kf->get_filter_time(); - if (!std::isnan(filter_time) && ((filter_time - current_time) > MAX_FILTER_REWIND_TIME)) { - LOGE("Observation timestamp is older than the max rewind threshold of the filter"); - return false; - } - return true; -} - -void Localizer::determine_gps_mode(double current_time) { - // 1. If the pos_std is greater than what's not acceptable and localizer is in gps-mode, reset to no-gps-mode - // 2. If the pos_std is greater than what's not acceptable and localizer is in no-gps-mode, fake obs - // 3. If the pos_std is smaller than what's not acceptable, let gps-mode be whatever it is - VectorXd current_pos_std = this->kf->get_P().block(STATE_ECEF_POS_ERR_START, STATE_ECEF_POS_ERR_START).diagonal().array().sqrt(); - if (current_pos_std.norm() > SANE_GPS_UNCERTAINTY){ - if (this->gps_mode){ - this->gps_mode = false; - this->reset_kalman(current_time); - } else { - this->input_fake_gps_observations(current_time); - } - } -} - -void Localizer::configure_gnss_source(const LocalizerGnssSource &source) { - this->gnss_source = source; - if (source == LocalizerGnssSource::UBLOX) { - this->gps_std_factor = 10.0; - this->gps_variance_factor = 1.0; - this->gps_vertical_variance_factor = 1.0; - this->gps_time_offset = GPS_UBLOX_SENSOR_TIME_OFFSET; - } else { - this->gps_std_factor = 2.0; - this->gps_variance_factor = 0.0; - this->gps_vertical_variance_factor = 3.0; - this->gps_time_offset = GPS_QUECTEL_SENSOR_TIME_OFFSET; - } -} - -int Localizer::locationd_thread() { - Params params; - LocalizerGnssSource source; - const char* gps_location_socket; - if (params.getBool("UbloxAvailable")) { - source = LocalizerGnssSource::UBLOX; - gps_location_socket = "gpsLocationExternal"; - } else { - source = LocalizerGnssSource::QCOM; - gps_location_socket = "gpsLocation"; - } - - this->configure_gnss_source(source); - const std::initializer_list service_list = {gps_location_socket, "cameraOdometry", "liveCalibration", - "carState", "accelerometer", "gyroscope"}; - - SubMaster sm(service_list, {}, nullptr, {gps_location_socket}); - PubMaster pm({"liveLocationKalman"}); - - uint64_t cnt = 0; - bool filterInitialized = false; - const std::vector critical_input_services = {"cameraOdometry", "liveCalibration", "accelerometer", "gyroscope"}; - for (std::string service : critical_input_services) { - this->observation_values_invalid.insert({service, 0.0}); - } - - bool ignore_gps = true; - while (!do_exit) { - sm.update(); - if (filterInitialized){ - this->observation_timings_invalid_reset(); - for (const char* service : service_list) { - if (sm.updated(service) && sm.valid(service)){ - const cereal::Event::Reader log = sm[service]; - this->handle_msg(log); - } - } - } else { - //filterInitialized = sm.allAliveAndValid(); - bool allValid = true; - for (const char* service : service_list) { - if (service != gps_location_socket && !sm.valid(service)) { - allValid = false; - break; - } - } - filterInitialized = allValid; - } - - const char* trigger_msg = "cameraOdometry"; - if (sm.updated(trigger_msg)) { - bool inputsOK = sm.allValid() && this->are_inputs_ok(); - if (ignore_gps) { - inputsOK = this->are_inputs_ok(); - } - bool gpsOK = this->is_gps_ok(); - bool sensorsOK = sm.allAliveAndValid({"accelerometer", "gyroscope"}); - - /* - if (!sm.allValid()) { - for (const char* service : service_list) { - if (!sm.valid(service)) { - printf("Service %s is INVALID! (Alive: %d)\n", service, sm.alive(service)); - } - } - } - printf("InputsOK: %d, SensorsOK: %d, GPSOK: %d, FilterInitialized: %d\n", inputsOK, sensorsOK, gpsOK, filterInitialized); - */ - - // Log time to first fix - if (gpsOK && std::isnan(this->ttff) && !std::isnan(this->first_valid_log_time)) { - this->ttff = std::max(1e-3, (sm[trigger_msg].getLogMonoTime() * 1e-9) - this->first_valid_log_time); - } - - MessageBuilder msg_builder; - kj::ArrayPtr bytes = this->get_message_bytes(msg_builder, inputsOK, sensorsOK, gpsOK, filterInitialized); - pm.send("liveLocationKalman", bytes.begin(), bytes.size()); - - if (cnt % 1200 == 0 && gpsOK) { // once a minute - //ignore_gps = false; - VectorXd posGeo = this->get_position_geodetic(); - std::string lastGPSPosJSON = util::string_format( - "{\"latitude\": %.15f, \"longitude\": %.15f, \"altitude\": %.15f}", posGeo(0), posGeo(1), posGeo(2)); - params.putNonBlocking("LastGPSPosition", lastGPSPosJSON); - } - cnt++; - } - } - return 0; -} - -int main() { - util::set_realtime_priority(5); - - Localizer localizer; - return localizer.locationd_thread(); -} diff --git a/selfdrive/locationd/locationd.h b/selfdrive/locationd/locationd.h deleted file mode 100644 index 47c8bf56..00000000 --- a/selfdrive/locationd/locationd.h +++ /dev/null @@ -1,100 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include - -#include "cereal/messaging/messaging.h" -#include "common/transformations/coordinates.hpp" -#include "common/transformations/orientation.hpp" -#include "common/params.h" -#include "common/swaglog.h" -#include "common/timing.h" -#include "common/util.h" - -#include "system/sensord/sensors/constants.h" -#define VISION_DECIMATION 2 -#define SENSOR_DECIMATION 10 -#include "selfdrive/locationd/models/live_kf.h" - -#define POSENET_STD_HIST_HALF 20 - -enum LocalizerGnssSource { - UBLOX, QCOM -}; - -class Localizer { -public: - Localizer(LocalizerGnssSource gnss_source = LocalizerGnssSource::UBLOX); - - int locationd_thread(); - - void reset_kalman(double current_time = NAN); - void reset_kalman(double current_time, const Eigen::VectorXd &init_orient, const Eigen::VectorXd &init_pos, const Eigen::VectorXd &init_vel, const MatrixXdr &init_pos_R, const MatrixXdr &init_vel_R); - void reset_kalman(double current_time, const Eigen::VectorXd &init_x, const MatrixXdr &init_P); - void finite_check(double current_time = NAN); - void time_check(double current_time = NAN); - void update_reset_tracker(); - bool is_gps_ok(); - bool critical_services_valid(const std::map &critical_services); - bool is_timestamp_valid(double current_time); - void determine_gps_mode(double current_time); - bool are_inputs_ok(); - void observation_timings_invalid_reset(); - - kj::ArrayPtr get_message_bytes(MessageBuilder& msg_builder, - bool inputsOK, bool sensorsOK, bool gpsOK, bool msgValid); - void build_live_location(cereal::LiveLocationKalman::Builder& fix); - - Eigen::VectorXd get_position_geodetic(); - Eigen::VectorXd get_state(); - Eigen::VectorXd get_stdev(); - - void handle_msg_bytes(const char *data, const size_t size); - void handle_msg(const cereal::Event::Reader& log); - void handle_sensor(double current_time, const cereal::SensorEventData::Reader& log); - void handle_gps(double current_time, const cereal::GpsLocationData::Reader& log, const double sensor_time_offset); - void handle_gnss(double current_time, const cereal::GnssMeasurements::Reader& log); - void handle_car_state(double current_time, const cereal::CarState::Reader& log); - void handle_cam_odo(double current_time, const cereal::CameraOdometry::Reader& log); - void handle_live_calib(double current_time, const cereal::LiveCalibrationData::Reader& log); - - void input_fake_gps_observations(double current_time); - -private: - std::unique_ptr kf; - - Eigen::VectorXd calib; - MatrixXdr device_from_calib; - MatrixXdr calib_from_device; - bool calibrated = false; - - double car_speed = 0.0; - double last_reset_time = NAN; - std::deque posenet_stds; - - std::unique_ptr converter; - - int64_t unix_timestamp_millis = 0; - double reset_tracker = 0.0; - bool device_fell = false; - bool gps_mode = false; - double first_valid_log_time = NAN; - double ttff = NAN; - double last_gps_msg = 0; - LocalizerGnssSource gnss_source; - bool observation_timings_invalid = false; - std::map observation_values_invalid; - bool standstill = true; - int32_t orientation_reset_count = 0; - float gps_std_factor; - float gps_variance_factor; - float gps_vertical_variance_factor; - double gps_time_offset; - Eigen::VectorXd camodo_yawrate_distribution = Eigen::Vector2d(0.0, 10.0); // mean, std - - void configure_gnss_source(const LocalizerGnssSource &source); -}; diff --git a/selfdrive/locationd/models/car_kf.py b/selfdrive/locationd/models/car_kf.py index 6db749b9..27cc4ef9 100755 --- a/selfdrive/locationd/models/car_kf.py +++ b/selfdrive/locationd/models/car_kf.py @@ -5,7 +5,7 @@ from typing import Any import numpy as np -from opendbc.car.vehicle_model import ACCELERATION_DUE_TO_GRAVITY +from openpilot.common.constants import ACCELERATION_DUE_TO_GRAVITY from openpilot.selfdrive.locationd.models.constants import ObservationKind from openpilot.common.swaglog import cloudlog diff --git a/selfdrive/locationd/models/live_kf.cc b/selfdrive/locationd/models/live_kf.cc deleted file mode 100644 index fc3bfb72..00000000 --- a/selfdrive/locationd/models/live_kf.cc +++ /dev/null @@ -1,122 +0,0 @@ -#include "selfdrive/locationd/models/live_kf.h" - -using namespace EKFS; -using namespace Eigen; - -Eigen::Map get_mapvec(const Eigen::VectorXd &vec) { - return Eigen::Map((double*)vec.data(), vec.rows(), vec.cols()); -} - -Eigen::Map get_mapmat(const MatrixXdr &mat) { - return Eigen::Map((double*)mat.data(), mat.rows(), mat.cols()); -} - -std::vector> get_vec_mapvec(const std::vector &vec_vec) { - std::vector> res; - for (const Eigen::VectorXd &vec : vec_vec) { - res.push_back(get_mapvec(vec)); - } - return res; -} - -std::vector> get_vec_mapmat(const std::vector &mat_vec) { - std::vector> res; - for (const MatrixXdr &mat : mat_vec) { - res.push_back(get_mapmat(mat)); - } - return res; -} - -LiveKalman::LiveKalman() { - this->dim_state = live_initial_x.rows(); - this->dim_state_err = live_initial_P_diag.rows(); - - this->initial_x = live_initial_x; - this->initial_P = live_initial_P_diag.asDiagonal(); - this->fake_gps_pos_cov = live_fake_gps_pos_cov_diag.asDiagonal(); - this->fake_gps_vel_cov = live_fake_gps_vel_cov_diag.asDiagonal(); - this->reset_orientation_P = live_reset_orientation_diag.asDiagonal(); - this->Q = live_Q_diag.asDiagonal(); - for (auto& pair : live_obs_noise_diag) { - this->obs_noise[pair.first] = pair.second.asDiagonal(); - } - - // init filter - this->filter = std::make_shared(this->name, get_mapmat(this->Q), get_mapvec(this->initial_x), - get_mapmat(initial_P), this->dim_state, this->dim_state_err, 0, 0, 0, std::vector(), - std::vector{3}, std::vector(), 0.8); -} - -void LiveKalman::init_state(const VectorXd &state, const VectorXd &covs_diag, double filter_time) { - MatrixXdr covs = covs_diag.asDiagonal(); - this->filter->init_state(get_mapvec(state), get_mapmat(covs), filter_time); -} - -void LiveKalman::init_state(const VectorXd &state, const MatrixXdr &covs, double filter_time) { - this->filter->init_state(get_mapvec(state), get_mapmat(covs), filter_time); -} - -void LiveKalman::init_state(const VectorXd &state, double filter_time) { - MatrixXdr covs = this->filter->covs(); - this->filter->init_state(get_mapvec(state), get_mapmat(covs), filter_time); -} - -VectorXd LiveKalman::get_x() { - return this->filter->state(); -} - -MatrixXdr LiveKalman::get_P() { - return this->filter->covs(); -} - -double LiveKalman::get_filter_time() { - return this->filter->get_filter_time(); -} - -std::vector LiveKalman::get_R(int kind, int n) { - std::vector R; - for (int i = 0; i < n; i++) { - R.push_back(this->obs_noise[kind]); - } - return R; -} - -std::optional LiveKalman::predict_and_observe(double t, int kind, const std::vector &meas, std::vector R) { - std::optional r; - if (R.size() == 0) { - R = this->get_R(kind, meas.size()); - } - r = this->filter->predict_and_update_batch(t, kind, get_vec_mapvec(meas), get_vec_mapmat(R)); - return r; -} - -void LiveKalman::predict(double t) { - this->filter->predict(t); -} - -const Eigen::VectorXd &LiveKalman::get_initial_x() { - return this->initial_x; -} - -const MatrixXdr &LiveKalman::get_initial_P() { - return this->initial_P; -} - -const MatrixXdr &LiveKalman::get_fake_gps_pos_cov() { - return this->fake_gps_pos_cov; -} - -const MatrixXdr &LiveKalman::get_fake_gps_vel_cov() { - return this->fake_gps_vel_cov; -} - -const MatrixXdr &LiveKalman::get_reset_orientation_P() { - return this->reset_orientation_P; -} - -MatrixXdr LiveKalman::H(const VectorXd &in) { - assert(in.size() == 6); - Matrix res; - this->filter->get_extra_routine("H")((double*)in.data(), res.data()); - return res; -} diff --git a/selfdrive/locationd/models/live_kf.h b/selfdrive/locationd/models/live_kf.h deleted file mode 100644 index e4b3e326..00000000 --- a/selfdrive/locationd/models/live_kf.h +++ /dev/null @@ -1,66 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include - -#include -#include - -#include "generated/live_kf_constants.h" -#include "rednose/helpers/ekf_sym.h" - -#define EARTH_GM 3.986005e14 // m^3/s^2 (gravitational constant * mass of earth) - -using namespace EKFS; - -Eigen::Map get_mapvec(const Eigen::VectorXd &vec); -Eigen::Map get_mapmat(const MatrixXdr &mat); -std::vector> get_vec_mapvec(const std::vector &vec_vec); -std::vector> get_vec_mapmat(const std::vector &mat_vec); - -class LiveKalman { -public: - LiveKalman(); - - void init_state(const Eigen::VectorXd &state, const Eigen::VectorXd &covs_diag, double filter_time); - void init_state(const Eigen::VectorXd &state, const MatrixXdr &covs, double filter_time); - void init_state(const Eigen::VectorXd &state, double filter_time); - - Eigen::VectorXd get_x(); - MatrixXdr get_P(); - double get_filter_time(); - std::vector get_R(int kind, int n); - - std::optional predict_and_observe(double t, int kind, const std::vector &meas, std::vector R = {}); - std::optional predict_and_update_odo_speed(std::vector speed, double t, int kind); - std::optional predict_and_update_odo_trans(std::vector trans, double t, int kind); - std::optional predict_and_update_odo_rot(std::vector rot, double t, int kind); - void predict(double t); - - const Eigen::VectorXd &get_initial_x(); - const MatrixXdr &get_initial_P(); - const MatrixXdr &get_fake_gps_pos_cov(); - const MatrixXdr &get_fake_gps_vel_cov(); - const MatrixXdr &get_reset_orientation_P(); - - MatrixXdr H(const Eigen::VectorXd &in); - -private: - std::string name = "live"; - - std::shared_ptr filter; - - int dim_state; - int dim_state_err; - - Eigen::VectorXd initial_x; - MatrixXdr initial_P; - MatrixXdr fake_gps_pos_cov; - MatrixXdr fake_gps_vel_cov; - MatrixXdr reset_orientation_P; - MatrixXdr Q; // process noise - std::unordered_map obs_noise; -}; diff --git a/selfdrive/locationd/models/live_kf.py b/selfdrive/locationd/models/live_kf.py deleted file mode 100755 index 0cc3eca6..00000000 --- a/selfdrive/locationd/models/live_kf.py +++ /dev/null @@ -1,242 +0,0 @@ -#!/usr/bin/env python3 - -import sys -import os -import numpy as np - -from openpilot.selfdrive.locationd.models.constants import ObservationKind - -import sympy as sp -import inspect -from rednose.helpers.sympy_helpers import euler_rotate, quat_matrix_r, quat_rotate -from rednose.helpers.ekf_sym import gen_code - -EARTH_GM = 3.986005e14 # m^3/s^2 (gravitational constant * mass of earth) - - -def numpy2eigenstring(arr): - assert(len(arr.shape) == 1) - arr_str = np.array2string(arr, precision=20, separator=',')[1:-1].replace(' ', '').replace('\n', '') - return f"(Eigen::VectorXd({len(arr)}) << {arr_str}).finished()" - - -class States: - ECEF_POS = slice(0, 3) # x, y and z in ECEF in meters - ECEF_ORIENTATION = slice(3, 7) # quat for pose of phone in ecef - ECEF_VELOCITY = slice(7, 10) # ecef velocity in m/s - ANGULAR_VELOCITY = slice(10, 13) # roll, pitch and yaw rates in device frame in radians/s - GYRO_BIAS = slice(13, 16) # roll, pitch and yaw biases - ACCELERATION = slice(16, 19) # Acceleration in device frame in m/s**2 - ACC_BIAS = slice(19, 22) # Acceletometer bias in m/s**2 - - # Error-state has different slices because it is an ESKF - ECEF_POS_ERR = slice(0, 3) - ECEF_ORIENTATION_ERR = slice(3, 6) # euler angles for orientation error - ECEF_VELOCITY_ERR = slice(6, 9) - ANGULAR_VELOCITY_ERR = slice(9, 12) - GYRO_BIAS_ERR = slice(12, 15) - ACCELERATION_ERR = slice(15, 18) - ACC_BIAS_ERR = slice(18, 21) - - -class LiveKalman: - name = 'live' - - initial_x = np.array([3.88e6, -3.37e6, 3.76e6, - 0.42254641, -0.31238054, -0.83602975, -0.15788347, # NED [0,0,0] -> ECEF Quat - 0, 0, 0, - 0, 0, 0, - 0, 0, 0, - 0, 0, 0, - 0, 0, 0]) - - # state covariance - initial_P_diag = np.array([10**2, 10**2, 10**2, - 0.01**2, 0.01**2, 0.01**2, - 10**2, 10**2, 10**2, - 1**2, 1**2, 1**2, - 1**2, 1**2, 1**2, - 100**2, 100**2, 100**2, - 0.01**2, 0.01**2, 0.01**2]) - - # state covariance when resetting midway in a segment - reset_orientation_diag = np.array([1**2, 1**2, 1**2]) - - # fake observation covariance, to ensure the uncertainty estimate of the filter is under control - fake_gps_pos_cov_diag = np.array([1000**2, 1000**2, 1000**2]) - fake_gps_vel_cov_diag = np.array([10**2, 10**2, 10**2]) - - # process noise - Q_diag = np.array([0.03**2, 0.03**2, 0.03**2, - 0.001**2, 0.001**2, 0.001**2, - 0.01**2, 0.01**2, 0.01**2, - 0.1**2, 0.1**2, 0.1**2, - (0.005 / 100)**2, (0.005 / 100)**2, (0.005 / 100)**2, - 3**2, 3**2, 3**2, - 0.005**2, 0.005**2, 0.005**2]) - - obs_noise_diag = {ObservationKind.PHONE_GYRO: np.array([0.025**2, 0.025**2, 0.025**2]), - ObservationKind.PHONE_ACCEL: np.array([.5**2, .5**2, .5**2]), - ObservationKind.CAMERA_ODO_ROTATION: np.array([0.05**2, 0.05**2, 0.05**2]), - ObservationKind.NO_ROT: np.array([0.005**2, 0.005**2, 0.005**2]), - ObservationKind.NO_ACCEL: np.array([0.05**2, 0.05**2, 0.05**2]), - ObservationKind.ECEF_POS: np.array([5**2, 5**2, 5**2]), - ObservationKind.ECEF_VEL: np.array([.5**2, .5**2, .5**2]), - ObservationKind.ECEF_ORIENTATION_FROM_GPS: np.array([.2**2, .2**2, .2**2, .2**2])} - - @staticmethod - def generate_code(generated_dir): - name = LiveKalman.name - dim_state = LiveKalman.initial_x.shape[0] - dim_state_err = LiveKalman.initial_P_diag.shape[0] - - state_sym = sp.MatrixSymbol('state', dim_state, 1) - state = sp.Matrix(state_sym) - x, y, z = state[States.ECEF_POS, :] - q = state[States.ECEF_ORIENTATION, :] - v = state[States.ECEF_VELOCITY, :] - vx, vy, vz = v - omega = state[States.ANGULAR_VELOCITY, :] - vroll, vpitch, vyaw = omega - roll_bias, pitch_bias, yaw_bias = state[States.GYRO_BIAS, :] - acceleration = state[States.ACCELERATION, :] - acc_bias = state[States.ACC_BIAS, :] - - dt = sp.Symbol('dt') - - # calibration and attitude rotation matrices - quat_rot = quat_rotate(*q) - - # Got the quat predict equations from here - # A New Quaternion-Based Kalman Filter for - # Real-Time Attitude Estimation Using the Two-Step - # Geometrically-Intuitive Correction Algorithm - A = 0.5 * sp.Matrix([[0, -vroll, -vpitch, -vyaw], - [vroll, 0, vyaw, -vpitch], - [vpitch, -vyaw, 0, vroll], - [vyaw, vpitch, -vroll, 0]]) - q_dot = A * q - - # Time derivative of the state as a function of state - state_dot = sp.Matrix(np.zeros((dim_state, 1))) - state_dot[States.ECEF_POS, :] = v - state_dot[States.ECEF_ORIENTATION, :] = q_dot - state_dot[States.ECEF_VELOCITY, 0] = quat_rot * acceleration - - # Basic descretization, 1st order intergrator - # Can be pretty bad if dt is big - f_sym = state + dt * state_dot - - state_err_sym = sp.MatrixSymbol('state_err', dim_state_err, 1) - state_err = sp.Matrix(state_err_sym) - quat_err = state_err[States.ECEF_ORIENTATION_ERR, :] - v_err = state_err[States.ECEF_VELOCITY_ERR, :] - omega_err = state_err[States.ANGULAR_VELOCITY_ERR, :] - acceleration_err = state_err[States.ACCELERATION_ERR, :] - - # Time derivative of the state error as a function of state error and state - quat_err_matrix = euler_rotate(quat_err[0], quat_err[1], quat_err[2]) - q_err_dot = quat_err_matrix * quat_rot * (omega + omega_err) - state_err_dot = sp.Matrix(np.zeros((dim_state_err, 1))) - state_err_dot[States.ECEF_POS_ERR, :] = v_err - state_err_dot[States.ECEF_ORIENTATION_ERR, :] = q_err_dot - state_err_dot[States.ECEF_VELOCITY_ERR, :] = quat_err_matrix * quat_rot * (acceleration + acceleration_err) - f_err_sym = state_err + dt * state_err_dot - - # Observation matrix modifier - H_mod_sym = sp.Matrix(np.zeros((dim_state, dim_state_err))) - H_mod_sym[States.ECEF_POS, States.ECEF_POS_ERR] = np.eye(States.ECEF_POS.stop - States.ECEF_POS.start) - H_mod_sym[States.ECEF_ORIENTATION, States.ECEF_ORIENTATION_ERR] = 0.5 * quat_matrix_r(state[3:7])[:, 1:] - H_mod_sym[States.ECEF_ORIENTATION.stop:, States.ECEF_ORIENTATION_ERR.stop:] = np.eye(dim_state - States.ECEF_ORIENTATION.stop) - - # these error functions are defined so that say there - # is a nominal x and true x: - # true x = err_function(nominal x, delta x) - # delta x = inv_err_function(nominal x, true x) - nom_x = sp.MatrixSymbol('nom_x', dim_state, 1) - true_x = sp.MatrixSymbol('true_x', dim_state, 1) - delta_x = sp.MatrixSymbol('delta_x', dim_state_err, 1) - - err_function_sym = sp.Matrix(np.zeros((dim_state, 1))) - delta_quat = sp.Matrix(np.ones(4)) - delta_quat[1:, :] = sp.Matrix(0.5 * delta_x[States.ECEF_ORIENTATION_ERR, :]) - err_function_sym[States.ECEF_POS, :] = sp.Matrix(nom_x[States.ECEF_POS, :] + delta_x[States.ECEF_POS_ERR, :]) - err_function_sym[States.ECEF_ORIENTATION, 0] = quat_matrix_r(nom_x[States.ECEF_ORIENTATION, 0]) * delta_quat - err_function_sym[States.ECEF_ORIENTATION.stop:, :] = sp.Matrix(nom_x[States.ECEF_ORIENTATION.stop:, :] + delta_x[States.ECEF_ORIENTATION_ERR.stop:, :]) - - inv_err_function_sym = sp.Matrix(np.zeros((dim_state_err, 1))) - inv_err_function_sym[States.ECEF_POS_ERR, 0] = sp.Matrix(-nom_x[States.ECEF_POS, 0] + true_x[States.ECEF_POS, 0]) - delta_quat = quat_matrix_r(nom_x[States.ECEF_ORIENTATION, 0]).T * true_x[States.ECEF_ORIENTATION, 0] - inv_err_function_sym[States.ECEF_ORIENTATION_ERR, 0] = sp.Matrix(2 * delta_quat[1:]) - inv_err_function_sym[States.ECEF_ORIENTATION_ERR.stop:, 0] = sp.Matrix(-nom_x[States.ECEF_ORIENTATION.stop:, 0] + true_x[States.ECEF_ORIENTATION.stop:, 0]) - - eskf_params = [[err_function_sym, nom_x, delta_x], - [inv_err_function_sym, nom_x, true_x], - H_mod_sym, f_err_sym, state_err_sym] - # - # Observation functions - # - h_gyro_sym = sp.Matrix([ - vroll + roll_bias, - vpitch + pitch_bias, - vyaw + yaw_bias]) - - pos = sp.Matrix([x, y, z]) - gravity = quat_rot.T * ((EARTH_GM / ((x**2 + y**2 + z**2)**(3.0 / 2.0))) * pos) - h_acc_sym = (gravity + acceleration + acc_bias) - h_acc_stationary_sym = acceleration - h_phone_rot_sym = sp.Matrix([vroll, vpitch, vyaw]) - h_pos_sym = sp.Matrix([x, y, z]) - h_vel_sym = sp.Matrix([vx, vy, vz]) - h_orientation_sym = q - h_relative_motion = sp.Matrix(quat_rot.T * v) - - obs_eqs = [[h_gyro_sym, ObservationKind.PHONE_GYRO, None], - [h_phone_rot_sym, ObservationKind.NO_ROT, None], - [h_acc_sym, ObservationKind.PHONE_ACCEL, None], - [h_pos_sym, ObservationKind.ECEF_POS, None], - [h_vel_sym, ObservationKind.ECEF_VEL, None], - [h_orientation_sym, ObservationKind.ECEF_ORIENTATION_FROM_GPS, None], - [h_relative_motion, ObservationKind.CAMERA_ODO_TRANSLATION, None], - [h_phone_rot_sym, ObservationKind.CAMERA_ODO_ROTATION, None], - [h_acc_stationary_sym, ObservationKind.NO_ACCEL, None]] - - # this returns a sympy routine for the jacobian of the observation function of the local vel - in_vec = sp.MatrixSymbol('in_vec', 6, 1) # roll, pitch, yaw, vx, vy, vz - h = euler_rotate(in_vec[0], in_vec[1], in_vec[2]).T * (sp.Matrix([in_vec[3], in_vec[4], in_vec[5]])) - extra_routines = [('H', h.jacobian(in_vec), [in_vec])] - - gen_code(generated_dir, name, f_sym, dt, state_sym, obs_eqs, dim_state, dim_state_err, eskf_params, extra_routines=extra_routines) - - # write constants to extra header file for use in cpp - live_kf_header = "#pragma once\n\n" - live_kf_header += "#include \n" - live_kf_header += "#include \n\n" - for state, slc in inspect.getmembers(States, lambda x: isinstance(x, slice)): - assert(slc.step is None) # unsupported - live_kf_header += f'#define STATE_{state}_START {slc.start}\n' - live_kf_header += f'#define STATE_{state}_END {slc.stop}\n' - live_kf_header += f'#define STATE_{state}_LEN {slc.stop - slc.start}\n' - live_kf_header += "\n" - - for kind, val in inspect.getmembers(ObservationKind, lambda x: isinstance(x, int)): - live_kf_header += f'#define OBSERVATION_{kind} {val}\n' - live_kf_header += "\n" - - live_kf_header += f"static const Eigen::VectorXd live_initial_x = {numpy2eigenstring(LiveKalman.initial_x)};\n" - live_kf_header += f"static const Eigen::VectorXd live_initial_P_diag = {numpy2eigenstring(LiveKalman.initial_P_diag)};\n" - live_kf_header += f"static const Eigen::VectorXd live_fake_gps_pos_cov_diag = {numpy2eigenstring(LiveKalman.fake_gps_pos_cov_diag)};\n" - live_kf_header += f"static const Eigen::VectorXd live_fake_gps_vel_cov_diag = {numpy2eigenstring(LiveKalman.fake_gps_vel_cov_diag)};\n" - live_kf_header += f"static const Eigen::VectorXd live_reset_orientation_diag = {numpy2eigenstring(LiveKalman.reset_orientation_diag)};\n" - live_kf_header += f"static const Eigen::VectorXd live_Q_diag = {numpy2eigenstring(LiveKalman.Q_diag)};\n" - live_kf_header += "static const std::unordered_map> live_obs_noise_diag = {\n" - for kind, noise in LiveKalman.obs_noise_diag.items(): - live_kf_header += f" {{ {kind}, {numpy2eigenstring(noise)} }},\n" - live_kf_header += "};\n\n" - - open(os.path.join(generated_dir, "live_kf_constants.h"), 'w').write(live_kf_header) - - -if __name__ == "__main__": - generated_dir = sys.argv[2] - LiveKalman.generate_code(generated_dir) diff --git a/selfdrive/locationd/paramsd.py b/selfdrive/locationd/paramsd.py index b166daed..43f9caf0 100644 --- a/selfdrive/locationd/paramsd.py +++ b/selfdrive/locationd/paramsd.py @@ -15,6 +15,8 @@ from openpilot.selfdrive.locationd.models.constants import GENERATED_DIR from openpilot.selfdrive.locationd.helpers import PoseCalibrator, Pose from openpilot.common.swaglog import cloudlog +from openpilot.common.gps import get_gps_location_service + MAX_ANGLE_OFFSET_DELTA = 20 * DT_MDL # Max 20 deg/s ROLL_MAX_DELTA = np.radians(20.0) * DT_MDL # 20deg in 1 second is well within curvature limits ROLL_MIN, ROLL_MAX = np.radians(-10), np.radians(10) @@ -245,7 +247,9 @@ def retrieve_initial_vehicle_params(params: Params, CP: car.CarParams, replay: b if debug and len(initial_filter_std) != 0: p_initial = np.diag(initial_filter_std) - steer_ratio, stiffness_factor, angle_offset_deg = lp.steerRatio, lp.stiffnessFactor, lp.angleOffsetAverageDeg + #steer_ratio, stiffness_factor, angle_offset_deg = lp.steerRatio, lp.stiffnessFactor, lp.angleOffsetAverageDeg + #steer_ratio, stiffness_factor, angle_offset_deg = lp.steerRatio, lp.stiffnessFactor, lp.angleOffsetDeg + steer_ratio, stiffness_factor = lp.steerRatio, lp.stiffnessFactor retrieve_success = True except Exception as e: cloudlog.error(f"Failed to retrieve initial values: {e}") @@ -269,7 +273,8 @@ def main(): REPLAY = bool(int(os.getenv("REPLAY", "0"))) pm = messaging.PubMaster(['liveParameters']) - sm = messaging.SubMaster(['livePose', 'liveCalibration', 'carState', 'liveLocationKalman'], poll='livePose') + gps_location_service = get_gps_location_service(Params()) + sm = messaging.SubMaster(['livePose', 'liveCalibration', 'carState', gps_location_service], poll='livePose', ignore_alive=[gps_location_service], ignore_valid=[gps_location_service]) params = Params() CP = messaging.log_from_bytes(params.get("CarParams", block=True), car.CarParams) @@ -289,12 +294,12 @@ def main(): t = sm.logMonoTime[which] * 1e-9 learner.handle_log(t, which, sm[which]) - if sm.updated['liveLocationKalman']: - location = sm['liveLocationKalman'] - if (location.status == log.LiveLocationKalman.Status.valid) and location.positionGeodetic.valid and location.gpsOK: - bearing = math.degrees(location.calibratedOrientationNED.value[2]) - lat = location.positionGeodetic.value[0] - lon = location.positionGeodetic.value[1] + if sm.updated[gps_location_service]: + gps = sm[gps_location_service] + if gps.hasFix: + bearing = gps.bearingDeg + lat = gps.latitude + lon = gps.longitude params_memory.put("LastGPSPosition", json.dumps({"latitude": lat, "longitude": lon, "bearing": bearing})) diff --git a/selfdrive/locationd/torqued.py b/selfdrive/locationd/torqued.py index d86127de..1f9fc3c7 100644 --- a/selfdrive/locationd/torqued.py +++ b/selfdrive/locationd/torqued.py @@ -5,7 +5,7 @@ from collections import deque, defaultdict import cereal.messaging as messaging from cereal import car, log -from opendbc.car.vehicle_model import ACCELERATION_DUE_TO_GRAVITY +from openpilot.common.constants import ACCELERATION_DUE_TO_GRAVITY from openpilot.common.params import Params from openpilot.common.realtime import config_realtime_process, DT_MDL from openpilot.common.filter_simple import FirstOrderFilter diff --git a/selfdrive/modeld/SConscript b/selfdrive/modeld/SConscript index 6802b45e..8b33a457 100644 --- a/selfdrive/modeld/SConscript +++ b/selfdrive/modeld/SConscript @@ -1,11 +1,11 @@ import os import glob -Import('env', 'envCython', 'arch', 'cereal', 'messaging', 'common', 'gpucommon', 'visionipc', 'transformations') +Import('env', 'envCython', 'arch', 'cereal', 'messaging', 'common', 'visionipc', 'transformations') lenv = env.Clone() lenvCython = envCython.Clone() -libs = [cereal, messaging, visionipc, gpucommon, common, 'capnp', 'kj', 'pthread'] +libs = [cereal, messaging, visionipc, common, 'capnp', 'kj', 'pthread'] frameworks = [] common_src = [ @@ -32,7 +32,7 @@ lenvCython.Program('models/commonmodel_pyx.so', 'models/commonmodel_pyx.pyx', LI tinygrad_files = ["#"+x for x in glob.glob(env.Dir("#tinygrad_repo").relpath + "/**", recursive=True, root_dir=env.Dir("#").abspath) if 'pycache' not in x] # Get model metadata -for model_name in ['driving_vision', 'driving_policy']: +for model_name in ['driving_vision', 'driving_policy', 'dmonitoring_model']: fn = File(f"models/{model_name}").abspath script_files = [File(Dir("#selfdrive/modeld").File("get_model_metadata.py").abspath)] cmd = f'python3 {Dir("#selfdrive/modeld").abspath}/get_model_metadata.py {fn}.onnx' @@ -50,9 +50,9 @@ def tg_compile(flags, model_name): # Compile small models for model_name in ['driving_vision', 'driving_policy', 'dmonitoring_model']: flags = { - 'larch64': 'DEV=QCOM', - 'Darwin': 'DEV=CPU IMAGE=0', - }.get(arch, 'DEV=LLVM IMAGE=0') + 'larch64': 'DEV=QCOM FLOAT16=1 NOLOCALS=1 IMAGE=2 JIT_BATCH_SIZE=0', + 'Darwin': f'DEV=CPU HOME={os.path.expanduser("~")}', # tinygrad calls brew which needs a $HOME in the env + }.get(arch, 'DEV=CPU CPU_LLVM=1') tg_compile(flags, model_name) # Compile BIG model if USB GPU is available diff --git a/selfdrive/modeld/dmonitoringmodeld.py b/selfdrive/modeld/dmonitoringmodeld.py index 2a3df3f6..e0b0ba52 100755 --- a/selfdrive/modeld/dmonitoringmodeld.py +++ b/selfdrive/modeld/dmonitoringmodeld.py @@ -1,14 +1,9 @@ #!/usr/bin/env python3 import os from openpilot.system.hardware import TICI +os.environ['DEV'] = 'QCOM' if TICI else 'CPU' from tinygrad.tensor import Tensor from tinygrad.dtype import dtypes -if TICI: - from openpilot.selfdrive.modeld.runners.tinygrad_helpers import qcom_tensor_from_opencl_address - os.environ['QCOM'] = '1' -else: - os.environ['LLVM'] = '1' -import math import time import pickle import ctypes @@ -21,48 +16,16 @@ from cereal.messaging import PubMaster, SubMaster from msgq.visionipc import VisionIpcClient, VisionStreamType, VisionBuf from openpilot.common.swaglog import cloudlog from openpilot.common.realtime import config_realtime_process -from openpilot.common.transformations.model import dmonitoringmodel_intrinsics, DM_INPUT_SIZE +from openpilot.common.transformations.model import dmonitoringmodel_intrinsics from openpilot.common.transformations.camera import _ar_ox_fisheye, _os_fisheye from openpilot.selfdrive.modeld.models.commonmodel_pyx import CLContext, MonitoringModelFrame -from openpilot.selfdrive.modeld.parse_model_outputs import sigmoid -from openpilot.system import sentry - -MODEL_WIDTH, MODEL_HEIGHT = DM_INPUT_SIZE -CALIB_LEN = 3 -FEATURE_LEN = 512 -OUTPUT_SIZE = 84 + FEATURE_LEN +from openpilot.selfdrive.modeld.parse_model_outputs import sigmoid, safe_exp +from openpilot.selfdrive.modeld.runners.tinygrad_helpers import qcom_tensor_from_opencl_address PROCESS_NAME = "selfdrive.modeld.dmonitoringmodeld" SEND_RAW_PRED = os.getenv('SEND_RAW_PRED') MODEL_PKL_PATH = Path(__file__).parent / 'models/dmonitoring_model_tinygrad.pkl' - - -class DriverStateResult(ctypes.Structure): - _fields_ = [ - ("face_orientation", ctypes.c_float*3), - ("face_position", ctypes.c_float*3), - ("face_orientation_std", ctypes.c_float*3), - ("face_position_std", ctypes.c_float*3), - ("face_prob", ctypes.c_float), - ("_unused_a", ctypes.c_float*8), - ("left_eye_prob", ctypes.c_float), - ("_unused_b", ctypes.c_float*8), - ("right_eye_prob", ctypes.c_float), - ("left_blink_prob", ctypes.c_float), - ("right_blink_prob", ctypes.c_float), - ("sunglasses_prob", ctypes.c_float), - ("occluded_prob", ctypes.c_float), - ("ready_prob", ctypes.c_float*4), - ("not_ready_prob", ctypes.c_float*2)] - - -class DMonitoringModelResult(ctypes.Structure): - _fields_ = [ - ("driver_state_lhd", DriverStateResult), - ("driver_state_rhd", DriverStateResult), - ("poor_vision_prob", ctypes.c_float), - ("wheel_on_right_prob", ctypes.c_float), - ("features", ctypes.c_float*FEATURE_LEN)] +METADATA_PATH = Path(__file__).parent / 'models/dmonitoring_model_metadata.pkl' class ModelState: @@ -70,11 +33,14 @@ class ModelState: output: np.ndarray def __init__(self, cl_ctx): - assert ctypes.sizeof(DMonitoringModelResult) == OUTPUT_SIZE * ctypes.sizeof(ctypes.c_float) + with open(METADATA_PATH, 'rb') as f: + model_metadata = pickle.load(f) + self.input_shapes = model_metadata['input_shapes'] + self.output_slices = model_metadata['output_slices'] self.frame = MonitoringModelFrame(cl_ctx) self.numpy_inputs = { - 'calib': np.zeros((1, CALIB_LEN), dtype=np.float32), + 'calib': np.zeros(self.input_shapes['calib'], dtype=np.float32), } self.tensor_inputs = {k: Tensor(v, device='NPY').realize() for k,v in self.numpy_inputs.items()} @@ -90,45 +56,53 @@ class ModelState: if TICI: # The imgs tensors are backed by opencl memory, only need init once if 'input_img' not in self.tensor_inputs: - self.tensor_inputs['input_img'] = qcom_tensor_from_opencl_address(input_img_cl.mem_address, (1, MODEL_WIDTH*MODEL_HEIGHT), dtype=dtypes.uint8) + self.tensor_inputs['input_img'] = qcom_tensor_from_opencl_address(input_img_cl.mem_address, self.input_shapes['input_img'], dtype=dtypes.uint8) else: - self.tensor_inputs['input_img'] = Tensor(self.frame.buffer_from_cl(input_img_cl).reshape((1, MODEL_WIDTH*MODEL_HEIGHT)), dtype=dtypes.uint8).realize() + self.tensor_inputs['input_img'] = Tensor(self.frame.buffer_from_cl(input_img_cl).reshape(self.input_shapes['input_img']), dtype=dtypes.uint8).realize() - output = self.model_run(**self.tensor_inputs).numpy().flatten() + output = self.model_run(**self.tensor_inputs).contiguous().realize().uop.base.buffer.numpy() t2 = time.perf_counter() return output, t2 - t1 +def slice_outputs(model_outputs, output_slices): + return {k: model_outputs[np.newaxis, v] for k,v in output_slices.items()} -def fill_driver_state(msg, ds_result: DriverStateResult): - msg.faceOrientation = list(ds_result.face_orientation) - msg.faceOrientationStd = [math.exp(x) for x in ds_result.face_orientation_std] - msg.facePosition = list(ds_result.face_position[:2]) - msg.facePositionStd = [math.exp(x) for x in ds_result.face_position_std[:2]] - msg.faceProb = float(sigmoid(ds_result.face_prob)) - msg.leftEyeProb = float(sigmoid(ds_result.left_eye_prob)) - msg.rightEyeProb = float(sigmoid(ds_result.right_eye_prob)) - msg.leftBlinkProb = float(sigmoid(ds_result.left_blink_prob)) - msg.rightBlinkProb = float(sigmoid(ds_result.right_blink_prob)) - msg.sunglassesProb = float(sigmoid(ds_result.sunglasses_prob)) - msg.occludedProb = float(sigmoid(ds_result.occluded_prob)) - msg.readyProb = [float(sigmoid(x)) for x in ds_result.ready_prob] - msg.notReadyProb = [float(sigmoid(x)) for x in ds_result.not_ready_prob] +def parse_model_output(model_output): + parsed = {} + parsed['wheel_on_right'] = sigmoid(model_output['wheel_on_right']) + for ds_suffix in ['lhd', 'rhd']: + face_descs = model_output[f'face_descs_{ds_suffix}'] + parsed[f'face_descs_{ds_suffix}'] = face_descs[:, :-6] + parsed[f'face_descs_{ds_suffix}_std'] = safe_exp(face_descs[:, -6:]) + for key in ['face_prob', 'left_eye_prob', 'right_eye_prob','left_blink_prob', 'right_blink_prob', 'sunglasses_prob', 'using_phone_prob']: + parsed[f'{key}_{ds_suffix}'] = sigmoid(model_output[f'{key}_{ds_suffix}']) + return parsed +def fill_driver_data(msg, model_output, ds_suffix): + msg.faceOrientation = model_output[f'face_descs_{ds_suffix}'][0, :3].tolist() + msg.faceOrientationStd = model_output[f'face_descs_{ds_suffix}_std'][0, :3].tolist() + msg.facePosition = model_output[f'face_descs_{ds_suffix}'][0, 3:5].tolist() + msg.facePositionStd = model_output[f'face_descs_{ds_suffix}_std'][0, 3:5].tolist() + msg.faceProb = model_output[f'face_prob_{ds_suffix}'][0, 0].item() + msg.leftEyeProb = model_output[f'left_eye_prob_{ds_suffix}'][0, 0].item() + msg.rightEyeProb = model_output[f'right_eye_prob_{ds_suffix}'][0, 0].item() + msg.leftBlinkProb = model_output[f'left_blink_prob_{ds_suffix}'][0, 0].item() + msg.rightBlinkProb = model_output[f'right_blink_prob_{ds_suffix}'][0, 0].item() + msg.sunglassesProb = model_output[f'sunglasses_prob_{ds_suffix}'][0, 0].item() + msg.phoneProb = model_output[f'using_phone_prob_{ds_suffix}'][0, 0].item() -def get_driverstate_packet(model_output: np.ndarray, frame_id: int, location_ts: int, execution_time: float, gpu_execution_time: float): - model_result = ctypes.cast(model_output.ctypes.data, ctypes.POINTER(DMonitoringModelResult)).contents +def get_driverstate_packet(model_output, frame_id: int, location_ts: int, exec_time: float, gpu_exec_time: float): msg = messaging.new_message('driverStateV2', valid=True) ds = msg.driverStateV2 ds.frameId = frame_id - ds.modelExecutionTime = execution_time - ds.gpuExecutionTime = gpu_execution_time - ds.poorVisionProb = float(sigmoid(model_result.poor_vision_prob)) - ds.wheelOnRightProb = float(sigmoid(model_result.wheel_on_right_prob)) - ds.rawPredictions = model_output.tobytes() if SEND_RAW_PRED else b'' - fill_driver_state(ds.leftDriverData, model_result.driver_state_lhd) - fill_driver_state(ds.rightDriverData, model_result.driver_state_rhd) + ds.modelExecutionTime = exec_time + ds.gpuExecutionTime = gpu_exec_time + ds.rawPredictions = model_output['raw_pred'] + ds.wheelOnRightProb = model_output['wheel_on_right'][0, 0].item() + fill_driver_data(ds.leftDriverData, model_output, 'lhd') + fill_driver_data(ds.rightDriverData, model_output, 'rhd') return msg @@ -153,7 +127,7 @@ def main(): sm = SubMaster(["liveCalibration"]) pm = PubMaster(["driverStateV2"]) - calib = np.zeros(CALIB_LEN, dtype=np.float32) + calib = np.zeros(model.numpy_inputs['calib'].size, dtype=np.float32) model_transform = None while True: @@ -172,8 +146,12 @@ def main(): t1 = time.perf_counter() model_output, gpu_execution_time = model.run(buf, calib, model_transform) t2 = time.perf_counter() - - pm.send("driverStateV2", get_driverstate_packet(model_output, vipc_client.frame_id, vipc_client.timestamp_sof, t2 - t1, gpu_execution_time)) + raw_pred = model_output.tobytes() if SEND_RAW_PRED else b'' + model_output = slice_outputs(model_output, model.output_slices) + model_output = parse_model_output(model_output) + model_output['raw_pred'] = raw_pred + msg = get_driverstate_packet(model_output, vipc_client.frame_id, vipc_client.timestamp_sof, t2 - t1, gpu_execution_time) + pm.send("driverStateV2", msg) if __name__ == "__main__": diff --git a/selfdrive/modeld/fill_model_msg.py b/selfdrive/modeld/fill_model_msg.py index a5cb02ab..12e5e500 100644 --- a/selfdrive/modeld/fill_model_msg.py +++ b/selfdrive/modeld/fill_model_msg.py @@ -1,6 +1,7 @@ import os import capnp import numpy as np +import math from cereal import log from openpilot.selfdrive.modeld.constants import ModelConstants, Plan, Meta @@ -102,21 +103,42 @@ def fill_model_msg(base_msg: capnp._DynamicStructBuilder, extended_msg: capnp._D LINE_T_IDXS = [np.nan] * ModelConstants.IDX_N LINE_T_IDXS[0] = 0.0 plan_x = net_output_data['plan'][0, :, Plan.POSITION][:, 0].tolist() + Tmax = ModelConstants.T_IDXS[ModelConstants.IDX_N - 1] for xidx in range(1, ModelConstants.IDX_N): tidx = 0 # increment tidx until we find an element that's further away than the current xidx while tidx < ModelConstants.IDX_N - 1 and plan_x[tidx + 1] < ModelConstants.X_IDXS[xidx]: tidx += 1 if tidx == ModelConstants.IDX_N - 1: - # if the Plan doesn't extend far enough, set plan_t to the max value (10s), then break - LINE_T_IDXS[xidx] = ModelConstants.T_IDXS[ModelConstants.IDX_N - 1] - break + for k in range(xidx, ModelConstants.IDX_N): + LINE_T_IDXS[k] = Tmax + break # interpolate to find `t` for the current xidx current_x_val = plan_x[tidx] next_x_val = plan_x[tidx + 1] - p = (ModelConstants.X_IDXS[xidx] - current_x_val) / (next_x_val - current_x_val) if abs( - next_x_val - current_x_val) > 1e-9 else float('nan') - LINE_T_IDXS[xidx] = p * ModelConstants.T_IDXS[tidx + 1] + (1 - p) * ModelConstants.T_IDXS[tidx] + + dx = next_x_val - current_x_val + if dx <= 1e-9: + LINE_T_IDXS[xidx] = ModelConstants.T_IDXS[tidx] + else: + p = (ModelConstants.X_IDXS[xidx] - current_x_val) / dx + if p < 0.0: p = 0.0 + elif p > 1.0: p = 1.0 + LINE_T_IDXS[xidx] = p * ModelConstants.T_IDXS[tidx + 1] + (1.0 - p) * ModelConstants.T_IDXS[tidx] + + #p = (ModelConstants.X_IDXS[xidx] - current_x_val) / (next_x_val - current_x_val) if abs( + # next_x_val - current_x_val) > 1e-9 else float('nan') + #LINE_T_IDXS[xidx] = p * ModelConstants.T_IDXS[tidx + 1] + (1 - p) * ModelConstants.T_IDXS[tidx] + + LINE_T_IDXS = [float(Tmax if math.isnan(float(v)) else float(v)) for v in LINE_T_IDXS] + + # 비내림(monotonic non-decreasing) 보정 (순수 파이썬, numpy 불사용) + running = LINE_T_IDXS[0] + for i in range(1, len(LINE_T_IDXS)): + if LINE_T_IDXS[i] < running: + LINE_T_IDXS[i] = running + else: + running = LINE_T_IDXS[i] # lane lines modelV2.init('laneLines', 4) diff --git a/selfdrive/modeld/models/README.md b/selfdrive/modeld/models/README.md index 255f28d8..04b69c61 100644 --- a/selfdrive/modeld/models/README.md +++ b/selfdrive/modeld/models/README.md @@ -62,6 +62,5 @@ Refer to **slice_outputs** and **parse_vision_outputs/parse_policy_outputs** in * (deprecated) distracted probabilities: 2 * using phone probability: 1 * distracted probability: 1 - * common outputs 2 - * poor camera vision probability: 1 + * common outputs 1 * left hand drive probability: 1 diff --git a/selfdrive/modeld/models/dmonitoring_model.current b/selfdrive/modeld/models/dmonitoring_model.current deleted file mode 100644 index 121871ef..00000000 --- a/selfdrive/modeld/models/dmonitoring_model.current +++ /dev/null @@ -1,2 +0,0 @@ -fa69be01-b430-4504-9d72-7dcb058eb6dd -d9fb22d1c4fa3ca3d201dbc8edf1d0f0918e53e6 diff --git a/selfdrive/modeld/models/dmonitoring_model.onnx b/selfdrive/modeld/models/dmonitoring_model.onnx index da60e9a9..51d18498 100644 Binary files a/selfdrive/modeld/models/dmonitoring_model.onnx and b/selfdrive/modeld/models/dmonitoring_model.onnx differ diff --git a/selfdrive/modeld/models/driving_policy.onnx b/selfdrive/modeld/models/driving_policy.onnx index 56fe9d10..28227de1 100644 Binary files a/selfdrive/modeld/models/driving_policy.onnx and b/selfdrive/modeld/models/driving_policy.onnx differ diff --git a/selfdrive/modeld/models/driving_vision.onnx b/selfdrive/modeld/models/driving_vision.onnx index bf076788..3f25f8fa 100644 Binary files a/selfdrive/modeld/models/driving_vision.onnx and b/selfdrive/modeld/models/driving_vision.onnx differ diff --git a/selfdrive/ui/carrot.cc b/selfdrive/ui/carrot.cc index 4dc742e4..0468adef 100644 --- a/selfdrive/ui/carrot.cc +++ b/selfdrive/ui/carrot.cc @@ -2400,9 +2400,8 @@ public: if (strcmp(driving_mode_str, driving_mode_str_last)) ui_draw_text_a(s, dx, dy, driving_mode_str, 30, COLOR_WHITE, BOLD); strcpy(driving_mode_str_last, driving_mode_str); - auto locationd = sm["liveLocationKalman"].getLiveLocationKalman(); - bool is_gps_valid = sm.valid("liveLocationKalman") && locationd.getGpsOK(); - if (is_gps_valid) { + auto gps = (s->ublox_avaliable) ? sm["gpsLocationExternal"].getGpsLocationExternal() : sm["gpsLocation"].getGpsLocation(); + if (gps.getHasFix()) { ui_draw_text(s, dx, dy - 45, "GPS", 30, COLOR_GREEN, BOLD); } diff --git a/selfdrive/ui/qt/maps/map.cc b/selfdrive/ui/qt/maps/map.cc index 8abfce5e..b6a0aeca 100644 --- a/selfdrive/ui/qt/maps/map.cc +++ b/selfdrive/ui/qt/maps/map.cc @@ -4,6 +4,9 @@ #include #include +#include +#include +#include #include "selfdrive/ui/qt/maps/map_helpers.h" #include "selfdrive/ui/qt/util.h" @@ -13,7 +16,7 @@ const int INTERACTION_TIMEOUT = 100; //const float MAX_ZOOM = 20;// 17; -const float MIN_ZOOM = 14; +const float MIN_ZOOM = 15; // 14; const float MAX_PITCH = 50; const float MIN_PITCH = 0; const float MAP_SCALE = 2; @@ -147,6 +150,40 @@ void MapWindow::initLayers() { 20, 0 }; + if (!m_map->sourceExists("carrotSpeedSource")) { + qDebug() << "Initializing carrotSpeedSource"; + + // FeatureCollection GeoJSON (QVariantMap ) + QVariantMap fc; + fc["type"] = "FeatureCollection"; + fc["features"] = QVariantList{}; // Ʈ + + QVariantMap src; + src["type"] = "geojson"; + src["data"] = fc; + m_map->addSource("carrotSpeedSource", src); + } + + if (!m_map->layerExists("carrotSpeedLayer")) { + qDebug() << "Initializing carrotSpeedLayer"; + QVariantMap layer; + layer["type"] = "symbol"; + layer["source"] = "carrotSpeedSource"; + m_map->addLayer("carrotSpeedLayer", layer); + + // properties.speed ؽƮ ǥ (ū ) + // "{speed}" properties.speed ڿ ־ + m_map->setLayoutProperty("carrotSpeedLayer", "text-field", "{speed}"); + m_map->setLayoutProperty("carrotSpeedLayer", "text-size", 16.0); + m_map->setLayoutProperty("carrotSpeedLayer", "text-offset", QVariantList{ 0.0, -1.5 }); + m_map->setLayoutProperty("carrotSpeedLayer", "text-anchor", "top"); + m_map->setLayoutProperty("carrotSpeedLayer", "icon-allow-overlap", true); + m_map->setPaintProperty("carrotSpeedLayer", "text-color", QColor("white")); + m_map->setPaintProperty("carrotSpeedLayer", "text-halo-color", QColor("black")); + m_map->setPaintProperty("carrotSpeedLayer", "text-halo-width", 1.0); + m_map->setLayoutProperty("carrotSpeedLayer", "text-allow-overlap", true); + + } m_map->setPaintProperty("buildingsLayer", "fill-extrusion-color", QColor("grey")); m_map->setPaintProperty("buildingsLayer", "fill-extrusion-opacity", fillExtrusionOpacity); m_map->setPaintProperty("buildingsLayer", "fill-extrusion-height", fillExtrusionHight); @@ -228,6 +265,74 @@ void MapWindow::updateState(const UIState &s) { initLayers(); + { + std::string raw = params_memory.get("CarrotSpeedViz"); + if (!raw.empty()) { + QString qraw = QString::fromStdString(raw); + //printf("%s\n", qraw.toStdString().c_str()); + if (qraw != last_viz_raw) { + last_viz_raw = qraw; + + QJsonParseError err; + QJsonDocument doc = QJsonDocument::fromJson(qraw.toUtf8(), &err); + if (err.error == QJsonParseError::NoError && doc.isObject()) { + QJsonObject obj = doc.object(); + QJsonArray pts = obj["pts"].toArray(); + + // GeoJSON FeatureCollection QVariantMap Ʈ + QVariantList features; // Feature Ʈ + + for (const QJsonValue& v : pts) { + QJsonArray arr = v.toArray(); + if (arr.size() < 3) continue; + + double plat = arr[0].toDouble(); + double plon = arr[1].toDouble(); + double spd = arr[2].toDouble(); + + // geometry: Point (GeoJSON: [lon, lat] ) + QVariantList coords; + coords.append(plon); + coords.append(plat); + + QVariantMap geom; + geom["type"] = "Point"; + geom["coordinates"] = coords; + + // properties: speed + QVariantMap props; + props["speed"] = static_cast(std::round(spd)); + + // Feature + QVariantMap feature; + feature["type"] = "Feature"; + feature["geometry"] = geom; + feature["properties"] = props; + + features.append(feature); + } + + QVariantMap fc; + fc["type"] = "FeatureCollection"; + fc["features"] = features; + + QJsonDocument fc_doc = QJsonDocument::fromVariant(fc); + QByteArray fc_bytes = fc_doc.toJson(QJsonDocument::Compact); + + QVariantMap src; + src["type"] = "geojson"; + src["data"] = fc_bytes; + m_map->updateSource("carrotSpeedSource", src); + m_map->setLayoutProperty("carrotSpeedLayer", "visibility", "visible"); + } + } + } + else { + // ʿϸ + // m_map->setLayoutProperty("carrotSpeedLayer", "visibility", "none"); + } + } + if (!locationd_valid) { setError(tr("Waiting for GPS(APN)")); } else if (routing_problem) { @@ -295,17 +400,17 @@ void MapWindow::updateState(const UIState &s) { updateDestinationMarker(); } if (loaded_once && (sm.rcv_frame("modelV2") != model_rcv_frame)) { - auto locationd_location = sm["liveLocationKalman"].getLiveLocationKalman(); - if (locationd_location.getGpsOK()) { - //auto carrot_man = sm["carrotMan"].getCarrotMan(); + /* + gps = (ublox_avaliable)? sm[gps_service].getGpsLocationExternal() : sm[gps_service].getGpsLocation(); + if (gps.getHasFix()) { auto model_path = model_to_collection(locationd_location.getCalibratedOrientationECEF(), locationd_location.getPositionECEF(), sm["modelV2"].getModelV2().getPosition(), carrotMan.getXPosLat(), carrotMan.getXPosLon()); - //auto model_path = model_to_collection(sm["modelV2"].getModelV2().getPosition(), carrotMan.getXPosLat(), carrotMan.getXPosLon(), carrotMan.getXPosAngle()); QMapLibre::Feature model_path_feature(QMapLibre::Feature::LineStringType, model_path, {}, {}); QVariantMap modelV2Path; modelV2Path["type"] = "geojson"; modelV2Path["data"] = QVariant::fromValue(model_path_feature); m_map->updateSource("modelPathSource", modelV2Path); } + */ model_rcv_frame = sm.rcv_frame("modelV2"); } } diff --git a/selfdrive/ui/qt/maps/map.h b/selfdrive/ui/qt/maps/map.h index f29bd114..89522aaa 100644 --- a/selfdrive/ui/qt/maps/map.h +++ b/selfdrive/ui/qt/maps/map.h @@ -69,6 +69,7 @@ private: MapInstructions* map_instructions; MapETA* map_eta; + // Blue with normal nav, green when nav is input into the model QColor getNavPathColor(bool nav_enabled) { return nav_enabled ? QColor("#31ee73") : QColor("#31a1ee"); @@ -78,10 +79,12 @@ private: void updateDestinationMarker(); uint64_t route_rcv_frame = 0; - // FrogPilot variables Params params; uint64_t model_rcv_frame = 0; + Params params_memory{ "/dev/shm/params" }; + QString last_viz_raw; + float MAX_ZOOM = 17; // carrot private slots: void updateState(const UIState &s); diff --git a/selfdrive/ui/qt/offroad/settings.cc b/selfdrive/ui/qt/offroad/settings.cc index b87c3e72..c8a7cc3f 100644 --- a/selfdrive/ui/qt/offroad/settings.cc +++ b/selfdrive/ui/qt/offroad/settings.cc @@ -215,19 +215,24 @@ DevicePanel::DevicePanel(SettingsWindow *parent) : ListWidget(parent) { //QObject::connect(init_btn, &QPushButton::clicked, this, &DevicePanel::reboot); QObject::connect(init_btn, &QPushButton::clicked, [&]() { if (ConfirmationDialog::confirm(tr("Git pull & Reboot?"), tr("Yes"), this)) { - QString cmd = - "bash -c 'cd /data/openpilot && " - "git fetch && " - "if git status -uno | grep -q \"Your branch is behind\"; then " - "git pull && reboot; " + QString pullscript = "cd /data/openpilot && " + "git fetch origin && " + "LOCAL=$(git rev-parse HEAD) && " + "BRANCH=$(git branch --show-current) && " + "REMOTE=$(git rev-parse origin/$BRANCH) && " + "if [ $LOCAL != $REMOTE ]; then " + "echo 'Local is behind. Pulling updates...' && " + "git pull --ff-only && " + "sudo reboot; " "else " - "echo \"Already up to date.\"; " + "echo 'Already up to date.'; " "fi'"; - if (!QProcess::startDetached(cmd)) { + bool success = QProcess::startDetached("/bin/sh", QStringList() << "-c" << pullscript); + + if (!success) { ConfirmationDialog::alert(tr("Failed to start update process."), this); - } - else { + } else { ConfirmationDialog::alert(tr("Update process started. Device will reboot if updates are applied."), this); } } @@ -858,6 +863,7 @@ CarrotPanel::CarrotPanel(QWidget* parent) : QWidget(parent) { speedToggles->addItem(new CValueControl("AutoNaviSpeedBumpSpeed", tr("SpeedBumpSpeed(35Km/h)"), "", 10, 100, 5)); speedToggles->addItem(new CValueControl("AutoNaviCountDownMode", tr("NaviCountDown mode(2)"), tr("0: off, 1:tbt+camera, 2:tbt+camera+bump"), 0, 2, 1)); speedToggles->addItem(new CValueControl("TurnSpeedControlMode", tr("Turn Speed control mode(1)"), tr("0: off, 1:vision, 2:vision+route, 3: route"), 0, 3, 1)); + speedToggles->addItem(new CValueControl("CarrotSmartSpeedControl", tr("Smart Speed Control(0)"), tr("0: off, 1:accel, 2:decel, 3: all"), 0, 3, 1)); speedToggles->addItem(new CValueControl("MapTurnSpeedFactor", tr("Map TurnSpeed Factor(100)"), "", 50, 300, 5)); speedToggles->addItem(new CValueControl("ModelTurnSpeedFactor", tr("Model TurnSpeed Factor(0)"), "", 0, 80, 10)); speedToggles->addItem(new CValueControl("AutoTurnControl", tr("ATC: Auto turn control(0)"), tr("0:None, 1: lane change, 2: lane change + speed, 3: speed"), 0, 3, 1)); diff --git a/selfdrive/ui/ui.cc b/selfdrive/ui/ui.cc index 34913a5f..305736fd 100644 --- a/selfdrive/ui/ui.cc +++ b/selfdrive/ui/ui.cc @@ -98,13 +98,15 @@ void UIState::updateStatus() { } UIState::UIState(QObject *parent) : QObject(parent) { + ublox_avaliable = Params().getBool("UbloxAvailable"); + auto gps_service = (ublox_avaliable) ? "gpsLocationExternal" : "gpsLocation"; sm = std::make_unique(std::vector{ "modelV2", "controlsState", "liveCalibration", "radarState", "deviceState", "pandaStates", "carParams", "driverMonitoringState", "carState", "driverStateV2", "wideRoadCameraState", "managerState", "selfdriveState", "longitudinalPlan", "longitudinalPlan", "carControl", "carrotMan", "liveTorqueParameters", "lateralPlan", "liveParameters", - "navRoute", "navInstruction", "navInstructionCarrot", "liveLocationKalman", "liveDelay", + "navRoute", "navInstruction", "navInstructionCarrot", gps_service, "liveDelay", "peripheralState", }); prime_state = new PrimeState(this); diff --git a/selfdrive/ui/ui.h b/selfdrive/ui/ui.h index 884e6bd1..1bcd2d59 100644 --- a/selfdrive/ui/ui.h +++ b/selfdrive/ui/ui.h @@ -98,6 +98,8 @@ public: float show_brightness_ratio = 1.0; int show_brightness_timer = 20; + bool ublox_avaliable = true; + signals: void uiUpdate(const UIState &s); void offroadTransition(bool offroad); diff --git a/system/camerad/SConscript b/system/camerad/SConscript index fe5cf87b..74819818 100644 --- a/system/camerad/SConscript +++ b/system/camerad/SConscript @@ -1,6 +1,6 @@ -Import('env', 'arch', 'messaging', 'common', 'gpucommon', 'visionipc') +Import('env', 'arch', 'messaging', 'common', 'visionipc') -libs = [common, 'OpenCL', messaging, visionipc, gpucommon] +libs = [common, 'OpenCL', messaging, visionipc] if arch != "Darwin": camera_obj = env.Object(['cameras/camera_qcom2.cc', 'cameras/camera_common.cc', 'cameras/spectra.cc', diff --git a/system/manager/manager.py b/system/manager/manager.py index adf7d9ca..f4da0f71 100755 --- a/system/manager/manager.py +++ b/system/manager/manager.py @@ -86,6 +86,7 @@ def get_default_params(): ("AutoRoadSpeedLimitOffset", "-1"), ("AutoNaviCountDownMode", "2"), ("TurnSpeedControlMode", "1"), + ("CarrotSmartSpeedControl", "0"), ("MapTurnSpeedFactor", "90"), ("ModelTurnSpeedFactor", "0"), ("StoppingAccel", "0"), diff --git a/system/manager/process_config.py b/system/manager/process_config.py index 58e62744..8ea27c42 100644 --- a/system/manager/process_config.py +++ b/system/manager/process_config.py @@ -93,17 +93,14 @@ procs = [ PythonProcess("micd", "system.micd", iscar), PythonProcess("timed", "system.timed", always_run, enabled=not PC), - # TODO: Make python process once TG allows opening QCOM from child pro - # https://github.com/tinygrad/tinygrad/blob/ac9c96dae1656dc220ee4acc39cef4dd449aa850/tinygrad/device.py#L26 - NativeProcess("modeld", "selfdrive/modeld", ["./modeld.py"], only_onroad), - NativeProcess("dmonitoringmodeld", "selfdrive/modeld", ["./dmonitoringmodeld.py"], enable_dm, enabled=(WEBCAM or not PC)), + PythonProcess("modeld", "selfdrive.modeld.modeld", only_onroad), + PythonProcess("dmonitoringmodeld", "selfdrive.modeld.dmonitoringmodeld", enable_dm, enabled=(WEBCAM or not PC)), #NativeProcess("mapsd", "selfdrive/navd", ["./mapsd"], only_onroad), #NativeProcess("mapsd", "selfdrive/navd", ["./mapsd"], always_run), #PythonProcess("navmodeld", "selfdrive.modeld.navmodeld", only_onroad), - NativeProcess("sensord", "system/sensord", ["./sensord"], only_onroad, enabled=not PC), + PythonProcess("sensord", "system.sensord.sensord", only_onroad, enabled=not PC), NativeProcess("ui", "selfdrive/ui", ["./ui"], always_run, watchdog_max_dt=(5 if not PC else None)), PythonProcess("soundd", "selfdrive.ui.soundd", only_onroad), - NativeProcess("locationd2", "selfdrive/locationd", ["./locationd"], only_onroad), PythonProcess("locationd", "selfdrive.locationd.locationd", only_onroad), NativeProcess("_pandad", "selfdrive/pandad", ["./pandad"], always_run, enabled=False), PythonProcess("calibrationd", "selfdrive.locationd.calibrationd", only_onroad), @@ -119,7 +116,7 @@ procs = [ PythonProcess("pandad", "selfdrive.pandad.pandad", always_run), PythonProcess("paramsd", "selfdrive.locationd.paramsd", only_onroad), PythonProcess("lagd", "selfdrive.locationd.lagd", only_onroad), - NativeProcess("ubloxd", "system/ubloxd", ["./ubloxd"], ublox, enabled=TICI), + PythonProcess("ubloxd", "system.ubloxd.ubloxd", ublox, enabled=TICI), PythonProcess("pigeond", "system.ubloxd.pigeond", ublox, enabled=TICI), PythonProcess("plannerd", "selfdrive.controls.plannerd", not_long_maneuver), PythonProcess("maneuversd", "tools.longitudinal_maneuvers.maneuversd", long_maneuver), diff --git a/system/qcomgpsd/qcomgpsd.py b/system/qcomgpsd/qcomgpsd.py index 9a6fae38..6037857e 100755 --- a/system/qcomgpsd/qcomgpsd.py +++ b/system/qcomgpsd/qcomgpsd.py @@ -16,7 +16,7 @@ from struct import unpack_from, calcsize, pack from cereal import log import cereal.messaging as messaging from openpilot.common.gpio import gpio_init, gpio_set -from openpilot.common.retry import retry +from openpilot.common.utils import retry from openpilot.common.time_helpers import system_time_valid from openpilot.system.hardware.tici.pins import GPIO from openpilot.common.swaglog import cloudlog diff --git a/system/sensord/.gitignore b/system/sensord/.gitignore deleted file mode 100644 index e17675e2..00000000 --- a/system/sensord/.gitignore +++ /dev/null @@ -1 +0,0 @@ -sensord diff --git a/system/sensord/SConscript b/system/sensord/SConscript deleted file mode 100644 index e2dfb522..00000000 --- a/system/sensord/SConscript +++ /dev/null @@ -1,17 +0,0 @@ -Import('env', 'arch', 'common', 'messaging') - -sensors = [ - 'sensors/i2c_sensor.cc', - 'sensors/bmx055_accel.cc', - 'sensors/bmx055_gyro.cc', - 'sensors/bmx055_magn.cc', - 'sensors/bmx055_temp.cc', - 'sensors/lsm6ds3_accel.cc', - 'sensors/lsm6ds3_gyro.cc', - 'sensors/lsm6ds3_temp.cc', - 'sensors/mmc5603nj_magn.cc', -] -libs = [common, messaging, 'pthread'] -if arch == "larch64": - libs.append('i2c') -env.Program('sensord', ['sensors_qcom2.cc'] + sensors, LIBS=libs) diff --git a/system/sensord/sensord.py b/system/sensord/sensord.py new file mode 100644 index 00000000..cc036688 --- /dev/null +++ b/system/sensord/sensord.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 +import os +import time +import ctypes +import select +import threading + +import cereal.messaging as messaging +from cereal.services import SERVICE_LIST +from openpilot.common.util import sudo_write +from openpilot.common.realtime import config_realtime_process, Ratekeeper +from openpilot.common.swaglog import cloudlog +from openpilot.common.gpio import gpiochip_get_ro_value_fd, gpioevent_data +from openpilot.system.hardware import HARDWARE + +from openpilot.system.sensord.sensors.i2c_sensor import Sensor +from openpilot.system.sensord.sensors.lsm6ds3_accel import LSM6DS3_Accel +from openpilot.system.sensord.sensors.lsm6ds3_gyro import LSM6DS3_Gyro +from openpilot.system.sensord.sensors.lsm6ds3_temp import LSM6DS3_Temp +from openpilot.system.sensord.sensors.mmc5603nj_magn import MMC5603NJ_Magn + +I2C_BUS_IMU = 1 + +def interrupt_loop(sensors: list[tuple[Sensor, str, bool]], event) -> None: + pm = messaging.PubMaster([service for sensor, service, interrupt in sensors if interrupt]) + + # Requesting both edges as the data ready pulse from the lsm6ds sensor is + # very short (75us) and is mostly detected as falling edge instead of rising. + # So if it is detected as rising the following falling edge is skipped. + fd = gpiochip_get_ro_value_fd("sensord", 0, 84) + + # Configure IRQ affinity + irq_path = "/proc/irq/336/smp_affinity_list" + if not os.path.exists(irq_path): + irq_path = "/proc/irq/335/smp_affinity_list" + if os.path.exists(irq_path): + sudo_write('1\n', irq_path) + + offset = time.time_ns() - time.monotonic_ns() + + poller = select.poll() + poller.register(fd, select.POLLIN | select.POLLPRI) + while not event.is_set(): + events = poller.poll(100) + if not events: + cloudlog.error("poll timed out") + continue + if not (events[0][1] & (select.POLLIN | select.POLLPRI)): + cloudlog.error("no poll events set") + continue + + dat = os.read(fd, ctypes.sizeof(gpioevent_data)*16) + evd = gpioevent_data.from_buffer_copy(dat) + + cur_offset = time.time_ns() - time.monotonic_ns() + if abs(cur_offset - offset) > 10 * 1e6: # ms + cloudlog.warning(f"time jumped: {cur_offset} {offset}") + offset = cur_offset + continue + + ts = evd.timestamp - cur_offset + for sensor, service, interrupt in sensors: + if interrupt: + try: + evt = sensor.get_event(ts) + if not sensor.is_data_valid(): + continue + msg = messaging.new_message(service, valid=True) + setattr(msg, service, evt) + pm.send(service, msg) + except Sensor.DataNotReady: + pass + except Exception: + cloudlog.exception(f"Error processing {service}") + + +def polling_loop(sensor: Sensor, service: str, event: threading.Event) -> None: + pm = messaging.PubMaster([service]) + rk = Ratekeeper(SERVICE_LIST[service].frequency, print_delay_threshold=None) + while not event.is_set(): + try: + evt = sensor.get_event() + if not sensor.is_data_valid(): + continue + msg = messaging.new_message(service, valid=True) + setattr(msg, service, evt) + pm.send(service, msg) + except Exception: + cloudlog.exception(f"Error in {service} polling loop") + rk.keep_time() + +def main() -> None: + config_realtime_process([1, ], 1) + + sensors_cfg = [ + (LSM6DS3_Accel(I2C_BUS_IMU), "accelerometer", True), + (LSM6DS3_Gyro(I2C_BUS_IMU), "gyroscope", True), + (LSM6DS3_Temp(I2C_BUS_IMU), "temperatureSensor", False), + ] + if HARDWARE.get_device_type() == "tizi": + sensors_cfg.append( + (MMC5603NJ_Magn(I2C_BUS_IMU), "magnetometer", False), + ) + + # Reset sensors + for sensor, _, _ in sensors_cfg: + try: + sensor.reset() + except Exception: + cloudlog.exception(f"Error initializing {sensor} sensor") + + # Initialize sensors + exit_event = threading.Event() + threads = [ + threading.Thread(target=interrupt_loop, args=(sensors_cfg, exit_event), daemon=True) + ] + for sensor, service, interrupt in sensors_cfg: + try: + sensor.init() + if not interrupt: + # Start polling thread for sensors without interrupts + threads.append(threading.Thread( + target=polling_loop, + args=(sensor, service, exit_event), + daemon=True + )) + except Exception: + cloudlog.exception(f"Error initializing {service} sensor") + + try: + for t in threads: + t.start() + while any(t.is_alive() for t in threads): + time.sleep(1) + except KeyboardInterrupt: + pass + finally: + exit_event.set() + for t in threads: + if t.is_alive(): + t.join() + + for sensor, _, _ in sensors_cfg: + try: + sensor.shutdown() + except Exception: + cloudlog.exception("Error shutting down sensor") + +if __name__ == "__main__": + main() diff --git a/system/sensord/sensors/__init__.py b/system/sensord/sensors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/system/sensord/sensors/bmx055_accel.cc b/system/sensord/sensors/bmx055_accel.cc deleted file mode 100644 index bcc31e1d..00000000 --- a/system/sensord/sensors/bmx055_accel.cc +++ /dev/null @@ -1,85 +0,0 @@ -#include "system/sensord/sensors/bmx055_accel.h" - -#include - -#include "common/swaglog.h" -#include "common/timing.h" -#include "common/util.h" - -BMX055_Accel::BMX055_Accel(I2CBus *bus) : I2CSensor(bus) {} - -int BMX055_Accel::init() { - int ret = verify_chip_id(BMX055_ACCEL_I2C_REG_ID, {BMX055_ACCEL_CHIP_ID}); - if (ret == -1) { - goto fail; - } - - ret = set_register(BMX055_ACCEL_I2C_REG_PMU, BMX055_ACCEL_NORMAL_MODE); - if (ret < 0) { - goto fail; - } - - // bmx055 accel has a 1.3ms wakeup time from deep suspend mode - util::sleep_for(10); - - // High bandwidth - // ret = set_register(BMX055_ACCEL_I2C_REG_HBW, BMX055_ACCEL_HBW_ENABLE); - // if (ret < 0) { - // goto fail; - // } - - // Low bandwidth - ret = set_register(BMX055_ACCEL_I2C_REG_HBW, BMX055_ACCEL_HBW_DISABLE); - if (ret < 0) { - goto fail; - } - - ret = set_register(BMX055_ACCEL_I2C_REG_BW, BMX055_ACCEL_BW_125HZ); - if (ret < 0) { - goto fail; - } - - enabled = true; - -fail: - return ret; -} - -int BMX055_Accel::shutdown() { - if (!enabled) return 0; - - // enter deep suspend mode (lowest power mode) - int ret = set_register(BMX055_ACCEL_I2C_REG_PMU, BMX055_ACCEL_DEEP_SUSPEND); - if (ret < 0) { - LOGE("Could not move BMX055 ACCEL in deep suspend mode!"); - } - - return ret; -} - -bool BMX055_Accel::get_event(MessageBuilder &msg, uint64_t ts) { - uint64_t start_time = nanos_since_boot(); - uint8_t buffer[6]; - int len = read_register(BMX055_ACCEL_I2C_REG_X_LSB, buffer, sizeof(buffer)); - assert(len == 6); - - // 12 bit = +-2g - float scale = 9.81 * 2.0f / (1 << 11); - float x = -read_12_bit(buffer[0], buffer[1]) * scale; - float y = -read_12_bit(buffer[2], buffer[3]) * scale; - float z = read_12_bit(buffer[4], buffer[5]) * scale; - - auto event = msg.initEvent().initAccelerometer2(); - event.setSource(cereal::SensorEventData::SensorSource::BMX055); - event.setVersion(1); - event.setSensor(SENSOR_ACCELEROMETER); - event.setType(SENSOR_TYPE_ACCELEROMETER); - event.setTimestamp(start_time); - - float xyz[] = {x, y, z}; - auto svec = event.initAcceleration(); - svec.setV(xyz); - svec.setStatus(true); - - return true; -} diff --git a/system/sensord/sensors/bmx055_accel.h b/system/sensord/sensors/bmx055_accel.h deleted file mode 100644 index 2cc316e9..00000000 --- a/system/sensord/sensors/bmx055_accel.h +++ /dev/null @@ -1,41 +0,0 @@ -#pragma once - -#include "system/sensord/sensors/i2c_sensor.h" - -// Address of the chip on the bus -#define BMX055_ACCEL_I2C_ADDR 0x18 - -// Registers of the chip -#define BMX055_ACCEL_I2C_REG_ID 0x00 -#define BMX055_ACCEL_I2C_REG_X_LSB 0x02 -#define BMX055_ACCEL_I2C_REG_TEMP 0x08 -#define BMX055_ACCEL_I2C_REG_BW 0x10 -#define BMX055_ACCEL_I2C_REG_PMU 0x11 -#define BMX055_ACCEL_I2C_REG_HBW 0x13 -#define BMX055_ACCEL_I2C_REG_FIFO 0x3F - -// Constants -#define BMX055_ACCEL_CHIP_ID 0xFA - -#define BMX055_ACCEL_HBW_ENABLE 0b10000000 -#define BMX055_ACCEL_HBW_DISABLE 0b00000000 -#define BMX055_ACCEL_DEEP_SUSPEND 0b00100000 -#define BMX055_ACCEL_NORMAL_MODE 0b00000000 - -#define BMX055_ACCEL_BW_7_81HZ 0b01000 -#define BMX055_ACCEL_BW_15_63HZ 0b01001 -#define BMX055_ACCEL_BW_31_25HZ 0b01010 -#define BMX055_ACCEL_BW_62_5HZ 0b01011 -#define BMX055_ACCEL_BW_125HZ 0b01100 -#define BMX055_ACCEL_BW_250HZ 0b01101 -#define BMX055_ACCEL_BW_500HZ 0b01110 -#define BMX055_ACCEL_BW_1000HZ 0b01111 - -class BMX055_Accel : public I2CSensor { - uint8_t get_device_address() {return BMX055_ACCEL_I2C_ADDR;} -public: - BMX055_Accel(I2CBus *bus); - int init(); - bool get_event(MessageBuilder &msg, uint64_t ts = 0); - int shutdown(); -}; diff --git a/system/sensord/sensors/bmx055_gyro.cc b/system/sensord/sensors/bmx055_gyro.cc deleted file mode 100644 index 0cc405f6..00000000 --- a/system/sensord/sensors/bmx055_gyro.cc +++ /dev/null @@ -1,92 +0,0 @@ -#include "system/sensord/sensors/bmx055_gyro.h" - -#include -#include - -#include "common/swaglog.h" -#include "common/util.h" - -#define DEG2RAD(x) ((x) * M_PI / 180.0) - - -BMX055_Gyro::BMX055_Gyro(I2CBus *bus) : I2CSensor(bus) {} - -int BMX055_Gyro::init() { - int ret = verify_chip_id(BMX055_GYRO_I2C_REG_ID, {BMX055_GYRO_CHIP_ID}); - if (ret == -1) return -1; - - ret = set_register(BMX055_GYRO_I2C_REG_LPM1, BMX055_GYRO_NORMAL_MODE); - if (ret < 0) { - goto fail; - } - // bmx055 gyro has a 30ms wakeup time from deep suspend mode - util::sleep_for(50); - - // High bandwidth - // ret = set_register(BMX055_GYRO_I2C_REG_HBW, BMX055_GYRO_HBW_ENABLE); - // if (ret < 0) { - // goto fail; - // } - - // Low bandwidth - ret = set_register(BMX055_GYRO_I2C_REG_HBW, BMX055_GYRO_HBW_DISABLE); - if (ret < 0) { - goto fail; - } - - // 116 Hz filter - ret = set_register(BMX055_GYRO_I2C_REG_BW, BMX055_GYRO_BW_116HZ); - if (ret < 0) { - goto fail; - } - - // +- 125 deg/s range - ret = set_register(BMX055_GYRO_I2C_REG_RANGE, BMX055_GYRO_RANGE_125); - if (ret < 0) { - goto fail; - } - - enabled = true; - -fail: - return ret; -} - -int BMX055_Gyro::shutdown() { - if (!enabled) return 0; - - // enter deep suspend mode (lowest power mode) - int ret = set_register(BMX055_GYRO_I2C_REG_LPM1, BMX055_GYRO_DEEP_SUSPEND); - if (ret < 0) { - LOGE("Could not move BMX055 GYRO in deep suspend mode!"); - } - - return ret; -} - -bool BMX055_Gyro::get_event(MessageBuilder &msg, uint64_t ts) { - uint64_t start_time = nanos_since_boot(); - uint8_t buffer[6]; - int len = read_register(BMX055_GYRO_I2C_REG_RATE_X_LSB, buffer, sizeof(buffer)); - assert(len == 6); - - // 16 bit = +- 125 deg/s - float scale = 125.0f / (1 << 15); - float x = -DEG2RAD(read_16_bit(buffer[0], buffer[1]) * scale); - float y = -DEG2RAD(read_16_bit(buffer[2], buffer[3]) * scale); - float z = DEG2RAD(read_16_bit(buffer[4], buffer[5]) * scale); - - auto event = msg.initEvent().initGyroscope2(); - event.setSource(cereal::SensorEventData::SensorSource::BMX055); - event.setVersion(1); - event.setSensor(SENSOR_GYRO_UNCALIBRATED); - event.setType(SENSOR_TYPE_GYROSCOPE_UNCALIBRATED); - event.setTimestamp(start_time); - - float xyz[] = {x, y, z}; - auto svec = event.initGyroUncalibrated(); - svec.setV(xyz); - svec.setStatus(true); - - return true; -} diff --git a/system/sensord/sensors/bmx055_gyro.h b/system/sensord/sensors/bmx055_gyro.h deleted file mode 100644 index 7be3e565..00000000 --- a/system/sensord/sensors/bmx055_gyro.h +++ /dev/null @@ -1,41 +0,0 @@ -#pragma once - -#include "system/sensord/sensors/i2c_sensor.h" - -// Address of the chip on the bus -#define BMX055_GYRO_I2C_ADDR 0x68 - -// Registers of the chip -#define BMX055_GYRO_I2C_REG_ID 0x00 -#define BMX055_GYRO_I2C_REG_RATE_X_LSB 0x02 -#define BMX055_GYRO_I2C_REG_RANGE 0x0F -#define BMX055_GYRO_I2C_REG_BW 0x10 -#define BMX055_GYRO_I2C_REG_LPM1 0x11 -#define BMX055_GYRO_I2C_REG_HBW 0x13 -#define BMX055_GYRO_I2C_REG_FIFO 0x3F - -// Constants -#define BMX055_GYRO_CHIP_ID 0x0F - -#define BMX055_GYRO_HBW_ENABLE 0b10000000 -#define BMX055_GYRO_HBW_DISABLE 0b00000000 -#define BMX055_GYRO_DEEP_SUSPEND 0b00100000 -#define BMX055_GYRO_NORMAL_MODE 0b00000000 - -#define BMX055_GYRO_RANGE_2000 0b000 -#define BMX055_GYRO_RANGE_1000 0b001 -#define BMX055_GYRO_RANGE_500 0b010 -#define BMX055_GYRO_RANGE_250 0b011 -#define BMX055_GYRO_RANGE_125 0b100 - -#define BMX055_GYRO_BW_116HZ 0b0010 - - -class BMX055_Gyro : public I2CSensor { - uint8_t get_device_address() {return BMX055_GYRO_I2C_ADDR;} -public: - BMX055_Gyro(I2CBus *bus); - int init(); - bool get_event(MessageBuilder &msg, uint64_t ts = 0); - int shutdown(); -}; diff --git a/system/sensord/sensors/bmx055_magn.cc b/system/sensord/sensors/bmx055_magn.cc deleted file mode 100644 index b498c5fe..00000000 --- a/system/sensord/sensors/bmx055_magn.cc +++ /dev/null @@ -1,258 +0,0 @@ -#include "system/sensord/sensors/bmx055_magn.h" - -#include - -#include -#include -#include - -#include "common/swaglog.h" -#include "common/util.h" - -static int16_t compensate_x(trim_data_t trim_data, int16_t mag_data_x, uint16_t data_rhall) { - uint16_t process_comp_x0 = data_rhall; - int32_t process_comp_x1 = ((int32_t)trim_data.dig_xyz1) * 16384; - uint16_t process_comp_x2 = ((uint16_t)(process_comp_x1 / process_comp_x0)) - ((uint16_t)0x4000); - int16_t retval = ((int16_t)process_comp_x2); - int32_t process_comp_x3 = (((int32_t)retval) * ((int32_t)retval)); - int32_t process_comp_x4 = (((int32_t)trim_data.dig_xy2) * (process_comp_x3 / 128)); - int32_t process_comp_x5 = (int32_t)(((int16_t)trim_data.dig_xy1) * 128); - int32_t process_comp_x6 = ((int32_t)retval) * process_comp_x5; - int32_t process_comp_x7 = (((process_comp_x4 + process_comp_x6) / 512) + ((int32_t)0x100000)); - int32_t process_comp_x8 = ((int32_t)(((int16_t)trim_data.dig_x2) + ((int16_t)0xA0))); - int32_t process_comp_x9 = ((process_comp_x7 * process_comp_x8) / 4096); - int32_t process_comp_x10 = ((int32_t)mag_data_x) * process_comp_x9; - retval = ((int16_t)(process_comp_x10 / 8192)); - retval = (retval + (((int16_t)trim_data.dig_x1) * 8)) / 16; - - return retval; -} - -static int16_t compensate_y(trim_data_t trim_data, int16_t mag_data_y, uint16_t data_rhall) { - uint16_t process_comp_y0 = trim_data.dig_xyz1; - int32_t process_comp_y1 = (((int32_t)trim_data.dig_xyz1) * 16384) / process_comp_y0; - uint16_t process_comp_y2 = ((uint16_t)process_comp_y1) - ((uint16_t)0x4000); - int16_t retval = ((int16_t)process_comp_y2); - int32_t process_comp_y3 = ((int32_t) retval) * ((int32_t)retval); - int32_t process_comp_y4 = ((int32_t)trim_data.dig_xy2) * (process_comp_y3 / 128); - int32_t process_comp_y5 = ((int32_t)(((int16_t)trim_data.dig_xy1) * 128)); - int32_t process_comp_y6 = ((process_comp_y4 + (((int32_t)retval) * process_comp_y5)) / 512); - int32_t process_comp_y7 = ((int32_t)(((int16_t)trim_data.dig_y2) + ((int16_t)0xA0))); - int32_t process_comp_y8 = (((process_comp_y6 + ((int32_t)0x100000)) * process_comp_y7) / 4096); - int32_t process_comp_y9 = (((int32_t)mag_data_y) * process_comp_y8); - retval = (int16_t)(process_comp_y9 / 8192); - retval = (retval + (((int16_t)trim_data.dig_y1) * 8)) / 16; - - return retval; -} - -static int16_t compensate_z(trim_data_t trim_data, int16_t mag_data_z, uint16_t data_rhall) { - int16_t process_comp_z0 = ((int16_t)data_rhall) - ((int16_t) trim_data.dig_xyz1); - int32_t process_comp_z1 = (((int32_t)trim_data.dig_z3) * ((int32_t)(process_comp_z0))) / 4; - int32_t process_comp_z2 = (((int32_t)(mag_data_z - trim_data.dig_z4)) * 32768); - int32_t process_comp_z3 = ((int32_t)trim_data.dig_z1) * (((int16_t)data_rhall) * 2); - int16_t process_comp_z4 = (int16_t)((process_comp_z3 + (32768)) / 65536); - int32_t retval = ((process_comp_z2 - process_comp_z1) / (trim_data.dig_z2 + process_comp_z4)); - - /* saturate result to +/- 2 micro-tesla */ - retval = std::clamp(retval, -32767, 32767); - - /* Conversion of LSB to micro-tesla*/ - retval = retval / 16; - - return (int16_t)retval; -} - -BMX055_Magn::BMX055_Magn(I2CBus *bus) : I2CSensor(bus) {} - -int BMX055_Magn::init() { - uint8_t trim_x1y1[2] = {0}; - uint8_t trim_x2y2[2] = {0}; - uint8_t trim_xy1xy2[2] = {0}; - uint8_t trim_z1[2] = {0}; - uint8_t trim_z2[2] = {0}; - uint8_t trim_z3[2] = {0}; - uint8_t trim_z4[2] = {0}; - uint8_t trim_xyz1[2] = {0}; - - // suspend -> sleep - int ret = set_register(BMX055_MAGN_I2C_REG_PWR_0, 0x01); - if (ret < 0) { - LOGD("Enabling power failed: %d", ret); - goto fail; - } - util::sleep_for(5); // wait until the chip is powered on - - ret = verify_chip_id(BMX055_MAGN_I2C_REG_ID, {BMX055_MAGN_CHIP_ID}); - if (ret == -1) { - goto fail; - } - - // Load magnetometer trim - ret = read_register(BMX055_MAGN_I2C_REG_DIG_X1, trim_x1y1, 2); - if (ret < 0) goto fail; - ret = read_register(BMX055_MAGN_I2C_REG_DIG_X2, trim_x2y2, 2); - if (ret < 0) goto fail; - ret = read_register(BMX055_MAGN_I2C_REG_DIG_XY2, trim_xy1xy2, 2); - if (ret < 0) goto fail; - ret = read_register(BMX055_MAGN_I2C_REG_DIG_Z1_LSB, trim_z1, 2); - if (ret < 0) goto fail; - ret = read_register(BMX055_MAGN_I2C_REG_DIG_Z2_LSB, trim_z2, 2); - if (ret < 0) goto fail; - ret = read_register(BMX055_MAGN_I2C_REG_DIG_Z3_LSB, trim_z3, 2); - if (ret < 0) goto fail; - ret = read_register(BMX055_MAGN_I2C_REG_DIG_Z4_LSB, trim_z4, 2); - if (ret < 0) goto fail; - ret = read_register(BMX055_MAGN_I2C_REG_DIG_XYZ1_LSB, trim_xyz1, 2); - if (ret < 0) goto fail; - - // Read trim data - trim_data.dig_x1 = trim_x1y1[0]; - trim_data.dig_y1 = trim_x1y1[1]; - - trim_data.dig_x2 = trim_x2y2[0]; - trim_data.dig_y2 = trim_x2y2[1]; - - trim_data.dig_xy1 = trim_xy1xy2[1]; // NB: MSB/LSB swapped - trim_data.dig_xy2 = trim_xy1xy2[0]; - - trim_data.dig_z1 = read_16_bit(trim_z1[0], trim_z1[1]); - trim_data.dig_z2 = read_16_bit(trim_z2[0], trim_z2[1]); - trim_data.dig_z3 = read_16_bit(trim_z3[0], trim_z3[1]); - trim_data.dig_z4 = read_16_bit(trim_z4[0], trim_z4[1]); - - trim_data.dig_xyz1 = read_16_bit(trim_xyz1[0], trim_xyz1[1] & 0x7f); - assert(trim_data.dig_xyz1 != 0); - - perform_self_test(); - - // f_max = 1 / (145us * nXY + 500us * NZ + 980us) - // Chose NXY = 7, NZ = 12, which gives 125 Hz, - // and has the same ratio as the high accuracy preset - ret = set_register(BMX055_MAGN_I2C_REG_REPXY, (7 - 1) / 2); - if (ret < 0) { - goto fail; - } - - ret = set_register(BMX055_MAGN_I2C_REG_REPZ, 12 - 1); - if (ret < 0) { - goto fail; - } - - enabled = true; - return 0; - - fail: - return ret; -} - -int BMX055_Magn::shutdown() { - if (!enabled) return 0; - - // move to suspend mode - int ret = set_register(BMX055_MAGN_I2C_REG_PWR_0, 0); - if (ret < 0) { - LOGE("Could not move BMX055 MAGN in suspend mode!"); - } - - return ret; -} - -bool BMX055_Magn::perform_self_test() { - uint8_t buffer[8]; - int16_t x, y; - int16_t neg_z, pos_z; - - // Increase z reps for less false positives (~30 Hz ODR) - set_register(BMX055_MAGN_I2C_REG_REPXY, 1); - set_register(BMX055_MAGN_I2C_REG_REPZ, 64 - 1); - - // Clean existing measurement - read_register(BMX055_MAGN_I2C_REG_DATAX_LSB, buffer, sizeof(buffer)); - - uint8_t forced = BMX055_MAGN_FORCED; - - // Negative current - set_register(BMX055_MAGN_I2C_REG_MAG, forced | (uint8_t(0b10) << 6)); - util::sleep_for(100); - - read_register(BMX055_MAGN_I2C_REG_DATAX_LSB, buffer, sizeof(buffer)); - parse_xyz(buffer, &x, &y, &neg_z); - - // Positive current - set_register(BMX055_MAGN_I2C_REG_MAG, forced | (uint8_t(0b11) << 6)); - util::sleep_for(100); - - read_register(BMX055_MAGN_I2C_REG_DATAX_LSB, buffer, sizeof(buffer)); - parse_xyz(buffer, &x, &y, &pos_z); - - // Put back in normal mode - set_register(BMX055_MAGN_I2C_REG_MAG, 0); - - int16_t diff = pos_z - neg_z; - bool passed = (diff > 180) && (diff < 240); - - if (!passed) { - LOGE("self test failed: neg %d pos %d diff %d", neg_z, pos_z, diff); - } - - return passed; -} - -bool BMX055_Magn::parse_xyz(uint8_t buffer[8], int16_t *x, int16_t *y, int16_t *z) { - bool ready = buffer[6] & 0x1; - if (ready) { - int16_t mdata_x = (int16_t) (((int16_t)buffer[1] << 8) | buffer[0]) >> 3; - int16_t mdata_y = (int16_t) (((int16_t)buffer[3] << 8) | buffer[2]) >> 3; - int16_t mdata_z = (int16_t) (((int16_t)buffer[5] << 8) | buffer[4]) >> 1; - uint16_t data_r = (uint16_t) (((uint16_t)buffer[7] << 8) | buffer[6]) >> 2; - assert(data_r != 0); - - *x = compensate_x(trim_data, mdata_x, data_r); - *y = compensate_y(trim_data, mdata_y, data_r); - *z = compensate_z(trim_data, mdata_z, data_r); - } - return ready; -} - - -bool BMX055_Magn::get_event(MessageBuilder &msg, uint64_t ts) { - uint64_t start_time = nanos_since_boot(); - uint8_t buffer[8]; - int16_t _x, _y, x, y, z; - - int len = read_register(BMX055_MAGN_I2C_REG_DATAX_LSB, buffer, sizeof(buffer)); - assert(len == sizeof(buffer)); - - bool parsed = parse_xyz(buffer, &_x, &_y, &z); - if (parsed) { - - auto event = msg.initEvent().initMagnetometer(); - event.setSource(cereal::SensorEventData::SensorSource::BMX055); - event.setVersion(2); - event.setSensor(SENSOR_MAGNETOMETER_UNCALIBRATED); - event.setType(SENSOR_TYPE_MAGNETIC_FIELD_UNCALIBRATED); - event.setTimestamp(start_time); - - // Move magnetometer into same reference frame as accel/gryo - x = -_y; - y = _x; - - // Axis convention - x = -x; - y = -y; - - float xyz[] = {(float)x, (float)y, (float)z}; - auto svec = event.initMagneticUncalibrated(); - svec.setV(xyz); - svec.setStatus(true); - } - - // The BMX055 Magnetometer has no FIFO mode. Self running mode only goes - // up to 30 Hz. Therefore we put in forced mode, and request measurements - // at a 100 Hz. When reading the registers we have to check the ready bit - // To verify the measurement was completed this cycle. - set_register(BMX055_MAGN_I2C_REG_MAG, BMX055_MAGN_FORCED); - - return parsed; -} diff --git a/system/sensord/sensors/bmx055_magn.h b/system/sensord/sensors/bmx055_magn.h deleted file mode 100644 index 15c4e734..00000000 --- a/system/sensord/sensors/bmx055_magn.h +++ /dev/null @@ -1,64 +0,0 @@ -#pragma once -#include - -#include "system/sensord/sensors/i2c_sensor.h" - -// Address of the chip on the bus -#define BMX055_MAGN_I2C_ADDR 0x10 - -// Registers of the chip -#define BMX055_MAGN_I2C_REG_ID 0x40 -#define BMX055_MAGN_I2C_REG_PWR_0 0x4B -#define BMX055_MAGN_I2C_REG_MAG 0x4C -#define BMX055_MAGN_I2C_REG_DATAX_LSB 0x42 -#define BMX055_MAGN_I2C_REG_RHALL_LSB 0x48 -#define BMX055_MAGN_I2C_REG_REPXY 0x51 -#define BMX055_MAGN_I2C_REG_REPZ 0x52 - -#define BMX055_MAGN_I2C_REG_DIG_X1 0x5D -#define BMX055_MAGN_I2C_REG_DIG_Y1 0x5E -#define BMX055_MAGN_I2C_REG_DIG_Z4_LSB 0x62 -#define BMX055_MAGN_I2C_REG_DIG_Z4_MSB 0x63 -#define BMX055_MAGN_I2C_REG_DIG_X2 0x64 -#define BMX055_MAGN_I2C_REG_DIG_Y2 0x65 -#define BMX055_MAGN_I2C_REG_DIG_Z2_LSB 0x68 -#define BMX055_MAGN_I2C_REG_DIG_Z2_MSB 0x69 -#define BMX055_MAGN_I2C_REG_DIG_Z1_LSB 0x6A -#define BMX055_MAGN_I2C_REG_DIG_Z1_MSB 0x6B -#define BMX055_MAGN_I2C_REG_DIG_XYZ1_LSB 0x6C -#define BMX055_MAGN_I2C_REG_DIG_XYZ1_MSB 0x6D -#define BMX055_MAGN_I2C_REG_DIG_Z3_LSB 0x6E -#define BMX055_MAGN_I2C_REG_DIG_Z3_MSB 0x6F -#define BMX055_MAGN_I2C_REG_DIG_XY2 0x70 -#define BMX055_MAGN_I2C_REG_DIG_XY1 0x71 - -// Constants -#define BMX055_MAGN_CHIP_ID 0x32 -#define BMX055_MAGN_FORCED (0b01 << 1) - -struct trim_data_t { - int8_t dig_x1; - int8_t dig_y1; - int8_t dig_x2; - int8_t dig_y2; - uint16_t dig_z1; - int16_t dig_z2; - int16_t dig_z3; - int16_t dig_z4; - uint8_t dig_xy1; - int8_t dig_xy2; - uint16_t dig_xyz1; -}; - - -class BMX055_Magn : public I2CSensor{ - uint8_t get_device_address() {return BMX055_MAGN_I2C_ADDR;} - trim_data_t trim_data = {0}; - bool perform_self_test(); - bool parse_xyz(uint8_t buffer[8], int16_t *x, int16_t *y, int16_t *z); -public: - BMX055_Magn(I2CBus *bus); - int init(); - bool get_event(MessageBuilder &msg, uint64_t ts = 0); - int shutdown(); -}; diff --git a/system/sensord/sensors/bmx055_temp.cc b/system/sensord/sensors/bmx055_temp.cc deleted file mode 100644 index da7b8647..00000000 --- a/system/sensord/sensors/bmx055_temp.cc +++ /dev/null @@ -1,31 +0,0 @@ -#include "system/sensord/sensors/bmx055_temp.h" - -#include - -#include "system/sensord/sensors/bmx055_accel.h" -#include "common/swaglog.h" -#include "common/timing.h" - -BMX055_Temp::BMX055_Temp(I2CBus *bus) : I2CSensor(bus) {} - -int BMX055_Temp::init() { - return verify_chip_id(BMX055_ACCEL_I2C_REG_ID, {BMX055_ACCEL_CHIP_ID}) == -1 ? -1 : 0; -} - -bool BMX055_Temp::get_event(MessageBuilder &msg, uint64_t ts) { - uint64_t start_time = nanos_since_boot(); - uint8_t buffer[1]; - int len = read_register(BMX055_ACCEL_I2C_REG_TEMP, buffer, sizeof(buffer)); - assert(len == sizeof(buffer)); - - float temp = 23.0f + int8_t(buffer[0]) / 2.0f; - - auto event = msg.initEvent().initTemperatureSensor(); - event.setSource(cereal::SensorEventData::SensorSource::BMX055); - event.setVersion(1); - event.setType(SENSOR_TYPE_AMBIENT_TEMPERATURE); - event.setTimestamp(start_time); - event.setTemperature(temp); - - return true; -} diff --git a/system/sensord/sensors/bmx055_temp.h b/system/sensord/sensors/bmx055_temp.h deleted file mode 100644 index a2eabae3..00000000 --- a/system/sensord/sensors/bmx055_temp.h +++ /dev/null @@ -1,13 +0,0 @@ -#pragma once - -#include "system/sensord/sensors/bmx055_accel.h" -#include "system/sensord/sensors/i2c_sensor.h" - -class BMX055_Temp : public I2CSensor { - uint8_t get_device_address() {return BMX055_ACCEL_I2C_ADDR;} -public: - BMX055_Temp(I2CBus *bus); - int init(); - bool get_event(MessageBuilder &msg, uint64_t ts = 0); - int shutdown() { return 0; } -}; diff --git a/system/sensord/sensors/constants.h b/system/sensord/sensors/constants.h deleted file mode 100644 index c216f838..00000000 --- a/system/sensord/sensors/constants.h +++ /dev/null @@ -1,18 +0,0 @@ -#pragma once - - -#define SENSOR_ACCELEROMETER 1 -#define SENSOR_MAGNETOMETER 2 -#define SENSOR_MAGNETOMETER_UNCALIBRATED 3 -#define SENSOR_GYRO 4 -#define SENSOR_GYRO_UNCALIBRATED 5 -#define SENSOR_LIGHT 7 - -#define SENSOR_TYPE_ACCELEROMETER 1 -#define SENSOR_TYPE_GEOMAGNETIC_FIELD 2 -#define SENSOR_TYPE_GYROSCOPE 4 -#define SENSOR_TYPE_LIGHT 5 -#define SENSOR_TYPE_AMBIENT_TEMPERATURE 13 -#define SENSOR_TYPE_MAGNETIC_FIELD_UNCALIBRATED 14 -#define SENSOR_TYPE_MAGNETIC_FIELD SENSOR_TYPE_GEOMAGNETIC_FIELD -#define SENSOR_TYPE_GYROSCOPE_UNCALIBRATED 16 diff --git a/system/sensord/sensors/i2c_sensor.cc b/system/sensord/sensors/i2c_sensor.cc deleted file mode 100644 index 90220f55..00000000 --- a/system/sensord/sensors/i2c_sensor.cc +++ /dev/null @@ -1,50 +0,0 @@ -#include "system/sensord/sensors/i2c_sensor.h" - -int16_t read_12_bit(uint8_t lsb, uint8_t msb) { - uint16_t combined = (uint16_t(msb) << 8) | uint16_t(lsb & 0xF0); - return int16_t(combined) / (1 << 4); -} - -int16_t read_16_bit(uint8_t lsb, uint8_t msb) { - uint16_t combined = (uint16_t(msb) << 8) | uint16_t(lsb); - return int16_t(combined); -} - -int32_t read_20_bit(uint8_t b2, uint8_t b1, uint8_t b0) { - uint32_t combined = (uint32_t(b0) << 16) | (uint32_t(b1) << 8) | uint32_t(b2); - return int32_t(combined) / (1 << 4); -} - -I2CSensor::I2CSensor(I2CBus *bus, int gpio_nr, bool shared_gpio) : - bus(bus), gpio_nr(gpio_nr), shared_gpio(shared_gpio) {} - -I2CSensor::~I2CSensor() { - if (gpio_fd != -1) { - close(gpio_fd); - } -} - -int I2CSensor::read_register(uint register_address, uint8_t *buffer, uint8_t len) { - return bus->read_register(get_device_address(), register_address, buffer, len); -} - -int I2CSensor::set_register(uint register_address, uint8_t data) { - return bus->set_register(get_device_address(), register_address, data); -} - -int I2CSensor::init_gpio() { - if (shared_gpio || gpio_nr == 0) { - return 0; - } - - gpio_fd = gpiochip_get_ro_value_fd("sensord", GPIOCHIP_INT, gpio_nr); - if (gpio_fd < 0) { - return -1; - } - - return 0; -} - -bool I2CSensor::has_interrupt_enabled() { - return gpio_nr != 0; -} diff --git a/system/sensord/sensors/i2c_sensor.h b/system/sensord/sensors/i2c_sensor.h deleted file mode 100644 index e6d328ce..00000000 --- a/system/sensord/sensors/i2c_sensor.h +++ /dev/null @@ -1,51 +0,0 @@ -#pragma once - -#include -#include -#include -#include "cereal/gen/cpp/log.capnp.h" - -#include "common/i2c.h" -#include "common/gpio.h" - -#include "common/swaglog.h" -#include "system/sensord/sensors/constants.h" -#include "system/sensord/sensors/sensor.h" - -int16_t read_12_bit(uint8_t lsb, uint8_t msb); -int16_t read_16_bit(uint8_t lsb, uint8_t msb); -int32_t read_20_bit(uint8_t b2, uint8_t b1, uint8_t b0); - - -class I2CSensor : public Sensor { -private: - I2CBus *bus; - int gpio_nr; - bool shared_gpio; - virtual uint8_t get_device_address() = 0; - -public: - I2CSensor(I2CBus *bus, int gpio_nr = 0, bool shared_gpio = false); - ~I2CSensor(); - int read_register(uint register_address, uint8_t *buffer, uint8_t len); - int set_register(uint register_address, uint8_t data); - int init_gpio(); - bool has_interrupt_enabled(); - virtual int init() = 0; - virtual bool get_event(MessageBuilder &msg, uint64_t ts = 0) = 0; - virtual int shutdown() = 0; - - int verify_chip_id(uint8_t address, const std::vector &expected_ids) { - uint8_t chip_id = 0; - int ret = read_register(address, &chip_id, 1); - if (ret < 0) { - LOGD("Reading chip ID failed: %d", ret); - return -1; - } - for (int i = 0; i < expected_ids.size(); ++i) { - if (chip_id == expected_ids[i]) return chip_id; - } - LOGE("Chip ID wrong. Got: %d, Expected %d", chip_id, expected_ids[0]); - return -1; - } -}; diff --git a/system/sensord/sensors/i2c_sensor.py b/system/sensord/sensors/i2c_sensor.py new file mode 100644 index 00000000..336ebb1f --- /dev/null +++ b/system/sensord/sensors/i2c_sensor.py @@ -0,0 +1,77 @@ +import time +import smbus2 +import ctypes +from collections.abc import Iterable + +from cereal import log + +class Sensor: + class SensorException(Exception): + pass + + class DataNotReady(SensorException): + pass + + def __init__(self, bus: int) -> None: + self.bus = smbus2.SMBus(bus) + self.source = log.SensorEventData.SensorSource.velodyne # unknown + self.start_ts = 0. + + def __del__(self): + self.bus.close() + + def read(self, addr: int, length: int) -> bytes: + return bytes(self.bus.read_i2c_block_data(self.device_address, addr, length)) + + def write(self, addr: int, data: int) -> None: + self.bus.write_byte_data(self.device_address, addr, data) + + def writes(self, writes: Iterable[tuple[int, int]]) -> None: + for addr, data in writes: + self.write(addr, data) + + def verify_chip_id(self, address: int, expected_ids: list[int]) -> int: + chip_id = self.read(address, 1)[0] + assert chip_id in expected_ids + return chip_id + + # Abstract methods that must be implemented by subclasses + @property + def device_address(self) -> int: + raise NotImplementedError + + def reset(self) -> None: + # optional. + # not part of init due to shared registers + pass + + def init(self) -> None: + raise NotImplementedError + + def get_event(self, ts: int | None = None) -> log.SensorEventData: + raise NotImplementedError + + def shutdown(self) -> None: + raise NotImplementedError + + def is_data_valid(self) -> bool: + if self.start_ts == 0: + self.start_ts = time.monotonic() + + # unclear whether we need this... + return (time.monotonic() - self.start_ts) > 0.5 + + # *** helpers *** + @staticmethod + def wait(): + # a standard small sleep + time.sleep(0.005) + + @staticmethod + def parse_16bit(lsb: int, msb: int) -> int: + return ctypes.c_int16((msb << 8) | lsb).value + + @staticmethod + def parse_20bit(b2: int, b1: int, b0: int) -> int: + combined = ctypes.c_uint32((b0 << 16) | (b1 << 8) | b2).value + return ctypes.c_int32(combined).value // (1 << 4) diff --git a/system/sensord/sensors/lsm6ds3_accel.cc b/system/sensord/sensors/lsm6ds3_accel.cc deleted file mode 100644 index 03533e06..00000000 --- a/system/sensord/sensors/lsm6ds3_accel.cc +++ /dev/null @@ -1,250 +0,0 @@ -#include "system/sensord/sensors/lsm6ds3_accel.h" - -#include -#include -#include - -#include "common/swaglog.h" -#include "common/timing.h" -#include "common/util.h" - -LSM6DS3_Accel::LSM6DS3_Accel(I2CBus *bus, int gpio_nr, bool shared_gpio) : - I2CSensor(bus, gpio_nr, shared_gpio) {} - -void LSM6DS3_Accel::wait_for_data_ready() { - uint8_t drdy = 0; - uint8_t buffer[6]; - - do { - read_register(LSM6DS3_ACCEL_I2C_REG_STAT_REG, &drdy, sizeof(drdy)); - drdy &= LSM6DS3_ACCEL_DRDY_XLDA; - } while (drdy == 0); - - read_register(LSM6DS3_ACCEL_I2C_REG_OUTX_L_XL, buffer, sizeof(buffer)); -} - -void LSM6DS3_Accel::read_and_avg_data(float* out_buf) { - uint8_t drdy = 0; - uint8_t buffer[6]; - - float scaling = 0.061f; - if (source == cereal::SensorEventData::SensorSource::LSM6DS3TRC) { - scaling = 0.122f; - } - - for (int i = 0; i < 5; i++) { - do { - read_register(LSM6DS3_ACCEL_I2C_REG_STAT_REG, &drdy, sizeof(drdy)); - drdy &= LSM6DS3_ACCEL_DRDY_XLDA; - } while (drdy == 0); - - int len = read_register(LSM6DS3_ACCEL_I2C_REG_OUTX_L_XL, buffer, sizeof(buffer)); - assert(len == sizeof(buffer)); - - for (int j = 0; j < 3; j++) { - out_buf[j] += (float)read_16_bit(buffer[j*2], buffer[j*2+1]) * scaling; - } - } - - for (int i = 0; i < 3; i++) { - out_buf[i] /= 5.0f; - } -} - -int LSM6DS3_Accel::self_test(int test_type) { - float val_st_off[3] = {0}; - float val_st_on[3] = {0}; - float test_val[3] = {0}; - uint8_t ODR_FS_MO = LSM6DS3_ACCEL_ODR_52HZ; // full scale: +-2g, ODR: 52Hz - - // prepare sensor for self-test - - // enable block data update and automatic increment - int ret = set_register(LSM6DS3_ACCEL_I2C_REG_CTRL3_C, LSM6DS3_ACCEL_IF_INC_BDU); - if (ret < 0) { - return ret; - } - - if (source == cereal::SensorEventData::SensorSource::LSM6DS3TRC) { - ODR_FS_MO = LSM6DS3_ACCEL_FS_4G | LSM6DS3_ACCEL_ODR_52HZ; - } - ret = set_register(LSM6DS3_ACCEL_I2C_REG_CTRL1_XL, ODR_FS_MO); - if (ret < 0) { - return ret; - } - - // wait for stable output, and discard first values - util::sleep_for(100); - wait_for_data_ready(); - read_and_avg_data(val_st_off); - - // enable Self Test positive (or negative) - ret = set_register(LSM6DS3_ACCEL_I2C_REG_CTRL5_C, test_type); - if (ret < 0) { - return ret; - } - - // wait for stable output, and discard first values - util::sleep_for(100); - wait_for_data_ready(); - read_and_avg_data(val_st_on); - - // disable sensor - ret = set_register(LSM6DS3_ACCEL_I2C_REG_CTRL1_XL, 0); - if (ret < 0) { - return ret; - } - - // disable self test - ret = set_register(LSM6DS3_ACCEL_I2C_REG_CTRL5_C, 0); - if (ret < 0) { - return ret; - } - - // calculate the mg values for self test - for (int i = 0; i < 3; i++) { - test_val[i] = fabs(val_st_on[i] - val_st_off[i]); - } - - // verify test result - for (int i = 0; i < 3; i++) { - if ((LSM6DS3_ACCEL_MIN_ST_LIMIT_mg > test_val[i]) || - (test_val[i] > LSM6DS3_ACCEL_MAX_ST_LIMIT_mg)) { - return -1; - } - } - - return ret; -} - -int LSM6DS3_Accel::init() { - uint8_t value = 0; - bool do_self_test = false; - - const char* env_lsm_selftest = std::getenv("LSM_SELF_TEST"); - if (env_lsm_selftest != nullptr && strncmp(env_lsm_selftest, "1", 1) == 0) { - do_self_test = true; - } - - int ret = verify_chip_id(LSM6DS3_ACCEL_I2C_REG_ID, {LSM6DS3_ACCEL_CHIP_ID, LSM6DS3TRC_ACCEL_CHIP_ID}); - if (ret == -1) return -1; - - if (ret == LSM6DS3TRC_ACCEL_CHIP_ID) { - source = cereal::SensorEventData::SensorSource::LSM6DS3TRC; - } - - ret = self_test(LSM6DS3_ACCEL_POSITIVE_TEST); - if (ret < 0) { - LOGE("LSM6DS3 accel positive self-test failed!"); - if (do_self_test) goto fail; - } - - ret = self_test(LSM6DS3_ACCEL_NEGATIVE_TEST); - if (ret < 0) { - LOGE("LSM6DS3 accel negative self-test failed!"); - if (do_self_test) goto fail; - } - - ret = init_gpio(); - if (ret < 0) { - goto fail; - } - - // enable continuous update, and automatic increase - ret = set_register(LSM6DS3_ACCEL_I2C_REG_CTRL3_C, LSM6DS3_ACCEL_IF_INC); - if (ret < 0) { - goto fail; - } - - // TODO: set scale and bandwidth. Default is +- 2G, 50 Hz - ret = set_register(LSM6DS3_ACCEL_I2C_REG_CTRL1_XL, LSM6DS3_ACCEL_ODR_104HZ); - if (ret < 0) { - goto fail; - } - - ret = set_register(LSM6DS3_ACCEL_I2C_REG_DRDY_CFG, LSM6DS3_ACCEL_DRDY_PULSE_MODE); - if (ret < 0) { - goto fail; - } - - // enable data ready interrupt for accel on INT1 - // (without resetting existing interrupts) - ret = read_register(LSM6DS3_ACCEL_I2C_REG_INT1_CTRL, &value, 1); - if (ret < 0) { - goto fail; - } - - value |= LSM6DS3_ACCEL_INT1_DRDY_XL; - ret = set_register(LSM6DS3_ACCEL_I2C_REG_INT1_CTRL, value); - -fail: - return ret; -} - -int LSM6DS3_Accel::shutdown() { - int ret = 0; - - // disable data ready interrupt for accel on INT1 - uint8_t value = 0; - ret = read_register(LSM6DS3_ACCEL_I2C_REG_INT1_CTRL, &value, 1); - if (ret < 0) { - goto fail; - } - - value &= ~(LSM6DS3_ACCEL_INT1_DRDY_XL); - ret = set_register(LSM6DS3_ACCEL_I2C_REG_INT1_CTRL, value); - if (ret < 0) { - LOGE("Could not disable lsm6ds3 acceleration interrupt!"); - goto fail; - } - - // enable power-down mode - value = 0; - ret = read_register(LSM6DS3_ACCEL_I2C_REG_CTRL1_XL, &value, 1); - if (ret < 0) { - goto fail; - } - - value &= 0x0F; - ret = set_register(LSM6DS3_ACCEL_I2C_REG_CTRL1_XL, value); - if (ret < 0) { - LOGE("Could not power-down lsm6ds3 accelerometer!"); - goto fail; - } - -fail: - return ret; -} - -bool LSM6DS3_Accel::get_event(MessageBuilder &msg, uint64_t ts) { - - // INT1 shared with gyro, check STATUS_REG who triggered - uint8_t status_reg = 0; - read_register(LSM6DS3_ACCEL_I2C_REG_STAT_REG, &status_reg, sizeof(status_reg)); - if ((status_reg & LSM6DS3_ACCEL_DRDY_XLDA) == 0) { - return false; - } - - uint8_t buffer[6]; - int len = read_register(LSM6DS3_ACCEL_I2C_REG_OUTX_L_XL, buffer, sizeof(buffer)); - assert(len == sizeof(buffer)); - - float scale = 9.81 * 2.0f / (1 << 15); - float x = read_16_bit(buffer[0], buffer[1]) * scale; - float y = read_16_bit(buffer[2], buffer[3]) * scale; - float z = read_16_bit(buffer[4], buffer[5]) * scale; - - auto event = msg.initEvent().initAccelerometer(); - event.setSource(source); - event.setVersion(1); - event.setSensor(SENSOR_ACCELEROMETER); - event.setType(SENSOR_TYPE_ACCELEROMETER); - event.setTimestamp(ts); - - float xyz[] = {y, -x, z}; - auto svec = event.initAcceleration(); - svec.setV(xyz); - svec.setStatus(true); - - return true; -} diff --git a/system/sensord/sensors/lsm6ds3_accel.h b/system/sensord/sensors/lsm6ds3_accel.h deleted file mode 100644 index 69667cb7..00000000 --- a/system/sensord/sensors/lsm6ds3_accel.h +++ /dev/null @@ -1,49 +0,0 @@ -#pragma once - -#include "system/sensord/sensors/i2c_sensor.h" - -// Address of the chip on the bus -#define LSM6DS3_ACCEL_I2C_ADDR 0x6A - -// Registers of the chip -#define LSM6DS3_ACCEL_I2C_REG_DRDY_CFG 0x0B -#define LSM6DS3_ACCEL_I2C_REG_ID 0x0F -#define LSM6DS3_ACCEL_I2C_REG_INT1_CTRL 0x0D -#define LSM6DS3_ACCEL_I2C_REG_CTRL1_XL 0x10 -#define LSM6DS3_ACCEL_I2C_REG_CTRL3_C 0x12 -#define LSM6DS3_ACCEL_I2C_REG_CTRL5_C 0x14 -#define LSM6DS3_ACCEL_I2C_REG_CTR9_XL 0x18 -#define LSM6DS3_ACCEL_I2C_REG_STAT_REG 0x1E -#define LSM6DS3_ACCEL_I2C_REG_OUTX_L_XL 0x28 - -// Constants -#define LSM6DS3_ACCEL_CHIP_ID 0x69 -#define LSM6DS3TRC_ACCEL_CHIP_ID 0x6A -#define LSM6DS3_ACCEL_FS_4G (0b10 << 2) -#define LSM6DS3_ACCEL_ODR_52HZ (0b0011 << 4) -#define LSM6DS3_ACCEL_ODR_104HZ (0b0100 << 4) -#define LSM6DS3_ACCEL_INT1_DRDY_XL 0b1 -#define LSM6DS3_ACCEL_DRDY_XLDA 0b1 -#define LSM6DS3_ACCEL_DRDY_PULSE_MODE (1 << 7) -#define LSM6DS3_ACCEL_IF_INC 0b00000100 -#define LSM6DS3_ACCEL_IF_INC_BDU 0b01000100 -#define LSM6DS3_ACCEL_XYZ_DEN 0b11100000 -#define LSM6DS3_ACCEL_POSITIVE_TEST 0b01 -#define LSM6DS3_ACCEL_NEGATIVE_TEST 0b10 -#define LSM6DS3_ACCEL_MIN_ST_LIMIT_mg 90.0f -#define LSM6DS3_ACCEL_MAX_ST_LIMIT_mg 1700.0f - -class LSM6DS3_Accel : public I2CSensor { - uint8_t get_device_address() {return LSM6DS3_ACCEL_I2C_ADDR;} - cereal::SensorEventData::SensorSource source = cereal::SensorEventData::SensorSource::LSM6DS3; - - // self test functions - int self_test(int test_type); - void wait_for_data_ready(); - void read_and_avg_data(float* val_st_off); -public: - LSM6DS3_Accel(I2CBus *bus, int gpio_nr = 0, bool shared_gpio = false); - int init(); - bool get_event(MessageBuilder &msg, uint64_t ts = 0); - int shutdown(); -}; diff --git a/system/sensord/sensors/lsm6ds3_accel.py b/system/sensord/sensors/lsm6ds3_accel.py new file mode 100644 index 00000000..43863daa --- /dev/null +++ b/system/sensord/sensors/lsm6ds3_accel.py @@ -0,0 +1,161 @@ +import os +import time + +from cereal import log +from openpilot.system.sensord.sensors.i2c_sensor import Sensor + +class LSM6DS3_Accel(Sensor): + LSM6DS3_ACCEL_I2C_REG_DRDY_CFG = 0x0B + LSM6DS3_ACCEL_I2C_REG_INT1_CTRL = 0x0D + LSM6DS3_ACCEL_I2C_REG_CTRL1_XL = 0x10 + LSM6DS3_ACCEL_I2C_REG_CTRL3_C = 0x12 + LSM6DS3_ACCEL_I2C_REG_CTRL5_C = 0x14 + LSM6DS3_ACCEL_I2C_REG_STAT_REG = 0x1E + LSM6DS3_ACCEL_I2C_REG_OUTX_L_XL = 0x28 + + LSM6DS3_ACCEL_ODR_104HZ = (0b0100 << 4) + LSM6DS3_ACCEL_INT1_DRDY_XL = 0b1 + LSM6DS3_ACCEL_DRDY_XLDA = 0b1 + LSM6DS3_ACCEL_DRDY_PULSE_MODE = (1 << 7) + LSM6DS3_ACCEL_IF_INC = 0b00000100 + + LSM6DS3_ACCEL_ODR_52HZ = (0b0011 << 4) + LSM6DS3_ACCEL_FS_4G = (0b10 << 2) + LSM6DS3_ACCEL_IF_INC_BDU = 0b01000100 + LSM6DS3_ACCEL_POSITIVE_TEST = 0b01 + LSM6DS3_ACCEL_NEGATIVE_TEST = 0b10 + LSM6DS3_ACCEL_MIN_ST_LIMIT_mg = 90.0 + LSM6DS3_ACCEL_MAX_ST_LIMIT_mg = 1700.0 + + @property + def device_address(self) -> int: + return 0x6A + + def reset(self): + self.write(0x12, 0x1) + time.sleep(0.1) + + def init(self): + chip_id = self.verify_chip_id(0x0F, [0x69, 0x6A]) + if chip_id == 0x6A: + self.source = log.SensorEventData.SensorSource.lsm6ds3trc + else: + self.source = log.SensorEventData.SensorSource.lsm6ds3 + + # self-test + if os.getenv("LSM_SELF_TEST") == "1": + self.self_test(self.LSM6DS3_ACCEL_POSITIVE_TEST) + self.self_test(self.LSM6DS3_ACCEL_NEGATIVE_TEST) + + # actual init + int1 = self.read(self.LSM6DS3_ACCEL_I2C_REG_INT1_CTRL, 1)[0] + int1 |= self.LSM6DS3_ACCEL_INT1_DRDY_XL + self.writes(( + # Enable continuous update and automatic address increment + (self.LSM6DS3_ACCEL_I2C_REG_CTRL3_C, self.LSM6DS3_ACCEL_IF_INC), + # Set ODR to 104 Hz, FS to ±2g (default) + (self.LSM6DS3_ACCEL_I2C_REG_CTRL1_XL, self.LSM6DS3_ACCEL_ODR_104HZ), + # Configure data ready signal to pulse mode + (self.LSM6DS3_ACCEL_I2C_REG_DRDY_CFG, self.LSM6DS3_ACCEL_DRDY_PULSE_MODE), + # Enable data ready interrupt on INT1 without resetting existing interrupts + (self.LSM6DS3_ACCEL_I2C_REG_INT1_CTRL, int1), + )) + + def get_event(self, ts: int | None = None) -> log.SensorEventData: + assert ts is not None # must come from the IRQ event + + # Check if data is ready since IRQ is shared with gyro + status_reg = self.read(self.LSM6DS3_ACCEL_I2C_REG_STAT_REG, 1)[0] + if (status_reg & self.LSM6DS3_ACCEL_DRDY_XLDA) == 0: + raise self.DataNotReady + + scale = 9.81 * 2.0 / (1 << 15) + b = self.read(self.LSM6DS3_ACCEL_I2C_REG_OUTX_L_XL, 6) + x = self.parse_16bit(b[0], b[1]) * scale + y = self.parse_16bit(b[2], b[3]) * scale + z = self.parse_16bit(b[4], b[5]) * scale + + event = log.SensorEventData.new_message() + event.timestamp = ts + event.version = 1 + event.sensor = 1 # SENSOR_ACCELEROMETER + event.type = 1 # SENSOR_TYPE_ACCELEROMETER + event.source = self.source + a = event.init('acceleration') + a.v = [y, -x, z] + a.status = 1 + return event + + def shutdown(self) -> None: + # Disable data ready interrupt on INT1 + value = self.read(self.LSM6DS3_ACCEL_I2C_REG_INT1_CTRL, 1)[0] + value &= ~self.LSM6DS3_ACCEL_INT1_DRDY_XL + self.write(self.LSM6DS3_ACCEL_I2C_REG_INT1_CTRL, value) + + # Power down by clearing ODR bits + value = self.read(self.LSM6DS3_ACCEL_I2C_REG_CTRL1_XL, 1)[0] + value &= 0x0F + self.write(self.LSM6DS3_ACCEL_I2C_REG_CTRL1_XL, value) + + # *** self-test stuff *** + def _wait_for_data_ready(self): + while True: + drdy = self.read(self.LSM6DS3_ACCEL_I2C_REG_STAT_REG, 1)[0] + if drdy & self.LSM6DS3_ACCEL_DRDY_XLDA: + break + + def _read_and_avg_data(self, scaling: float) -> list[float]: + out_buf = [0.0, 0.0, 0.0] + for _ in range(5): + self._wait_for_data_ready() + b = self.read(self.LSM6DS3_ACCEL_I2C_REG_OUTX_L_XL, 6) + for j in range(3): + val = self.parse_16bit(b[j*2], b[j*2+1]) * scaling + out_buf[j] += val + return [x / 5.0 for x in out_buf] + + def self_test(self, test_type: int) -> None: + # Prepare sensor for self-test + self.write(self.LSM6DS3_ACCEL_I2C_REG_CTRL3_C, self.LSM6DS3_ACCEL_IF_INC_BDU) + + # Configure ODR and full scale based on sensor type + if self.source == log.SensorEventData.SensorSource.lsm6ds3trc: + odr_fs = self.LSM6DS3_ACCEL_FS_4G | self.LSM6DS3_ACCEL_ODR_52HZ + scaling = 0.122 # mg/LSB for ±4g + else: + odr_fs = self.LSM6DS3_ACCEL_ODR_52HZ + scaling = 0.061 # mg/LSB for ±2g + self.write(self.LSM6DS3_ACCEL_I2C_REG_CTRL1_XL, odr_fs) + + # Wait for stable output + time.sleep(0.1) + self._wait_for_data_ready() + val_st_off = self._read_and_avg_data(scaling) + + # Enable self-test + self.write(self.LSM6DS3_ACCEL_I2C_REG_CTRL5_C, test_type) + + # Wait for stable output + time.sleep(0.1) + self._wait_for_data_ready() + val_st_on = self._read_and_avg_data(scaling) + + # Disable sensor and self-test + self.write(self.LSM6DS3_ACCEL_I2C_REG_CTRL1_XL, 0) + self.write(self.LSM6DS3_ACCEL_I2C_REG_CTRL5_C, 0) + + # Calculate differences and check limits + test_val = [abs(on - off) for on, off in zip(val_st_on, val_st_off, strict=False)] + for val in test_val: + if val < self.LSM6DS3_ACCEL_MIN_ST_LIMIT_mg or val > self.LSM6DS3_ACCEL_MAX_ST_LIMIT_mg: + raise self.SensorException(f"Accelerometer self-test failed for test type {test_type}") + +if __name__ == "__main__": + import numpy as np + s = LSM6DS3_Accel(1) + s.init() + time.sleep(0.2) + e = s.get_event(0) + print(e) + print(np.linalg.norm(e.acceleration.v)) + s.shutdown() diff --git a/system/sensord/sensors/lsm6ds3_gyro.cc b/system/sensord/sensors/lsm6ds3_gyro.cc deleted file mode 100644 index bb560ede..00000000 --- a/system/sensord/sensors/lsm6ds3_gyro.cc +++ /dev/null @@ -1,233 +0,0 @@ -#include "system/sensord/sensors/lsm6ds3_gyro.h" - -#include -#include -#include - -#include "common/swaglog.h" -#include "common/timing.h" -#include "common/util.h" - -#define DEG2RAD(x) ((x) * M_PI / 180.0) - -LSM6DS3_Gyro::LSM6DS3_Gyro(I2CBus *bus, int gpio_nr, bool shared_gpio) : - I2CSensor(bus, gpio_nr, shared_gpio) {} - -void LSM6DS3_Gyro::wait_for_data_ready() { - uint8_t drdy = 0; - uint8_t buffer[6]; - - do { - read_register(LSM6DS3_GYRO_I2C_REG_STAT_REG, &drdy, sizeof(drdy)); - drdy &= LSM6DS3_GYRO_DRDY_GDA; - } while (drdy == 0); - - read_register(LSM6DS3_GYRO_I2C_REG_OUTX_L_G, buffer, sizeof(buffer)); -} - -void LSM6DS3_Gyro::read_and_avg_data(float* out_buf) { - uint8_t drdy = 0; - uint8_t buffer[6]; - - for (int i = 0; i < 5; i++) { - do { - read_register(LSM6DS3_GYRO_I2C_REG_STAT_REG, &drdy, sizeof(drdy)); - drdy &= LSM6DS3_GYRO_DRDY_GDA; - } while (drdy == 0); - - int len = read_register(LSM6DS3_GYRO_I2C_REG_OUTX_L_G, buffer, sizeof(buffer)); - assert(len == sizeof(buffer)); - - for (int j = 0; j < 3; j++) { - out_buf[j] += (float)read_16_bit(buffer[j*2], buffer[j*2+1]) * 70.0f; - } - } - - // calculate the mg average values - for (int i = 0; i < 3; i++) { - out_buf[i] /= 5.0f; - } -} - -int LSM6DS3_Gyro::self_test(int test_type) { - float val_st_off[3] = {0}; - float val_st_on[3] = {0}; - float test_val[3] = {0}; - - // prepare sensor for self-test - - // full scale: 2000dps, ODR: 208Hz - int ret = set_register(LSM6DS3_GYRO_I2C_REG_CTRL2_G, LSM6DS3_GYRO_ODR_208HZ | LSM6DS3_GYRO_FS_2000dps); - if (ret < 0) { - return ret; - } - - // wait for stable output, and discard first values - util::sleep_for(150); - wait_for_data_ready(); - read_and_avg_data(val_st_off); - - // enable Self Test positive (or negative) - ret = set_register(LSM6DS3_GYRO_I2C_REG_CTRL5_C, test_type); - if (ret < 0) { - return ret; - } - - // wait for stable output, and discard first values - util::sleep_for(50); - wait_for_data_ready(); - read_and_avg_data(val_st_on); - - // disable sensor - ret = set_register(LSM6DS3_GYRO_I2C_REG_CTRL2_G, 0); - if (ret < 0) { - return ret; - } - - // disable self test - ret = set_register(LSM6DS3_GYRO_I2C_REG_CTRL5_C, 0); - if (ret < 0) { - return ret; - } - - // calculate the mg values for self test - for (int i = 0; i < 3; i++) { - test_val[i] = fabs(val_st_on[i] - val_st_off[i]); - } - - // verify test result - for (int i = 0; i < 3; i++) { - if ((LSM6DS3_GYRO_MIN_ST_LIMIT_mdps > test_val[i]) || - (test_val[i] > LSM6DS3_GYRO_MAX_ST_LIMIT_mdps)) { - return -1; - } - } - - return ret; -} - -int LSM6DS3_Gyro::init() { - uint8_t value = 0; - bool do_self_test = false; - - const char* env_lsm_selftest = std::getenv("LSM_SELF_TEST"); - if (env_lsm_selftest != nullptr && strncmp(env_lsm_selftest, "1", 1) == 0) { - do_self_test = true; - } - - int ret = verify_chip_id(LSM6DS3_GYRO_I2C_REG_ID, {LSM6DS3_GYRO_CHIP_ID, LSM6DS3TRC_GYRO_CHIP_ID}); - if (ret == -1) return -1; - - if (ret == LSM6DS3TRC_GYRO_CHIP_ID) { - source = cereal::SensorEventData::SensorSource::LSM6DS3TRC; - } - - ret = init_gpio(); - if (ret < 0) { - goto fail; - } - - ret = self_test(LSM6DS3_GYRO_POSITIVE_TEST); - if (ret < 0) { - LOGE("LSM6DS3 gyro positive self-test failed!"); - if (do_self_test) goto fail; - } - - ret = self_test(LSM6DS3_GYRO_NEGATIVE_TEST); - if (ret < 0) { - LOGE("LSM6DS3 gyro negative self-test failed!"); - if (do_self_test) goto fail; - } - - // TODO: set scale. Default is +- 250 deg/s - ret = set_register(LSM6DS3_GYRO_I2C_REG_CTRL2_G, LSM6DS3_GYRO_ODR_104HZ); - if (ret < 0) { - goto fail; - } - - ret = set_register(LSM6DS3_GYRO_I2C_REG_DRDY_CFG, LSM6DS3_GYRO_DRDY_PULSE_MODE); - if (ret < 0) { - goto fail; - } - - // enable data ready interrupt for gyro on INT1 - // (without resetting existing interrupts) - ret = read_register(LSM6DS3_GYRO_I2C_REG_INT1_CTRL, &value, 1); - if (ret < 0) { - goto fail; - } - - value |= LSM6DS3_GYRO_INT1_DRDY_G; - ret = set_register(LSM6DS3_GYRO_I2C_REG_INT1_CTRL, value); - -fail: - return ret; -} - -int LSM6DS3_Gyro::shutdown() { - int ret = 0; - - // disable data ready interrupt for gyro on INT1 - uint8_t value = 0; - ret = read_register(LSM6DS3_GYRO_I2C_REG_INT1_CTRL, &value, 1); - if (ret < 0) { - goto fail; - } - - value &= ~(LSM6DS3_GYRO_INT1_DRDY_G); - ret = set_register(LSM6DS3_GYRO_I2C_REG_INT1_CTRL, value); - if (ret < 0) { - LOGE("Could not disable lsm6ds3 gyroscope interrupt!"); - goto fail; - } - - // enable power-down mode - value = 0; - ret = read_register(LSM6DS3_GYRO_I2C_REG_CTRL2_G, &value, 1); - if (ret < 0) { - goto fail; - } - - value &= 0x0F; - ret = set_register(LSM6DS3_GYRO_I2C_REG_CTRL2_G, value); - if (ret < 0) { - LOGE("Could not power-down lsm6ds3 gyroscope!"); - goto fail; - } - -fail: - return ret; -} - -bool LSM6DS3_Gyro::get_event(MessageBuilder &msg, uint64_t ts) { - - // INT1 shared with accel, check STATUS_REG who triggered - uint8_t status_reg = 0; - read_register(LSM6DS3_GYRO_I2C_REG_STAT_REG, &status_reg, sizeof(status_reg)); - if ((status_reg & LSM6DS3_GYRO_DRDY_GDA) == 0) { - return false; - } - - uint8_t buffer[6]; - int len = read_register(LSM6DS3_GYRO_I2C_REG_OUTX_L_G, buffer, sizeof(buffer)); - assert(len == sizeof(buffer)); - - float scale = 8.75 / 1000.0; - float x = DEG2RAD(read_16_bit(buffer[0], buffer[1]) * scale); - float y = DEG2RAD(read_16_bit(buffer[2], buffer[3]) * scale); - float z = DEG2RAD(read_16_bit(buffer[4], buffer[5]) * scale); - - auto event = msg.initEvent().initGyroscope(); - event.setSource(source); - event.setVersion(2); - event.setSensor(SENSOR_GYRO_UNCALIBRATED); - event.setType(SENSOR_TYPE_GYROSCOPE_UNCALIBRATED); - event.setTimestamp(ts); - - float xyz[] = {y, -x, z}; - auto svec = event.initGyroUncalibrated(); - svec.setV(xyz); - svec.setStatus(true); - - return true; -} diff --git a/system/sensord/sensors/lsm6ds3_gyro.h b/system/sensord/sensors/lsm6ds3_gyro.h deleted file mode 100644 index adaae62d..00000000 --- a/system/sensord/sensors/lsm6ds3_gyro.h +++ /dev/null @@ -1,45 +0,0 @@ -#pragma once - -#include "system/sensord/sensors/i2c_sensor.h" - -// Address of the chip on the bus -#define LSM6DS3_GYRO_I2C_ADDR 0x6A - -// Registers of the chip -#define LSM6DS3_GYRO_I2C_REG_DRDY_CFG 0x0B -#define LSM6DS3_GYRO_I2C_REG_ID 0x0F -#define LSM6DS3_GYRO_I2C_REG_INT1_CTRL 0x0D -#define LSM6DS3_GYRO_I2C_REG_CTRL2_G 0x11 -#define LSM6DS3_GYRO_I2C_REG_CTRL5_C 0x14 -#define LSM6DS3_GYRO_I2C_REG_STAT_REG 0x1E -#define LSM6DS3_GYRO_I2C_REG_OUTX_L_G 0x22 -#define LSM6DS3_GYRO_POSITIVE_TEST (0b01 << 2) -#define LSM6DS3_GYRO_NEGATIVE_TEST (0b11 << 2) - -// Constants -#define LSM6DS3_GYRO_CHIP_ID 0x69 -#define LSM6DS3TRC_GYRO_CHIP_ID 0x6A -#define LSM6DS3_GYRO_FS_2000dps (0b11 << 2) -#define LSM6DS3_GYRO_ODR_104HZ (0b0100 << 4) -#define LSM6DS3_GYRO_ODR_208HZ (0b0101 << 4) -#define LSM6DS3_GYRO_INT1_DRDY_G 0b10 -#define LSM6DS3_GYRO_DRDY_GDA 0b10 -#define LSM6DS3_GYRO_DRDY_PULSE_MODE (1 << 7) -#define LSM6DS3_GYRO_MIN_ST_LIMIT_mdps 150000.0f -#define LSM6DS3_GYRO_MAX_ST_LIMIT_mdps 700000.0f - - -class LSM6DS3_Gyro : public I2CSensor { - uint8_t get_device_address() {return LSM6DS3_GYRO_I2C_ADDR;} - cereal::SensorEventData::SensorSource source = cereal::SensorEventData::SensorSource::LSM6DS3; - - // self test functions - int self_test(int test_type); - void wait_for_data_ready(); - void read_and_avg_data(float* val_st_off); -public: - LSM6DS3_Gyro(I2CBus *bus, int gpio_nr = 0, bool shared_gpio = false); - int init(); - bool get_event(MessageBuilder &msg, uint64_t ts = 0); - int shutdown(); -}; diff --git a/system/sensord/sensors/lsm6ds3_gyro.py b/system/sensord/sensors/lsm6ds3_gyro.py new file mode 100644 index 00000000..60de2bbe --- /dev/null +++ b/system/sensord/sensors/lsm6ds3_gyro.py @@ -0,0 +1,145 @@ +import os +import math +import time + +from cereal import log +from openpilot.system.sensord.sensors.i2c_sensor import Sensor + +class LSM6DS3_Gyro(Sensor): + LSM6DS3_GYRO_I2C_REG_DRDY_CFG = 0x0B + LSM6DS3_GYRO_I2C_REG_INT1_CTRL = 0x0D + LSM6DS3_GYRO_I2C_REG_CTRL2_G = 0x11 + LSM6DS3_GYRO_I2C_REG_CTRL5_C = 0x14 + LSM6DS3_GYRO_I2C_REG_STAT_REG = 0x1E + LSM6DS3_GYRO_I2C_REG_OUTX_L_G = 0x22 + + LSM6DS3_GYRO_ODR_104HZ = (0b0100 << 4) + LSM6DS3_GYRO_INT1_DRDY_G = 0b10 + LSM6DS3_GYRO_DRDY_GDA = 0b10 + LSM6DS3_GYRO_DRDY_PULSE_MODE = (1 << 7) + + LSM6DS3_GYRO_ODR_208HZ = (0b0101 << 4) + LSM6DS3_GYRO_FS_2000dps = (0b11 << 2) + LSM6DS3_GYRO_POSITIVE_TEST = (0b01 << 2) + LSM6DS3_GYRO_NEGATIVE_TEST = (0b11 << 2) + LSM6DS3_GYRO_MIN_ST_LIMIT_mdps = 150000.0 + LSM6DS3_GYRO_MAX_ST_LIMIT_mdps = 700000.0 + + @property + def device_address(self) -> int: + return 0x6A + + def reset(self): + self.write(0x12, 0x1) + time.sleep(0.1) + + def init(self): + chip_id = self.verify_chip_id(0x0F, [0x69, 0x6A]) + if chip_id == 0x6A: + self.source = log.SensorEventData.SensorSource.lsm6ds3trc + else: + self.source = log.SensorEventData.SensorSource.lsm6ds3 + + # self-test + if "LSM_SELF_TEST" in os.environ: + self.self_test(self.LSM6DS3_GYRO_POSITIVE_TEST) + self.self_test(self.LSM6DS3_GYRO_NEGATIVE_TEST) + + # actual init + self.writes(( + # TODO: set scale. Default is +- 250 deg/s + (self.LSM6DS3_GYRO_I2C_REG_CTRL2_G, self.LSM6DS3_GYRO_ODR_104HZ), + # Configure data ready signal to pulse mode + (self.LSM6DS3_GYRO_I2C_REG_DRDY_CFG, self.LSM6DS3_GYRO_DRDY_PULSE_MODE), + )) + value = self.read(self.LSM6DS3_GYRO_I2C_REG_INT1_CTRL, 1)[0] + value |= self.LSM6DS3_GYRO_INT1_DRDY_G + self.write(self.LSM6DS3_GYRO_I2C_REG_INT1_CTRL, value) + + def get_event(self, ts: int | None = None) -> log.SensorEventData: + assert ts is not None # must come from the IRQ event + + # Check if gyroscope data is ready, since it's shared with accelerometer + status_reg = self.read(self.LSM6DS3_GYRO_I2C_REG_STAT_REG, 1)[0] + if not (status_reg & self.LSM6DS3_GYRO_DRDY_GDA): + raise self.DataNotReady + + b = self.read(self.LSM6DS3_GYRO_I2C_REG_OUTX_L_G, 6) + x = self.parse_16bit(b[0], b[1]) + y = self.parse_16bit(b[2], b[3]) + z = self.parse_16bit(b[4], b[5]) + scale = (8.75 / 1000.0) * (math.pi / 180.0) + xyz = [y * scale, -x * scale, z * scale] + + event = log.SensorEventData.new_message() + event.timestamp = ts + event.version = 2 + event.sensor = 5 # SENSOR_GYRO_UNCALIBRATED + event.type = 16 # SENSOR_TYPE_GYROSCOPE_UNCALIBRATED + event.source = self.source + g = event.init('gyroUncalibrated') + g.v = xyz + g.status = 1 + return event + + def shutdown(self) -> None: + # Disable data ready interrupt on INT1 + value = self.read(self.LSM6DS3_GYRO_I2C_REG_INT1_CTRL, 1)[0] + value &= ~self.LSM6DS3_GYRO_INT1_DRDY_G + self.write(self.LSM6DS3_GYRO_I2C_REG_INT1_CTRL, value) + + # Power down by clearing ODR bits + value = self.read(self.LSM6DS3_GYRO_I2C_REG_CTRL2_G, 1)[0] + value &= 0x0F + self.write(self.LSM6DS3_GYRO_I2C_REG_CTRL2_G, value) + + # *** self-test stuff *** + def _wait_for_data_ready(self): + while True: + drdy = self.read(self.LSM6DS3_GYRO_I2C_REG_STAT_REG, 1)[0] + if drdy & self.LSM6DS3_GYRO_DRDY_GDA: + break + + def _read_and_avg_data(self) -> list[float]: + out_buf = [0.0, 0.0, 0.0] + for _ in range(5): + self._wait_for_data_ready() + b = self.read(self.LSM6DS3_GYRO_I2C_REG_OUTX_L_G, 6) + for j in range(3): + val = self.parse_16bit(b[j*2], b[j*2+1]) * 70.0 # mdps/LSB for 2000 dps + out_buf[j] += val + return [x / 5.0 for x in out_buf] + + def self_test(self, test_type: int): + # Set ODR to 208Hz, FS to 2000dps + self.write(self.LSM6DS3_GYRO_I2C_REG_CTRL2_G, self.LSM6DS3_GYRO_ODR_208HZ | self.LSM6DS3_GYRO_FS_2000dps) + + # Wait for stable output + time.sleep(0.15) + self._wait_for_data_ready() + val_st_off = self._read_and_avg_data() + + # Enable self-test + self.write(self.LSM6DS3_GYRO_I2C_REG_CTRL5_C, test_type) + + # Wait for stable output + time.sleep(0.05) + self._wait_for_data_ready() + val_st_on = self._read_and_avg_data() + + # Disable sensor and self-test + self.write(self.LSM6DS3_GYRO_I2C_REG_CTRL2_G, 0) + self.write(self.LSM6DS3_GYRO_I2C_REG_CTRL5_C, 0) + + # Calculate differences and check limits + test_val = [abs(on - off) for on, off in zip(val_st_on, val_st_off, strict=False)] + for val in test_val: + if val < self.LSM6DS3_GYRO_MIN_ST_LIMIT_mdps or val > self.LSM6DS3_GYRO_MAX_ST_LIMIT_mdps: + raise Exception(f"Gyroscope self-test failed for test type {test_type}") + +if __name__ == "__main__": + s = LSM6DS3_Gyro(1) + s.init() + time.sleep(0.1) + print(s.get_event(0)) + s.shutdown() diff --git a/system/sensord/sensors/lsm6ds3_temp.cc b/system/sensord/sensors/lsm6ds3_temp.cc deleted file mode 100644 index f4816141..00000000 --- a/system/sensord/sensors/lsm6ds3_temp.cc +++ /dev/null @@ -1,37 +0,0 @@ -#include "system/sensord/sensors/lsm6ds3_temp.h" - -#include - -#include "common/swaglog.h" -#include "common/timing.h" - -LSM6DS3_Temp::LSM6DS3_Temp(I2CBus *bus) : I2CSensor(bus) {} - -int LSM6DS3_Temp::init() { - int ret = verify_chip_id(LSM6DS3_TEMP_I2C_REG_ID, {LSM6DS3_TEMP_CHIP_ID, LSM6DS3TRC_TEMP_CHIP_ID}); - if (ret == -1) return -1; - - if (ret == LSM6DS3TRC_TEMP_CHIP_ID) { - source = cereal::SensorEventData::SensorSource::LSM6DS3TRC; - } - return 0; -} - -bool LSM6DS3_Temp::get_event(MessageBuilder &msg, uint64_t ts) { - uint64_t start_time = nanos_since_boot(); - uint8_t buffer[2]; - int len = read_register(LSM6DS3_TEMP_I2C_REG_OUT_TEMP_L, buffer, sizeof(buffer)); - assert(len == sizeof(buffer)); - - float scale = (source == cereal::SensorEventData::SensorSource::LSM6DS3TRC) ? 256.0f : 16.0f; - float temp = 25.0f + read_16_bit(buffer[0], buffer[1]) / scale; - - auto event = msg.initEvent().initTemperatureSensor(); - event.setSource(source); - event.setVersion(1); - event.setType(SENSOR_TYPE_AMBIENT_TEMPERATURE); - event.setTimestamp(start_time); - event.setTemperature(temp); - - return true; -} diff --git a/system/sensord/sensors/lsm6ds3_temp.h b/system/sensord/sensors/lsm6ds3_temp.h deleted file mode 100644 index 1b5b6218..00000000 --- a/system/sensord/sensors/lsm6ds3_temp.h +++ /dev/null @@ -1,26 +0,0 @@ -#pragma once - -#include "system/sensord/sensors/i2c_sensor.h" - -// Address of the chip on the bus -#define LSM6DS3_TEMP_I2C_ADDR 0x6A - -// Registers of the chip -#define LSM6DS3_TEMP_I2C_REG_ID 0x0F -#define LSM6DS3_TEMP_I2C_REG_OUT_TEMP_L 0x20 - -// Constants -#define LSM6DS3_TEMP_CHIP_ID 0x69 -#define LSM6DS3TRC_TEMP_CHIP_ID 0x6A - - -class LSM6DS3_Temp : public I2CSensor { - uint8_t get_device_address() {return LSM6DS3_TEMP_I2C_ADDR;} - cereal::SensorEventData::SensorSource source = cereal::SensorEventData::SensorSource::LSM6DS3; - -public: - LSM6DS3_Temp(I2CBus *bus); - int init(); - bool get_event(MessageBuilder &msg, uint64_t ts = 0); - int shutdown() { return 0; } -}; diff --git a/system/sensord/sensors/lsm6ds3_temp.py b/system/sensord/sensors/lsm6ds3_temp.py new file mode 100644 index 00000000..b9bb9fe3 --- /dev/null +++ b/system/sensord/sensors/lsm6ds3_temp.py @@ -0,0 +1,33 @@ +import time + +from cereal import log +from openpilot.system.sensord.sensors.i2c_sensor import Sensor + +# https://content.arduino.cc/assets/st_imu_lsm6ds3_datasheet.pdf +class LSM6DS3_Temp(Sensor): + @property + def device_address(self) -> int: + return 0x6A + + def _read_temperature(self) -> float: + scale = 16.0 if self.source == log.SensorEventData.SensorSource.lsm6ds3 else 256.0 + data = self.read(0x20, 2) + return 25 + (self.parse_16bit(data[0], data[1]) / scale) + + def init(self): + chip_id = self.verify_chip_id(0x0F, [0x69, 0x6A]) + if chip_id == 0x6A: + self.source = log.SensorEventData.SensorSource.lsm6ds3trc + else: + self.source = log.SensorEventData.SensorSource.lsm6ds3 + + def get_event(self, ts: int | None = None) -> log.SensorEventData: + event = log.SensorEventData.new_message() + event.version = 1 + event.timestamp = int(time.monotonic() * 1e9) + event.source = self.source + event.temperature = self._read_temperature() + return event + + def shutdown(self) -> None: + pass diff --git a/system/sensord/sensors/mmc5603nj_magn.cc b/system/sensord/sensors/mmc5603nj_magn.cc deleted file mode 100644 index 0e8ba967..00000000 --- a/system/sensord/sensors/mmc5603nj_magn.cc +++ /dev/null @@ -1,108 +0,0 @@ -#include "system/sensord/sensors/mmc5603nj_magn.h" - -#include -#include -#include - -#include "common/swaglog.h" -#include "common/timing.h" -#include "common/util.h" - -MMC5603NJ_Magn::MMC5603NJ_Magn(I2CBus *bus) : I2CSensor(bus) {} - -int MMC5603NJ_Magn::init() { - int ret = verify_chip_id(MMC5603NJ_I2C_REG_ID, {MMC5603NJ_CHIP_ID}); - if (ret == -1) return -1; - - // Set ODR to 0 - ret = set_register(MMC5603NJ_I2C_REG_ODR, 0); - if (ret < 0) { - goto fail; - } - - // Set BW to 0b01 for 1-150 Hz operation - ret = set_register(MMC5603NJ_I2C_REG_INTERNAL_1, 0b01); - if (ret < 0) { - goto fail; - } - -fail: - return ret; -} - -int MMC5603NJ_Magn::shutdown() { - int ret = 0; - - // disable auto reset of measurements - uint8_t value = 0; - ret = read_register(MMC5603NJ_I2C_REG_INTERNAL_0, &value, 1); - if (ret < 0) { - goto fail; - } - - value &= ~(MMC5603NJ_CMM_FREQ_EN | MMC5603NJ_AUTO_SR_EN); - ret = set_register(MMC5603NJ_I2C_REG_INTERNAL_0, value); - if (ret < 0) { - goto fail; - } - - // set ODR to 0 to leave continuous mode - ret = set_register(MMC5603NJ_I2C_REG_ODR, 0); - if (ret < 0) { - goto fail; - } - return ret; - -fail: - LOGE("Could not disable mmc5603nj auto set reset"); - return ret; -} - -void MMC5603NJ_Magn::start_measurement() { - set_register(MMC5603NJ_I2C_REG_INTERNAL_0, 0b01); - util::sleep_for(5); -} - -std::vector MMC5603NJ_Magn::read_measurement() { - int len; - uint8_t buffer[9]; - len = read_register(MMC5603NJ_I2C_REG_XOUT0, buffer, sizeof(buffer)); - assert(len == sizeof(buffer)); - float scale = 1.0 / 16384.0; - float x = (read_20_bit(buffer[6], buffer[1], buffer[0]) * scale) - 32.0; - float y = (read_20_bit(buffer[7], buffer[3], buffer[2]) * scale) - 32.0; - float z = (read_20_bit(buffer[8], buffer[5], buffer[4]) * scale) - 32.0; - std::vector xyz = {x, y, z}; - return xyz; -} - -bool MMC5603NJ_Magn::get_event(MessageBuilder &msg, uint64_t ts) { - uint64_t start_time = nanos_since_boot(); - // SET - RESET cycle - set_register(MMC5603NJ_I2C_REG_INTERNAL_0, MMC5603NJ_SET); - util::sleep_for(5); - MMC5603NJ_Magn::start_measurement(); - std::vector xyz = MMC5603NJ_Magn::read_measurement(); - - set_register(MMC5603NJ_I2C_REG_INTERNAL_0, MMC5603NJ_RESET); - util::sleep_for(5); - MMC5603NJ_Magn::start_measurement(); - std::vector reset_xyz = MMC5603NJ_Magn::read_measurement(); - - auto event = msg.initEvent().initMagnetometer(); - event.setSource(cereal::SensorEventData::SensorSource::MMC5603NJ); - event.setVersion(1); - event.setSensor(SENSOR_MAGNETOMETER_UNCALIBRATED); - event.setType(SENSOR_TYPE_MAGNETIC_FIELD_UNCALIBRATED); - event.setTimestamp(start_time); - - float vals[] = {xyz[0], xyz[1], xyz[2], reset_xyz[0], reset_xyz[1], reset_xyz[2]}; - bool valid = true; - if (std::any_of(std::begin(vals), std::end(vals), [](float val) { return val == -32.0; })) { - valid = false; - } - auto svec = event.initMagneticUncalibrated(); - svec.setV(vals); - svec.setStatus(valid); - return true; -} diff --git a/system/sensord/sensors/mmc5603nj_magn.h b/system/sensord/sensors/mmc5603nj_magn.h deleted file mode 100644 index 9c0fbd25..00000000 --- a/system/sensord/sensors/mmc5603nj_magn.h +++ /dev/null @@ -1,37 +0,0 @@ -#pragma once - -#include - -#include "system/sensord/sensors/i2c_sensor.h" - -// Address of the chip on the bus -#define MMC5603NJ_I2C_ADDR 0x30 - -// Registers of the chip -#define MMC5603NJ_I2C_REG_XOUT0 0x00 -#define MMC5603NJ_I2C_REG_ODR 0x1A -#define MMC5603NJ_I2C_REG_INTERNAL_0 0x1B -#define MMC5603NJ_I2C_REG_INTERNAL_1 0x1C -#define MMC5603NJ_I2C_REG_INTERNAL_2 0x1D -#define MMC5603NJ_I2C_REG_ID 0x39 - -// Constants -#define MMC5603NJ_CHIP_ID 0x10 -#define MMC5603NJ_CMM_FREQ_EN (1 << 7) -#define MMC5603NJ_AUTO_SR_EN (1 << 5) -#define MMC5603NJ_CMM_EN (1 << 4) -#define MMC5603NJ_EN_PRD_SET (1 << 3) -#define MMC5603NJ_SET (1 << 3) -#define MMC5603NJ_RESET (1 << 4) - -class MMC5603NJ_Magn : public I2CSensor { -private: - uint8_t get_device_address() {return MMC5603NJ_I2C_ADDR;} - void start_measurement(); - std::vector read_measurement(); -public: - MMC5603NJ_Magn(I2CBus *bus); - int init(); - bool get_event(MessageBuilder &msg, uint64_t ts = 0); - int shutdown(); -}; diff --git a/system/sensord/sensors/mmc5603nj_magn.py b/system/sensord/sensors/mmc5603nj_magn.py new file mode 100644 index 00000000..255e99eb --- /dev/null +++ b/system/sensord/sensors/mmc5603nj_magn.py @@ -0,0 +1,76 @@ +import time + +from cereal import log +from openpilot.system.sensord.sensors.i2c_sensor import Sensor + +# https://www.mouser.com/datasheet/2/821/Memsic_09102019_Datasheet_Rev.B-1635324.pdf + +# Register addresses +REG_ODR = 0x1A +REG_INTERNAL_0 = 0x1B +REG_INTERNAL_1 = 0x1C + +# Control register settings +CMM_FREQ_EN = (1 << 7) +AUTO_SR_EN = (1 << 5) +SET = (1 << 3) +RESET = (1 << 4) + +class MMC5603NJ_Magn(Sensor): + @property + def device_address(self) -> int: + return 0x30 + + def init(self): + self.verify_chip_id(0x39, [0x10, ]) + self.writes(( + (REG_ODR, 0), + + # Set BW to 0b01 for 1-150 Hz operation + (REG_INTERNAL_1, 0b01), + )) + + def _read_data(self, cycle) -> list[float]: + # start measurement + self.write(REG_INTERNAL_0, cycle) + self.wait() + + # read out XYZ + scale = 1.0 / 16384.0 + b = self.read(0x00, 9) + return [ + (self.parse_20bit(b[6], b[1], b[0]) * scale) - 32.0, + (self.parse_20bit(b[7], b[3], b[2]) * scale) - 32.0, + (self.parse_20bit(b[8], b[5], b[4]) * scale) - 32.0, + ] + + def get_event(self, ts: int | None = None) -> log.SensorEventData: + ts = time.monotonic_ns() + + # SET - RESET cycle + xyz = self._read_data(SET) + reset_xyz = self._read_data(RESET) + vals = [*xyz, *reset_xyz] + + event = log.SensorEventData.new_message() + event.timestamp = ts + event.version = 1 + event.sensor = 3 # SENSOR_MAGNETOMETER_UNCALIBRATED + event.type = 14 # SENSOR_TYPE_MAGNETIC_FIELD_UNCALIBRATED + event.source = log.SensorEventData.SensorSource.mmc5603nj + + m = event.init('magneticUncalibrated') + m.v = vals + m.status = int(all(int(v) != -32 for v in vals)) + + return event + + def shutdown(self) -> None: + v = self.read(REG_INTERNAL_0, 1)[0] + self.writes(( + # disable auto-reset of measurements + (REG_INTERNAL_0, (v & (~(CMM_FREQ_EN | AUTO_SR_EN)))), + + # disable continuous mode + (REG_ODR, 0), + )) diff --git a/system/sensord/sensors/sensor.h b/system/sensord/sensors/sensor.h deleted file mode 100644 index ccf998d1..00000000 --- a/system/sensord/sensors/sensor.h +++ /dev/null @@ -1,24 +0,0 @@ -#pragma once - -#include "cereal/messaging/messaging.h" - -class Sensor { -public: - int gpio_fd = -1; - bool enabled = false; - uint64_t start_ts = 0; - uint64_t init_delay = 500e6; // default dealy 500ms - - virtual ~Sensor() {} - virtual int init() = 0; - virtual bool get_event(MessageBuilder &msg, uint64_t ts = 0) = 0; - virtual bool has_interrupt_enabled() = 0; - virtual int shutdown() = 0; - - virtual bool is_data_valid(uint64_t current_ts) { - if (start_ts == 0) { - start_ts = current_ts; - } - return (current_ts - start_ts) > init_delay; - } -}; diff --git a/system/sensord/sensors_qcom2.cc b/system/sensord/sensors_qcom2.cc deleted file mode 100644 index f9f51539..00000000 --- a/system/sensord/sensors_qcom2.cc +++ /dev/null @@ -1,179 +0,0 @@ -#include - -#include -#include -#include -#include -#include -#include - -#include "cereal/services.h" -#include "cereal/messaging/messaging.h" -#include "common/i2c.h" -#include "common/ratekeeper.h" -#include "common/swaglog.h" -#include "common/timing.h" -#include "common/util.h" -#include "system/sensord/sensors/bmx055_accel.h" -#include "system/sensord/sensors/bmx055_gyro.h" -#include "system/sensord/sensors/bmx055_magn.h" -#include "system/sensord/sensors/bmx055_temp.h" -#include "system/sensord/sensors/constants.h" -#include "system/sensord/sensors/lsm6ds3_accel.h" -#include "system/sensord/sensors/lsm6ds3_gyro.h" -#include "system/sensord/sensors/lsm6ds3_temp.h" -#include "system/sensord/sensors/mmc5603nj_magn.h" - -#define I2C_BUS_IMU 1 - -ExitHandler do_exit; - -void interrupt_loop(std::vector> sensors) { - PubMaster pm({"gyroscope", "accelerometer"}); - - int fd = -1; - for (auto &[sensor, msg_name] : sensors) { - if (sensor->has_interrupt_enabled()) { - fd = sensor->gpio_fd; - break; - } - } - - uint64_t offset = nanos_since_epoch() - nanos_since_boot(); - struct pollfd fd_list[1] = {0}; - fd_list[0].fd = fd; - fd_list[0].events = POLLIN | POLLPRI; - - while (!do_exit) { - int err = poll(fd_list, 1, 100); - if (err == -1) { - if (errno == EINTR) { - continue; - } - return; - } else if (err == 0) { - LOGE("poll timed out"); - continue; - } - - if ((fd_list[0].revents & (POLLIN | POLLPRI)) == 0) { - LOGE("no poll events set"); - continue; - } - - // Read all events - struct gpioevent_data evdata[16]; - err = HANDLE_EINTR(read(fd, evdata, sizeof(evdata))); - if (err < 0 || err % sizeof(*evdata) != 0) { - LOGE("error reading event data %d", err); - continue; - } - - uint64_t cur_offset = nanos_since_epoch() - nanos_since_boot(); - uint64_t diff = cur_offset > offset ? cur_offset - offset : offset - cur_offset; - if (diff > 10*1e6) { // 10ms - LOGW("time jumped: %lu %lu", cur_offset, offset); - offset = cur_offset; - - // we don't have a valid timestamp since the - // time jumped, so throw out this measurement. - continue; - } - - int num_events = err / sizeof(*evdata); - uint64_t ts = evdata[num_events - 1].timestamp - cur_offset; - - for (auto &[sensor, msg_name] : sensors) { - if (!sensor->has_interrupt_enabled()) { - continue; - } - - MessageBuilder msg; - if (!sensor->get_event(msg, ts)) { - continue; - } - - if (!sensor->is_data_valid(ts)) { - continue; - } - - pm.send(msg_name.c_str(), msg); - } - } -} - -void polling_loop(Sensor *sensor, std::string msg_name) { - PubMaster pm({msg_name.c_str()}); - RateKeeper rk(msg_name, services.at(msg_name).frequency); - while (!do_exit) { - MessageBuilder msg; - if (sensor->get_event(msg) && sensor->is_data_valid(nanos_since_boot())) { - pm.send(msg_name.c_str(), msg); - } - rk.keepTime(); - } -} - -int sensor_loop(I2CBus *i2c_bus_imu) { - // Sensor init - std::vector> sensors_init = { - {new BMX055_Accel(i2c_bus_imu), "accelerometer2"}, - {new BMX055_Gyro(i2c_bus_imu), "gyroscope2"}, - {new BMX055_Magn(i2c_bus_imu), "magnetometer"}, - {new BMX055_Temp(i2c_bus_imu), "temperatureSensor2"}, - - {new LSM6DS3_Accel(i2c_bus_imu, GPIO_LSM_INT), "accelerometer"}, - {new LSM6DS3_Gyro(i2c_bus_imu, GPIO_LSM_INT, true), "gyroscope"}, - {new LSM6DS3_Temp(i2c_bus_imu), "temperatureSensor"}, - - {new MMC5603NJ_Magn(i2c_bus_imu), "magnetometer"}, - }; - - // Initialize sensors - std::vector threads; - for (auto &[sensor, msg_name] : sensors_init) { - int err = sensor->init(); - if (err < 0) { - continue; - } - - if (!sensor->has_interrupt_enabled()) { - threads.emplace_back(polling_loop, sensor, msg_name); - } - } - - // increase interrupt quality by pinning interrupt and process to core 1 - setpriority(PRIO_PROCESS, 0, -18); - util::set_core_affinity({1}); - - // TODO: get the IRQ number from gpiochip - std::string irq_path = "/proc/irq/336/smp_affinity_list"; - if (!util::file_exists(irq_path)) { - irq_path = "/proc/irq/335/smp_affinity_list"; - } - std::system(util::string_format("sudo su -c 'echo 1 > %s'", irq_path.c_str()).c_str()); - - // thread for reading events via interrupts - threads.emplace_back(&interrupt_loop, std::ref(sensors_init)); - - // wait for all threads to finish - for (auto &t : threads) { - t.join(); - } - - for (auto &[sensor, msg_name] : sensors_init) { - sensor->shutdown(); - delete sensor; - } - return 0; -} - -int main(int argc, char *argv[]) { - try { - auto i2c_bus_imu = std::make_unique(I2C_BUS_IMU); - return sensor_loop(i2c_bus_imu.get()); - } catch (std::exception &e) { - LOGE("I2CBus init failed"); - return -1; - } -} diff --git a/system/ubloxd/.gitignore b/system/ubloxd/.gitignore deleted file mode 100644 index 05263ff6..00000000 --- a/system/ubloxd/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -ubloxd -tests/test_glonass_runner diff --git a/system/ubloxd/SConscript b/system/ubloxd/SConscript index ce09e235..9eb50760 100644 --- a/system/ubloxd/SConscript +++ b/system/ubloxd/SConscript @@ -1,20 +1,11 @@ -Import('env', 'common', 'messaging') - -loc_libs = [messaging, common, 'kaitai', 'pthread'] +Import('env') if GetOption('kaitai'): - generated = Dir('generated').srcnode().abspath - cmd = f"kaitai-struct-compiler --target cpp_stl --outdir {generated} $SOURCES" - env.Command(['generated/ubx.cpp', 'generated/ubx.h'], 'ubx.ksy', cmd) - env.Command(['generated/gps.cpp', 'generated/gps.h'], 'gps.ksy', cmd) - glonass = env.Command(['generated/glonass.cpp', 'generated/glonass.h'], 'glonass.ksy', cmd) - + current_dir = Dir('./generated/').srcnode().abspath + python_cmd = f"kaitai-struct-compiler --target python --outdir {current_dir} $SOURCES" + env.Command(File('./generated/ubx.py'), 'ubx.ksy', python_cmd) + env.Command(File('./generated/gps.py'), 'gps.ksy', python_cmd) + env.Command(File('./generated/glonass.py'), 'glonass.ksy', python_cmd) # kaitai issue: https://github.com/kaitai-io/kaitai_struct/issues/910 - patch = env.Command(None, 'glonass_fix.patch', 'git apply $SOURCES') - env.Depends(patch, glonass) - -glonass_obj = env.Object('generated/glonass.cpp') -env.Program("ubloxd", ["ubloxd.cc", "ublox_msg.cc", "generated/ubx.cpp", "generated/gps.cpp", glonass_obj], LIBS=loc_libs) - -if GetOption('extras'): - env.Program("tests/test_glonass_runner", ['tests/test_glonass_runner.cc', 'tests/test_glonass_kaitai.cc', glonass_obj], LIBS=[loc_libs]) \ No newline at end of file + py_glonass_fix = env.Command(None, File('./generated/glonass.py'), "sed -i 's/self._io.align_to_byte()/# self._io.align_to_byte()/' $SOURCES") + env.Depends(py_glonass_fix, File('./generated/glonass.py')) diff --git a/system/ubloxd/generated/glonass.cpp b/system/ubloxd/generated/glonass.cpp deleted file mode 100644 index cd0f96ab..00000000 --- a/system/ubloxd/generated/glonass.cpp +++ /dev/null @@ -1,353 +0,0 @@ -// This is a generated file! Please edit source .ksy file and use kaitai-struct-compiler to rebuild - -#include "glonass.h" - -glonass_t::glonass_t(kaitai::kstream* p__io, kaitai::kstruct* p__parent, glonass_t* p__root) : kaitai::kstruct(p__io) { - m__parent = p__parent; - m__root = this; - - try { - _read(); - } catch(...) { - _clean_up(); - throw; - } -} - -void glonass_t::_read() { - m_idle_chip = m__io->read_bits_int_be(1); - m_string_number = m__io->read_bits_int_be(4); - //m__io->align_to_byte(); - switch (string_number()) { - case 4: { - m_data = new string_4_t(m__io, this, m__root); - break; - } - case 1: { - m_data = new string_1_t(m__io, this, m__root); - break; - } - case 3: { - m_data = new string_3_t(m__io, this, m__root); - break; - } - case 5: { - m_data = new string_5_t(m__io, this, m__root); - break; - } - case 2: { - m_data = new string_2_t(m__io, this, m__root); - break; - } - default: { - m_data = new string_non_immediate_t(m__io, this, m__root); - break; - } - } - m_hamming_code = m__io->read_bits_int_be(8); - m_pad_1 = m__io->read_bits_int_be(11); - m_superframe_number = m__io->read_bits_int_be(16); - m_pad_2 = m__io->read_bits_int_be(8); - m_frame_number = m__io->read_bits_int_be(8); -} - -glonass_t::~glonass_t() { - _clean_up(); -} - -void glonass_t::_clean_up() { - if (m_data) { - delete m_data; m_data = 0; - } -} - -glonass_t::string_4_t::string_4_t(kaitai::kstream* p__io, glonass_t* p__parent, glonass_t* p__root) : kaitai::kstruct(p__io) { - m__parent = p__parent; - m__root = p__root; - f_tau_n = false; - f_delta_tau_n = false; - - try { - _read(); - } catch(...) { - _clean_up(); - throw; - } -} - -void glonass_t::string_4_t::_read() { - m_tau_n_sign = m__io->read_bits_int_be(1); - m_tau_n_value = m__io->read_bits_int_be(21); - m_delta_tau_n_sign = m__io->read_bits_int_be(1); - m_delta_tau_n_value = m__io->read_bits_int_be(4); - m_e_n = m__io->read_bits_int_be(5); - m_not_used_1 = m__io->read_bits_int_be(14); - m_p4 = m__io->read_bits_int_be(1); - m_f_t = m__io->read_bits_int_be(4); - m_not_used_2 = m__io->read_bits_int_be(3); - m_n_t = m__io->read_bits_int_be(11); - m_n = m__io->read_bits_int_be(5); - m_m = m__io->read_bits_int_be(2); -} - -glonass_t::string_4_t::~string_4_t() { - _clean_up(); -} - -void glonass_t::string_4_t::_clean_up() { -} - -int32_t glonass_t::string_4_t::tau_n() { - if (f_tau_n) - return m_tau_n; - m_tau_n = ((tau_n_sign()) ? ((tau_n_value() * -1)) : (tau_n_value())); - f_tau_n = true; - return m_tau_n; -} - -int32_t glonass_t::string_4_t::delta_tau_n() { - if (f_delta_tau_n) - return m_delta_tau_n; - m_delta_tau_n = ((delta_tau_n_sign()) ? ((delta_tau_n_value() * -1)) : (delta_tau_n_value())); - f_delta_tau_n = true; - return m_delta_tau_n; -} - -glonass_t::string_non_immediate_t::string_non_immediate_t(kaitai::kstream* p__io, glonass_t* p__parent, glonass_t* p__root) : kaitai::kstruct(p__io) { - m__parent = p__parent; - m__root = p__root; - - try { - _read(); - } catch(...) { - _clean_up(); - throw; - } -} - -void glonass_t::string_non_immediate_t::_read() { - m_data_1 = m__io->read_bits_int_be(64); - m_data_2 = m__io->read_bits_int_be(8); -} - -glonass_t::string_non_immediate_t::~string_non_immediate_t() { - _clean_up(); -} - -void glonass_t::string_non_immediate_t::_clean_up() { -} - -glonass_t::string_5_t::string_5_t(kaitai::kstream* p__io, glonass_t* p__parent, glonass_t* p__root) : kaitai::kstruct(p__io) { - m__parent = p__parent; - m__root = p__root; - - try { - _read(); - } catch(...) { - _clean_up(); - throw; - } -} - -void glonass_t::string_5_t::_read() { - m_n_a = m__io->read_bits_int_be(11); - m_tau_c = m__io->read_bits_int_be(32); - m_not_used = m__io->read_bits_int_be(1); - m_n_4 = m__io->read_bits_int_be(5); - m_tau_gps = m__io->read_bits_int_be(22); - m_l_n = m__io->read_bits_int_be(1); -} - -glonass_t::string_5_t::~string_5_t() { - _clean_up(); -} - -void glonass_t::string_5_t::_clean_up() { -} - -glonass_t::string_1_t::string_1_t(kaitai::kstream* p__io, glonass_t* p__parent, glonass_t* p__root) : kaitai::kstruct(p__io) { - m__parent = p__parent; - m__root = p__root; - f_x_vel = false; - f_x_accel = false; - f_x = false; - - try { - _read(); - } catch(...) { - _clean_up(); - throw; - } -} - -void glonass_t::string_1_t::_read() { - m_not_used = m__io->read_bits_int_be(2); - m_p1 = m__io->read_bits_int_be(2); - m_t_k = m__io->read_bits_int_be(12); - m_x_vel_sign = m__io->read_bits_int_be(1); - m_x_vel_value = m__io->read_bits_int_be(23); - m_x_accel_sign = m__io->read_bits_int_be(1); - m_x_accel_value = m__io->read_bits_int_be(4); - m_x_sign = m__io->read_bits_int_be(1); - m_x_value = m__io->read_bits_int_be(26); -} - -glonass_t::string_1_t::~string_1_t() { - _clean_up(); -} - -void glonass_t::string_1_t::_clean_up() { -} - -int32_t glonass_t::string_1_t::x_vel() { - if (f_x_vel) - return m_x_vel; - m_x_vel = ((x_vel_sign()) ? ((x_vel_value() * -1)) : (x_vel_value())); - f_x_vel = true; - return m_x_vel; -} - -int32_t glonass_t::string_1_t::x_accel() { - if (f_x_accel) - return m_x_accel; - m_x_accel = ((x_accel_sign()) ? ((x_accel_value() * -1)) : (x_accel_value())); - f_x_accel = true; - return m_x_accel; -} - -int32_t glonass_t::string_1_t::x() { - if (f_x) - return m_x; - m_x = ((x_sign()) ? ((x_value() * -1)) : (x_value())); - f_x = true; - return m_x; -} - -glonass_t::string_2_t::string_2_t(kaitai::kstream* p__io, glonass_t* p__parent, glonass_t* p__root) : kaitai::kstruct(p__io) { - m__parent = p__parent; - m__root = p__root; - f_y_vel = false; - f_y_accel = false; - f_y = false; - - try { - _read(); - } catch(...) { - _clean_up(); - throw; - } -} - -void glonass_t::string_2_t::_read() { - m_b_n = m__io->read_bits_int_be(3); - m_p2 = m__io->read_bits_int_be(1); - m_t_b = m__io->read_bits_int_be(7); - m_not_used = m__io->read_bits_int_be(5); - m_y_vel_sign = m__io->read_bits_int_be(1); - m_y_vel_value = m__io->read_bits_int_be(23); - m_y_accel_sign = m__io->read_bits_int_be(1); - m_y_accel_value = m__io->read_bits_int_be(4); - m_y_sign = m__io->read_bits_int_be(1); - m_y_value = m__io->read_bits_int_be(26); -} - -glonass_t::string_2_t::~string_2_t() { - _clean_up(); -} - -void glonass_t::string_2_t::_clean_up() { -} - -int32_t glonass_t::string_2_t::y_vel() { - if (f_y_vel) - return m_y_vel; - m_y_vel = ((y_vel_sign()) ? ((y_vel_value() * -1)) : (y_vel_value())); - f_y_vel = true; - return m_y_vel; -} - -int32_t glonass_t::string_2_t::y_accel() { - if (f_y_accel) - return m_y_accel; - m_y_accel = ((y_accel_sign()) ? ((y_accel_value() * -1)) : (y_accel_value())); - f_y_accel = true; - return m_y_accel; -} - -int32_t glonass_t::string_2_t::y() { - if (f_y) - return m_y; - m_y = ((y_sign()) ? ((y_value() * -1)) : (y_value())); - f_y = true; - return m_y; -} - -glonass_t::string_3_t::string_3_t(kaitai::kstream* p__io, glonass_t* p__parent, glonass_t* p__root) : kaitai::kstruct(p__io) { - m__parent = p__parent; - m__root = p__root; - f_gamma_n = false; - f_z_vel = false; - f_z_accel = false; - f_z = false; - - try { - _read(); - } catch(...) { - _clean_up(); - throw; - } -} - -void glonass_t::string_3_t::_read() { - m_p3 = m__io->read_bits_int_be(1); - m_gamma_n_sign = m__io->read_bits_int_be(1); - m_gamma_n_value = m__io->read_bits_int_be(10); - m_not_used = m__io->read_bits_int_be(1); - m_p = m__io->read_bits_int_be(2); - m_l_n = m__io->read_bits_int_be(1); - m_z_vel_sign = m__io->read_bits_int_be(1); - m_z_vel_value = m__io->read_bits_int_be(23); - m_z_accel_sign = m__io->read_bits_int_be(1); - m_z_accel_value = m__io->read_bits_int_be(4); - m_z_sign = m__io->read_bits_int_be(1); - m_z_value = m__io->read_bits_int_be(26); -} - -glonass_t::string_3_t::~string_3_t() { - _clean_up(); -} - -void glonass_t::string_3_t::_clean_up() { -} - -int32_t glonass_t::string_3_t::gamma_n() { - if (f_gamma_n) - return m_gamma_n; - m_gamma_n = ((gamma_n_sign()) ? ((gamma_n_value() * -1)) : (gamma_n_value())); - f_gamma_n = true; - return m_gamma_n; -} - -int32_t glonass_t::string_3_t::z_vel() { - if (f_z_vel) - return m_z_vel; - m_z_vel = ((z_vel_sign()) ? ((z_vel_value() * -1)) : (z_vel_value())); - f_z_vel = true; - return m_z_vel; -} - -int32_t glonass_t::string_3_t::z_accel() { - if (f_z_accel) - return m_z_accel; - m_z_accel = ((z_accel_sign()) ? ((z_accel_value() * -1)) : (z_accel_value())); - f_z_accel = true; - return m_z_accel; -} - -int32_t glonass_t::string_3_t::z() { - if (f_z) - return m_z; - m_z = ((z_sign()) ? ((z_value() * -1)) : (z_value())); - f_z = true; - return m_z; -} diff --git a/system/ubloxd/generated/glonass.h b/system/ubloxd/generated/glonass.h deleted file mode 100644 index 19867ba2..00000000 --- a/system/ubloxd/generated/glonass.h +++ /dev/null @@ -1,375 +0,0 @@ -#ifndef GLONASS_H_ -#define GLONASS_H_ - -// This is a generated file! Please edit source .ksy file and use kaitai-struct-compiler to rebuild - -#include "kaitai/kaitaistruct.h" -#include - -#if KAITAI_STRUCT_VERSION < 9000L -#error "Incompatible Kaitai Struct C++/STL API: version 0.9 or later is required" -#endif - -class glonass_t : public kaitai::kstruct { - -public: - class string_4_t; - class string_non_immediate_t; - class string_5_t; - class string_1_t; - class string_2_t; - class string_3_t; - - glonass_t(kaitai::kstream* p__io, kaitai::kstruct* p__parent = 0, glonass_t* p__root = 0); - -private: - void _read(); - void _clean_up(); - -public: - ~glonass_t(); - - class string_4_t : public kaitai::kstruct { - - public: - - string_4_t(kaitai::kstream* p__io, glonass_t* p__parent = 0, glonass_t* p__root = 0); - - private: - void _read(); - void _clean_up(); - - public: - ~string_4_t(); - - private: - bool f_tau_n; - int32_t m_tau_n; - - public: - int32_t tau_n(); - - private: - bool f_delta_tau_n; - int32_t m_delta_tau_n; - - public: - int32_t delta_tau_n(); - - private: - bool m_tau_n_sign; - uint64_t m_tau_n_value; - bool m_delta_tau_n_sign; - uint64_t m_delta_tau_n_value; - uint64_t m_e_n; - uint64_t m_not_used_1; - bool m_p4; - uint64_t m_f_t; - uint64_t m_not_used_2; - uint64_t m_n_t; - uint64_t m_n; - uint64_t m_m; - glonass_t* m__root; - glonass_t* m__parent; - - public: - bool tau_n_sign() const { return m_tau_n_sign; } - uint64_t tau_n_value() const { return m_tau_n_value; } - bool delta_tau_n_sign() const { return m_delta_tau_n_sign; } - uint64_t delta_tau_n_value() const { return m_delta_tau_n_value; } - uint64_t e_n() const { return m_e_n; } - uint64_t not_used_1() const { return m_not_used_1; } - bool p4() const { return m_p4; } - uint64_t f_t() const { return m_f_t; } - uint64_t not_used_2() const { return m_not_used_2; } - uint64_t n_t() const { return m_n_t; } - uint64_t n() const { return m_n; } - uint64_t m() const { return m_m; } - glonass_t* _root() const { return m__root; } - glonass_t* _parent() const { return m__parent; } - }; - - class string_non_immediate_t : public kaitai::kstruct { - - public: - - string_non_immediate_t(kaitai::kstream* p__io, glonass_t* p__parent = 0, glonass_t* p__root = 0); - - private: - void _read(); - void _clean_up(); - - public: - ~string_non_immediate_t(); - - private: - uint64_t m_data_1; - uint64_t m_data_2; - glonass_t* m__root; - glonass_t* m__parent; - - public: - uint64_t data_1() const { return m_data_1; } - uint64_t data_2() const { return m_data_2; } - glonass_t* _root() const { return m__root; } - glonass_t* _parent() const { return m__parent; } - }; - - class string_5_t : public kaitai::kstruct { - - public: - - string_5_t(kaitai::kstream* p__io, glonass_t* p__parent = 0, glonass_t* p__root = 0); - - private: - void _read(); - void _clean_up(); - - public: - ~string_5_t(); - - private: - uint64_t m_n_a; - uint64_t m_tau_c; - bool m_not_used; - uint64_t m_n_4; - uint64_t m_tau_gps; - bool m_l_n; - glonass_t* m__root; - glonass_t* m__parent; - - public: - uint64_t n_a() const { return m_n_a; } - uint64_t tau_c() const { return m_tau_c; } - bool not_used() const { return m_not_used; } - uint64_t n_4() const { return m_n_4; } - uint64_t tau_gps() const { return m_tau_gps; } - bool l_n() const { return m_l_n; } - glonass_t* _root() const { return m__root; } - glonass_t* _parent() const { return m__parent; } - }; - - class string_1_t : public kaitai::kstruct { - - public: - - string_1_t(kaitai::kstream* p__io, glonass_t* p__parent = 0, glonass_t* p__root = 0); - - private: - void _read(); - void _clean_up(); - - public: - ~string_1_t(); - - private: - bool f_x_vel; - int32_t m_x_vel; - - public: - int32_t x_vel(); - - private: - bool f_x_accel; - int32_t m_x_accel; - - public: - int32_t x_accel(); - - private: - bool f_x; - int32_t m_x; - - public: - int32_t x(); - - private: - uint64_t m_not_used; - uint64_t m_p1; - uint64_t m_t_k; - bool m_x_vel_sign; - uint64_t m_x_vel_value; - bool m_x_accel_sign; - uint64_t m_x_accel_value; - bool m_x_sign; - uint64_t m_x_value; - glonass_t* m__root; - glonass_t* m__parent; - - public: - uint64_t not_used() const { return m_not_used; } - uint64_t p1() const { return m_p1; } - uint64_t t_k() const { return m_t_k; } - bool x_vel_sign() const { return m_x_vel_sign; } - uint64_t x_vel_value() const { return m_x_vel_value; } - bool x_accel_sign() const { return m_x_accel_sign; } - uint64_t x_accel_value() const { return m_x_accel_value; } - bool x_sign() const { return m_x_sign; } - uint64_t x_value() const { return m_x_value; } - glonass_t* _root() const { return m__root; } - glonass_t* _parent() const { return m__parent; } - }; - - class string_2_t : public kaitai::kstruct { - - public: - - string_2_t(kaitai::kstream* p__io, glonass_t* p__parent = 0, glonass_t* p__root = 0); - - private: - void _read(); - void _clean_up(); - - public: - ~string_2_t(); - - private: - bool f_y_vel; - int32_t m_y_vel; - - public: - int32_t y_vel(); - - private: - bool f_y_accel; - int32_t m_y_accel; - - public: - int32_t y_accel(); - - private: - bool f_y; - int32_t m_y; - - public: - int32_t y(); - - private: - uint64_t m_b_n; - bool m_p2; - uint64_t m_t_b; - uint64_t m_not_used; - bool m_y_vel_sign; - uint64_t m_y_vel_value; - bool m_y_accel_sign; - uint64_t m_y_accel_value; - bool m_y_sign; - uint64_t m_y_value; - glonass_t* m__root; - glonass_t* m__parent; - - public: - uint64_t b_n() const { return m_b_n; } - bool p2() const { return m_p2; } - uint64_t t_b() const { return m_t_b; } - uint64_t not_used() const { return m_not_used; } - bool y_vel_sign() const { return m_y_vel_sign; } - uint64_t y_vel_value() const { return m_y_vel_value; } - bool y_accel_sign() const { return m_y_accel_sign; } - uint64_t y_accel_value() const { return m_y_accel_value; } - bool y_sign() const { return m_y_sign; } - uint64_t y_value() const { return m_y_value; } - glonass_t* _root() const { return m__root; } - glonass_t* _parent() const { return m__parent; } - }; - - class string_3_t : public kaitai::kstruct { - - public: - - string_3_t(kaitai::kstream* p__io, glonass_t* p__parent = 0, glonass_t* p__root = 0); - - private: - void _read(); - void _clean_up(); - - public: - ~string_3_t(); - - private: - bool f_gamma_n; - int32_t m_gamma_n; - - public: - int32_t gamma_n(); - - private: - bool f_z_vel; - int32_t m_z_vel; - - public: - int32_t z_vel(); - - private: - bool f_z_accel; - int32_t m_z_accel; - - public: - int32_t z_accel(); - - private: - bool f_z; - int32_t m_z; - - public: - int32_t z(); - - private: - bool m_p3; - bool m_gamma_n_sign; - uint64_t m_gamma_n_value; - bool m_not_used; - uint64_t m_p; - bool m_l_n; - bool m_z_vel_sign; - uint64_t m_z_vel_value; - bool m_z_accel_sign; - uint64_t m_z_accel_value; - bool m_z_sign; - uint64_t m_z_value; - glonass_t* m__root; - glonass_t* m__parent; - - public: - bool p3() const { return m_p3; } - bool gamma_n_sign() const { return m_gamma_n_sign; } - uint64_t gamma_n_value() const { return m_gamma_n_value; } - bool not_used() const { return m_not_used; } - uint64_t p() const { return m_p; } - bool l_n() const { return m_l_n; } - bool z_vel_sign() const { return m_z_vel_sign; } - uint64_t z_vel_value() const { return m_z_vel_value; } - bool z_accel_sign() const { return m_z_accel_sign; } - uint64_t z_accel_value() const { return m_z_accel_value; } - bool z_sign() const { return m_z_sign; } - uint64_t z_value() const { return m_z_value; } - glonass_t* _root() const { return m__root; } - glonass_t* _parent() const { return m__parent; } - }; - -private: - bool m_idle_chip; - uint64_t m_string_number; - kaitai::kstruct* m_data; - uint64_t m_hamming_code; - uint64_t m_pad_1; - uint64_t m_superframe_number; - uint64_t m_pad_2; - uint64_t m_frame_number; - glonass_t* m__root; - kaitai::kstruct* m__parent; - -public: - bool idle_chip() const { return m_idle_chip; } - uint64_t string_number() const { return m_string_number; } - kaitai::kstruct* data() const { return m_data; } - uint64_t hamming_code() const { return m_hamming_code; } - uint64_t pad_1() const { return m_pad_1; } - uint64_t superframe_number() const { return m_superframe_number; } - uint64_t pad_2() const { return m_pad_2; } - uint64_t frame_number() const { return m_frame_number; } - glonass_t* _root() const { return m__root; } - kaitai::kstruct* _parent() const { return m__parent; } -}; - -#endif // GLONASS_H_ diff --git a/system/ubloxd/generated/glonass.py b/system/ubloxd/generated/glonass.py new file mode 100644 index 00000000..40aa16bb --- /dev/null +++ b/system/ubloxd/generated/glonass.py @@ -0,0 +1,247 @@ +# This is a generated file! Please edit source .ksy file and use kaitai-struct-compiler to rebuild + +import kaitaistruct +from kaitaistruct import KaitaiStruct, KaitaiStream, BytesIO + + +if getattr(kaitaistruct, 'API_VERSION', (0, 9)) < (0, 9): + raise Exception("Incompatible Kaitai Struct Python API: 0.9 or later is required, but you have %s" % (kaitaistruct.__version__)) + +class Glonass(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self._read() + + def _read(self): + self.idle_chip = self._io.read_bits_int_be(1) != 0 + self.string_number = self._io.read_bits_int_be(4) + # workaround for kaitai bit alignment issue (see glonass_fix.patch for C++) + # self._io.align_to_byte() + _on = self.string_number + if _on == 4: + self.data = Glonass.String4(self._io, self, self._root) + elif _on == 1: + self.data = Glonass.String1(self._io, self, self._root) + elif _on == 3: + self.data = Glonass.String3(self._io, self, self._root) + elif _on == 5: + self.data = Glonass.String5(self._io, self, self._root) + elif _on == 2: + self.data = Glonass.String2(self._io, self, self._root) + else: + self.data = Glonass.StringNonImmediate(self._io, self, self._root) + self.hamming_code = self._io.read_bits_int_be(8) + self.pad_1 = self._io.read_bits_int_be(11) + self.superframe_number = self._io.read_bits_int_be(16) + self.pad_2 = self._io.read_bits_int_be(8) + self.frame_number = self._io.read_bits_int_be(8) + + class String4(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self._read() + + def _read(self): + self.tau_n_sign = self._io.read_bits_int_be(1) != 0 + self.tau_n_value = self._io.read_bits_int_be(21) + self.delta_tau_n_sign = self._io.read_bits_int_be(1) != 0 + self.delta_tau_n_value = self._io.read_bits_int_be(4) + self.e_n = self._io.read_bits_int_be(5) + self.not_used_1 = self._io.read_bits_int_be(14) + self.p4 = self._io.read_bits_int_be(1) != 0 + self.f_t = self._io.read_bits_int_be(4) + self.not_used_2 = self._io.read_bits_int_be(3) + self.n_t = self._io.read_bits_int_be(11) + self.n = self._io.read_bits_int_be(5) + self.m = self._io.read_bits_int_be(2) + + @property + def tau_n(self): + if hasattr(self, '_m_tau_n'): + return self._m_tau_n + + self._m_tau_n = ((self.tau_n_value * -1) if self.tau_n_sign else self.tau_n_value) + return getattr(self, '_m_tau_n', None) + + @property + def delta_tau_n(self): + if hasattr(self, '_m_delta_tau_n'): + return self._m_delta_tau_n + + self._m_delta_tau_n = ((self.delta_tau_n_value * -1) if self.delta_tau_n_sign else self.delta_tau_n_value) + return getattr(self, '_m_delta_tau_n', None) + + + class StringNonImmediate(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self._read() + + def _read(self): + self.data_1 = self._io.read_bits_int_be(64) + self.data_2 = self._io.read_bits_int_be(8) + + + class String5(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self._read() + + def _read(self): + self.n_a = self._io.read_bits_int_be(11) + self.tau_c = self._io.read_bits_int_be(32) + self.not_used = self._io.read_bits_int_be(1) != 0 + self.n_4 = self._io.read_bits_int_be(5) + self.tau_gps = self._io.read_bits_int_be(22) + self.l_n = self._io.read_bits_int_be(1) != 0 + + + class String1(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self._read() + + def _read(self): + self.not_used = self._io.read_bits_int_be(2) + self.p1 = self._io.read_bits_int_be(2) + self.t_k = self._io.read_bits_int_be(12) + self.x_vel_sign = self._io.read_bits_int_be(1) != 0 + self.x_vel_value = self._io.read_bits_int_be(23) + self.x_accel_sign = self._io.read_bits_int_be(1) != 0 + self.x_accel_value = self._io.read_bits_int_be(4) + self.x_sign = self._io.read_bits_int_be(1) != 0 + self.x_value = self._io.read_bits_int_be(26) + + @property + def x_vel(self): + if hasattr(self, '_m_x_vel'): + return self._m_x_vel + + self._m_x_vel = ((self.x_vel_value * -1) if self.x_vel_sign else self.x_vel_value) + return getattr(self, '_m_x_vel', None) + + @property + def x_accel(self): + if hasattr(self, '_m_x_accel'): + return self._m_x_accel + + self._m_x_accel = ((self.x_accel_value * -1) if self.x_accel_sign else self.x_accel_value) + return getattr(self, '_m_x_accel', None) + + @property + def x(self): + if hasattr(self, '_m_x'): + return self._m_x + + self._m_x = ((self.x_value * -1) if self.x_sign else self.x_value) + return getattr(self, '_m_x', None) + + + class String2(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self._read() + + def _read(self): + self.b_n = self._io.read_bits_int_be(3) + self.p2 = self._io.read_bits_int_be(1) != 0 + self.t_b = self._io.read_bits_int_be(7) + self.not_used = self._io.read_bits_int_be(5) + self.y_vel_sign = self._io.read_bits_int_be(1) != 0 + self.y_vel_value = self._io.read_bits_int_be(23) + self.y_accel_sign = self._io.read_bits_int_be(1) != 0 + self.y_accel_value = self._io.read_bits_int_be(4) + self.y_sign = self._io.read_bits_int_be(1) != 0 + self.y_value = self._io.read_bits_int_be(26) + + @property + def y_vel(self): + if hasattr(self, '_m_y_vel'): + return self._m_y_vel + + self._m_y_vel = ((self.y_vel_value * -1) if self.y_vel_sign else self.y_vel_value) + return getattr(self, '_m_y_vel', None) + + @property + def y_accel(self): + if hasattr(self, '_m_y_accel'): + return self._m_y_accel + + self._m_y_accel = ((self.y_accel_value * -1) if self.y_accel_sign else self.y_accel_value) + return getattr(self, '_m_y_accel', None) + + @property + def y(self): + if hasattr(self, '_m_y'): + return self._m_y + + self._m_y = ((self.y_value * -1) if self.y_sign else self.y_value) + return getattr(self, '_m_y', None) + + + class String3(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self._read() + + def _read(self): + self.p3 = self._io.read_bits_int_be(1) != 0 + self.gamma_n_sign = self._io.read_bits_int_be(1) != 0 + self.gamma_n_value = self._io.read_bits_int_be(10) + self.not_used = self._io.read_bits_int_be(1) != 0 + self.p = self._io.read_bits_int_be(2) + self.l_n = self._io.read_bits_int_be(1) != 0 + self.z_vel_sign = self._io.read_bits_int_be(1) != 0 + self.z_vel_value = self._io.read_bits_int_be(23) + self.z_accel_sign = self._io.read_bits_int_be(1) != 0 + self.z_accel_value = self._io.read_bits_int_be(4) + self.z_sign = self._io.read_bits_int_be(1) != 0 + self.z_value = self._io.read_bits_int_be(26) + + @property + def gamma_n(self): + if hasattr(self, '_m_gamma_n'): + return self._m_gamma_n + + self._m_gamma_n = ((self.gamma_n_value * -1) if self.gamma_n_sign else self.gamma_n_value) + return getattr(self, '_m_gamma_n', None) + + @property + def z_vel(self): + if hasattr(self, '_m_z_vel'): + return self._m_z_vel + + self._m_z_vel = ((self.z_vel_value * -1) if self.z_vel_sign else self.z_vel_value) + return getattr(self, '_m_z_vel', None) + + @property + def z_accel(self): + if hasattr(self, '_m_z_accel'): + return self._m_z_accel + + self._m_z_accel = ((self.z_accel_value * -1) if self.z_accel_sign else self.z_accel_value) + return getattr(self, '_m_z_accel', None) + + @property + def z(self): + if hasattr(self, '_m_z'): + return self._m_z + + self._m_z = ((self.z_value * -1) if self.z_sign else self.z_value) + return getattr(self, '_m_z', None) + + diff --git a/system/ubloxd/generated/gps.cpp b/system/ubloxd/generated/gps.cpp deleted file mode 100644 index 8e1cb85b..00000000 --- a/system/ubloxd/generated/gps.cpp +++ /dev/null @@ -1,325 +0,0 @@ -// This is a generated file! Please edit source .ksy file and use kaitai-struct-compiler to rebuild - -#include "gps.h" -#include "kaitai/exceptions.h" - -gps_t::gps_t(kaitai::kstream* p__io, kaitai::kstruct* p__parent, gps_t* p__root) : kaitai::kstruct(p__io) { - m__parent = p__parent; - m__root = this; - m_tlm = 0; - m_how = 0; - - try { - _read(); - } catch(...) { - _clean_up(); - throw; - } -} - -void gps_t::_read() { - m_tlm = new tlm_t(m__io, this, m__root); - m_how = new how_t(m__io, this, m__root); - n_body = true; - switch (how()->subframe_id()) { - case 1: { - n_body = false; - m_body = new subframe_1_t(m__io, this, m__root); - break; - } - case 2: { - n_body = false; - m_body = new subframe_2_t(m__io, this, m__root); - break; - } - case 3: { - n_body = false; - m_body = new subframe_3_t(m__io, this, m__root); - break; - } - case 4: { - n_body = false; - m_body = new subframe_4_t(m__io, this, m__root); - break; - } - } -} - -gps_t::~gps_t() { - _clean_up(); -} - -void gps_t::_clean_up() { - if (m_tlm) { - delete m_tlm; m_tlm = 0; - } - if (m_how) { - delete m_how; m_how = 0; - } - if (!n_body) { - if (m_body) { - delete m_body; m_body = 0; - } - } -} - -gps_t::subframe_1_t::subframe_1_t(kaitai::kstream* p__io, gps_t* p__parent, gps_t* p__root) : kaitai::kstruct(p__io) { - m__parent = p__parent; - m__root = p__root; - f_af_0 = false; - - try { - _read(); - } catch(...) { - _clean_up(); - throw; - } -} - -void gps_t::subframe_1_t::_read() { - m_week_no = m__io->read_bits_int_be(10); - m_code = m__io->read_bits_int_be(2); - m_sv_accuracy = m__io->read_bits_int_be(4); - m_sv_health = m__io->read_bits_int_be(6); - m_iodc_msb = m__io->read_bits_int_be(2); - m_l2_p_data_flag = m__io->read_bits_int_be(1); - m_reserved1 = m__io->read_bits_int_be(23); - m_reserved2 = m__io->read_bits_int_be(24); - m_reserved3 = m__io->read_bits_int_be(24); - m_reserved4 = m__io->read_bits_int_be(16); - m__io->align_to_byte(); - m_t_gd = m__io->read_s1(); - m_iodc_lsb = m__io->read_u1(); - m_t_oc = m__io->read_u2be(); - m_af_2 = m__io->read_s1(); - m_af_1 = m__io->read_s2be(); - m_af_0_sign = m__io->read_bits_int_be(1); - m_af_0_value = m__io->read_bits_int_be(21); - m_reserved5 = m__io->read_bits_int_be(2); -} - -gps_t::subframe_1_t::~subframe_1_t() { - _clean_up(); -} - -void gps_t::subframe_1_t::_clean_up() { -} - -int32_t gps_t::subframe_1_t::af_0() { - if (f_af_0) - return m_af_0; - m_af_0 = ((af_0_sign()) ? ((af_0_value() - (1 << 21))) : (af_0_value())); - f_af_0 = true; - return m_af_0; -} - -gps_t::subframe_3_t::subframe_3_t(kaitai::kstream* p__io, gps_t* p__parent, gps_t* p__root) : kaitai::kstruct(p__io) { - m__parent = p__parent; - m__root = p__root; - f_omega_dot = false; - f_idot = false; - - try { - _read(); - } catch(...) { - _clean_up(); - throw; - } -} - -void gps_t::subframe_3_t::_read() { - m_c_ic = m__io->read_s2be(); - m_omega_0 = m__io->read_s4be(); - m_c_is = m__io->read_s2be(); - m_i_0 = m__io->read_s4be(); - m_c_rc = m__io->read_s2be(); - m_omega = m__io->read_s4be(); - m_omega_dot_sign = m__io->read_bits_int_be(1); - m_omega_dot_value = m__io->read_bits_int_be(23); - m__io->align_to_byte(); - m_iode = m__io->read_u1(); - m_idot_sign = m__io->read_bits_int_be(1); - m_idot_value = m__io->read_bits_int_be(13); - m_reserved = m__io->read_bits_int_be(2); -} - -gps_t::subframe_3_t::~subframe_3_t() { - _clean_up(); -} - -void gps_t::subframe_3_t::_clean_up() { -} - -int32_t gps_t::subframe_3_t::omega_dot() { - if (f_omega_dot) - return m_omega_dot; - m_omega_dot = ((omega_dot_sign()) ? ((omega_dot_value() - (1 << 23))) : (omega_dot_value())); - f_omega_dot = true; - return m_omega_dot; -} - -int32_t gps_t::subframe_3_t::idot() { - if (f_idot) - return m_idot; - m_idot = ((idot_sign()) ? ((idot_value() - (1 << 13))) : (idot_value())); - f_idot = true; - return m_idot; -} - -gps_t::subframe_4_t::subframe_4_t(kaitai::kstream* p__io, gps_t* p__parent, gps_t* p__root) : kaitai::kstruct(p__io) { - m__parent = p__parent; - m__root = p__root; - - try { - _read(); - } catch(...) { - _clean_up(); - throw; - } -} - -void gps_t::subframe_4_t::_read() { - m_data_id = m__io->read_bits_int_be(2); - m_page_id = m__io->read_bits_int_be(6); - m__io->align_to_byte(); - n_body = true; - switch (page_id()) { - case 56: { - n_body = false; - m_body = new ionosphere_data_t(m__io, this, m__root); - break; - } - } -} - -gps_t::subframe_4_t::~subframe_4_t() { - _clean_up(); -} - -void gps_t::subframe_4_t::_clean_up() { - if (!n_body) { - if (m_body) { - delete m_body; m_body = 0; - } - } -} - -gps_t::subframe_4_t::ionosphere_data_t::ionosphere_data_t(kaitai::kstream* p__io, gps_t::subframe_4_t* p__parent, gps_t* p__root) : kaitai::kstruct(p__io) { - m__parent = p__parent; - m__root = p__root; - - try { - _read(); - } catch(...) { - _clean_up(); - throw; - } -} - -void gps_t::subframe_4_t::ionosphere_data_t::_read() { - m_a0 = m__io->read_s1(); - m_a1 = m__io->read_s1(); - m_a2 = m__io->read_s1(); - m_a3 = m__io->read_s1(); - m_b0 = m__io->read_s1(); - m_b1 = m__io->read_s1(); - m_b2 = m__io->read_s1(); - m_b3 = m__io->read_s1(); -} - -gps_t::subframe_4_t::ionosphere_data_t::~ionosphere_data_t() { - _clean_up(); -} - -void gps_t::subframe_4_t::ionosphere_data_t::_clean_up() { -} - -gps_t::how_t::how_t(kaitai::kstream* p__io, gps_t* p__parent, gps_t* p__root) : kaitai::kstruct(p__io) { - m__parent = p__parent; - m__root = p__root; - - try { - _read(); - } catch(...) { - _clean_up(); - throw; - } -} - -void gps_t::how_t::_read() { - m_tow_count = m__io->read_bits_int_be(17); - m_alert = m__io->read_bits_int_be(1); - m_anti_spoof = m__io->read_bits_int_be(1); - m_subframe_id = m__io->read_bits_int_be(3); - m_reserved = m__io->read_bits_int_be(2); -} - -gps_t::how_t::~how_t() { - _clean_up(); -} - -void gps_t::how_t::_clean_up() { -} - -gps_t::tlm_t::tlm_t(kaitai::kstream* p__io, gps_t* p__parent, gps_t* p__root) : kaitai::kstruct(p__io) { - m__parent = p__parent; - m__root = p__root; - - try { - _read(); - } catch(...) { - _clean_up(); - throw; - } -} - -void gps_t::tlm_t::_read() { - m_preamble = m__io->read_bytes(1); - if (!(preamble() == std::string("\x8B", 1))) { - throw kaitai::validation_not_equal_error(std::string("\x8B", 1), preamble(), _io(), std::string("/types/tlm/seq/0")); - } - m_tlm = m__io->read_bits_int_be(14); - m_integrity_status = m__io->read_bits_int_be(1); - m_reserved = m__io->read_bits_int_be(1); -} - -gps_t::tlm_t::~tlm_t() { - _clean_up(); -} - -void gps_t::tlm_t::_clean_up() { -} - -gps_t::subframe_2_t::subframe_2_t(kaitai::kstream* p__io, gps_t* p__parent, gps_t* p__root) : kaitai::kstruct(p__io) { - m__parent = p__parent; - m__root = p__root; - - try { - _read(); - } catch(...) { - _clean_up(); - throw; - } -} - -void gps_t::subframe_2_t::_read() { - m_iode = m__io->read_u1(); - m_c_rs = m__io->read_s2be(); - m_delta_n = m__io->read_s2be(); - m_m_0 = m__io->read_s4be(); - m_c_uc = m__io->read_s2be(); - m_e = m__io->read_s4be(); - m_c_us = m__io->read_s2be(); - m_sqrt_a = m__io->read_u4be(); - m_t_oe = m__io->read_u2be(); - m_fit_interval_flag = m__io->read_bits_int_be(1); - m_aoda = m__io->read_bits_int_be(5); - m_reserved = m__io->read_bits_int_be(2); -} - -gps_t::subframe_2_t::~subframe_2_t() { - _clean_up(); -} - -void gps_t::subframe_2_t::_clean_up() { -} diff --git a/system/ubloxd/generated/gps.h b/system/ubloxd/generated/gps.h deleted file mode 100644 index 9dfc5031..00000000 --- a/system/ubloxd/generated/gps.h +++ /dev/null @@ -1,359 +0,0 @@ -#ifndef GPS_H_ -#define GPS_H_ - -// This is a generated file! Please edit source .ksy file and use kaitai-struct-compiler to rebuild - -#include "kaitai/kaitaistruct.h" -#include - -#if KAITAI_STRUCT_VERSION < 9000L -#error "Incompatible Kaitai Struct C++/STL API: version 0.9 or later is required" -#endif - -class gps_t : public kaitai::kstruct { - -public: - class subframe_1_t; - class subframe_3_t; - class subframe_4_t; - class how_t; - class tlm_t; - class subframe_2_t; - - gps_t(kaitai::kstream* p__io, kaitai::kstruct* p__parent = 0, gps_t* p__root = 0); - -private: - void _read(); - void _clean_up(); - -public: - ~gps_t(); - - class subframe_1_t : public kaitai::kstruct { - - public: - - subframe_1_t(kaitai::kstream* p__io, gps_t* p__parent = 0, gps_t* p__root = 0); - - private: - void _read(); - void _clean_up(); - - public: - ~subframe_1_t(); - - private: - bool f_af_0; - int32_t m_af_0; - - public: - int32_t af_0(); - - private: - uint64_t m_week_no; - uint64_t m_code; - uint64_t m_sv_accuracy; - uint64_t m_sv_health; - uint64_t m_iodc_msb; - bool m_l2_p_data_flag; - uint64_t m_reserved1; - uint64_t m_reserved2; - uint64_t m_reserved3; - uint64_t m_reserved4; - int8_t m_t_gd; - uint8_t m_iodc_lsb; - uint16_t m_t_oc; - int8_t m_af_2; - int16_t m_af_1; - bool m_af_0_sign; - uint64_t m_af_0_value; - uint64_t m_reserved5; - gps_t* m__root; - gps_t* m__parent; - - public: - uint64_t week_no() const { return m_week_no; } - uint64_t code() const { return m_code; } - uint64_t sv_accuracy() const { return m_sv_accuracy; } - uint64_t sv_health() const { return m_sv_health; } - uint64_t iodc_msb() const { return m_iodc_msb; } - bool l2_p_data_flag() const { return m_l2_p_data_flag; } - uint64_t reserved1() const { return m_reserved1; } - uint64_t reserved2() const { return m_reserved2; } - uint64_t reserved3() const { return m_reserved3; } - uint64_t reserved4() const { return m_reserved4; } - int8_t t_gd() const { return m_t_gd; } - uint8_t iodc_lsb() const { return m_iodc_lsb; } - uint16_t t_oc() const { return m_t_oc; } - int8_t af_2() const { return m_af_2; } - int16_t af_1() const { return m_af_1; } - bool af_0_sign() const { return m_af_0_sign; } - uint64_t af_0_value() const { return m_af_0_value; } - uint64_t reserved5() const { return m_reserved5; } - gps_t* _root() const { return m__root; } - gps_t* _parent() const { return m__parent; } - }; - - class subframe_3_t : public kaitai::kstruct { - - public: - - subframe_3_t(kaitai::kstream* p__io, gps_t* p__parent = 0, gps_t* p__root = 0); - - private: - void _read(); - void _clean_up(); - - public: - ~subframe_3_t(); - - private: - bool f_omega_dot; - int32_t m_omega_dot; - - public: - int32_t omega_dot(); - - private: - bool f_idot; - int32_t m_idot; - - public: - int32_t idot(); - - private: - int16_t m_c_ic; - int32_t m_omega_0; - int16_t m_c_is; - int32_t m_i_0; - int16_t m_c_rc; - int32_t m_omega; - bool m_omega_dot_sign; - uint64_t m_omega_dot_value; - uint8_t m_iode; - bool m_idot_sign; - uint64_t m_idot_value; - uint64_t m_reserved; - gps_t* m__root; - gps_t* m__parent; - - public: - int16_t c_ic() const { return m_c_ic; } - int32_t omega_0() const { return m_omega_0; } - int16_t c_is() const { return m_c_is; } - int32_t i_0() const { return m_i_0; } - int16_t c_rc() const { return m_c_rc; } - int32_t omega() const { return m_omega; } - bool omega_dot_sign() const { return m_omega_dot_sign; } - uint64_t omega_dot_value() const { return m_omega_dot_value; } - uint8_t iode() const { return m_iode; } - bool idot_sign() const { return m_idot_sign; } - uint64_t idot_value() const { return m_idot_value; } - uint64_t reserved() const { return m_reserved; } - gps_t* _root() const { return m__root; } - gps_t* _parent() const { return m__parent; } - }; - - class subframe_4_t : public kaitai::kstruct { - - public: - class ionosphere_data_t; - - subframe_4_t(kaitai::kstream* p__io, gps_t* p__parent = 0, gps_t* p__root = 0); - - private: - void _read(); - void _clean_up(); - - public: - ~subframe_4_t(); - - class ionosphere_data_t : public kaitai::kstruct { - - public: - - ionosphere_data_t(kaitai::kstream* p__io, gps_t::subframe_4_t* p__parent = 0, gps_t* p__root = 0); - - private: - void _read(); - void _clean_up(); - - public: - ~ionosphere_data_t(); - - private: - int8_t m_a0; - int8_t m_a1; - int8_t m_a2; - int8_t m_a3; - int8_t m_b0; - int8_t m_b1; - int8_t m_b2; - int8_t m_b3; - gps_t* m__root; - gps_t::subframe_4_t* m__parent; - - public: - int8_t a0() const { return m_a0; } - int8_t a1() const { return m_a1; } - int8_t a2() const { return m_a2; } - int8_t a3() const { return m_a3; } - int8_t b0() const { return m_b0; } - int8_t b1() const { return m_b1; } - int8_t b2() const { return m_b2; } - int8_t b3() const { return m_b3; } - gps_t* _root() const { return m__root; } - gps_t::subframe_4_t* _parent() const { return m__parent; } - }; - - private: - uint64_t m_data_id; - uint64_t m_page_id; - ionosphere_data_t* m_body; - bool n_body; - - public: - bool _is_null_body() { body(); return n_body; }; - - private: - gps_t* m__root; - gps_t* m__parent; - - public: - uint64_t data_id() const { return m_data_id; } - uint64_t page_id() const { return m_page_id; } - ionosphere_data_t* body() const { return m_body; } - gps_t* _root() const { return m__root; } - gps_t* _parent() const { return m__parent; } - }; - - class how_t : public kaitai::kstruct { - - public: - - how_t(kaitai::kstream* p__io, gps_t* p__parent = 0, gps_t* p__root = 0); - - private: - void _read(); - void _clean_up(); - - public: - ~how_t(); - - private: - uint64_t m_tow_count; - bool m_alert; - bool m_anti_spoof; - uint64_t m_subframe_id; - uint64_t m_reserved; - gps_t* m__root; - gps_t* m__parent; - - public: - uint64_t tow_count() const { return m_tow_count; } - bool alert() const { return m_alert; } - bool anti_spoof() const { return m_anti_spoof; } - uint64_t subframe_id() const { return m_subframe_id; } - uint64_t reserved() const { return m_reserved; } - gps_t* _root() const { return m__root; } - gps_t* _parent() const { return m__parent; } - }; - - class tlm_t : public kaitai::kstruct { - - public: - - tlm_t(kaitai::kstream* p__io, gps_t* p__parent = 0, gps_t* p__root = 0); - - private: - void _read(); - void _clean_up(); - - public: - ~tlm_t(); - - private: - std::string m_preamble; - uint64_t m_tlm; - bool m_integrity_status; - bool m_reserved; - gps_t* m__root; - gps_t* m__parent; - - public: - std::string preamble() const { return m_preamble; } - uint64_t tlm() const { return m_tlm; } - bool integrity_status() const { return m_integrity_status; } - bool reserved() const { return m_reserved; } - gps_t* _root() const { return m__root; } - gps_t* _parent() const { return m__parent; } - }; - - class subframe_2_t : public kaitai::kstruct { - - public: - - subframe_2_t(kaitai::kstream* p__io, gps_t* p__parent = 0, gps_t* p__root = 0); - - private: - void _read(); - void _clean_up(); - - public: - ~subframe_2_t(); - - private: - uint8_t m_iode; - int16_t m_c_rs; - int16_t m_delta_n; - int32_t m_m_0; - int16_t m_c_uc; - int32_t m_e; - int16_t m_c_us; - uint32_t m_sqrt_a; - uint16_t m_t_oe; - bool m_fit_interval_flag; - uint64_t m_aoda; - uint64_t m_reserved; - gps_t* m__root; - gps_t* m__parent; - - public: - uint8_t iode() const { return m_iode; } - int16_t c_rs() const { return m_c_rs; } - int16_t delta_n() const { return m_delta_n; } - int32_t m_0() const { return m_m_0; } - int16_t c_uc() const { return m_c_uc; } - int32_t e() const { return m_e; } - int16_t c_us() const { return m_c_us; } - uint32_t sqrt_a() const { return m_sqrt_a; } - uint16_t t_oe() const { return m_t_oe; } - bool fit_interval_flag() const { return m_fit_interval_flag; } - uint64_t aoda() const { return m_aoda; } - uint64_t reserved() const { return m_reserved; } - gps_t* _root() const { return m__root; } - gps_t* _parent() const { return m__parent; } - }; - -private: - tlm_t* m_tlm; - how_t* m_how; - kaitai::kstruct* m_body; - bool n_body; - -public: - bool _is_null_body() { body(); return n_body; }; - -private: - gps_t* m__root; - kaitai::kstruct* m__parent; - -public: - tlm_t* tlm() const { return m_tlm; } - how_t* how() const { return m_how; } - kaitai::kstruct* body() const { return m_body; } - gps_t* _root() const { return m__root; } - kaitai::kstruct* _parent() const { return m__parent; } -}; - -#endif // GPS_H_ diff --git a/system/ubloxd/generated/gps.py b/system/ubloxd/generated/gps.py new file mode 100644 index 00000000..a999016f --- /dev/null +++ b/system/ubloxd/generated/gps.py @@ -0,0 +1,193 @@ +# This is a generated file! Please edit source .ksy file and use kaitai-struct-compiler to rebuild + +import kaitaistruct +from kaitaistruct import KaitaiStruct, KaitaiStream, BytesIO + + +if getattr(kaitaistruct, 'API_VERSION', (0, 9)) < (0, 9): + raise Exception("Incompatible Kaitai Struct Python API: 0.9 or later is required, but you have %s" % (kaitaistruct.__version__)) + +class Gps(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self._read() + + def _read(self): + self.tlm = Gps.Tlm(self._io, self, self._root) + self.how = Gps.How(self._io, self, self._root) + _on = self.how.subframe_id + if _on == 1: + self.body = Gps.Subframe1(self._io, self, self._root) + elif _on == 2: + self.body = Gps.Subframe2(self._io, self, self._root) + elif _on == 3: + self.body = Gps.Subframe3(self._io, self, self._root) + elif _on == 4: + self.body = Gps.Subframe4(self._io, self, self._root) + + class Subframe1(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self._read() + + def _read(self): + self.week_no = self._io.read_bits_int_be(10) + self.code = self._io.read_bits_int_be(2) + self.sv_accuracy = self._io.read_bits_int_be(4) + self.sv_health = self._io.read_bits_int_be(6) + self.iodc_msb = self._io.read_bits_int_be(2) + self.l2_p_data_flag = self._io.read_bits_int_be(1) != 0 + self.reserved1 = self._io.read_bits_int_be(23) + self.reserved2 = self._io.read_bits_int_be(24) + self.reserved3 = self._io.read_bits_int_be(24) + self.reserved4 = self._io.read_bits_int_be(16) + self._io.align_to_byte() + self.t_gd = self._io.read_s1() + self.iodc_lsb = self._io.read_u1() + self.t_oc = self._io.read_u2be() + self.af_2 = self._io.read_s1() + self.af_1 = self._io.read_s2be() + self.af_0_sign = self._io.read_bits_int_be(1) != 0 + self.af_0_value = self._io.read_bits_int_be(21) + self.reserved5 = self._io.read_bits_int_be(2) + + @property + def af_0(self): + if hasattr(self, '_m_af_0'): + return self._m_af_0 + + self._m_af_0 = ((self.af_0_value - (1 << 21)) if self.af_0_sign else self.af_0_value) + return getattr(self, '_m_af_0', None) + + + class Subframe3(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self._read() + + def _read(self): + self.c_ic = self._io.read_s2be() + self.omega_0 = self._io.read_s4be() + self.c_is = self._io.read_s2be() + self.i_0 = self._io.read_s4be() + self.c_rc = self._io.read_s2be() + self.omega = self._io.read_s4be() + self.omega_dot_sign = self._io.read_bits_int_be(1) != 0 + self.omega_dot_value = self._io.read_bits_int_be(23) + self._io.align_to_byte() + self.iode = self._io.read_u1() + self.idot_sign = self._io.read_bits_int_be(1) != 0 + self.idot_value = self._io.read_bits_int_be(13) + self.reserved = self._io.read_bits_int_be(2) + + @property + def omega_dot(self): + if hasattr(self, '_m_omega_dot'): + return self._m_omega_dot + + self._m_omega_dot = ((self.omega_dot_value - (1 << 23)) if self.omega_dot_sign else self.omega_dot_value) + return getattr(self, '_m_omega_dot', None) + + @property + def idot(self): + if hasattr(self, '_m_idot'): + return self._m_idot + + self._m_idot = ((self.idot_value - (1 << 13)) if self.idot_sign else self.idot_value) + return getattr(self, '_m_idot', None) + + + class Subframe4(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self._read() + + def _read(self): + self.data_id = self._io.read_bits_int_be(2) + self.page_id = self._io.read_bits_int_be(6) + self._io.align_to_byte() + _on = self.page_id + if _on == 56: + self.body = Gps.Subframe4.IonosphereData(self._io, self, self._root) + + class IonosphereData(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self._read() + + def _read(self): + self.a0 = self._io.read_s1() + self.a1 = self._io.read_s1() + self.a2 = self._io.read_s1() + self.a3 = self._io.read_s1() + self.b0 = self._io.read_s1() + self.b1 = self._io.read_s1() + self.b2 = self._io.read_s1() + self.b3 = self._io.read_s1() + + + + class How(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self._read() + + def _read(self): + self.tow_count = self._io.read_bits_int_be(17) + self.alert = self._io.read_bits_int_be(1) != 0 + self.anti_spoof = self._io.read_bits_int_be(1) != 0 + self.subframe_id = self._io.read_bits_int_be(3) + self.reserved = self._io.read_bits_int_be(2) + + + class Tlm(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self._read() + + def _read(self): + self.preamble = self._io.read_bytes(1) + if not self.preamble == b"\x8B": + raise kaitaistruct.ValidationNotEqualError(b"\x8B", self.preamble, self._io, u"/types/tlm/seq/0") + self.tlm = self._io.read_bits_int_be(14) + self.integrity_status = self._io.read_bits_int_be(1) != 0 + self.reserved = self._io.read_bits_int_be(1) != 0 + + + class Subframe2(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self._read() + + def _read(self): + self.iode = self._io.read_u1() + self.c_rs = self._io.read_s2be() + self.delta_n = self._io.read_s2be() + self.m_0 = self._io.read_s4be() + self.c_uc = self._io.read_s2be() + self.e = self._io.read_s4be() + self.c_us = self._io.read_s2be() + self.sqrt_a = self._io.read_u4be() + self.t_oe = self._io.read_u2be() + self.fit_interval_flag = self._io.read_bits_int_be(1) != 0 + self.aoda = self._io.read_bits_int_be(5) + self.reserved = self._io.read_bits_int_be(2) + + + diff --git a/system/ubloxd/generated/ubx.cpp b/system/ubloxd/generated/ubx.cpp deleted file mode 100644 index 81b82cca..00000000 --- a/system/ubloxd/generated/ubx.cpp +++ /dev/null @@ -1,424 +0,0 @@ -// This is a generated file! Please edit source .ksy file and use kaitai-struct-compiler to rebuild - -#include "ubx.h" -#include "kaitai/exceptions.h" - -ubx_t::ubx_t(kaitai::kstream* p__io, kaitai::kstruct* p__parent, ubx_t* p__root) : kaitai::kstruct(p__io) { - m__parent = p__parent; - m__root = this; - f_checksum = false; - - try { - _read(); - } catch(...) { - _clean_up(); - throw; - } -} - -void ubx_t::_read() { - m_magic = m__io->read_bytes(2); - if (!(magic() == std::string("\xB5\x62", 2))) { - throw kaitai::validation_not_equal_error(std::string("\xB5\x62", 2), magic(), _io(), std::string("/seq/0")); - } - m_msg_type = m__io->read_u2be(); - m_length = m__io->read_u2le(); - n_body = true; - switch (msg_type()) { - case 2569: { - n_body = false; - m_body = new mon_hw_t(m__io, this, m__root); - break; - } - case 533: { - n_body = false; - m_body = new rxm_rawx_t(m__io, this, m__root); - break; - } - case 531: { - n_body = false; - m_body = new rxm_sfrbx_t(m__io, this, m__root); - break; - } - case 309: { - n_body = false; - m_body = new nav_sat_t(m__io, this, m__root); - break; - } - case 2571: { - n_body = false; - m_body = new mon_hw2_t(m__io, this, m__root); - break; - } - case 263: { - n_body = false; - m_body = new nav_pvt_t(m__io, this, m__root); - break; - } - } -} - -ubx_t::~ubx_t() { - _clean_up(); -} - -void ubx_t::_clean_up() { - if (!n_body) { - if (m_body) { - delete m_body; m_body = 0; - } - } - if (f_checksum) { - } -} - -ubx_t::rxm_rawx_t::rxm_rawx_t(kaitai::kstream* p__io, ubx_t* p__parent, ubx_t* p__root) : kaitai::kstruct(p__io) { - m__parent = p__parent; - m__root = p__root; - m_meas = 0; - m__raw_meas = 0; - m__io__raw_meas = 0; - - try { - _read(); - } catch(...) { - _clean_up(); - throw; - } -} - -void ubx_t::rxm_rawx_t::_read() { - m_rcv_tow = m__io->read_f8le(); - m_week = m__io->read_u2le(); - m_leap_s = m__io->read_s1(); - m_num_meas = m__io->read_u1(); - m_rec_stat = m__io->read_u1(); - m_reserved1 = m__io->read_bytes(3); - m__raw_meas = new std::vector(); - m__io__raw_meas = new std::vector(); - m_meas = new std::vector(); - const int l_meas = num_meas(); - for (int i = 0; i < l_meas; i++) { - m__raw_meas->push_back(m__io->read_bytes(32)); - kaitai::kstream* io__raw_meas = new kaitai::kstream(m__raw_meas->at(m__raw_meas->size() - 1)); - m__io__raw_meas->push_back(io__raw_meas); - m_meas->push_back(new measurement_t(io__raw_meas, this, m__root)); - } -} - -ubx_t::rxm_rawx_t::~rxm_rawx_t() { - _clean_up(); -} - -void ubx_t::rxm_rawx_t::_clean_up() { - if (m__raw_meas) { - delete m__raw_meas; m__raw_meas = 0; - } - if (m__io__raw_meas) { - for (std::vector::iterator it = m__io__raw_meas->begin(); it != m__io__raw_meas->end(); ++it) { - delete *it; - } - delete m__io__raw_meas; m__io__raw_meas = 0; - } - if (m_meas) { - for (std::vector::iterator it = m_meas->begin(); it != m_meas->end(); ++it) { - delete *it; - } - delete m_meas; m_meas = 0; - } -} - -ubx_t::rxm_rawx_t::measurement_t::measurement_t(kaitai::kstream* p__io, ubx_t::rxm_rawx_t* p__parent, ubx_t* p__root) : kaitai::kstruct(p__io) { - m__parent = p__parent; - m__root = p__root; - - try { - _read(); - } catch(...) { - _clean_up(); - throw; - } -} - -void ubx_t::rxm_rawx_t::measurement_t::_read() { - m_pr_mes = m__io->read_f8le(); - m_cp_mes = m__io->read_f8le(); - m_do_mes = m__io->read_f4le(); - m_gnss_id = static_cast(m__io->read_u1()); - m_sv_id = m__io->read_u1(); - m_reserved2 = m__io->read_bytes(1); - m_freq_id = m__io->read_u1(); - m_lock_time = m__io->read_u2le(); - m_cno = m__io->read_u1(); - m_pr_stdev = m__io->read_u1(); - m_cp_stdev = m__io->read_u1(); - m_do_stdev = m__io->read_u1(); - m_trk_stat = m__io->read_u1(); - m_reserved3 = m__io->read_bytes(1); -} - -ubx_t::rxm_rawx_t::measurement_t::~measurement_t() { - _clean_up(); -} - -void ubx_t::rxm_rawx_t::measurement_t::_clean_up() { -} - -ubx_t::rxm_sfrbx_t::rxm_sfrbx_t(kaitai::kstream* p__io, ubx_t* p__parent, ubx_t* p__root) : kaitai::kstruct(p__io) { - m__parent = p__parent; - m__root = p__root; - m_body = 0; - - try { - _read(); - } catch(...) { - _clean_up(); - throw; - } -} - -void ubx_t::rxm_sfrbx_t::_read() { - m_gnss_id = static_cast(m__io->read_u1()); - m_sv_id = m__io->read_u1(); - m_reserved1 = m__io->read_bytes(1); - m_freq_id = m__io->read_u1(); - m_num_words = m__io->read_u1(); - m_reserved2 = m__io->read_bytes(1); - m_version = m__io->read_u1(); - m_reserved3 = m__io->read_bytes(1); - m_body = new std::vector(); - const int l_body = num_words(); - for (int i = 0; i < l_body; i++) { - m_body->push_back(m__io->read_u4le()); - } -} - -ubx_t::rxm_sfrbx_t::~rxm_sfrbx_t() { - _clean_up(); -} - -void ubx_t::rxm_sfrbx_t::_clean_up() { - if (m_body) { - delete m_body; m_body = 0; - } -} - -ubx_t::nav_sat_t::nav_sat_t(kaitai::kstream* p__io, ubx_t* p__parent, ubx_t* p__root) : kaitai::kstruct(p__io) { - m__parent = p__parent; - m__root = p__root; - m_svs = 0; - m__raw_svs = 0; - m__io__raw_svs = 0; - - try { - _read(); - } catch(...) { - _clean_up(); - throw; - } -} - -void ubx_t::nav_sat_t::_read() { - m_itow = m__io->read_u4le(); - m_version = m__io->read_u1(); - m_num_svs = m__io->read_u1(); - m_reserved = m__io->read_bytes(2); - m__raw_svs = new std::vector(); - m__io__raw_svs = new std::vector(); - m_svs = new std::vector(); - const int l_svs = num_svs(); - for (int i = 0; i < l_svs; i++) { - m__raw_svs->push_back(m__io->read_bytes(12)); - kaitai::kstream* io__raw_svs = new kaitai::kstream(m__raw_svs->at(m__raw_svs->size() - 1)); - m__io__raw_svs->push_back(io__raw_svs); - m_svs->push_back(new nav_t(io__raw_svs, this, m__root)); - } -} - -ubx_t::nav_sat_t::~nav_sat_t() { - _clean_up(); -} - -void ubx_t::nav_sat_t::_clean_up() { - if (m__raw_svs) { - delete m__raw_svs; m__raw_svs = 0; - } - if (m__io__raw_svs) { - for (std::vector::iterator it = m__io__raw_svs->begin(); it != m__io__raw_svs->end(); ++it) { - delete *it; - } - delete m__io__raw_svs; m__io__raw_svs = 0; - } - if (m_svs) { - for (std::vector::iterator it = m_svs->begin(); it != m_svs->end(); ++it) { - delete *it; - } - delete m_svs; m_svs = 0; - } -} - -ubx_t::nav_sat_t::nav_t::nav_t(kaitai::kstream* p__io, ubx_t::nav_sat_t* p__parent, ubx_t* p__root) : kaitai::kstruct(p__io) { - m__parent = p__parent; - m__root = p__root; - - try { - _read(); - } catch(...) { - _clean_up(); - throw; - } -} - -void ubx_t::nav_sat_t::nav_t::_read() { - m_gnss_id = static_cast(m__io->read_u1()); - m_sv_id = m__io->read_u1(); - m_cno = m__io->read_u1(); - m_elev = m__io->read_s1(); - m_azim = m__io->read_s2le(); - m_pr_res = m__io->read_s2le(); - m_flags = m__io->read_u4le(); -} - -ubx_t::nav_sat_t::nav_t::~nav_t() { - _clean_up(); -} - -void ubx_t::nav_sat_t::nav_t::_clean_up() { -} - -ubx_t::nav_pvt_t::nav_pvt_t(kaitai::kstream* p__io, ubx_t* p__parent, ubx_t* p__root) : kaitai::kstruct(p__io) { - m__parent = p__parent; - m__root = p__root; - - try { - _read(); - } catch(...) { - _clean_up(); - throw; - } -} - -void ubx_t::nav_pvt_t::_read() { - m_i_tow = m__io->read_u4le(); - m_year = m__io->read_u2le(); - m_month = m__io->read_u1(); - m_day = m__io->read_u1(); - m_hour = m__io->read_u1(); - m_min = m__io->read_u1(); - m_sec = m__io->read_u1(); - m_valid = m__io->read_u1(); - m_t_acc = m__io->read_u4le(); - m_nano = m__io->read_s4le(); - m_fix_type = m__io->read_u1(); - m_flags = m__io->read_u1(); - m_flags2 = m__io->read_u1(); - m_num_sv = m__io->read_u1(); - m_lon = m__io->read_s4le(); - m_lat = m__io->read_s4le(); - m_height = m__io->read_s4le(); - m_h_msl = m__io->read_s4le(); - m_h_acc = m__io->read_u4le(); - m_v_acc = m__io->read_u4le(); - m_vel_n = m__io->read_s4le(); - m_vel_e = m__io->read_s4le(); - m_vel_d = m__io->read_s4le(); - m_g_speed = m__io->read_s4le(); - m_head_mot = m__io->read_s4le(); - m_s_acc = m__io->read_s4le(); - m_head_acc = m__io->read_u4le(); - m_p_dop = m__io->read_u2le(); - m_flags3 = m__io->read_u1(); - m_reserved1 = m__io->read_bytes(5); - m_head_veh = m__io->read_s4le(); - m_mag_dec = m__io->read_s2le(); - m_mag_acc = m__io->read_u2le(); -} - -ubx_t::nav_pvt_t::~nav_pvt_t() { - _clean_up(); -} - -void ubx_t::nav_pvt_t::_clean_up() { -} - -ubx_t::mon_hw2_t::mon_hw2_t(kaitai::kstream* p__io, ubx_t* p__parent, ubx_t* p__root) : kaitai::kstruct(p__io) { - m__parent = p__parent; - m__root = p__root; - - try { - _read(); - } catch(...) { - _clean_up(); - throw; - } -} - -void ubx_t::mon_hw2_t::_read() { - m_ofs_i = m__io->read_s1(); - m_mag_i = m__io->read_u1(); - m_ofs_q = m__io->read_s1(); - m_mag_q = m__io->read_u1(); - m_cfg_source = static_cast(m__io->read_u1()); - m_reserved1 = m__io->read_bytes(3); - m_low_lev_cfg = m__io->read_u4le(); - m_reserved2 = m__io->read_bytes(8); - m_post_status = m__io->read_u4le(); - m_reserved3 = m__io->read_bytes(4); -} - -ubx_t::mon_hw2_t::~mon_hw2_t() { - _clean_up(); -} - -void ubx_t::mon_hw2_t::_clean_up() { -} - -ubx_t::mon_hw_t::mon_hw_t(kaitai::kstream* p__io, ubx_t* p__parent, ubx_t* p__root) : kaitai::kstruct(p__io) { - m__parent = p__parent; - m__root = p__root; - - try { - _read(); - } catch(...) { - _clean_up(); - throw; - } -} - -void ubx_t::mon_hw_t::_read() { - m_pin_sel = m__io->read_u4le(); - m_pin_bank = m__io->read_u4le(); - m_pin_dir = m__io->read_u4le(); - m_pin_val = m__io->read_u4le(); - m_noise_per_ms = m__io->read_u2le(); - m_agc_cnt = m__io->read_u2le(); - m_a_status = static_cast(m__io->read_u1()); - m_a_power = static_cast(m__io->read_u1()); - m_flags = m__io->read_u1(); - m_reserved1 = m__io->read_bytes(1); - m_used_mask = m__io->read_u4le(); - m_vp = m__io->read_bytes(17); - m_jam_ind = m__io->read_u1(); - m_reserved2 = m__io->read_bytes(2); - m_pin_irq = m__io->read_u4le(); - m_pull_h = m__io->read_u4le(); - m_pull_l = m__io->read_u4le(); -} - -ubx_t::mon_hw_t::~mon_hw_t() { - _clean_up(); -} - -void ubx_t::mon_hw_t::_clean_up() { -} - -uint16_t ubx_t::checksum() { - if (f_checksum) - return m_checksum; - std::streampos _pos = m__io->pos(); - m__io->seek((length() + 6)); - m_checksum = m__io->read_u2le(); - m__io->seek(_pos); - f_checksum = true; - return m_checksum; -} diff --git a/system/ubloxd/generated/ubx.h b/system/ubloxd/generated/ubx.h deleted file mode 100644 index 02210848..00000000 --- a/system/ubloxd/generated/ubx.h +++ /dev/null @@ -1,484 +0,0 @@ -#ifndef UBX_H_ -#define UBX_H_ - -// This is a generated file! Please edit source .ksy file and use kaitai-struct-compiler to rebuild - -#include "kaitai/kaitaistruct.h" -#include -#include - -#if KAITAI_STRUCT_VERSION < 9000L -#error "Incompatible Kaitai Struct C++/STL API: version 0.9 or later is required" -#endif - -class ubx_t : public kaitai::kstruct { - -public: - class rxm_rawx_t; - class rxm_sfrbx_t; - class nav_sat_t; - class nav_pvt_t; - class mon_hw2_t; - class mon_hw_t; - - enum gnss_type_t { - GNSS_TYPE_GPS = 0, - GNSS_TYPE_SBAS = 1, - GNSS_TYPE_GALILEO = 2, - GNSS_TYPE_BEIDOU = 3, - GNSS_TYPE_IMES = 4, - GNSS_TYPE_QZSS = 5, - GNSS_TYPE_GLONASS = 6 - }; - - ubx_t(kaitai::kstream* p__io, kaitai::kstruct* p__parent = 0, ubx_t* p__root = 0); - -private: - void _read(); - void _clean_up(); - -public: - ~ubx_t(); - - class rxm_rawx_t : public kaitai::kstruct { - - public: - class measurement_t; - - rxm_rawx_t(kaitai::kstream* p__io, ubx_t* p__parent = 0, ubx_t* p__root = 0); - - private: - void _read(); - void _clean_up(); - - public: - ~rxm_rawx_t(); - - class measurement_t : public kaitai::kstruct { - - public: - - measurement_t(kaitai::kstream* p__io, ubx_t::rxm_rawx_t* p__parent = 0, ubx_t* p__root = 0); - - private: - void _read(); - void _clean_up(); - - public: - ~measurement_t(); - - private: - double m_pr_mes; - double m_cp_mes; - float m_do_mes; - gnss_type_t m_gnss_id; - uint8_t m_sv_id; - std::string m_reserved2; - uint8_t m_freq_id; - uint16_t m_lock_time; - uint8_t m_cno; - uint8_t m_pr_stdev; - uint8_t m_cp_stdev; - uint8_t m_do_stdev; - uint8_t m_trk_stat; - std::string m_reserved3; - ubx_t* m__root; - ubx_t::rxm_rawx_t* m__parent; - - public: - double pr_mes() const { return m_pr_mes; } - double cp_mes() const { return m_cp_mes; } - float do_mes() const { return m_do_mes; } - gnss_type_t gnss_id() const { return m_gnss_id; } - uint8_t sv_id() const { return m_sv_id; } - std::string reserved2() const { return m_reserved2; } - uint8_t freq_id() const { return m_freq_id; } - uint16_t lock_time() const { return m_lock_time; } - uint8_t cno() const { return m_cno; } - uint8_t pr_stdev() const { return m_pr_stdev; } - uint8_t cp_stdev() const { return m_cp_stdev; } - uint8_t do_stdev() const { return m_do_stdev; } - uint8_t trk_stat() const { return m_trk_stat; } - std::string reserved3() const { return m_reserved3; } - ubx_t* _root() const { return m__root; } - ubx_t::rxm_rawx_t* _parent() const { return m__parent; } - }; - - private: - double m_rcv_tow; - uint16_t m_week; - int8_t m_leap_s; - uint8_t m_num_meas; - uint8_t m_rec_stat; - std::string m_reserved1; - std::vector* m_meas; - ubx_t* m__root; - ubx_t* m__parent; - std::vector* m__raw_meas; - std::vector* m__io__raw_meas; - - public: - double rcv_tow() const { return m_rcv_tow; } - uint16_t week() const { return m_week; } - int8_t leap_s() const { return m_leap_s; } - uint8_t num_meas() const { return m_num_meas; } - uint8_t rec_stat() const { return m_rec_stat; } - std::string reserved1() const { return m_reserved1; } - std::vector* meas() const { return m_meas; } - ubx_t* _root() const { return m__root; } - ubx_t* _parent() const { return m__parent; } - std::vector* _raw_meas() const { return m__raw_meas; } - std::vector* _io__raw_meas() const { return m__io__raw_meas; } - }; - - class rxm_sfrbx_t : public kaitai::kstruct { - - public: - - rxm_sfrbx_t(kaitai::kstream* p__io, ubx_t* p__parent = 0, ubx_t* p__root = 0); - - private: - void _read(); - void _clean_up(); - - public: - ~rxm_sfrbx_t(); - - private: - gnss_type_t m_gnss_id; - uint8_t m_sv_id; - std::string m_reserved1; - uint8_t m_freq_id; - uint8_t m_num_words; - std::string m_reserved2; - uint8_t m_version; - std::string m_reserved3; - std::vector* m_body; - ubx_t* m__root; - ubx_t* m__parent; - - public: - gnss_type_t gnss_id() const { return m_gnss_id; } - uint8_t sv_id() const { return m_sv_id; } - std::string reserved1() const { return m_reserved1; } - uint8_t freq_id() const { return m_freq_id; } - uint8_t num_words() const { return m_num_words; } - std::string reserved2() const { return m_reserved2; } - uint8_t version() const { return m_version; } - std::string reserved3() const { return m_reserved3; } - std::vector* body() const { return m_body; } - ubx_t* _root() const { return m__root; } - ubx_t* _parent() const { return m__parent; } - }; - - class nav_sat_t : public kaitai::kstruct { - - public: - class nav_t; - - nav_sat_t(kaitai::kstream* p__io, ubx_t* p__parent = 0, ubx_t* p__root = 0); - - private: - void _read(); - void _clean_up(); - - public: - ~nav_sat_t(); - - class nav_t : public kaitai::kstruct { - - public: - - nav_t(kaitai::kstream* p__io, ubx_t::nav_sat_t* p__parent = 0, ubx_t* p__root = 0); - - private: - void _read(); - void _clean_up(); - - public: - ~nav_t(); - - private: - gnss_type_t m_gnss_id; - uint8_t m_sv_id; - uint8_t m_cno; - int8_t m_elev; - int16_t m_azim; - int16_t m_pr_res; - uint32_t m_flags; - ubx_t* m__root; - ubx_t::nav_sat_t* m__parent; - - public: - gnss_type_t gnss_id() const { return m_gnss_id; } - uint8_t sv_id() const { return m_sv_id; } - uint8_t cno() const { return m_cno; } - int8_t elev() const { return m_elev; } - int16_t azim() const { return m_azim; } - int16_t pr_res() const { return m_pr_res; } - uint32_t flags() const { return m_flags; } - ubx_t* _root() const { return m__root; } - ubx_t::nav_sat_t* _parent() const { return m__parent; } - }; - - private: - uint32_t m_itow; - uint8_t m_version; - uint8_t m_num_svs; - std::string m_reserved; - std::vector* m_svs; - ubx_t* m__root; - ubx_t* m__parent; - std::vector* m__raw_svs; - std::vector* m__io__raw_svs; - - public: - uint32_t itow() const { return m_itow; } - uint8_t version() const { return m_version; } - uint8_t num_svs() const { return m_num_svs; } - std::string reserved() const { return m_reserved; } - std::vector* svs() const { return m_svs; } - ubx_t* _root() const { return m__root; } - ubx_t* _parent() const { return m__parent; } - std::vector* _raw_svs() const { return m__raw_svs; } - std::vector* _io__raw_svs() const { return m__io__raw_svs; } - }; - - class nav_pvt_t : public kaitai::kstruct { - - public: - - nav_pvt_t(kaitai::kstream* p__io, ubx_t* p__parent = 0, ubx_t* p__root = 0); - - private: - void _read(); - void _clean_up(); - - public: - ~nav_pvt_t(); - - private: - uint32_t m_i_tow; - uint16_t m_year; - uint8_t m_month; - uint8_t m_day; - uint8_t m_hour; - uint8_t m_min; - uint8_t m_sec; - uint8_t m_valid; - uint32_t m_t_acc; - int32_t m_nano; - uint8_t m_fix_type; - uint8_t m_flags; - uint8_t m_flags2; - uint8_t m_num_sv; - int32_t m_lon; - int32_t m_lat; - int32_t m_height; - int32_t m_h_msl; - uint32_t m_h_acc; - uint32_t m_v_acc; - int32_t m_vel_n; - int32_t m_vel_e; - int32_t m_vel_d; - int32_t m_g_speed; - int32_t m_head_mot; - int32_t m_s_acc; - uint32_t m_head_acc; - uint16_t m_p_dop; - uint8_t m_flags3; - std::string m_reserved1; - int32_t m_head_veh; - int16_t m_mag_dec; - uint16_t m_mag_acc; - ubx_t* m__root; - ubx_t* m__parent; - - public: - uint32_t i_tow() const { return m_i_tow; } - uint16_t year() const { return m_year; } - uint8_t month() const { return m_month; } - uint8_t day() const { return m_day; } - uint8_t hour() const { return m_hour; } - uint8_t min() const { return m_min; } - uint8_t sec() const { return m_sec; } - uint8_t valid() const { return m_valid; } - uint32_t t_acc() const { return m_t_acc; } - int32_t nano() const { return m_nano; } - uint8_t fix_type() const { return m_fix_type; } - uint8_t flags() const { return m_flags; } - uint8_t flags2() const { return m_flags2; } - uint8_t num_sv() const { return m_num_sv; } - int32_t lon() const { return m_lon; } - int32_t lat() const { return m_lat; } - int32_t height() const { return m_height; } - int32_t h_msl() const { return m_h_msl; } - uint32_t h_acc() const { return m_h_acc; } - uint32_t v_acc() const { return m_v_acc; } - int32_t vel_n() const { return m_vel_n; } - int32_t vel_e() const { return m_vel_e; } - int32_t vel_d() const { return m_vel_d; } - int32_t g_speed() const { return m_g_speed; } - int32_t head_mot() const { return m_head_mot; } - int32_t s_acc() const { return m_s_acc; } - uint32_t head_acc() const { return m_head_acc; } - uint16_t p_dop() const { return m_p_dop; } - uint8_t flags3() const { return m_flags3; } - std::string reserved1() const { return m_reserved1; } - int32_t head_veh() const { return m_head_veh; } - int16_t mag_dec() const { return m_mag_dec; } - uint16_t mag_acc() const { return m_mag_acc; } - ubx_t* _root() const { return m__root; } - ubx_t* _parent() const { return m__parent; } - }; - - class mon_hw2_t : public kaitai::kstruct { - - public: - - enum config_source_t { - CONFIG_SOURCE_FLASH = 102, - CONFIG_SOURCE_OTP = 111, - CONFIG_SOURCE_CONFIG_PINS = 112, - CONFIG_SOURCE_ROM = 113 - }; - - mon_hw2_t(kaitai::kstream* p__io, ubx_t* p__parent = 0, ubx_t* p__root = 0); - - private: - void _read(); - void _clean_up(); - - public: - ~mon_hw2_t(); - - private: - int8_t m_ofs_i; - uint8_t m_mag_i; - int8_t m_ofs_q; - uint8_t m_mag_q; - config_source_t m_cfg_source; - std::string m_reserved1; - uint32_t m_low_lev_cfg; - std::string m_reserved2; - uint32_t m_post_status; - std::string m_reserved3; - ubx_t* m__root; - ubx_t* m__parent; - - public: - int8_t ofs_i() const { return m_ofs_i; } - uint8_t mag_i() const { return m_mag_i; } - int8_t ofs_q() const { return m_ofs_q; } - uint8_t mag_q() const { return m_mag_q; } - config_source_t cfg_source() const { return m_cfg_source; } - std::string reserved1() const { return m_reserved1; } - uint32_t low_lev_cfg() const { return m_low_lev_cfg; } - std::string reserved2() const { return m_reserved2; } - uint32_t post_status() const { return m_post_status; } - std::string reserved3() const { return m_reserved3; } - ubx_t* _root() const { return m__root; } - ubx_t* _parent() const { return m__parent; } - }; - - class mon_hw_t : public kaitai::kstruct { - - public: - - enum antenna_status_t { - ANTENNA_STATUS_INIT = 0, - ANTENNA_STATUS_DONTKNOW = 1, - ANTENNA_STATUS_OK = 2, - ANTENNA_STATUS_SHORT = 3, - ANTENNA_STATUS_OPEN = 4 - }; - - enum antenna_power_t { - ANTENNA_POWER_FALSE = 0, - ANTENNA_POWER_TRUE = 1, - ANTENNA_POWER_DONTKNOW = 2 - }; - - mon_hw_t(kaitai::kstream* p__io, ubx_t* p__parent = 0, ubx_t* p__root = 0); - - private: - void _read(); - void _clean_up(); - - public: - ~mon_hw_t(); - - private: - uint32_t m_pin_sel; - uint32_t m_pin_bank; - uint32_t m_pin_dir; - uint32_t m_pin_val; - uint16_t m_noise_per_ms; - uint16_t m_agc_cnt; - antenna_status_t m_a_status; - antenna_power_t m_a_power; - uint8_t m_flags; - std::string m_reserved1; - uint32_t m_used_mask; - std::string m_vp; - uint8_t m_jam_ind; - std::string m_reserved2; - uint32_t m_pin_irq; - uint32_t m_pull_h; - uint32_t m_pull_l; - ubx_t* m__root; - ubx_t* m__parent; - - public: - uint32_t pin_sel() const { return m_pin_sel; } - uint32_t pin_bank() const { return m_pin_bank; } - uint32_t pin_dir() const { return m_pin_dir; } - uint32_t pin_val() const { return m_pin_val; } - uint16_t noise_per_ms() const { return m_noise_per_ms; } - uint16_t agc_cnt() const { return m_agc_cnt; } - antenna_status_t a_status() const { return m_a_status; } - antenna_power_t a_power() const { return m_a_power; } - uint8_t flags() const { return m_flags; } - std::string reserved1() const { return m_reserved1; } - uint32_t used_mask() const { return m_used_mask; } - std::string vp() const { return m_vp; } - uint8_t jam_ind() const { return m_jam_ind; } - std::string reserved2() const { return m_reserved2; } - uint32_t pin_irq() const { return m_pin_irq; } - uint32_t pull_h() const { return m_pull_h; } - uint32_t pull_l() const { return m_pull_l; } - ubx_t* _root() const { return m__root; } - ubx_t* _parent() const { return m__parent; } - }; - -private: - bool f_checksum; - uint16_t m_checksum; - -public: - uint16_t checksum(); - -private: - std::string m_magic; - uint16_t m_msg_type; - uint16_t m_length; - kaitai::kstruct* m_body; - bool n_body; - -public: - bool _is_null_body() { body(); return n_body; }; - -private: - ubx_t* m__root; - kaitai::kstruct* m__parent; - -public: - std::string magic() const { return m_magic; } - uint16_t msg_type() const { return m_msg_type; } - uint16_t length() const { return m_length; } - kaitai::kstruct* body() const { return m_body; } - ubx_t* _root() const { return m__root; } - kaitai::kstruct* _parent() const { return m__parent; } -}; - -#endif // UBX_H_ diff --git a/system/ubloxd/generated/ubx.py b/system/ubloxd/generated/ubx.py new file mode 100644 index 00000000..99465843 --- /dev/null +++ b/system/ubloxd/generated/ubx.py @@ -0,0 +1,273 @@ +# This is a generated file! Please edit source .ksy file and use kaitai-struct-compiler to rebuild + +import kaitaistruct +from kaitaistruct import KaitaiStruct, KaitaiStream, BytesIO +from enum import Enum + + +if getattr(kaitaistruct, 'API_VERSION', (0, 9)) < (0, 9): + raise Exception("Incompatible Kaitai Struct Python API: 0.9 or later is required, but you have %s" % (kaitaistruct.__version__)) + +class Ubx(KaitaiStruct): + + class GnssType(Enum): + gps = 0 + sbas = 1 + galileo = 2 + beidou = 3 + imes = 4 + qzss = 5 + glonass = 6 + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self._read() + + def _read(self): + self.magic = self._io.read_bytes(2) + if not self.magic == b"\xB5\x62": + raise kaitaistruct.ValidationNotEqualError(b"\xB5\x62", self.magic, self._io, u"/seq/0") + self.msg_type = self._io.read_u2be() + self.length = self._io.read_u2le() + _on = self.msg_type + if _on == 2569: + self.body = Ubx.MonHw(self._io, self, self._root) + elif _on == 533: + self.body = Ubx.RxmRawx(self._io, self, self._root) + elif _on == 531: + self.body = Ubx.RxmSfrbx(self._io, self, self._root) + elif _on == 309: + self.body = Ubx.NavSat(self._io, self, self._root) + elif _on == 2571: + self.body = Ubx.MonHw2(self._io, self, self._root) + elif _on == 263: + self.body = Ubx.NavPvt(self._io, self, self._root) + + class RxmRawx(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self._read() + + def _read(self): + self.rcv_tow = self._io.read_f8le() + self.week = self._io.read_u2le() + self.leap_s = self._io.read_s1() + self.num_meas = self._io.read_u1() + self.rec_stat = self._io.read_u1() + self.reserved1 = self._io.read_bytes(3) + self._raw_meas = [] + self.meas = [] + for i in range(self.num_meas): + self._raw_meas.append(self._io.read_bytes(32)) + _io__raw_meas = KaitaiStream(BytesIO(self._raw_meas[i])) + self.meas.append(Ubx.RxmRawx.Measurement(_io__raw_meas, self, self._root)) + + + class Measurement(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self._read() + + def _read(self): + self.pr_mes = self._io.read_f8le() + self.cp_mes = self._io.read_f8le() + self.do_mes = self._io.read_f4le() + self.gnss_id = KaitaiStream.resolve_enum(Ubx.GnssType, self._io.read_u1()) + self.sv_id = self._io.read_u1() + self.reserved2 = self._io.read_bytes(1) + self.freq_id = self._io.read_u1() + self.lock_time = self._io.read_u2le() + self.cno = self._io.read_u1() + self.pr_stdev = self._io.read_u1() + self.cp_stdev = self._io.read_u1() + self.do_stdev = self._io.read_u1() + self.trk_stat = self._io.read_u1() + self.reserved3 = self._io.read_bytes(1) + + + + class RxmSfrbx(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self._read() + + def _read(self): + self.gnss_id = KaitaiStream.resolve_enum(Ubx.GnssType, self._io.read_u1()) + self.sv_id = self._io.read_u1() + self.reserved1 = self._io.read_bytes(1) + self.freq_id = self._io.read_u1() + self.num_words = self._io.read_u1() + self.reserved2 = self._io.read_bytes(1) + self.version = self._io.read_u1() + self.reserved3 = self._io.read_bytes(1) + self.body = [] + for i in range(self.num_words): + self.body.append(self._io.read_u4le()) + + + + class NavSat(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self._read() + + def _read(self): + self.itow = self._io.read_u4le() + self.version = self._io.read_u1() + self.num_svs = self._io.read_u1() + self.reserved = self._io.read_bytes(2) + self._raw_svs = [] + self.svs = [] + for i in range(self.num_svs): + self._raw_svs.append(self._io.read_bytes(12)) + _io__raw_svs = KaitaiStream(BytesIO(self._raw_svs[i])) + self.svs.append(Ubx.NavSat.Nav(_io__raw_svs, self, self._root)) + + + class Nav(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self._read() + + def _read(self): + self.gnss_id = KaitaiStream.resolve_enum(Ubx.GnssType, self._io.read_u1()) + self.sv_id = self._io.read_u1() + self.cno = self._io.read_u1() + self.elev = self._io.read_s1() + self.azim = self._io.read_s2le() + self.pr_res = self._io.read_s2le() + self.flags = self._io.read_u4le() + + + + class NavPvt(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self._read() + + def _read(self): + self.i_tow = self._io.read_u4le() + self.year = self._io.read_u2le() + self.month = self._io.read_u1() + self.day = self._io.read_u1() + self.hour = self._io.read_u1() + self.min = self._io.read_u1() + self.sec = self._io.read_u1() + self.valid = self._io.read_u1() + self.t_acc = self._io.read_u4le() + self.nano = self._io.read_s4le() + self.fix_type = self._io.read_u1() + self.flags = self._io.read_u1() + self.flags2 = self._io.read_u1() + self.num_sv = self._io.read_u1() + self.lon = self._io.read_s4le() + self.lat = self._io.read_s4le() + self.height = self._io.read_s4le() + self.h_msl = self._io.read_s4le() + self.h_acc = self._io.read_u4le() + self.v_acc = self._io.read_u4le() + self.vel_n = self._io.read_s4le() + self.vel_e = self._io.read_s4le() + self.vel_d = self._io.read_s4le() + self.g_speed = self._io.read_s4le() + self.head_mot = self._io.read_s4le() + self.s_acc = self._io.read_s4le() + self.head_acc = self._io.read_u4le() + self.p_dop = self._io.read_u2le() + self.flags3 = self._io.read_u1() + self.reserved1 = self._io.read_bytes(5) + self.head_veh = self._io.read_s4le() + self.mag_dec = self._io.read_s2le() + self.mag_acc = self._io.read_u2le() + + + class MonHw2(KaitaiStruct): + + class ConfigSource(Enum): + flash = 102 + otp = 111 + config_pins = 112 + rom = 113 + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self._read() + + def _read(self): + self.ofs_i = self._io.read_s1() + self.mag_i = self._io.read_u1() + self.ofs_q = self._io.read_s1() + self.mag_q = self._io.read_u1() + self.cfg_source = KaitaiStream.resolve_enum(Ubx.MonHw2.ConfigSource, self._io.read_u1()) + self.reserved1 = self._io.read_bytes(3) + self.low_lev_cfg = self._io.read_u4le() + self.reserved2 = self._io.read_bytes(8) + self.post_status = self._io.read_u4le() + self.reserved3 = self._io.read_bytes(4) + + + class MonHw(KaitaiStruct): + + class AntennaStatus(Enum): + init = 0 + dontknow = 1 + ok = 2 + short = 3 + open = 4 + + class AntennaPower(Enum): + false = 0 + true = 1 + dontknow = 2 + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self._read() + + def _read(self): + self.pin_sel = self._io.read_u4le() + self.pin_bank = self._io.read_u4le() + self.pin_dir = self._io.read_u4le() + self.pin_val = self._io.read_u4le() + self.noise_per_ms = self._io.read_u2le() + self.agc_cnt = self._io.read_u2le() + self.a_status = KaitaiStream.resolve_enum(Ubx.MonHw.AntennaStatus, self._io.read_u1()) + self.a_power = KaitaiStream.resolve_enum(Ubx.MonHw.AntennaPower, self._io.read_u1()) + self.flags = self._io.read_u1() + self.reserved1 = self._io.read_bytes(1) + self.used_mask = self._io.read_u4le() + self.vp = self._io.read_bytes(17) + self.jam_ind = self._io.read_u1() + self.reserved2 = self._io.read_bytes(2) + self.pin_irq = self._io.read_u4le() + self.pull_h = self._io.read_u4le() + self.pull_l = self._io.read_u4le() + + + @property + def checksum(self): + if hasattr(self, '_m_checksum'): + return self._m_checksum + + _pos = self._io.pos() + self._io.seek((self.length + 6)) + self._m_checksum = self._io.read_u2le() + self._io.seek(_pos) + return getattr(self, '_m_checksum', None) + + diff --git a/system/ubloxd/glonass_fix.patch b/system/ubloxd/glonass_fix.patch deleted file mode 100644 index 7eb973a3..00000000 --- a/system/ubloxd/glonass_fix.patch +++ /dev/null @@ -1,13 +0,0 @@ -diff --git a/system/ubloxd/generated/glonass.cpp b/system/ubloxd/generated/glonass.cpp -index 5b17bc327..b5c6aa610 100644 ---- a/system/ubloxd/generated/glonass.cpp -+++ b/system/ubloxd/generated/glonass.cpp -@@ -17,7 +17,7 @@ glonass_t::glonass_t(kaitai::kstream* p__io, kaitai::kstruct* p__parent, glonass - void glonass_t::_read() { - m_idle_chip = m__io->read_bits_int_be(1); - m_string_number = m__io->read_bits_int_be(4); -- m__io->align_to_byte(); -+ //m__io->align_to_byte(); - switch (string_number()) { - case 4: { - m_data = new string_4_t(m__io, this, m__root); diff --git a/system/ubloxd/pigeond.py b/system/ubloxd/pigeond.py index 2194a4b9..e458a9d6 100755 --- a/system/ubloxd/pigeond.py +++ b/system/ubloxd/pigeond.py @@ -41,7 +41,7 @@ def add_ubx_checksum(msg: bytes) -> bytes: B = (B + A) % 256 return msg + bytes([A, B]) -def get_assistnow_messages(token: bytes) -> list[bytes]: +def get_assistnow_messages(token: str) -> list[bytes]: # make request # TODO: implement adding the last known location r = requests.get("https://online-live2.services.u-blox.com/GetOnlineData.ashx", params=urllib.parse.urlencode({ @@ -136,6 +136,17 @@ class TTYPigeon: return True return False +def save_almanac(pigeon: TTYPigeon) -> None: + # store almanac in flash + pigeon.send(b"\xB5\x62\x09\x14\x04\x00\x00\x00\x00\x00\x21\xEC") + try: + if pigeon.wait_for_ack(ack=UBLOX_SOS_ACK, nack=UBLOX_SOS_NACK): + cloudlog.info("Done storing almanac") + else: + cloudlog.error("Error storing almanac") + except TimeoutError: + pass + def init_baudrate(pigeon: TTYPigeon): # ublox default setting on startup is 9600 baudrate pigeon.set_baud(9600) @@ -146,7 +157,7 @@ def init_baudrate(pigeon: TTYPigeon): pigeon.set_baud(460800) -def initialize_pigeon(pigeon: TTYPigeon) -> bool: +def init_pigeon(pigeon: TTYPigeon) -> bool: # try initializing a few times for _ in range(10): try: @@ -239,32 +250,17 @@ def initialize_pigeon(pigeon: TTYPigeon) -> bool: return True def deinitialize_and_exit(pigeon: TTYPigeon | None): - cloudlog.warning("Storing almanac in ublox flash") - if pigeon is not None: # controlled GNSS stop pigeon.send(b"\xB5\x62\x06\x04\x04\x00\x00\x00\x08\x00\x16\x74") - # store almanac in flash - pigeon.send(b"\xB5\x62\x09\x14\x04\x00\x00\x00\x00\x00\x21\xEC") - try: - if pigeon.wait_for_ack(ack=UBLOX_SOS_ACK, nack=UBLOX_SOS_NACK): - cloudlog.warning("Done storing almanac") - else: - cloudlog.error("Error storing almanac") - except TimeoutError: - pass - # turn off power and exit cleanly set_power(False) sys.exit(0) -def create_pigeon() -> tuple[TTYPigeon, messaging.PubMaster]: - pigeon = None - +def init(pigeon: TTYPigeon) -> None: # register exit handler signal.signal(signal.SIGINT, lambda sig, frame: deinitialize_and_exit(pigeon)) - pm = messaging.PubMaster(['ubloxRaw']) # power cycle ublox set_power(False) @@ -272,28 +268,34 @@ def create_pigeon() -> tuple[TTYPigeon, messaging.PubMaster]: set_power(True) time.sleep(0.5) - pigeon = TTYPigeon() - return pigeon, pm + init_baudrate(pigeon) + init_pigeon(pigeon) -def run_receiving(pigeon: TTYPigeon, pm: messaging.PubMaster, duration: int = 0): +def run_receiving(duration: int = 0): + pm = messaging.PubMaster(['ubloxRaw']) + + pigeon = TTYPigeon() + init(pigeon) start_time = time.monotonic() - def end_condition(): - return True if duration == 0 else time.monotonic() - start_time < duration - - while end_condition(): + last_almanac_save = time.monotonic() + while (duration == 0) or (time.monotonic() - start_time < duration): dat = pigeon.receive() if len(dat) > 0: if dat[0] == 0x00: cloudlog.warning("received invalid data from ublox, re-initing!") - init_baudrate(pigeon) - initialize_pigeon(pigeon) + init(pigeon) continue # send out to socket msg = messaging.new_message('ubloxRaw', len(dat), valid=True) msg.ubloxRaw = dat[:] pm.send('ubloxRaw', msg) + + # save almanac every 5 minutes + if (time.monotonic() - last_almanac_save) > 60*5: + save_almanac(pigeon) + last_almanac_save = time.monotonic() else: # prevent locking up a CPU core if ublox disconnects time.sleep(0.001) @@ -301,13 +303,7 @@ def run_receiving(pigeon: TTYPigeon, pm: messaging.PubMaster, duration: int = 0) def main(): assert TICI, "unsupported hardware for pigeond" - - pigeon, pm = create_pigeon() - init_baudrate(pigeon) - initialize_pigeon(pigeon) - - # start receiving data - run_receiving(pigeon, pm) + run_receiving() if __name__ == "__main__": main() diff --git a/system/ubloxd/tests/print_gps_stats.py b/system/ubloxd/tests/print_gps_stats.py deleted file mode 100755 index 8d190f9e..00000000 --- a/system/ubloxd/tests/print_gps_stats.py +++ /dev/null @@ -1,20 +0,0 @@ -#!/usr/bin/env python3 -import time -import cereal.messaging as messaging - -if __name__ == "__main__": - sm = messaging.SubMaster(['ubloxGnss', 'gpsLocationExternal']) - - while 1: - ug = sm['ubloxGnss'] - gle = sm['gpsLocationExternal'] - - try: - cnos = [] - for m in ug.measurementReport.measurements: - cnos.append(m.cno) - print(f"Sats: {ug.measurementReport.numMeas} Accuracy: {gle.horizontalAccuracy:.2f} m cnos", sorted(cnos)) - except Exception: - pass - sm.update() - time.sleep(0.1) diff --git a/system/ubloxd/tests/test_glonass_kaitai.cc b/system/ubloxd/tests/test_glonass_kaitai.cc deleted file mode 100644 index 96f43742..00000000 --- a/system/ubloxd/tests/test_glonass_kaitai.cc +++ /dev/null @@ -1,360 +0,0 @@ -#include -#include -#include -#include -#include -#include - -#include "catch2/catch.hpp" -#include "system/ubloxd/generated/glonass.h" - -typedef std::vector> string_data; - -#define IDLE_CHIP_IDX 0 -#define STRING_NUMBER_IDX 1 -// string data 1-5 -#define HC_IDX 0 -#define PAD1_IDX 1 -#define SUPERFRAME_IDX 2 -#define PAD2_IDX 3 -#define FRAME_IDX 4 - -// Indexes for string number 1 -#define ST1_NU_IDX 2 -#define ST1_P1_IDX 3 -#define ST1_T_K_IDX 4 -#define ST1_X_VEL_S_IDX 5 -#define ST1_X_VEL_V_IDX 6 -#define ST1_X_ACCEL_S_IDX 7 -#define ST1_X_ACCEL_V_IDX 8 -#define ST1_X_S_IDX 9 -#define ST1_X_V_IDX 10 -#define ST1_HC_OFF 11 - -// Indexes for string number 2 -#define ST2_BN_IDX 2 -#define ST2_P2_IDX 3 -#define ST2_TB_IDX 4 -#define ST2_NU_IDX 5 -#define ST2_Y_VEL_S_IDX 6 -#define ST2_Y_VEL_V_IDX 7 -#define ST2_Y_ACCEL_S_IDX 8 -#define ST2_Y_ACCEL_V_IDX 9 -#define ST2_Y_S_IDX 10 -#define ST2_Y_V_IDX 11 -#define ST2_HC_OFF 12 - -// Indexes for string number 3 -#define ST3_P3_IDX 2 -#define ST3_GAMMA_N_S_IDX 3 -#define ST3_GAMMA_N_V_IDX 4 -#define ST3_NU_1_IDX 5 -#define ST3_P_IDX 6 -#define ST3_L_N_IDX 7 -#define ST3_Z_VEL_S_IDX 8 -#define ST3_Z_VEL_V_IDX 9 -#define ST3_Z_ACCEL_S_IDX 10 -#define ST3_Z_ACCEL_V_IDX 11 -#define ST3_Z_S_IDX 12 -#define ST3_Z_V_IDX 13 -#define ST3_HC_OFF 14 - -// Indexes for string number 4 -#define ST4_TAU_N_S_IDX 2 -#define ST4_TAU_N_V_IDX 3 -#define ST4_DELTA_TAU_N_S_IDX 4 -#define ST4_DELTA_TAU_N_V_IDX 5 -#define ST4_E_N_IDX 6 -#define ST4_NU_1_IDX 7 -#define ST4_P4_IDX 8 -#define ST4_F_T_IDX 9 -#define ST4_NU_2_IDX 10 -#define ST4_N_T_IDX 11 -#define ST4_N_IDX 12 -#define ST4_M_IDX 13 -#define ST4_HC_OFF 14 - -// Indexes for string number 5 -#define ST5_N_A_IDX 2 -#define ST5_TAU_C_IDX 3 -#define ST5_NU_IDX 4 -#define ST5_N_4_IDX 5 -#define ST5_TAU_GPS_IDX 6 -#define ST5_L_N_IDX 7 -#define ST5_HC_OFF 8 - -// Indexes for non immediate -#define ST6_DATA_1_IDX 2 -#define ST6_DATA_2_IDX 3 -#define ST6_HC_OFF 4 - - -std::string generate_inp_data(string_data& data) { - std::string inp_data = ""; - for (auto& [b, v] : data) { - std::string tmp = std::bitset<64>(v).to_string(); - inp_data += tmp.substr(64-b, b); - } - assert(inp_data.size() == 128); - - std::string string_data; - string_data.reserve(16); - for (int i = 0; i < 128; i+=8) { - std::string substr = inp_data.substr(i, 8); - string_data.push_back((uint8_t)std::stoi(substr.c_str(), 0, 2)); - } - - return string_data; -} - -string_data generate_string_data(uint8_t string_number) { - - srand((unsigned)time(0)); - string_data data; // - data.push_back({1, 0}); // idle chip - data.push_back({4, string_number}); // string number - - if (string_number == 1) { - data.push_back({2, 3}); // not_used - data.push_back({2, 1}); // p1 - data.push_back({12, 113}); // t_k - data.push_back({1, rand() & 1}); // x_vel_sign - data.push_back({23, 7122}); // x_vel_value - data.push_back({1, rand() & 1}); // x_accel_sign - data.push_back({4, 3}); // x_accel_value - data.push_back({1, rand() & 1}); // x_sign - data.push_back({26, 33554431}); // x_value - } else if (string_number == 2) { - data.push_back({3, 3}); // b_n - data.push_back({1, 1}); // p2 - data.push_back({7, 123}); // t_b - data.push_back({5, 31}); // not_used - data.push_back({1, rand() & 1}); // y_vel_sign - data.push_back({23, 7422}); // y_vel_value - data.push_back({1, rand() & 1}); // y_accel_sign - data.push_back({4, 3}); // y_accel_value - data.push_back({1, rand() & 1}); // y_sign - data.push_back({26, 67108863}); // y_value - } else if (string_number == 3) { - data.push_back({1, 0}); // p3 - data.push_back({1, 1}); // gamma_n_sign - data.push_back({10, 123}); // gamma_n_value - data.push_back({1, 0}); // not_used - data.push_back({2, 2}); // p - data.push_back({1, 1}); // l_n - data.push_back({1, rand() & 1}); // z_vel_sign - data.push_back({23, 1337}); // z_vel_value - data.push_back({1, rand() & 1}); // z_accel_sign - data.push_back({4, 9}); // z_accel_value - data.push_back({1, rand() & 1}); // z_sign - data.push_back({26, 100023}); // z_value - } else if (string_number == 4) { - data.push_back({1, rand() & 1}); // tau_n_sign - data.push_back({21, 197152}); // tau_n_value - data.push_back({1, rand() & 1}); // delta_tau_n_sign - data.push_back({4, 4}); // delta_tau_n_value - data.push_back({5, 0}); // e_n - data.push_back({14, 2}); // not_used_1 - data.push_back({1, 1}); // p4 - data.push_back({4, 9}); // f_t - data.push_back({3, 3}); // not_used_2 - data.push_back({11, 2047}); // n_t - data.push_back({5, 2}); // n - data.push_back({2, 1}); // m - } else if (string_number == 5) { - data.push_back({11, 2047}); // n_a - data.push_back({32, 4294767295}); // tau_c - data.push_back({1, 0}); // not_used_1 - data.push_back({5, 2}); // n_4 - data.push_back({22, 4114304}); // tau_gps - data.push_back({1, 0}); // l_n - } else { // non-immediate data is not parsed - data.push_back({64, rand()}); // data_1 - data.push_back({8, 6}); // data_2 - } - - data.push_back({8, rand() & 0xFF}); // hamming code - data.push_back({11, rand() & 0x7FF}); // pad - data.push_back({16, rand() & 0xFFFF}); // superframe - data.push_back({8, rand() & 0xFF}); // pad - data.push_back({8, rand() & 0xFF}); // frame - return data; -} - -TEST_CASE("parse_string_number_1"){ - string_data data = generate_string_data(1); - std::string inp_data = generate_inp_data(data); - - kaitai::kstream stream(inp_data); - glonass_t gl_string(&stream); - - REQUIRE(gl_string.idle_chip() == data[IDLE_CHIP_IDX].second); - REQUIRE(gl_string.string_number() == data[STRING_NUMBER_IDX].second); - REQUIRE(gl_string.hamming_code() == data[ST1_HC_OFF + HC_IDX].second); - REQUIRE(gl_string.pad_1() == data[ST1_HC_OFF + PAD1_IDX].second); - REQUIRE(gl_string.superframe_number() == data[ST1_HC_OFF + SUPERFRAME_IDX].second); - REQUIRE(gl_string.pad_2() == data[ST1_HC_OFF + PAD2_IDX].second); - REQUIRE(gl_string.frame_number() == data[ST1_HC_OFF + FRAME_IDX].second); - - kaitai::kstream str1(inp_data); - glonass_t str1_data(&str1); - glonass_t::string_1_t* s1 = static_cast(str1_data.data()); - - REQUIRE(s1->not_used() == data[ST1_NU_IDX].second); - REQUIRE(s1->p1() == data[ST1_P1_IDX].second); - REQUIRE(s1->t_k() == data[ST1_T_K_IDX].second); - - int mul = s1->x_vel_sign() ? (-1) : 1; - REQUIRE(s1->x_vel() == (data[ST1_X_VEL_V_IDX].second * mul)); - mul = s1->x_accel_sign() ? (-1) : 1; - REQUIRE(s1->x_accel() == (data[ST1_X_ACCEL_V_IDX].second * mul)); - mul = s1->x_sign() ? (-1) : 1; - REQUIRE(s1->x() == (data[ST1_X_V_IDX].second * mul)); -} - -TEST_CASE("parse_string_number_2"){ - string_data data = generate_string_data(2); - std::string inp_data = generate_inp_data(data); - - kaitai::kstream stream(inp_data); - glonass_t gl_string(&stream); - - REQUIRE(gl_string.idle_chip() == data[IDLE_CHIP_IDX].second); - REQUIRE(gl_string.string_number() == data[STRING_NUMBER_IDX].second); - REQUIRE(gl_string.hamming_code() == data[ST2_HC_OFF + HC_IDX].second); - REQUIRE(gl_string.pad_1() == data[ST2_HC_OFF + PAD1_IDX].second); - REQUIRE(gl_string.superframe_number() == data[ST2_HC_OFF + SUPERFRAME_IDX].second); - REQUIRE(gl_string.pad_2() == data[ST2_HC_OFF + PAD2_IDX].second); - REQUIRE(gl_string.frame_number() == data[ST2_HC_OFF + FRAME_IDX].second); - - kaitai::kstream str2(inp_data); - glonass_t str2_data(&str2); - glonass_t::string_2_t* s2 = static_cast(str2_data.data()); - - REQUIRE(s2->b_n() == data[ST2_BN_IDX].second); - REQUIRE(s2->not_used() == data[ST2_NU_IDX].second); - REQUIRE(s2->p2() == data[ST2_P2_IDX].second); - REQUIRE(s2->t_b() == data[ST2_TB_IDX].second); - int mul = s2->y_vel_sign() ? (-1) : 1; - REQUIRE(s2->y_vel() == (data[ST2_Y_VEL_V_IDX].second * mul)); - mul = s2->y_accel_sign() ? (-1) : 1; - REQUIRE(s2->y_accel() == (data[ST2_Y_ACCEL_V_IDX].second * mul)); - mul = s2->y_sign() ? (-1) : 1; - REQUIRE(s2->y() == (data[ST2_Y_V_IDX].second * mul)); -} - -TEST_CASE("parse_string_number_3"){ - string_data data = generate_string_data(3); - std::string inp_data = generate_inp_data(data); - - kaitai::kstream stream(inp_data); - glonass_t gl_string(&stream); - - REQUIRE(gl_string.idle_chip() == data[IDLE_CHIP_IDX].second); - REQUIRE(gl_string.string_number() == data[STRING_NUMBER_IDX].second); - REQUIRE(gl_string.hamming_code() == data[ST3_HC_OFF + HC_IDX].second); - REQUIRE(gl_string.pad_1() == data[ST3_HC_OFF + PAD1_IDX].second); - REQUIRE(gl_string.superframe_number() == data[ST3_HC_OFF + SUPERFRAME_IDX].second); - REQUIRE(gl_string.pad_2() == data[ST3_HC_OFF + PAD2_IDX].second); - REQUIRE(gl_string.frame_number() == data[ST3_HC_OFF + FRAME_IDX].second); - - kaitai::kstream str3(inp_data); - glonass_t str3_data(&str3); - glonass_t::string_3_t* s3 = static_cast(str3_data.data()); - - REQUIRE(s3->p3() == data[ST3_P3_IDX].second); - int mul = s3->gamma_n_sign() ? (-1) : 1; - REQUIRE(s3->gamma_n() == (data[ST3_GAMMA_N_V_IDX].second * mul)); - REQUIRE(s3->not_used() == data[ST3_NU_1_IDX].second); - REQUIRE(s3->p() == data[ST3_P_IDX].second); - REQUIRE(s3->l_n() == data[ST3_L_N_IDX].second); - mul = s3->z_vel_sign() ? (-1) : 1; - REQUIRE(s3->z_vel() == (data[ST3_Z_VEL_V_IDX].second * mul)); - mul = s3->z_accel_sign() ? (-1) : 1; - REQUIRE(s3->z_accel() == (data[ST3_Z_ACCEL_V_IDX].second * mul)); - mul = s3->z_sign() ? (-1) : 1; - REQUIRE(s3->z() == (data[ST3_Z_V_IDX].second * mul)); -} - -TEST_CASE("parse_string_number_4"){ - string_data data = generate_string_data(4); - std::string inp_data = generate_inp_data(data); - - kaitai::kstream stream(inp_data); - glonass_t gl_string(&stream); - - REQUIRE(gl_string.idle_chip() == data[IDLE_CHIP_IDX].second); - REQUIRE(gl_string.string_number() == data[STRING_NUMBER_IDX].second); - REQUIRE(gl_string.hamming_code() == data[ST4_HC_OFF + HC_IDX].second); - REQUIRE(gl_string.pad_1() == data[ST4_HC_OFF + PAD1_IDX].second); - REQUIRE(gl_string.superframe_number() == data[ST4_HC_OFF + SUPERFRAME_IDX].second); - REQUIRE(gl_string.pad_2() == data[ST4_HC_OFF + PAD2_IDX].second); - REQUIRE(gl_string.frame_number() == data[ST4_HC_OFF + FRAME_IDX].second); - - kaitai::kstream str4(inp_data); - glonass_t str4_data(&str4); - glonass_t::string_4_t* s4 = static_cast(str4_data.data()); - - int mul = s4->tau_n_sign() ? (-1) : 1; - REQUIRE(s4->tau_n() == (data[ST4_TAU_N_V_IDX].second * mul)); - mul = s4->delta_tau_n_sign() ? (-1) : 1; - REQUIRE(s4->delta_tau_n() == (data[ST4_DELTA_TAU_N_V_IDX].second * mul)); - REQUIRE(s4->e_n() == data[ST4_E_N_IDX].second); - REQUIRE(s4->not_used_1() == data[ST4_NU_1_IDX].second); - REQUIRE(s4->p4() == data[ST4_P4_IDX].second); - REQUIRE(s4->f_t() == data[ST4_F_T_IDX].second); - REQUIRE(s4->not_used_2() == data[ST4_NU_2_IDX].second); - REQUIRE(s4->n_t() == data[ST4_N_T_IDX].second); - REQUIRE(s4->n() == data[ST4_N_IDX].second); - REQUIRE(s4->m() == data[ST4_M_IDX].second); -} - -TEST_CASE("parse_string_number_5"){ - string_data data = generate_string_data(5); - std::string inp_data = generate_inp_data(data); - - kaitai::kstream stream(inp_data); - glonass_t gl_string(&stream); - - REQUIRE(gl_string.idle_chip() == data[IDLE_CHIP_IDX].second); - REQUIRE(gl_string.string_number() == data[STRING_NUMBER_IDX].second); - REQUIRE(gl_string.hamming_code() == data[ST5_HC_OFF + HC_IDX].second); - REQUIRE(gl_string.pad_1() == data[ST5_HC_OFF + PAD1_IDX].second); - REQUIRE(gl_string.superframe_number() == data[ST5_HC_OFF + SUPERFRAME_IDX].second); - REQUIRE(gl_string.pad_2() == data[ST5_HC_OFF + PAD2_IDX].second); - REQUIRE(gl_string.frame_number() == data[ST5_HC_OFF + FRAME_IDX].second); - - kaitai::kstream str5(inp_data); - glonass_t str5_data(&str5); - glonass_t::string_5_t* s5 = static_cast(str5_data.data()); - - REQUIRE(s5->n_a() == data[ST5_N_A_IDX].second); - REQUIRE(s5->tau_c() == data[ST5_TAU_C_IDX].second); - REQUIRE(s5->not_used() == data[ST5_NU_IDX].second); - REQUIRE(s5->n_4() == data[ST5_N_4_IDX].second); - REQUIRE(s5->tau_gps() == data[ST5_TAU_GPS_IDX].second); - REQUIRE(s5->l_n() == data[ST5_L_N_IDX].second); -} - -TEST_CASE("parse_string_number_NI"){ - string_data data = generate_string_data((rand() % 10) + 6); - std::string inp_data = generate_inp_data(data); - - kaitai::kstream stream(inp_data); - glonass_t gl_string(&stream); - - REQUIRE(gl_string.idle_chip() == data[IDLE_CHIP_IDX].second); - REQUIRE(gl_string.string_number() == data[STRING_NUMBER_IDX].second); - REQUIRE(gl_string.hamming_code() == data[ST6_HC_OFF + HC_IDX].second); - REQUIRE(gl_string.pad_1() == data[ST6_HC_OFF + PAD1_IDX].second); - REQUIRE(gl_string.superframe_number() == data[ST6_HC_OFF + SUPERFRAME_IDX].second); - REQUIRE(gl_string.pad_2() == data[ST6_HC_OFF + PAD2_IDX].second); - REQUIRE(gl_string.frame_number() == data[ST6_HC_OFF + FRAME_IDX].second); - - kaitai::kstream strni(inp_data); - glonass_t strni_data(&strni); - glonass_t::string_non_immediate_t* sni = static_cast(strni_data.data()); - - REQUIRE(sni->data_1() == data[ST6_DATA_1_IDX].second); - REQUIRE(sni->data_2() == data[ST6_DATA_2_IDX].second); -} diff --git a/system/ubloxd/tests/test_glonass_runner.cc b/system/ubloxd/tests/test_glonass_runner.cc deleted file mode 100644 index 62bf7476..00000000 --- a/system/ubloxd/tests/test_glonass_runner.cc +++ /dev/null @@ -1,2 +0,0 @@ -#define CATCH_CONFIG_MAIN -#include "catch2/catch.hpp" diff --git a/system/ubloxd/tests/ubloxd.py b/system/ubloxd/tests/ubloxd.py deleted file mode 100755 index c1738711..00000000 --- a/system/ubloxd/tests/ubloxd.py +++ /dev/null @@ -1,89 +0,0 @@ -#!/usr/bin/env python3 -# type: ignore - -from openpilot.selfdrive.locationd.test import ublox -import struct - -baudrate = 460800 -rate = 100 # send new data every 100ms - - -def configure_ublox(dev): - # configure ports and solution parameters and rate - dev.configure_port(port=ublox.PORT_USB, inMask=1, outMask=1) # enable only UBX on USB - dev.configure_port(port=0, inMask=0, outMask=0) # disable DDC - - payload = struct.pack(' - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "common/swaglog.h" - -const double gpsPi = 3.1415926535898; -#define UBLOX_MSG_SIZE(hdr) (*(uint16_t *)&hdr[4]) - -inline static bool bit_to_bool(uint8_t val, int shifts) { - return (bool)(val & (1 << shifts)); -} - -inline int UbloxMsgParser::needed_bytes() { - // Msg header incomplete? - if (bytes_in_parse_buf < ublox::UBLOX_HEADER_SIZE) - return ublox::UBLOX_HEADER_SIZE + ublox::UBLOX_CHECKSUM_SIZE - bytes_in_parse_buf; - uint16_t needed = UBLOX_MSG_SIZE(msg_parse_buf) + ublox::UBLOX_HEADER_SIZE + ublox::UBLOX_CHECKSUM_SIZE; - // too much data - if (needed < (uint16_t)bytes_in_parse_buf) - return -1; - return needed - (uint16_t)bytes_in_parse_buf; -} - -inline bool UbloxMsgParser::valid_cheksum() { - uint8_t ck_a = 0, ck_b = 0; - for (int i = 2; i < bytes_in_parse_buf - ublox::UBLOX_CHECKSUM_SIZE; i++) { - ck_a = (ck_a + msg_parse_buf[i]) & 0xFF; - ck_b = (ck_b + ck_a) & 0xFF; - } - if (ck_a != msg_parse_buf[bytes_in_parse_buf - 2]) { - LOGD("Checksum a mismatch: %02X, %02X", ck_a, msg_parse_buf[6]); - return false; - } - if (ck_b != msg_parse_buf[bytes_in_parse_buf - 1]) { - LOGD("Checksum b mismatch: %02X, %02X", ck_b, msg_parse_buf[7]); - return false; - } - return true; -} - -inline bool UbloxMsgParser::valid() { - return bytes_in_parse_buf >= ublox::UBLOX_HEADER_SIZE + ublox::UBLOX_CHECKSUM_SIZE && - needed_bytes() == 0 && valid_cheksum(); -} - -inline bool UbloxMsgParser::valid_so_far() { - if (bytes_in_parse_buf > 0 && msg_parse_buf[0] != ublox::PREAMBLE1) { - return false; - } - if (bytes_in_parse_buf > 1 && msg_parse_buf[1] != ublox::PREAMBLE2) { - return false; - } - if (needed_bytes() == 0 && !valid()) { - return false; - } - return true; -} - -bool UbloxMsgParser::add_data(float log_time, const uint8_t *incoming_data, uint32_t incoming_data_len, size_t &bytes_consumed) { - last_log_time = log_time; - int needed = needed_bytes(); - if (needed > 0) { - bytes_consumed = std::min((uint32_t)needed, incoming_data_len); - // Add data to buffer - memcpy(msg_parse_buf + bytes_in_parse_buf, incoming_data, bytes_consumed); - bytes_in_parse_buf += bytes_consumed; - } else { - bytes_consumed = incoming_data_len; - } - - // Validate msg format, detect invalid header and invalid checksum. - while (!valid_so_far() && bytes_in_parse_buf != 0) { - // Corrupted msg, drop a byte. - bytes_in_parse_buf -= 1; - if (bytes_in_parse_buf > 0) - memmove(&msg_parse_buf[0], &msg_parse_buf[1], bytes_in_parse_buf); - } - - // There is redundant data at the end of buffer, reset the buffer. - if (needed_bytes() == -1) { - bytes_in_parse_buf = 0; - } - return valid(); -} - - -std::pair> UbloxMsgParser::gen_msg() { - std::string dat = data(); - kaitai::kstream stream(dat); - - ubx_t ubx_message(&stream); - auto body = ubx_message.body(); - - switch (ubx_message.msg_type()) { - case 0x0107: - return {"gpsLocationExternal", gen_nav_pvt(static_cast(body))}; - case 0x0213: // UBX-RXM-SFRB (Broadcast Navigation Data Subframe) - return {"ubloxGnss", gen_rxm_sfrbx(static_cast(body))}; - case 0x0215: // UBX-RXM-RAW (Multi-GNSS Raw Measurement Data) - return {"ubloxGnss", gen_rxm_rawx(static_cast(body))}; - case 0x0a09: - return {"ubloxGnss", gen_mon_hw(static_cast(body))}; - case 0x0a0b: - return {"ubloxGnss", gen_mon_hw2(static_cast(body))}; - case 0x0135: - return {"ubloxGnss", gen_nav_sat(static_cast(body))}; - default: - LOGE("Unknown message type %x", ubx_message.msg_type()); - return {"ubloxGnss", kj::Array()}; - } -} - - -kj::Array UbloxMsgParser::gen_nav_pvt(ubx_t::nav_pvt_t *msg) { - MessageBuilder msg_builder; - auto gpsLoc = msg_builder.initEvent().initGpsLocationExternal(); - gpsLoc.setSource(cereal::GpsLocationData::SensorSource::UBLOX); - gpsLoc.setFlags(msg->flags()); - gpsLoc.setHasFix((msg->flags() % 2) == 1); - gpsLoc.setLatitude(msg->lat() * 1e-07); - gpsLoc.setLongitude(msg->lon() * 1e-07); - gpsLoc.setAltitude(msg->height() * 1e-03); - gpsLoc.setSpeed(msg->g_speed() * 1e-03); - gpsLoc.setBearingDeg(msg->head_mot() * 1e-5); - gpsLoc.setHorizontalAccuracy(msg->h_acc() * 1e-03); - gpsLoc.setSatelliteCount(msg->num_sv()); - std::tm timeinfo = std::tm(); - timeinfo.tm_year = msg->year() - 1900; - timeinfo.tm_mon = msg->month() - 1; - timeinfo.tm_mday = msg->day(); - timeinfo.tm_hour = msg->hour(); - timeinfo.tm_min = msg->min(); - timeinfo.tm_sec = msg->sec(); - - std::time_t utc_tt = timegm(&timeinfo); - gpsLoc.setUnixTimestampMillis(utc_tt * 1e+03 + msg->nano() * 1e-06); - float f[] = { msg->vel_n() * 1e-03f, msg->vel_e() * 1e-03f, msg->vel_d() * 1e-03f }; - gpsLoc.setVNED(f); - gpsLoc.setVerticalAccuracy(msg->v_acc() * 1e-03); - gpsLoc.setSpeedAccuracy(msg->s_acc() * 1e-03); - gpsLoc.setBearingAccuracyDeg(msg->head_acc() * 1e-05); - return capnp::messageToFlatArray(msg_builder); -} - -kj::Array UbloxMsgParser::parse_gps_ephemeris(ubx_t::rxm_sfrbx_t *msg) { - // GPS subframes are packed into 10x 4 bytes, each containing 3 actual bytes - // We will first need to separate the data from the padding and parity - auto body = *msg->body(); - assert(body.size() == 10); - - std::string subframe_data; - subframe_data.reserve(30); - for (uint32_t word : body) { - word = word >> 6; // TODO: Verify parity - subframe_data.push_back(word >> 16); - subframe_data.push_back(word >> 8); - subframe_data.push_back(word >> 0); - } - - // Collect subframes in map and parse when we have all the parts - { - kaitai::kstream stream(subframe_data); - gps_t subframe(&stream); - - int subframe_id = subframe.how()->subframe_id(); - if (subframe_id > 3 || subframe_id < 1) { - // don't parse almanac subframes - return kj::Array(); - } - gps_subframes[msg->sv_id()][subframe_id] = subframe_data; - } - - // publish if subframes 1-3 have been collected - if (gps_subframes[msg->sv_id()].size() == 3) { - MessageBuilder msg_builder; - auto eph = msg_builder.initEvent().initUbloxGnss().initEphemeris(); - eph.setSvId(msg->sv_id()); - - int iode_s2 = 0; - int iode_s3 = 0; - int iodc_lsb = 0; - int week; - - // Subframe 1 - { - kaitai::kstream stream(gps_subframes[msg->sv_id()][1]); - gps_t subframe(&stream); - gps_t::subframe_1_t* subframe_1 = static_cast(subframe.body()); - - // Each message is incremented to be greater or equal than week 1877 (2015-12-27). - // To skip this use the current_time argument - week = subframe_1->week_no(); - week += 1024; - if (week < 1877) { - week += 1024; - } - //eph.setGpsWeek(subframe_1->week_no()); - eph.setTgd(subframe_1->t_gd() * pow(2, -31)); - eph.setToc(subframe_1->t_oc() * pow(2, 4)); - eph.setAf2(subframe_1->af_2() * pow(2, -55)); - eph.setAf1(subframe_1->af_1() * pow(2, -43)); - eph.setAf0(subframe_1->af_0() * pow(2, -31)); - eph.setSvHealth(subframe_1->sv_health()); - eph.setTowCount(subframe.how()->tow_count()); - iodc_lsb = subframe_1->iodc_lsb(); - } - - // Subframe 2 - { - kaitai::kstream stream(gps_subframes[msg->sv_id()][2]); - gps_t subframe(&stream); - gps_t::subframe_2_t* subframe_2 = static_cast(subframe.body()); - - // GPS week refers to current week, the ephemeris can be valid for the next - // if toe equals 0, this can be verified by the TOW count if it is within the - // last 2 hours of the week (gps ephemeris valid for 4hours) - if (subframe_2->t_oe() == 0 and subframe.how()->tow_count()*6 >= (SECS_IN_WEEK - 2*SECS_IN_HR)){ - week += 1; - } - eph.setCrs(subframe_2->c_rs() * pow(2, -5)); - eph.setDeltaN(subframe_2->delta_n() * pow(2, -43) * gpsPi); - eph.setM0(subframe_2->m_0() * pow(2, -31) * gpsPi); - eph.setCuc(subframe_2->c_uc() * pow(2, -29)); - eph.setEcc(subframe_2->e() * pow(2, -33)); - eph.setCus(subframe_2->c_us() * pow(2, -29)); - eph.setA(pow(subframe_2->sqrt_a() * pow(2, -19), 2.0)); - eph.setToe(subframe_2->t_oe() * pow(2, 4)); - iode_s2 = subframe_2->iode(); - } - - // Subframe 3 - { - kaitai::kstream stream(gps_subframes[msg->sv_id()][3]); - gps_t subframe(&stream); - gps_t::subframe_3_t* subframe_3 = static_cast(subframe.body()); - - eph.setCic(subframe_3->c_ic() * pow(2, -29)); - eph.setOmega0(subframe_3->omega_0() * pow(2, -31) * gpsPi); - eph.setCis(subframe_3->c_is() * pow(2, -29)); - eph.setI0(subframe_3->i_0() * pow(2, -31) * gpsPi); - eph.setCrc(subframe_3->c_rc() * pow(2, -5)); - eph.setOmega(subframe_3->omega() * pow(2, -31) * gpsPi); - eph.setOmegaDot(subframe_3->omega_dot() * pow(2, -43) * gpsPi); - eph.setIode(subframe_3->iode()); - eph.setIDot(subframe_3->idot() * pow(2, -43) * gpsPi); - iode_s3 = subframe_3->iode(); - } - - eph.setToeWeek(week); - eph.setTocWeek(week); - - gps_subframes[msg->sv_id()].clear(); - if (iodc_lsb != iode_s2 || iodc_lsb != iode_s3) { - // data set cutover, reject ephemeris - return kj::Array(); - } - return capnp::messageToFlatArray(msg_builder); - } - return kj::Array(); -} - -kj::Array UbloxMsgParser::parse_glonass_ephemeris(ubx_t::rxm_sfrbx_t *msg) { - // This parser assumes that no 2 satellites of the same frequency - // can be in view at the same time - auto body = *msg->body(); - assert(body.size() == 4); - { - std::string string_data; - string_data.reserve(16); - for (uint32_t word : body) { - for (int i = 3; i >= 0; i--) - string_data.push_back(word >> 8*i); - } - - kaitai::kstream stream(string_data); - glonass_t gl_string(&stream); - int string_number = gl_string.string_number(); - if (string_number < 1 || string_number > 5 || gl_string.idle_chip()) { - // don't parse non immediate data, idle_chip == 0 - return kj::Array(); - } - - // Check if new string either has same superframe_id or log transmission times make sense - bool superframe_unknown = false; - bool needs_clear = false; - for (int i = 1; i <= 5; i++) { - if (glonass_strings[msg->freq_id()].find(i) == glonass_strings[msg->freq_id()].end()) - continue; - if (glonass_string_superframes[msg->freq_id()][i] == 0 || gl_string.superframe_number() == 0) { - superframe_unknown = true; - } else if (glonass_string_superframes[msg->freq_id()][i] != gl_string.superframe_number()) { - needs_clear = true; - } - // Check if string times add up to being from the same frame - // If superframe is known this is redundant - // Strings are sent 2s apart and frames are 30s apart - if (superframe_unknown && - std::abs((glonass_string_times[msg->freq_id()][i] - 2.0 * i) - (last_log_time - 2.0 * string_number)) > 10) - needs_clear = true; - } - if (needs_clear) { - glonass_strings[msg->freq_id()].clear(); - glonass_string_superframes[msg->freq_id()].clear(); - glonass_string_times[msg->freq_id()].clear(); - } - glonass_strings[msg->freq_id()][string_number] = string_data; - glonass_string_superframes[msg->freq_id()][string_number] = gl_string.superframe_number(); - glonass_string_times[msg->freq_id()][string_number] = last_log_time; - } - if (msg->sv_id() == 255) { - // data can be decoded before identifying the SV number, in this case 255 - // is returned, which means "unknown" (ublox p32) - return kj::Array(); - } - - // publish if strings 1-5 have been collected - if (glonass_strings[msg->freq_id()].size() != 5) { - return kj::Array(); - } - - MessageBuilder msg_builder; - auto eph = msg_builder.initEvent().initUbloxGnss().initGlonassEphemeris(); - eph.setSvId(msg->sv_id()); - eph.setFreqNum(msg->freq_id() - 7); - - uint16_t current_day = 0; - uint16_t tk = 0; - - // string number 1 - { - kaitai::kstream stream(glonass_strings[msg->freq_id()][1]); - glonass_t gl_stream(&stream); - glonass_t::string_1_t* data = static_cast(gl_stream.data()); - - eph.setP1(data->p1()); - tk = data->t_k(); - eph.setTkDEPRECATED(tk); - eph.setXVel(data->x_vel() * pow(2, -20)); - eph.setXAccel(data->x_accel() * pow(2, -30)); - eph.setX(data->x() * pow(2, -11)); - } - - // string number 2 - { - kaitai::kstream stream(glonass_strings[msg->freq_id()][2]); - glonass_t gl_stream(&stream); - glonass_t::string_2_t* data = static_cast(gl_stream.data()); - - eph.setSvHealth(data->b_n()>>2); // MSB indicates health - eph.setP2(data->p2()); - eph.setTb(data->t_b()); - eph.setYVel(data->y_vel() * pow(2, -20)); - eph.setYAccel(data->y_accel() * pow(2, -30)); - eph.setY(data->y() * pow(2, -11)); - } - - // string number 3 - { - kaitai::kstream stream(glonass_strings[msg->freq_id()][3]); - glonass_t gl_stream(&stream); - glonass_t::string_3_t* data = static_cast(gl_stream.data()); - - eph.setP3(data->p3()); - eph.setGammaN(data->gamma_n() * pow(2, -40)); - eph.setSvHealth(eph.getSvHealth() | data->l_n()); - eph.setZVel(data->z_vel() * pow(2, -20)); - eph.setZAccel(data->z_accel() * pow(2, -30)); - eph.setZ(data->z() * pow(2, -11)); - } - - // string number 4 - { - kaitai::kstream stream(glonass_strings[msg->freq_id()][4]); - glonass_t gl_stream(&stream); - glonass_t::string_4_t* data = static_cast(gl_stream.data()); - - current_day = data->n_t(); - eph.setNt(current_day); - eph.setTauN(data->tau_n() * pow(2, -30)); - eph.setDeltaTauN(data->delta_tau_n() * pow(2, -30)); - eph.setAge(data->e_n()); - eph.setP4(data->p4()); - eph.setSvURA(glonass_URA_lookup.at(data->f_t())); - if (msg->sv_id() != data->n()) { - LOGE("SV_ID != SLOT_NUMBER: %d %" PRIu64, msg->sv_id(), data->n()); - } - eph.setSvType(data->m()); - } - - // string number 5 - { - kaitai::kstream stream(glonass_strings[msg->freq_id()][5]); - glonass_t gl_stream(&stream); - glonass_t::string_5_t* data = static_cast(gl_stream.data()); - - // string5 parsing is only needed to get the year, this can be removed and - // the year can be fetched later in laika (note rollovers and leap year) - eph.setN4(data->n_4()); - int tk_seconds = SECS_IN_HR * ((tk>>7) & 0x1F) + SECS_IN_MIN * ((tk>>1) & 0x3F) + (tk & 0x1) * 30; - eph.setTkSeconds(tk_seconds); - } - - glonass_strings[msg->freq_id()].clear(); - return capnp::messageToFlatArray(msg_builder); -} - - -kj::Array UbloxMsgParser::gen_rxm_sfrbx(ubx_t::rxm_sfrbx_t *msg) { - switch (msg->gnss_id()) { - case ubx_t::gnss_type_t::GNSS_TYPE_GPS: - return parse_gps_ephemeris(msg); - case ubx_t::gnss_type_t::GNSS_TYPE_GLONASS: - return parse_glonass_ephemeris(msg); - default: - return kj::Array(); - } -} - -kj::Array UbloxMsgParser::gen_rxm_rawx(ubx_t::rxm_rawx_t *msg) { - MessageBuilder msg_builder; - auto mr = msg_builder.initEvent().initUbloxGnss().initMeasurementReport(); - mr.setRcvTow(msg->rcv_tow()); - mr.setGpsWeek(msg->week()); - mr.setLeapSeconds(msg->leap_s()); - mr.setGpsWeek(msg->week()); - - auto mb = mr.initMeasurements(msg->num_meas()); - auto measurements = *msg->meas(); - for (int8_t i = 0; i < msg->num_meas(); i++) { - mb[i].setSvId(measurements[i]->sv_id()); - mb[i].setPseudorange(measurements[i]->pr_mes()); - mb[i].setCarrierCycles(measurements[i]->cp_mes()); - mb[i].setDoppler(measurements[i]->do_mes()); - mb[i].setGnssId(measurements[i]->gnss_id()); - mb[i].setGlonassFrequencyIndex(measurements[i]->freq_id()); - mb[i].setLocktime(measurements[i]->lock_time()); - mb[i].setCno(measurements[i]->cno()); - mb[i].setPseudorangeStdev(0.01 * (pow(2, (measurements[i]->pr_stdev() & 15)))); // weird scaling, might be wrong - mb[i].setCarrierPhaseStdev(0.004 * (measurements[i]->cp_stdev() & 15)); - mb[i].setDopplerStdev(0.002 * (pow(2, (measurements[i]->do_stdev() & 15)))); // weird scaling, might be wrong - - auto ts = mb[i].initTrackingStatus(); - auto trk_stat = measurements[i]->trk_stat(); - ts.setPseudorangeValid(bit_to_bool(trk_stat, 0)); - ts.setCarrierPhaseValid(bit_to_bool(trk_stat, 1)); - ts.setHalfCycleValid(bit_to_bool(trk_stat, 2)); - ts.setHalfCycleSubtracted(bit_to_bool(trk_stat, 3)); - } - - mr.setNumMeas(msg->num_meas()); - auto rs = mr.initReceiverStatus(); - rs.setLeapSecValid(bit_to_bool(msg->rec_stat(), 0)); - rs.setClkReset(bit_to_bool(msg->rec_stat(), 2)); - return capnp::messageToFlatArray(msg_builder); -} - -kj::Array UbloxMsgParser::gen_nav_sat(ubx_t::nav_sat_t *msg) { - MessageBuilder msg_builder; - auto sr = msg_builder.initEvent().initUbloxGnss().initSatReport(); - sr.setITow(msg->itow()); - - auto svs = sr.initSvs(msg->num_svs()); - auto svs_data = *msg->svs(); - for (int8_t i = 0; i < msg->num_svs(); i++) { - svs[i].setSvId(svs_data[i]->sv_id()); - svs[i].setGnssId(svs_data[i]->gnss_id()); - svs[i].setFlagsBitfield(svs_data[i]->flags()); - svs[i].setCno(svs_data[i]->cno()); - svs[i].setElevationDeg(svs_data[i]->elev()); - svs[i].setAzimuthDeg(svs_data[i]->azim()); - svs[i].setPseudorangeResidual(svs_data[i]->pr_res() * 0.1); - } - - return capnp::messageToFlatArray(msg_builder); -} - -kj::Array UbloxMsgParser::gen_mon_hw(ubx_t::mon_hw_t *msg) { - MessageBuilder msg_builder; - auto hwStatus = msg_builder.initEvent().initUbloxGnss().initHwStatus(); - hwStatus.setNoisePerMS(msg->noise_per_ms()); - hwStatus.setFlags(msg->flags()); - hwStatus.setAgcCnt(msg->agc_cnt()); - hwStatus.setAStatus((cereal::UbloxGnss::HwStatus::AntennaSupervisorState) msg->a_status()); - hwStatus.setAPower((cereal::UbloxGnss::HwStatus::AntennaPowerStatus) msg->a_power()); - hwStatus.setJamInd(msg->jam_ind()); - return capnp::messageToFlatArray(msg_builder); -} - -kj::Array UbloxMsgParser::gen_mon_hw2(ubx_t::mon_hw2_t *msg) { - MessageBuilder msg_builder; - auto hwStatus = msg_builder.initEvent().initUbloxGnss().initHwStatus2(); - hwStatus.setOfsI(msg->ofs_i()); - hwStatus.setMagI(msg->mag_i()); - hwStatus.setOfsQ(msg->ofs_q()); - hwStatus.setMagQ(msg->mag_q()); - - switch (msg->cfg_source()) { - case ubx_t::mon_hw2_t::config_source_t::CONFIG_SOURCE_ROM: - hwStatus.setCfgSource(cereal::UbloxGnss::HwStatus2::ConfigSource::ROM); - break; - case ubx_t::mon_hw2_t::config_source_t::CONFIG_SOURCE_OTP: - hwStatus.setCfgSource(cereal::UbloxGnss::HwStatus2::ConfigSource::OTP); - break; - case ubx_t::mon_hw2_t::config_source_t::CONFIG_SOURCE_CONFIG_PINS: - hwStatus.setCfgSource(cereal::UbloxGnss::HwStatus2::ConfigSource::CONFIGPINS); - break; - case ubx_t::mon_hw2_t::config_source_t::CONFIG_SOURCE_FLASH: - hwStatus.setCfgSource(cereal::UbloxGnss::HwStatus2::ConfigSource::FLASH); - break; - default: - hwStatus.setCfgSource(cereal::UbloxGnss::HwStatus2::ConfigSource::UNDEFINED); - break; - } - - hwStatus.setLowLevCfg(msg->low_lev_cfg()); - hwStatus.setPostStatus(msg->post_status()); - - return capnp::messageToFlatArray(msg_builder); -} diff --git a/system/ubloxd/ublox_msg.h b/system/ubloxd/ublox_msg.h deleted file mode 100644 index d21760ed..00000000 --- a/system/ubloxd/ublox_msg.h +++ /dev/null @@ -1,131 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -#include "cereal/messaging/messaging.h" -#include "common/util.h" -#include "system/ubloxd/generated/gps.h" -#include "system/ubloxd/generated/glonass.h" -#include "system/ubloxd/generated/ubx.h" - -using namespace std::string_literals; - -const int SECS_IN_MIN = 60; -const int SECS_IN_HR = 60 * SECS_IN_MIN; -const int SECS_IN_DAY = 24 * SECS_IN_HR; -const int SECS_IN_WEEK = 7 * SECS_IN_DAY; - -// protocol constants -namespace ublox { - const uint8_t PREAMBLE1 = 0xb5; - const uint8_t PREAMBLE2 = 0x62; - - const int UBLOX_HEADER_SIZE = 6; - const int UBLOX_CHECKSUM_SIZE = 2; - const int UBLOX_MAX_MSG_SIZE = 65536; - - struct ubx_mga_ini_time_utc_t { - uint8_t type; - uint8_t version; - uint8_t ref; - int8_t leapSecs; - uint16_t year; - uint8_t month; - uint8_t day; - uint8_t hour; - uint8_t minute; - uint8_t second; - uint8_t reserved1; - uint32_t ns; - uint16_t tAccS; - uint16_t reserved2; - uint32_t tAccNs; - } __attribute__((packed)); - - inline std::string ubx_add_checksum(const std::string &msg) { - assert(msg.size() > 2); - - uint8_t ck_a = 0, ck_b = 0; - for (int i = 2; i < msg.size(); i++) { - ck_a = (ck_a + msg[i]) & 0xFF; - ck_b = (ck_b + ck_a) & 0xFF; - } - - std::string r = msg; - r.push_back(ck_a); - r.push_back(ck_b); - return r; - } - - inline std::string build_ubx_mga_ini_time_utc(struct tm time) { - ublox::ubx_mga_ini_time_utc_t payload = { - .type = 0x10, - .version = 0x0, - .ref = 0x0, - .leapSecs = -128, // Unknown - .year = (uint16_t)(1900 + time.tm_year), - .month = (uint8_t)(1 + time.tm_mon), - .day = (uint8_t)time.tm_mday, - .hour = (uint8_t)time.tm_hour, - .minute = (uint8_t)time.tm_min, - .second = (uint8_t)time.tm_sec, - .reserved1 = 0x0, - .ns = 0, - .tAccS = 30, - .reserved2 = 0x0, - .tAccNs = 0, - }; - assert(sizeof(payload) == 24); - - std::string msg = "\xb5\x62\x13\x40\x18\x00"s; - msg += std::string((char*)&payload, sizeof(payload)); - - return ubx_add_checksum(msg); - } -} - -class UbloxMsgParser { - public: - bool add_data(float log_time, const uint8_t *incoming_data, uint32_t incoming_data_len, size_t &bytes_consumed); - inline void reset() {bytes_in_parse_buf = 0;} - inline int needed_bytes(); - inline std::string data() {return std::string((const char*)msg_parse_buf, bytes_in_parse_buf);} - - std::pair> gen_msg(); - kj::Array gen_nav_pvt(ubx_t::nav_pvt_t *msg); - kj::Array gen_rxm_sfrbx(ubx_t::rxm_sfrbx_t *msg); - kj::Array gen_rxm_rawx(ubx_t::rxm_rawx_t *msg); - kj::Array gen_mon_hw(ubx_t::mon_hw_t *msg); - kj::Array gen_mon_hw2(ubx_t::mon_hw2_t *msg); - kj::Array gen_nav_sat(ubx_t::nav_sat_t *msg); - - private: - inline bool valid_cheksum(); - inline bool valid(); - inline bool valid_so_far(); - - kj::Array parse_gps_ephemeris(ubx_t::rxm_sfrbx_t *msg); - kj::Array parse_glonass_ephemeris(ubx_t::rxm_sfrbx_t *msg); - - std::unordered_map> gps_subframes; - - float last_log_time = 0.0; - size_t bytes_in_parse_buf = 0; - uint8_t msg_parse_buf[ublox::UBLOX_HEADER_SIZE + ublox::UBLOX_MAX_MSG_SIZE]; - - // user range accuracy in meters - const std::unordered_map glonass_URA_lookup = - {{ 0, 1}, { 1, 2}, { 2, 2.5}, { 3, 4}, { 4, 5}, {5, 7}, - { 6, 10}, { 7, 12}, { 8, 14}, { 9, 16}, {10, 32}, - {11, 64}, {12, 128}, {13, 256}, {14, 512}, {15, 1024}}; - - std::unordered_map> glonass_strings; - std::unordered_map> glonass_string_times; - std::unordered_map> glonass_string_superframes; -}; diff --git a/system/ubloxd/ubloxd.cc b/system/ubloxd/ubloxd.cc deleted file mode 100644 index 4e7e91f8..00000000 --- a/system/ubloxd/ubloxd.cc +++ /dev/null @@ -1,62 +0,0 @@ -#include - -#include - -#include "cereal/messaging/messaging.h" -#include "common/swaglog.h" -#include "common/util.h" -#include "system/ubloxd/ublox_msg.h" - -ExitHandler do_exit; -using namespace ublox; - -int main() { - LOGW("starting ubloxd"); - AlignedBuffer aligned_buf; - UbloxMsgParser parser; - - PubMaster pm({"ubloxGnss", "gpsLocationExternal"}); - - std::unique_ptr context(Context::create()); - std::unique_ptr subscriber(SubSocket::create(context.get(), "ubloxRaw")); - assert(subscriber != NULL); - subscriber->setTimeout(100); - - - while (!do_exit) { - std::unique_ptr msg(subscriber->receive()); - if (!msg) { - continue; - } - - capnp::FlatArrayMessageReader cmsg(aligned_buf.align(msg.get())); - cereal::Event::Reader event = cmsg.getRoot(); - auto ubloxRaw = event.getUbloxRaw(); - float log_time = 1e-9 * event.getLogMonoTime(); - - const uint8_t *data = ubloxRaw.begin(); - size_t len = ubloxRaw.size(); - size_t bytes_consumed = 0; - - while (bytes_consumed < len && !do_exit) { - size_t bytes_consumed_this_time = 0U; - if (parser.add_data(log_time, data + bytes_consumed, (uint32_t)(len - bytes_consumed), bytes_consumed_this_time)) { - - try { - auto ublox_msg = parser.gen_msg(); - if (ublox_msg.second.size() > 0) { - auto bytes = ublox_msg.second.asBytes(); - pm.send(ublox_msg.first.c_str(), bytes.begin(), bytes.size()); - } - } catch (const std::exception& e) { - LOGE("Error parsing ublox message %s", e.what()); - } - - parser.reset(); - } - bytes_consumed += bytes_consumed_this_time; - } - } - - return 0; -} diff --git a/system/ubloxd/ubloxd.py b/system/ubloxd/ubloxd.py new file mode 100644 index 00000000..6882ad09 --- /dev/null +++ b/system/ubloxd/ubloxd.py @@ -0,0 +1,519 @@ +#!/usr/bin/env python3 +import math +import capnp +import calendar +import numpy as np +from collections import defaultdict +from dataclasses import dataclass + +from cereal import log +from cereal import messaging +from openpilot.system.ubloxd.generated.ubx import Ubx +from openpilot.system.ubloxd.generated.gps import Gps +from openpilot.system.ubloxd.generated.glonass import Glonass + + +SECS_IN_MIN = 60 +SECS_IN_HR = 60 * SECS_IN_MIN +SECS_IN_DAY = 24 * SECS_IN_HR +SECS_IN_WEEK = 7 * SECS_IN_DAY + + +class UbxFramer: + PREAMBLE1 = 0xB5 + PREAMBLE2 = 0x62 + HEADER_SIZE = 6 + CHECKSUM_SIZE = 2 + + def __init__(self) -> None: + self.buf = bytearray() + self.last_log_time = 0.0 + + def reset(self) -> None: + self.buf.clear() + + @staticmethod + def _checksum_ok(frame: bytes) -> bool: + ck_a = 0 + ck_b = 0 + for b in frame[2:-2]: + ck_a = (ck_a + b) & 0xFF + ck_b = (ck_b + ck_a) & 0xFF + return ck_a == frame[-2] and ck_b == frame[-1] + + def add_data(self, log_time: float, incoming: bytes) -> list[bytes]: + self.last_log_time = log_time + out: list[bytes] = [] + if not incoming: + return out + self.buf += incoming + + while True: + # find preamble + if len(self.buf) < 2: + break + start = self.buf.find(b"\xB5\x62") + if start < 0: + # no preamble in buffer + self.buf.clear() + break + if start > 0: + # drop garbage before preamble + self.buf = self.buf[start:] + + if len(self.buf) < self.HEADER_SIZE: + break + + length_le = int.from_bytes(self.buf[4:6], 'little', signed=False) + total_len = self.HEADER_SIZE + length_le + self.CHECKSUM_SIZE + if len(self.buf) < total_len: + break + + candidate = bytes(self.buf[:total_len]) + if self._checksum_ok(candidate): + out.append(candidate) + # consume this frame + self.buf = self.buf[total_len:] + else: + # drop first byte and retry + self.buf = self.buf[1:] + + return out + + +def _bit(b: int, shift: int) -> bool: + return (b & (1 << shift)) != 0 + + +@dataclass +class EphemerisCaches: + gps_subframes: defaultdict[int, dict[int, bytes]] + glonass_strings: defaultdict[int, dict[int, bytes]] + glonass_string_times: defaultdict[int, dict[int, float]] + glonass_string_superframes: defaultdict[int, dict[int, int]] + + +class UbloxMsgParser: + gpsPi = 3.1415926535898 + + # user range accuracy in meters + glonass_URA_lookup: dict[int, float] = { + 0: 1, 1: 2, 2: 2.5, 3: 4, 4: 5, 5: 7, + 6: 10, 7: 12, 8: 14, 9: 16, 10: 32, + 11: 64, 12: 128, 13: 256, 14: 512, 15: 1024, + } + + def __init__(self) -> None: + self.framer = UbxFramer() + self.caches = EphemerisCaches( + gps_subframes=defaultdict(dict), + glonass_strings=defaultdict(dict), + glonass_string_times=defaultdict(dict), + glonass_string_superframes=defaultdict(dict), + ) + + # Message generation entry point + def parse_frame(self, frame: bytes) -> tuple[str, capnp.lib.capnp._DynamicStructBuilder] | None: + # Quick header parse + msg_type = int.from_bytes(frame[2:4], 'big') + payload = frame[6:-2] + if msg_type == 0x0107: + body = Ubx.NavPvt.from_bytes(payload) + return self._gen_nav_pvt(body) + if msg_type == 0x0213: + # Manually parse RXM-SFRBX to avoid Kaitai EOF on some frames + if len(payload) < 8: + return None + gnss_id = payload[0] + sv_id = payload[1] + freq_id = payload[3] + num_words = payload[4] + exp = 8 + 4 * num_words + if exp != len(payload): + return None + words: list[int] = [] + off = 8 + for _ in range(num_words): + words.append(int.from_bytes(payload[off:off+4], 'little')) + off += 4 + + class _SfrbxView: + def __init__(self, gid: int, sid: int, fid: int, body: list[int]): + self.gnss_id = Ubx.GnssType(gid) + self.sv_id = sid + self.freq_id = fid + self.body = body + view = _SfrbxView(gnss_id, sv_id, freq_id, words) + return self._gen_rxm_sfrbx(view) + if msg_type == 0x0215: + body = Ubx.RxmRawx.from_bytes(payload) + return self._gen_rxm_rawx(body) + if msg_type == 0x0A09: + body = Ubx.MonHw.from_bytes(payload) + return self._gen_mon_hw(body) + if msg_type == 0x0A0B: + body = Ubx.MonHw2.from_bytes(payload) + return self._gen_mon_hw2(body) + if msg_type == 0x0135: + body = Ubx.NavSat.from_bytes(payload) + return self._gen_nav_sat(body) + return None + + # NAV-PVT -> gpsLocationExternal + def _gen_nav_pvt(self, msg: Ubx.NavPvt) -> tuple[str, capnp.lib.capnp._DynamicStructBuilder]: + dat = messaging.new_message('gpsLocationExternal', valid=True) + gps = dat.gpsLocationExternal + gps.source = log.GpsLocationData.SensorSource.ublox + gps.flags = msg.flags + gps.hasFix = (msg.flags % 2) == 1 + gps.latitude = msg.lat * 1e-07 + gps.longitude = msg.lon * 1e-07 + gps.altitude = msg.height * 1e-03 + gps.speed = msg.g_speed * 1e-03 + gps.bearingDeg = msg.head_mot * 1e-5 + gps.horizontalAccuracy = msg.h_acc * 1e-03 + gps.satelliteCount = msg.num_sv + + # build UTC timestamp millis (NAV-PVT is in UTC) + # tolerate invalid or unset date values like C++ timegm + try: + utc_tt = calendar.timegm((msg.year, msg.month, msg.day, msg.hour, msg.min, msg.sec, 0, 0, 0)) + except Exception: + utc_tt = 0 + gps.unixTimestampMillis = int(utc_tt * 1e3 + (msg.nano * 1e-6)) + + # match C++ float32 rounding semantics exactly + gps.vNED = [ + float(np.float32(msg.vel_n) * np.float32(1e-03)), + float(np.float32(msg.vel_e) * np.float32(1e-03)), + float(np.float32(msg.vel_d) * np.float32(1e-03)), + ] + gps.verticalAccuracy = msg.v_acc * 1e-03 + gps.speedAccuracy = msg.s_acc * 1e-03 + gps.bearingAccuracyDeg = msg.head_acc * 1e-05 + return ('gpsLocationExternal', dat) + + # RXM-SFRBX dispatch to GPS or GLONASS ephemeris + def _gen_rxm_sfrbx(self, msg) -> tuple[str, capnp.lib.capnp._DynamicStructBuilder] | None: + if msg.gnss_id == Ubx.GnssType.gps: + return self._parse_gps_ephemeris(msg) + if msg.gnss_id == Ubx.GnssType.glonass: + return self._parse_glonass_ephemeris(msg) + return None + + def _parse_gps_ephemeris(self, msg: Ubx.RxmSfrbx) -> tuple[str, capnp.lib.capnp._DynamicStructBuilder] | None: + # body is list of 10 words; convert to 30-byte subframe (strip parity/padding) + body = msg.body + if len(body) != 10: + return None + subframe_data = bytearray() + for word in body: + word >>= 6 + subframe_data.append((word >> 16) & 0xFF) + subframe_data.append((word >> 8) & 0xFF) + subframe_data.append(word & 0xFF) + + sf = Gps.from_bytes(bytes(subframe_data)) + subframe_id = sf.how.subframe_id + if subframe_id < 1 or subframe_id > 3: + return None + self.caches.gps_subframes[msg.sv_id][subframe_id] = bytes(subframe_data) + + if len(self.caches.gps_subframes[msg.sv_id]) != 3: + return None + + dat = messaging.new_message('ubloxGnss', valid=True) + eph = dat.ubloxGnss.init('ephemeris') + eph.svId = msg.sv_id + + iode_s2 = 0 + iode_s3 = 0 + iodc_lsb = 0 + week = 0 + + # Subframe 1 + sf1 = Gps.from_bytes(self.caches.gps_subframes[msg.sv_id][1]) + s1 = sf1.body + assert isinstance(s1, Gps.Subframe1) + week = s1.week_no + week += 1024 + if week < 1877: + week += 1024 + eph.tgd = s1.t_gd * math.pow(2, -31) + eph.toc = s1.t_oc * math.pow(2, 4) + eph.af2 = s1.af_2 * math.pow(2, -55) + eph.af1 = s1.af_1 * math.pow(2, -43) + eph.af0 = s1.af_0 * math.pow(2, -31) + eph.svHealth = s1.sv_health + eph.towCount = sf1.how.tow_count + iodc_lsb = s1.iodc_lsb + + # Subframe 2 + sf2 = Gps.from_bytes(self.caches.gps_subframes[msg.sv_id][2]) + s2 = sf2.body + assert isinstance(s2, Gps.Subframe2) + if s2.t_oe == 0 and sf2.how.tow_count * 6 >= (SECS_IN_WEEK - 2 * SECS_IN_HR): + week += 1 + eph.crs = s2.c_rs * math.pow(2, -5) + eph.deltaN = s2.delta_n * math.pow(2, -43) * self.gpsPi + eph.m0 = s2.m_0 * math.pow(2, -31) * self.gpsPi + eph.cuc = s2.c_uc * math.pow(2, -29) + eph.ecc = s2.e * math.pow(2, -33) + eph.cus = s2.c_us * math.pow(2, -29) + eph.a = math.pow(s2.sqrt_a * math.pow(2, -19), 2.0) + eph.toe = s2.t_oe * math.pow(2, 4) + iode_s2 = s2.iode + + # Subframe 3 + sf3 = Gps.from_bytes(self.caches.gps_subframes[msg.sv_id][3]) + s3 = sf3.body + assert isinstance(s3, Gps.Subframe3) + eph.cic = s3.c_ic * math.pow(2, -29) + eph.omega0 = s3.omega_0 * math.pow(2, -31) * self.gpsPi + eph.cis = s3.c_is * math.pow(2, -29) + eph.i0 = s3.i_0 * math.pow(2, -31) * self.gpsPi + eph.crc = s3.c_rc * math.pow(2, -5) + eph.omega = s3.omega * math.pow(2, -31) * self.gpsPi + eph.omegaDot = s3.omega_dot * math.pow(2, -43) * self.gpsPi + eph.iode = s3.iode + eph.iDot = s3.idot * math.pow(2, -43) * self.gpsPi + iode_s3 = s3.iode + + eph.toeWeek = week + eph.tocWeek = week + + # clear cache for this SV + self.caches.gps_subframes[msg.sv_id].clear() + if not (iodc_lsb == iode_s2 == iode_s3): + return None + return ('ubloxGnss', dat) + + def _parse_glonass_ephemeris(self, msg: Ubx.RxmSfrbx) -> tuple[str, capnp.lib.capnp._DynamicStructBuilder] | None: + # words are 4 bytes each; Glonass parser expects 16 bytes (string) + body = msg.body + if len(body) != 4: + return None + string_bytes = bytearray() + for word in body: + for i in (3, 2, 1, 0): + string_bytes.append((word >> (8 * i)) & 0xFF) + + gl = Glonass.from_bytes(bytes(string_bytes)) + string_number = gl.string_number + if string_number < 1 or string_number > 5 or gl.idle_chip: + return None + + # correlate by superframe and timing, similar to C++ logic + freq_id = msg.freq_id + superframe_unknown = False + needs_clear = False + for i in range(1, 6): + if i not in self.caches.glonass_strings[freq_id]: + continue + sf_prev = self.caches.glonass_string_superframes[freq_id].get(i, 0) + if sf_prev == 0 or gl.superframe_number == 0: + superframe_unknown = True + elif sf_prev != gl.superframe_number: + needs_clear = True + if superframe_unknown: + prev_time = self.caches.glonass_string_times[freq_id].get(i, 0.0) + if abs((prev_time - 2.0 * i) - (self.framer.last_log_time - 2.0 * string_number)) > 10: + needs_clear = True + + if needs_clear: + self.caches.glonass_strings[freq_id].clear() + self.caches.glonass_string_superframes[freq_id].clear() + self.caches.glonass_string_times[freq_id].clear() + + self.caches.glonass_strings[freq_id][string_number] = bytes(string_bytes) + self.caches.glonass_string_superframes[freq_id][string_number] = gl.superframe_number + self.caches.glonass_string_times[freq_id][string_number] = self.framer.last_log_time + + if msg.sv_id == 255: + # unknown SV id + return None + if len(self.caches.glonass_strings[freq_id]) != 5: + return None + + dat = messaging.new_message('ubloxGnss', valid=True) + eph = dat.ubloxGnss.init('glonassEphemeris') + eph.svId = msg.sv_id + eph.freqNum = msg.freq_id - 7 + + current_day = 0 + tk = 0 + + # string 1 + try: + s1 = Glonass.from_bytes(self.caches.glonass_strings[freq_id][1]).data + except Exception: + return None + assert isinstance(s1, Glonass.String1) + eph.p1 = int(s1.p1) + tk = int(s1.t_k) + eph.tkDEPRECATED = tk + eph.xVel = float(s1.x_vel) * math.pow(2, -20) + eph.xAccel = float(s1.x_accel) * math.pow(2, -30) + eph.x = float(s1.x) * math.pow(2, -11) + + # string 2 + try: + s2 = Glonass.from_bytes(self.caches.glonass_strings[freq_id][2]).data + except Exception: + return None + assert isinstance(s2, Glonass.String2) + eph.svHealth = int(s2.b_n >> 2) + eph.p2 = int(s2.p2) + eph.tb = int(s2.t_b) + eph.yVel = float(s2.y_vel) * math.pow(2, -20) + eph.yAccel = float(s2.y_accel) * math.pow(2, -30) + eph.y = float(s2.y) * math.pow(2, -11) + + # string 3 + try: + s3 = Glonass.from_bytes(self.caches.glonass_strings[freq_id][3]).data + except Exception: + return None + assert isinstance(s3, Glonass.String3) + eph.p3 = int(s3.p3) + eph.gammaN = float(s3.gamma_n) * math.pow(2, -40) + eph.svHealth = int(eph.svHealth | (1 if s3.l_n else 0)) + eph.zVel = float(s3.z_vel) * math.pow(2, -20) + eph.zAccel = float(s3.z_accel) * math.pow(2, -30) + eph.z = float(s3.z) * math.pow(2, -11) + + # string 4 + try: + s4 = Glonass.from_bytes(self.caches.glonass_strings[freq_id][4]).data + except Exception: + return None + assert isinstance(s4, Glonass.String4) + current_day = int(s4.n_t) + eph.nt = current_day + eph.tauN = float(s4.tau_n) * math.pow(2, -30) + eph.deltaTauN = float(s4.delta_tau_n) * math.pow(2, -30) + eph.age = int(s4.e_n) + eph.p4 = int(s4.p4) + eph.svURA = float(self.glonass_URA_lookup.get(int(s4.f_t), 0.0)) + # consistency check: SV slot number + # if it doesn't match, keep going but note mismatch (no logging here) + eph.svType = int(s4.m) + + # string 5 + try: + s5 = Glonass.from_bytes(self.caches.glonass_strings[freq_id][5]).data + except Exception: + return None + assert isinstance(s5, Glonass.String5) + eph.n4 = int(s5.n_4) + tk_seconds = int(SECS_IN_HR * ((tk >> 7) & 0x1F) + SECS_IN_MIN * ((tk >> 1) & 0x3F) + (tk & 0x1) * 30) + eph.tkSeconds = tk_seconds + + self.caches.glonass_strings[freq_id].clear() + return ('ubloxGnss', dat) + + def _gen_rxm_rawx(self, msg: Ubx.RxmRawx) -> tuple[str, capnp.lib.capnp._DynamicStructBuilder]: + dat = messaging.new_message('ubloxGnss', valid=True) + mr = dat.ubloxGnss.init('measurementReport') + mr.rcvTow = msg.rcv_tow + mr.gpsWeek = msg.week + mr.leapSeconds = msg.leap_s + + mb = mr.init('measurements', msg.num_meas) + for i, m in enumerate(msg.meas): + mb[i].svId = m.sv_id + mb[i].pseudorange = m.pr_mes + mb[i].carrierCycles = m.cp_mes + mb[i].doppler = m.do_mes + mb[i].gnssId = int(m.gnss_id.value) + mb[i].glonassFrequencyIndex = m.freq_id + mb[i].locktime = m.lock_time + mb[i].cno = m.cno + mb[i].pseudorangeStdev = 0.01 * (math.pow(2, (m.pr_stdev & 15))) + mb[i].carrierPhaseStdev = 0.004 * (m.cp_stdev & 15) + mb[i].dopplerStdev = 0.002 * (math.pow(2, (m.do_stdev & 15))) + + ts = mb[i].init('trackingStatus') + trk = m.trk_stat + ts.pseudorangeValid = _bit(trk, 0) + ts.carrierPhaseValid = _bit(trk, 1) + ts.halfCycleValid = _bit(trk, 2) + ts.halfCycleSubtracted = _bit(trk, 3) + + mr.numMeas = msg.num_meas + rs = mr.init('receiverStatus') + rs.leapSecValid = _bit(msg.rec_stat, 0) + rs.clkReset = _bit(msg.rec_stat, 2) + return ('ubloxGnss', dat) + + def _gen_nav_sat(self, msg: Ubx.NavSat) -> tuple[str, capnp.lib.capnp._DynamicStructBuilder]: + dat = messaging.new_message('ubloxGnss', valid=True) + sr = dat.ubloxGnss.init('satReport') + sr.iTow = msg.itow + svs = sr.init('svs', msg.num_svs) + for i, s in enumerate(msg.svs): + svs[i].svId = s.sv_id + svs[i].gnssId = int(s.gnss_id.value) + svs[i].flagsBitfield = s.flags + svs[i].cno = s.cno + svs[i].elevationDeg = s.elev + svs[i].azimuthDeg = s.azim + svs[i].pseudorangeResidual = s.pr_res * 0.1 + return ('ubloxGnss', dat) + + def _gen_mon_hw(self, msg: Ubx.MonHw) -> tuple[str, capnp.lib.capnp._DynamicStructBuilder]: + dat = messaging.new_message('ubloxGnss', valid=True) + hw = dat.ubloxGnss.init('hwStatus') + hw.noisePerMS = msg.noise_per_ms + hw.flags = msg.flags + hw.agcCnt = msg.agc_cnt + hw.aStatus = int(msg.a_status.value) + hw.aPower = int(msg.a_power.value) + hw.jamInd = msg.jam_ind + return ('ubloxGnss', dat) + + def _gen_mon_hw2(self, msg: Ubx.MonHw2) -> tuple[str, capnp.lib.capnp._DynamicStructBuilder]: + dat = messaging.new_message('ubloxGnss', valid=True) + hw = dat.ubloxGnss.init('hwStatus2') + hw.ofsI = msg.ofs_i + hw.magI = msg.mag_i + hw.ofsQ = msg.ofs_q + hw.magQ = msg.mag_q + # Map Ubx enum to cereal enum {undefined=0, rom=1, otp=2, configpins=3, flash=4} + cfg_map = { + Ubx.MonHw2.ConfigSource.rom: 1, + Ubx.MonHw2.ConfigSource.otp: 2, + Ubx.MonHw2.ConfigSource.config_pins: 3, + Ubx.MonHw2.ConfigSource.flash: 4, + } + hw.cfgSource = cfg_map.get(msg.cfg_source, 0) + hw.lowLevCfg = msg.low_lev_cfg + hw.postStatus = msg.post_status + return ('ubloxGnss', dat) + + +def main(): + parser = UbloxMsgParser() + pm = messaging.PubMaster(['ubloxGnss', 'gpsLocationExternal']) + sock = messaging.sub_sock('ubloxRaw', timeout=100, conflate=False) + + while True: + msg = messaging.recv_one(sock) + if msg is None: + continue + + data = bytes(msg.ubloxRaw) + log_time = msg.logMonoTime * 1e-9 + frames = parser.framer.add_data(log_time, data) + for frame in frames: + try: + res = parser.parse_frame(frame) + except Exception: + continue + if not res: + continue + service, dat = res + pm.send(service, dat) + +if __name__ == '__main__': + main() diff --git a/system/updated/updated.py b/system/updated/updated.py index 0759c0a7..d8879b4e 100755 --- a/system/updated/updated.py +++ b/system/updated/updated.py @@ -31,8 +31,8 @@ FINALIZED = os.path.join(STAGING_ROOT, "finalized") OVERLAY_INIT = Path(os.path.join(BASEDIR, ".overlay_init")) -DAYS_NO_CONNECTIVITY_MAX = 14 # do not allow to engage after this many days -DAYS_NO_CONNECTIVITY_PROMPT = 10 # send an offroad prompt after this many days +DAYS_NO_CONNECTIVITY_MAX = 1400 # do not allow to engage after this many days +DAYS_NO_CONNECTIVITY_PROMPT = 1000 # send an offroad prompt after this many days class UserRequest: NONE = 0 diff --git a/third_party/SConscript b/third_party/SConscript index 507c17c4..3a7497d1 100644 --- a/third_party/SConscript +++ b/third_party/SConscript @@ -1,4 +1,3 @@ Import('env') env.Library('json11', ['json11/json11.cpp'], CCFLAGS=env['CCFLAGS'] + ['-Wno-unqualified-std-cast-call']) -env.Library('kaitai', ['kaitai/kaitaistream.cpp'], CPPDEFINES=['KS_STR_ENCODING_NONE']) diff --git a/third_party/kaitai/custom_decoder.h b/third_party/kaitai/custom_decoder.h deleted file mode 100644 index 6da7f5fd..00000000 --- a/third_party/kaitai/custom_decoder.h +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef KAITAI_CUSTOM_DECODER_H -#define KAITAI_CUSTOM_DECODER_H - -#include - -namespace kaitai { - -class custom_decoder { -public: - virtual ~custom_decoder() {}; - virtual std::string decode(std::string src) = 0; -}; - -} - -#endif diff --git a/third_party/kaitai/exceptions.h b/third_party/kaitai/exceptions.h deleted file mode 100644 index 5c09c467..00000000 --- a/third_party/kaitai/exceptions.h +++ /dev/null @@ -1,189 +0,0 @@ -#ifndef KAITAI_EXCEPTIONS_H -#define KAITAI_EXCEPTIONS_H - -#include - -#include -#include - -// We need to use "noexcept" in virtual destructor of our exceptions -// subclasses. Different compilers have different ideas on how to -// achieve that: C++98 compilers prefer `throw()`, C++11 and later -// use `noexcept`. We define KS_NOEXCEPT macro for that. - -#if __cplusplus >= 201103L || (defined(_MSC_VER) && _MSC_VER >= 1900) -#define KS_NOEXCEPT noexcept -#else -#define KS_NOEXCEPT throw() -#endif - -namespace kaitai { - -/** - * Common ancestor for all error originating from Kaitai Struct usage. - * Stores KSY source path, pointing to an element supposedly guilty of - * an error. - */ -class kstruct_error: public std::runtime_error { -public: - kstruct_error(const std::string what, const std::string src_path): - std::runtime_error(src_path + ": " + what), - m_src_path(src_path) - { - } - - virtual ~kstruct_error() KS_NOEXCEPT {}; - -protected: - const std::string m_src_path; -}; - -/** - * Error that occurs when default endianness should be decided with - * a switch, but nothing matches (although using endianness expression - * implies that there should be some positive result). - */ -class undecided_endianness_error: public kstruct_error { -public: - undecided_endianness_error(const std::string src_path): - kstruct_error("unable to decide on endianness for a type", src_path) - { - } - - virtual ~undecided_endianness_error() KS_NOEXCEPT {}; -}; - -/** - * Common ancestor for all validation failures. Stores pointer to - * KaitaiStream IO object which was involved in an error. - */ -class validation_failed_error: public kstruct_error { -public: - validation_failed_error(const std::string what, kstream* io, const std::string src_path): - kstruct_error("at pos " + kstream::to_string(static_cast(io->pos())) + ": validation failed: " + what, src_path), - m_io(io) - { - } - -// "at pos #{io.pos}: validation failed: #{msg}" - - virtual ~validation_failed_error() KS_NOEXCEPT {}; - -protected: - kstream* m_io; -}; - -/** - * Signals validation failure: we required "actual" value to be equal to - * "expected", but it turned out that it's not. - */ -template -class validation_not_equal_error: public validation_failed_error { -public: - validation_not_equal_error(const T& expected, const T& actual, kstream* io, const std::string src_path): - validation_failed_error("not equal", io, src_path), - m_expected(expected), - m_actual(actual) - { - } - - // "not equal, expected #{expected.inspect}, but got #{actual.inspect}" - - virtual ~validation_not_equal_error() KS_NOEXCEPT {}; - -protected: - const T& m_expected; - const T& m_actual; -}; - -/** - * Signals validation failure: we required "actual" value to be greater - * than or equal to "min", but it turned out that it's not. - */ -template -class validation_less_than_error: public validation_failed_error { -public: - validation_less_than_error(const T& min, const T& actual, kstream* io, const std::string src_path): - validation_failed_error("not in range", io, src_path), - m_min(min), - m_actual(actual) - { - } - - // "not in range, min #{min.inspect}, but got #{actual.inspect}" - - virtual ~validation_less_than_error() KS_NOEXCEPT {}; - -protected: - const T& m_min; - const T& m_actual; -}; - -/** - * Signals validation failure: we required "actual" value to be less - * than or equal to "max", but it turned out that it's not. - */ -template -class validation_greater_than_error: public validation_failed_error { -public: - validation_greater_than_error(const T& max, const T& actual, kstream* io, const std::string src_path): - validation_failed_error("not in range", io, src_path), - m_max(max), - m_actual(actual) - { - } - - // "not in range, max #{max.inspect}, but got #{actual.inspect}" - - virtual ~validation_greater_than_error() KS_NOEXCEPT {}; - -protected: - const T& m_max; - const T& m_actual; -}; - -/** - * Signals validation failure: we required "actual" value to be from - * the list, but it turned out that it's not. - */ -template -class validation_not_any_of_error: public validation_failed_error { -public: - validation_not_any_of_error(const T& actual, kstream* io, const std::string src_path): - validation_failed_error("not any of the list", io, src_path), - m_actual(actual) - { - } - - // "not any of the list, got #{actual.inspect}" - - virtual ~validation_not_any_of_error() KS_NOEXCEPT {}; - -protected: - const T& m_actual; -}; - -/** - * Signals validation failure: we required "actual" value to match - * the expression, but it turned out that it doesn't. - */ -template -class validation_expr_error: public validation_failed_error { -public: - validation_expr_error(const T& actual, kstream* io, const std::string src_path): - validation_failed_error("not matching the expression", io, src_path), - m_actual(actual) - { - } - - // "not matching the expression, got #{actual.inspect}" - - virtual ~validation_expr_error() KS_NOEXCEPT {}; - -protected: - const T& m_actual; -}; - -} - -#endif diff --git a/third_party/kaitai/kaitaistream.cpp b/third_party/kaitai/kaitaistream.cpp deleted file mode 100644 index d82ddb7e..00000000 --- a/third_party/kaitai/kaitaistream.cpp +++ /dev/null @@ -1,689 +0,0 @@ -#include - -#if defined(__APPLE__) -#include -#include -#define bswap_16(x) OSSwapInt16(x) -#define bswap_32(x) OSSwapInt32(x) -#define bswap_64(x) OSSwapInt64(x) -#define __BYTE_ORDER BYTE_ORDER -#define __BIG_ENDIAN BIG_ENDIAN -#define __LITTLE_ENDIAN LITTLE_ENDIAN -#elif defined(_MSC_VER) // !__APPLE__ -#include -#define __LITTLE_ENDIAN 1234 -#define __BIG_ENDIAN 4321 -#define __BYTE_ORDER __LITTLE_ENDIAN -#define bswap_16(x) _byteswap_ushort(x) -#define bswap_32(x) _byteswap_ulong(x) -#define bswap_64(x) _byteswap_uint64(x) -#else // !__APPLE__ or !_MSC_VER -#include -#include -#endif - -#include -#include -#include - -kaitai::kstream::kstream(std::istream* io) { - m_io = io; - init(); -} - -kaitai::kstream::kstream(std::string& data): m_io_str(data) { - m_io = &m_io_str; - init(); -} - -void kaitai::kstream::init() { - exceptions_enable(); - align_to_byte(); -} - -void kaitai::kstream::close() { - // m_io->close(); -} - -void kaitai::kstream::exceptions_enable() const { - m_io->exceptions( - std::istream::eofbit | - std::istream::failbit | - std::istream::badbit - ); -} - -// ======================================================================== -// Stream positioning -// ======================================================================== - -bool kaitai::kstream::is_eof() const { - if (m_bits_left > 0) { - return false; - } - char t; - m_io->exceptions( - std::istream::badbit - ); - m_io->get(t); - if (m_io->eof()) { - m_io->clear(); - exceptions_enable(); - return true; - } else { - m_io->unget(); - exceptions_enable(); - return false; - } -} - -void kaitai::kstream::seek(uint64_t pos) { - m_io->seekg(pos); -} - -uint64_t kaitai::kstream::pos() { - return m_io->tellg(); -} - -uint64_t kaitai::kstream::size() { - std::iostream::pos_type cur_pos = m_io->tellg(); - m_io->seekg(0, std::ios::end); - std::iostream::pos_type len = m_io->tellg(); - m_io->seekg(cur_pos); - return len; -} - -// ======================================================================== -// Integer numbers -// ======================================================================== - -// ------------------------------------------------------------------------ -// Signed -// ------------------------------------------------------------------------ - -int8_t kaitai::kstream::read_s1() { - char t; - m_io->get(t); - return t; -} - -// ........................................................................ -// Big-endian -// ........................................................................ - -int16_t kaitai::kstream::read_s2be() { - int16_t t; - m_io->read(reinterpret_cast(&t), 2); -#if __BYTE_ORDER == __LITTLE_ENDIAN - t = bswap_16(t); -#endif - return t; -} - -int32_t kaitai::kstream::read_s4be() { - int32_t t; - m_io->read(reinterpret_cast(&t), 4); -#if __BYTE_ORDER == __LITTLE_ENDIAN - t = bswap_32(t); -#endif - return t; -} - -int64_t kaitai::kstream::read_s8be() { - int64_t t; - m_io->read(reinterpret_cast(&t), 8); -#if __BYTE_ORDER == __LITTLE_ENDIAN - t = bswap_64(t); -#endif - return t; -} - -// ........................................................................ -// Little-endian -// ........................................................................ - -int16_t kaitai::kstream::read_s2le() { - int16_t t; - m_io->read(reinterpret_cast(&t), 2); -#if __BYTE_ORDER == __BIG_ENDIAN - t = bswap_16(t); -#endif - return t; -} - -int32_t kaitai::kstream::read_s4le() { - int32_t t; - m_io->read(reinterpret_cast(&t), 4); -#if __BYTE_ORDER == __BIG_ENDIAN - t = bswap_32(t); -#endif - return t; -} - -int64_t kaitai::kstream::read_s8le() { - int64_t t; - m_io->read(reinterpret_cast(&t), 8); -#if __BYTE_ORDER == __BIG_ENDIAN - t = bswap_64(t); -#endif - return t; -} - -// ------------------------------------------------------------------------ -// Unsigned -// ------------------------------------------------------------------------ - -uint8_t kaitai::kstream::read_u1() { - char t; - m_io->get(t); - return t; -} - -// ........................................................................ -// Big-endian -// ........................................................................ - -uint16_t kaitai::kstream::read_u2be() { - uint16_t t; - m_io->read(reinterpret_cast(&t), 2); -#if __BYTE_ORDER == __LITTLE_ENDIAN - t = bswap_16(t); -#endif - return t; -} - -uint32_t kaitai::kstream::read_u4be() { - uint32_t t; - m_io->read(reinterpret_cast(&t), 4); -#if __BYTE_ORDER == __LITTLE_ENDIAN - t = bswap_32(t); -#endif - return t; -} - -uint64_t kaitai::kstream::read_u8be() { - uint64_t t; - m_io->read(reinterpret_cast(&t), 8); -#if __BYTE_ORDER == __LITTLE_ENDIAN - t = bswap_64(t); -#endif - return t; -} - -// ........................................................................ -// Little-endian -// ........................................................................ - -uint16_t kaitai::kstream::read_u2le() { - uint16_t t; - m_io->read(reinterpret_cast(&t), 2); -#if __BYTE_ORDER == __BIG_ENDIAN - t = bswap_16(t); -#endif - return t; -} - -uint32_t kaitai::kstream::read_u4le() { - uint32_t t; - m_io->read(reinterpret_cast(&t), 4); -#if __BYTE_ORDER == __BIG_ENDIAN - t = bswap_32(t); -#endif - return t; -} - -uint64_t kaitai::kstream::read_u8le() { - uint64_t t; - m_io->read(reinterpret_cast(&t), 8); -#if __BYTE_ORDER == __BIG_ENDIAN - t = bswap_64(t); -#endif - return t; -} - -// ======================================================================== -// Floating point numbers -// ======================================================================== - -// ........................................................................ -// Big-endian -// ........................................................................ - -float kaitai::kstream::read_f4be() { - uint32_t t; - m_io->read(reinterpret_cast(&t), 4); -#if __BYTE_ORDER == __LITTLE_ENDIAN - t = bswap_32(t); -#endif - return reinterpret_cast(t); -} - -double kaitai::kstream::read_f8be() { - uint64_t t; - m_io->read(reinterpret_cast(&t), 8); -#if __BYTE_ORDER == __LITTLE_ENDIAN - t = bswap_64(t); -#endif - return reinterpret_cast(t); -} - -// ........................................................................ -// Little-endian -// ........................................................................ - -float kaitai::kstream::read_f4le() { - uint32_t t; - m_io->read(reinterpret_cast(&t), 4); -#if __BYTE_ORDER == __BIG_ENDIAN - t = bswap_32(t); -#endif - return reinterpret_cast(t); -} - -double kaitai::kstream::read_f8le() { - uint64_t t; - m_io->read(reinterpret_cast(&t), 8); -#if __BYTE_ORDER == __BIG_ENDIAN - t = bswap_64(t); -#endif - return reinterpret_cast(t); -} - -// ======================================================================== -// Unaligned bit values -// ======================================================================== - -void kaitai::kstream::align_to_byte() { - m_bits_left = 0; - m_bits = 0; -} - -uint64_t kaitai::kstream::read_bits_int_be(int n) { - int bits_needed = n - m_bits_left; - if (bits_needed > 0) { - // 1 bit => 1 byte - // 8 bits => 1 byte - // 9 bits => 2 bytes - int bytes_needed = ((bits_needed - 1) / 8) + 1; - if (bytes_needed > 8) - throw std::runtime_error("read_bits_int: more than 8 bytes requested"); - char buf[8]; - m_io->read(buf, bytes_needed); - for (int i = 0; i < bytes_needed; i++) { - uint8_t b = buf[i]; - m_bits <<= 8; - m_bits |= b; - m_bits_left += 8; - } - } - - // raw mask with required number of 1s, starting from lowest bit - uint64_t mask = get_mask_ones(n); - // shift mask to align with highest bits available in @bits - int shift_bits = m_bits_left - n; - mask <<= shift_bits; - // derive reading result - uint64_t res = (m_bits & mask) >> shift_bits; - // clear top bits that we've just read => AND with 1s - m_bits_left -= n; - mask = get_mask_ones(m_bits_left); - m_bits &= mask; - - return res; -} - -// Deprecated, use read_bits_int_be() instead. -uint64_t kaitai::kstream::read_bits_int(int n) { - return read_bits_int_be(n); -} - -uint64_t kaitai::kstream::read_bits_int_le(int n) { - int bits_needed = n - m_bits_left; - if (bits_needed > 0) { - // 1 bit => 1 byte - // 8 bits => 1 byte - // 9 bits => 2 bytes - int bytes_needed = ((bits_needed - 1) / 8) + 1; - if (bytes_needed > 8) - throw std::runtime_error("read_bits_int_le: more than 8 bytes requested"); - char buf[8]; - m_io->read(buf, bytes_needed); - for (int i = 0; i < bytes_needed; i++) { - uint8_t b = buf[i]; - m_bits |= (static_cast(b) << m_bits_left); - m_bits_left += 8; - } - } - - // raw mask with required number of 1s, starting from lowest bit - uint64_t mask = get_mask_ones(n); - // derive reading result - uint64_t res = m_bits & mask; - // remove bottom bits that we've just read by shifting - m_bits >>= n; - m_bits_left -= n; - - return res; -} - -uint64_t kaitai::kstream::get_mask_ones(int n) { - if (n == 64) { - return 0xFFFFFFFFFFFFFFFF; - } else { - return ((uint64_t) 1 << n) - 1; - } -} - -// ======================================================================== -// Byte arrays -// ======================================================================== - -std::string kaitai::kstream::read_bytes(std::streamsize len) { - std::vector result(len); - - // NOTE: streamsize type is signed, negative values are only *supposed* to not be used. - // http://en.cppreference.com/w/cpp/io/streamsize - if (len < 0) { - throw std::runtime_error("read_bytes: requested a negative amount"); - } - - if (len > 0) { - m_io->read(&result[0], len); - } - - return std::string(result.begin(), result.end()); -} - -std::string kaitai::kstream::read_bytes_full() { - std::iostream::pos_type p1 = m_io->tellg(); - m_io->seekg(0, std::ios::end); - std::iostream::pos_type p2 = m_io->tellg(); - size_t len = p2 - p1; - - // Note: this requires a std::string to be backed with a - // contiguous buffer. Officially, it's a only requirement since - // C++11 (C++98 and C++03 didn't have this requirement), but all - // major implementations had contiguous buffers anyway. - std::string result(len, ' '); - m_io->seekg(p1); - m_io->read(&result[0], len); - - return result; -} - -std::string kaitai::kstream::read_bytes_term(char term, bool include, bool consume, bool eos_error) { - std::string result; - std::getline(*m_io, result, term); - if (m_io->eof()) { - // encountered EOF - if (eos_error) { - throw std::runtime_error("read_bytes_term: encountered EOF"); - } - } else { - // encountered terminator - if (include) - result.push_back(term); - if (!consume) - m_io->unget(); - } - return result; -} - -std::string kaitai::kstream::ensure_fixed_contents(std::string expected) { - std::string actual = read_bytes(expected.length()); - - if (actual != expected) { - // NOTE: I think printing it outright is not best idea, it could contain non-ascii charactes like backspace and beeps and whatnot. It would be better to print hexlified version, and also to redirect it to stderr. - throw std::runtime_error("ensure_fixed_contents: actual data does not match expected data"); - } - - return actual; -} - -std::string kaitai::kstream::bytes_strip_right(std::string src, char pad_byte) { - std::size_t new_len = src.length(); - - while (new_len > 0 && src[new_len - 1] == pad_byte) - new_len--; - - return src.substr(0, new_len); -} - -std::string kaitai::kstream::bytes_terminate(std::string src, char term, bool include) { - std::size_t new_len = 0; - std::size_t max_len = src.length(); - - while (new_len < max_len && src[new_len] != term) - new_len++; - - if (include && new_len < max_len) - new_len++; - - return src.substr(0, new_len); -} - -// ======================================================================== -// Byte array processing -// ======================================================================== - -std::string kaitai::kstream::process_xor_one(std::string data, uint8_t key) { - size_t len = data.length(); - std::string result(len, ' '); - - for (size_t i = 0; i < len; i++) - result[i] = data[i] ^ key; - - return result; -} - -std::string kaitai::kstream::process_xor_many(std::string data, std::string key) { - size_t len = data.length(); - size_t kl = key.length(); - std::string result(len, ' '); - - size_t ki = 0; - for (size_t i = 0; i < len; i++) { - result[i] = data[i] ^ key[ki]; - ki++; - if (ki >= kl) - ki = 0; - } - - return result; -} - -std::string kaitai::kstream::process_rotate_left(std::string data, int amount) { - size_t len = data.length(); - std::string result(len, ' '); - - for (size_t i = 0; i < len; i++) { - uint8_t bits = data[i]; - result[i] = (bits << amount) | (bits >> (8 - amount)); - } - - return result; -} - -#ifdef KS_ZLIB -#include - -std::string kaitai::kstream::process_zlib(std::string data) { - int ret; - - unsigned char *src_ptr = reinterpret_cast(&data[0]); - std::stringstream dst_strm; - - z_stream strm; - strm.zalloc = Z_NULL; - strm.zfree = Z_NULL; - strm.opaque = Z_NULL; - - ret = inflateInit(&strm); - if (ret != Z_OK) - throw std::runtime_error("process_zlib: inflateInit error"); - - strm.next_in = src_ptr; - strm.avail_in = data.length(); - - unsigned char outbuffer[ZLIB_BUF_SIZE]; - std::string outstring; - - // get the decompressed bytes blockwise using repeated calls to inflate - do { - strm.next_out = reinterpret_cast(outbuffer); - strm.avail_out = sizeof(outbuffer); - - ret = inflate(&strm, 0); - - if (outstring.size() < strm.total_out) - outstring.append(reinterpret_cast(outbuffer), strm.total_out - outstring.size()); - } while (ret == Z_OK); - - if (ret != Z_STREAM_END) { // an error occurred that was not EOF - std::ostringstream exc_msg; - exc_msg << "process_zlib: error #" << ret << "): " << strm.msg; - throw std::runtime_error(exc_msg.str()); - } - - if (inflateEnd(&strm) != Z_OK) - throw std::runtime_error("process_zlib: inflateEnd error"); - - return outstring; -} -#endif - -// ======================================================================== -// Misc utility methods -// ======================================================================== - -int kaitai::kstream::mod(int a, int b) { - if (b <= 0) - throw std::invalid_argument("mod: divisor b <= 0"); - int r = a % b; - if (r < 0) - r += b; - return r; -} - -#include -std::string kaitai::kstream::to_string(int val) { - // if int is 32 bits, "-2147483648" is the longest string representation - // => 11 chars + zero => 12 chars - // if int is 64 bits, "-9223372036854775808" is the longest - // => 20 chars + zero => 21 chars - char buf[25]; - int got_len = snprintf(buf, sizeof(buf), "%d", val); - - // should never happen, but check nonetheless - if (got_len > sizeof(buf)) - throw std::invalid_argument("to_string: integer is longer than string buffer"); - - return std::string(buf); -} - -#include -std::string kaitai::kstream::reverse(std::string val) { - std::reverse(val.begin(), val.end()); - - return val; -} - -uint8_t kaitai::kstream::byte_array_min(const std::string val) { - uint8_t min = 0xff; // UINT8_MAX - std::string::const_iterator end = val.end(); - for (std::string::const_iterator it = val.begin(); it != end; ++it) { - uint8_t cur = static_cast(*it); - if (cur < min) { - min = cur; - } - } - return min; -} - -uint8_t kaitai::kstream::byte_array_max(const std::string val) { - uint8_t max = 0; // UINT8_MIN - std::string::const_iterator end = val.end(); - for (std::string::const_iterator it = val.begin(); it != end; ++it) { - uint8_t cur = static_cast(*it); - if (cur > max) { - max = cur; - } - } - return max; -} - -// ======================================================================== -// Other internal methods -// ======================================================================== - -#ifndef KS_STR_DEFAULT_ENCODING -#define KS_STR_DEFAULT_ENCODING "UTF-8" -#endif - -#ifdef KS_STR_ENCODING_ICONV - -#include -#include -#include - -std::string kaitai::kstream::bytes_to_str(std::string src, std::string src_enc) { - iconv_t cd = iconv_open(KS_STR_DEFAULT_ENCODING, src_enc.c_str()); - - if (cd == (iconv_t) -1) { - if (errno == EINVAL) { - throw std::runtime_error("bytes_to_str: invalid encoding pair conversion requested"); - } else { - throw std::runtime_error("bytes_to_str: error opening iconv"); - } - } - - size_t src_len = src.length(); - size_t src_left = src_len; - - // Start with a buffer length of double the source length. - size_t dst_len = src_len * 2; - std::string dst(dst_len, ' '); - size_t dst_left = dst_len; - - char *src_ptr = &src[0]; - char *dst_ptr = &dst[0]; - - while (true) { - size_t res = iconv(cd, &src_ptr, &src_left, &dst_ptr, &dst_left); - - if (res == (size_t) -1) { - if (errno == E2BIG) { - // dst buffer is not enough to accomodate whole string - // enlarge the buffer and try again - size_t dst_used = dst_len - dst_left; - dst_left += dst_len; - dst_len += dst_len; - dst.resize(dst_len); - - // dst.resize might have allocated destination buffer in another area - // of memory, thus our previous pointer "dst" will be invalid; re-point - // it using "dst_used". - dst_ptr = &dst[dst_used]; - } else { - throw std::runtime_error("bytes_to_str: iconv error"); - } - } else { - // conversion successful - dst.resize(dst_len - dst_left); - break; - } - } - - if (iconv_close(cd) != 0) { - throw std::runtime_error("bytes_to_str: iconv close error"); - } - - return dst; -} -#elif defined(KS_STR_ENCODING_NONE) -std::string kaitai::kstream::bytes_to_str(std::string src, std::string src_enc) { - return src; -} -#else -#error Need to decide how to handle strings: please define one of: KS_STR_ENCODING_ICONV, KS_STR_ENCODING_NONE -#endif diff --git a/third_party/kaitai/kaitaistream.h b/third_party/kaitai/kaitaistream.h deleted file mode 100644 index e7f4c6ce..00000000 --- a/third_party/kaitai/kaitaistream.h +++ /dev/null @@ -1,268 +0,0 @@ -#ifndef KAITAI_STREAM_H -#define KAITAI_STREAM_H - -// Kaitai Struct runtime API version: x.y.z = 'xxxyyyzzz' decimal -#define KAITAI_STRUCT_VERSION 9000L - -#include -#include -#include -#include - -namespace kaitai { - -/** - * Kaitai Stream class (kaitai::kstream) is an implementation of - * Kaitai Struct stream API - * for C++/STL. It's implemented as a wrapper over generic STL std::istream. - * - * It provides a wide variety of simple methods to read (parse) binary - * representations of primitive types, such as integer and floating - * point numbers, byte arrays and strings, and also provides stream - * positioning / navigation methods with unified cross-language and - * cross-toolkit semantics. - * - * Typically, end users won't access Kaitai Stream class manually, but would - * describe a binary structure format using .ksy language and then would use - * Kaitai Struct compiler to generate source code in desired target language. - * That code, in turn, would use this class and API to do the actual parsing - * job. - */ -class kstream { -public: - /** - * Constructs new Kaitai Stream object, wrapping a given std::istream. - * \param io istream object to use for this Kaitai Stream - */ - kstream(std::istream* io); - - /** - * Constructs new Kaitai Stream object, wrapping a given in-memory data - * buffer. - * \param data data buffer to use for this Kaitai Stream - */ - kstream(std::string& data); - - void close(); - - /** @name Stream positioning */ - //@{ - /** - * Check if stream pointer is at the end of stream. Note that the semantics - * are different from traditional STL semantics: one does *not* need to do a - * read (which will fail) after the actual end of the stream to trigger EOF - * flag, which can be accessed after that read. It is sufficient to just be - * at the end of the stream for this method to return true. - * \return "true" if we are located at the end of the stream. - */ - bool is_eof() const; - - /** - * Set stream pointer to designated position. - * \param pos new position (offset in bytes from the beginning of the stream) - */ - void seek(uint64_t pos); - - /** - * Get current position of a stream pointer. - * \return pointer position, number of bytes from the beginning of the stream - */ - uint64_t pos(); - - /** - * Get total size of the stream in bytes. - * \return size of the stream in bytes - */ - uint64_t size(); - //@} - - /** @name Integer numbers */ - //@{ - - // ------------------------------------------------------------------------ - // Signed - // ------------------------------------------------------------------------ - - int8_t read_s1(); - - // ........................................................................ - // Big-endian - // ........................................................................ - - int16_t read_s2be(); - int32_t read_s4be(); - int64_t read_s8be(); - - // ........................................................................ - // Little-endian - // ........................................................................ - - int16_t read_s2le(); - int32_t read_s4le(); - int64_t read_s8le(); - - // ------------------------------------------------------------------------ - // Unsigned - // ------------------------------------------------------------------------ - - uint8_t read_u1(); - - // ........................................................................ - // Big-endian - // ........................................................................ - - uint16_t read_u2be(); - uint32_t read_u4be(); - uint64_t read_u8be(); - - // ........................................................................ - // Little-endian - // ........................................................................ - - uint16_t read_u2le(); - uint32_t read_u4le(); - uint64_t read_u8le(); - - //@} - - /** @name Floating point numbers */ - //@{ - - // ........................................................................ - // Big-endian - // ........................................................................ - - float read_f4be(); - double read_f8be(); - - // ........................................................................ - // Little-endian - // ........................................................................ - - float read_f4le(); - double read_f8le(); - - //@} - - /** @name Unaligned bit values */ - //@{ - - void align_to_byte(); - uint64_t read_bits_int_be(int n); - uint64_t read_bits_int(int n); - uint64_t read_bits_int_le(int n); - - //@} - - /** @name Byte arrays */ - //@{ - - std::string read_bytes(std::streamsize len); - std::string read_bytes_full(); - std::string read_bytes_term(char term, bool include, bool consume, bool eos_error); - std::string ensure_fixed_contents(std::string expected); - - static std::string bytes_strip_right(std::string src, char pad_byte); - static std::string bytes_terminate(std::string src, char term, bool include); - static std::string bytes_to_str(std::string src, std::string src_enc); - - //@} - - /** @name Byte array processing */ - //@{ - - /** - * Performs a XOR processing with given data, XORing every byte of input with a single - * given value. - * @param data data to process - * @param key value to XOR with - * @return processed data - */ - static std::string process_xor_one(std::string data, uint8_t key); - - /** - * Performs a XOR processing with given data, XORing every byte of input with a key - * array, repeating key array many times, if necessary (i.e. if data array is longer - * than key array). - * @param data data to process - * @param key array of bytes to XOR with - * @return processed data - */ - static std::string process_xor_many(std::string data, std::string key); - - /** - * Performs a circular left rotation shift for a given buffer by a given amount of bits, - * using groups of 1 bytes each time. Right circular rotation should be performed - * using this procedure with corrected amount. - * @param data source data to process - * @param amount number of bits to shift by - * @return copy of source array with requested shift applied - */ - static std::string process_rotate_left(std::string data, int amount); - -#ifdef KS_ZLIB - /** - * Performs an unpacking ("inflation") of zlib-compressed data with usual zlib headers. - * @param data data to unpack - * @return unpacked data - * @throws IOException - */ - static std::string process_zlib(std::string data); -#endif - - //@} - - /** - * Performs modulo operation between two integers: dividend `a` - * and divisor `b`. Divisor `b` is expected to be positive. The - * result is always 0 <= x <= b - 1. - */ - static int mod(int a, int b); - - /** - * Converts given integer `val` to a decimal string representation. - * Should be used in place of std::to_string() (which is available only - * since C++11) in older C++ implementations. - */ - static std::string to_string(int val); - - /** - * Reverses given string `val`, so that the first character becomes the - * last and the last one becomes the first. This should be used to avoid - * the need of local variables at the caller. - */ - static std::string reverse(std::string val); - - /** - * Finds the minimal byte in a byte array, treating bytes as - * unsigned values. - * @param val byte array to scan - * @return minimal byte in byte array as integer - */ - static uint8_t byte_array_min(const std::string val); - - /** - * Finds the maximal byte in a byte array, treating bytes as - * unsigned values. - * @param val byte array to scan - * @return maximal byte in byte array as integer - */ - static uint8_t byte_array_max(const std::string val); - -private: - std::istream* m_io; - std::istringstream m_io_str; - int m_bits_left; - uint64_t m_bits; - - void init(); - void exceptions_enable() const; - - static uint64_t get_mask_ones(int n); - - static const int ZLIB_BUF_SIZE = 128 * 1024; -}; - -} - -#endif diff --git a/third_party/kaitai/kaitaistruct.h b/third_party/kaitai/kaitaistruct.h deleted file mode 100644 index 8172ede6..00000000 --- a/third_party/kaitai/kaitaistruct.h +++ /dev/null @@ -1,20 +0,0 @@ -#ifndef KAITAI_STRUCT_H -#define KAITAI_STRUCT_H - -#include - -namespace kaitai { - -class kstruct { -public: - kstruct(kstream *_io) { m__io = _io; } - virtual ~kstruct() {} -protected: - kstream *m__io; -public: - kstream *_io() { return m__io; } -}; - -} - -#endif diff --git a/tinygrad_repo/.pre-commit-config.yaml b/tinygrad_repo/.pre-commit-config.yaml index b6c7ef17..d3baff37 100644 --- a/tinygrad_repo/.pre-commit-config.yaml +++ b/tinygrad_repo/.pre-commit-config.yaml @@ -20,12 +20,6 @@ repos: language: system always_run: true pass_filenames: false - - id: devicetests - name: select GPU tests - entry: env GPU=1 PYTHONPATH="." python3 -m pytest test/test_uops.py test/test_search.py - language: system - always_run: true - pass_filenames: false - id: tests name: subset of tests entry: env PYTHONPATH="." python3 -m pytest -n=4 test/test_ops.py test/test_dtype.py test/test_schedule.py test/test_assign.py diff --git a/tinygrad_repo/.pylintrc b/tinygrad_repo/.pylintrc index 57e05e93..2f1de519 100644 --- a/tinygrad_repo/.pylintrc +++ b/tinygrad_repo/.pylintrc @@ -54,11 +54,12 @@ confidence= # --enable=similarities". If you want to run only the classes checker, but have # no Warning level messages displayed, use"--disable=all --enable=classes # --disable=W" -disable=C,R,W0613,W0511,W0212,W0201,W0106,W0603,W0621,W0703,W1201,W1203,E1136,W1514,E1101,W0221,W0105,E0401,abstract-method +disable=C,R,W0613,W0511,W0212,W0201,W0106,W0603,W0621,W0703,W1201,W1203,E1136,W1514,E1101,W0221,W0105,E0401,abstract-method,W0707 # E1101 for function binding # W0221 for Function class # W0105 for comment strings # E0401 for missing imports +# W0707 for not reraising # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option diff --git a/tinygrad_repo/README.md b/tinygrad_repo/README.md index 274f5000..dab378a2 100644 --- a/tinygrad_repo/README.md +++ b/tinygrad_repo/README.md @@ -79,9 +79,8 @@ See [examples/beautiful_mnist.py](examples/beautiful_mnist.py) for the full vers tinygrad already supports numerous accelerators, including: -- [x] [GPU (OpenCL)](tinygrad/runtime/ops_gpu.py) -- [x] [CPU (C Code)](tinygrad/runtime/ops_cpu.py) -- [x] [LLVM](tinygrad/runtime/ops_llvm.py) +- [x] [OpenCL](tinygrad/runtime/ops_cl.py) +- [x] [CPU](tinygrad/runtime/ops_cpu.py) - [x] [METAL](tinygrad/runtime/ops_metal.py) - [x] [CUDA](tinygrad/runtime/ops_cuda.py) - [x] [AMD](tinygrad/runtime/ops_amd.py) diff --git a/tinygrad_repo/docs/developer/layout.md b/tinygrad_repo/docs/developer/layout.md index 2f9a53a4..ab7701fb 100644 --- a/tinygrad_repo/docs/developer/layout.md +++ b/tinygrad_repo/docs/developer/layout.md @@ -22,12 +22,6 @@ Group UOps into kernels. Transforms the ast into an optimized ast. This is where BEAM search and heuristics live. -::: tinygrad.codegen.opt.get_optimized_ast - options: - members: false - show_labels: false - show_source: false - --- ## tinygrad/codegen diff --git a/tinygrad_repo/docs/env_vars.md b/tinygrad_repo/docs/env_vars.md index 74d1351e..f8844ba7 100644 --- a/tinygrad_repo/docs/env_vars.md +++ b/tinygrad_repo/docs/env_vars.md @@ -3,7 +3,7 @@ This is a list of environment variable that control the runtime behavior of tinygrad and its examples. Most of these are self-explanatory, and are usually used to set an option at runtime. -Example: `GPU=1 DEBUG=4 python3 -m pytest` +Example: `CL=1 DEBUG=4 python3 -m pytest` However you can also decorate a function to set a value only inside that function. @@ -31,19 +31,16 @@ These control the behavior of core tinygrad even when used as a library. Variable | Possible Value(s) | Description ---|---|--- DEBUG | [1-7] | enable debugging output (operations, timings, speed, generated code and more) -GPU | [1] | enable the GPU (OpenCL) backend +CL | [1] | enable OpenCL backend CUDA | [1] | enable CUDA backend AMD | [1] | enable AMD backend NV | [1] | enable NV backend METAL | [1] | enable Metal backend (for Mac M1 and after) -CPU | [1] | enable CPU (Clang) backend -LLVM | [1] | enable LLVM backend +CPU | [1] | enable CPU backend BEAM | [#] | number of beams in kernel beam search DEFAULT_FLOAT | [HALF, ...]| specify the default float dtype (FLOAT32, HALF, BFLOAT16, FLOAT64, ...), default to FLOAT32 IMAGE | [1-2] | enable 2d specific optimizations FLOAT16 | [1] | use float16 for images instead of float32 -PTX | [1] | enable the specialized [PTX](https://docs.nvidia.com/cuda/parallel-thread-execution/) assembler for Nvidia GPUs. If not set, defaults to generic CUDA codegen backend. -PROFILE | [1] | enable profiling. This feature is supported in NV, AMD, QCOM and METAL backends. VISIBLE_DEVICES | [list[int]]| restricts the NV/AMD devices that are available. The format is a comma-separated list of identifiers (indexing starts with 0). JIT | [0-2] | 0=disabled, 1=[jit enabled](quickstart.md#jit) (default), 2=jit enabled, but graphs are disabled VIZ | [1] | 0=disabled, 1=[viz enabled](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/viz) diff --git a/tinygrad_repo/docs/runtime.md b/tinygrad_repo/docs/runtime.md index bc85d9be..28a7aad0 100644 --- a/tinygrad_repo/docs/runtime.md +++ b/tinygrad_repo/docs/runtime.md @@ -2,17 +2,17 @@ tinygrad supports various runtimes, enabling your code to scale across a wide range of devices. The default runtime can be automatically selected based on the available hardware, or you can force a specific runtime to be default using environment variables (e.g., `CPU=1`). -| Runtime | Description | Requirements | -|---------|-------------|--------------| -| [NV](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_nv.py) | Provides acceleration for NVIDIA GPUs | Ampere/Ada series GPUs | -| [AMD](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_amd.py) | Provides acceleration for AMD GPUs | RDNA2/RDNA3/RDNA4 series GPUs. You can select one of the interfaces for communication by setting `AMD_IFACE=(KFD|PCI)`. See [AMD interfaces](#amd-interfaces) for more details. | -| [QCOM](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_qcom.py) | Provides acceleration for QCOM GPUs | 6xx series GPUs | -| [METAL](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_metal.py) | Utilizes Metal for acceleration on Apple devices | M1+ Macs; Metal 3.0+ for `bfloat` support | -| [CUDA](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_cuda.py) | Utilizes CUDA for acceleration on NVIDIA GPUs | NVIDIA GPU with CUDA support | -| [GPU (OpenCL)](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_gpu.py) | Accelerates computations using OpenCL on GPUs | OpenCL 2.0 compatible device | -| [CPU (C Code)](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_cpu.py) | Runs on CPU using the clang compiler | `clang` compiler in system `PATH` | -| [LLVM (LLVM IR)](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_llvm.py) | Runs on CPU using the LLVM compiler infrastructure | llvm libraries installed and findable | -| [WEBGPU](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_webgpu.py) | Runs on GPU using the Dawn WebGPU engine (used in Google Chrome) | Dawn library installed and findable. Download binaries [here](https://github.com/wpmed92/pydawn/releases/tag/v0.3.0). | +| Runtime | Description | Compiler Options | Requirements | +|---------|-------------|------------------|--------------| +| [NV](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_nv.py) | Provides acceleration for NVIDIA GPUs | nvrtc (default)
PTX (`NV_PTX=1`) | Ampere/Ada/Blackwell series GPUs.
You can select an interface via `NV_IFACE=(NVK\|PCI)`. See [NV interfaces](#nv-interfaces) for details. | +| [AMD](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_amd.py) | Provides acceleration for AMD GPUs | LLVM (`AMD_LLVM=1`)
HIP/COMGR (`AMD_HIP=1`) | RDNA2 or newer GPUs.
You can select an interface via `AMD_IFACE=(KFD\|PCI\|USB)`. See [AMD interfaces](#amd-interfaces) for details. | +| [QCOM](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_qcom.py) | Provides acceleration for QCOM GPUs | - | 6xx series GPUs | +| [METAL](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_metal.py) | Utilizes Metal for acceleration on Apple devices | - | M1+ Macs; Metal 3.0+ for `bfloat` support | +| [CUDA](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_cuda.py) | Utilizes CUDA for acceleration on NVIDIA GPUs | nvrtc (default)
PTX (`CUDA_PTX=1`) | NVIDIA GPU with CUDA support | +| [CL](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_cl.py) | Accelerates computations using OpenCL on GPUs | - | OpenCL 2.0 compatible device | +| [CPU](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_cpu.py) | Runs on CPU using the clang or llvm compiler | Clang JIT (default)
LLVM IR (`CPU_LLVM=1`) | `clang` compiler in system `PATH` | +| [WEBGPU](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_webgpu.py) | Runs on GPU using the Dawn WebGPU engine (used in Google Chrome) | - | Dawn library installed and discoverable. Binaries: [pydawn v0.3.0](https://github.com/wpmed92/pydawn/releases/tag/v0.3.0) | + ## Interoperability @@ -70,5 +70,12 @@ AMD backend supports several interfaces for communicating with devices: * `KFD`: uses the amdgpu driver * `PCI`: uses the [AM driver](developer/am.md) +* `USB`: USB3 interafce for asm24xx chips. You can force an interface by setting `AMD_IFACE` to one of these values. In the case of `AMD_IFACE=PCI`, this may unbind your GPU from the amdgpu driver. + +## NV Interfaces +NV backend supports several interfaces for communicating with devices: + +* `NVK`: uses the nvidia driver +* `PCI`: uses the [NV driver](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/support/nv/nvdev.py) diff --git a/tinygrad_repo/examples/beautiful_cifar.py b/tinygrad_repo/examples/beautiful_cifar.py index 66f693d9..cea8262f 100644 --- a/tinygrad_repo/examples/beautiful_cifar.py +++ b/tinygrad_repo/examples/beautiful_cifar.py @@ -2,7 +2,6 @@ import time start_tm = time.perf_counter() import math from typing import Tuple, cast -import numpy as np from tinygrad import Tensor, nn, GlobalCounters, TinyJit, dtypes, Device from tinygrad.helpers import partition, trange, getenv, Context from extra.lr_scheduler import OneCycleLR @@ -150,13 +149,12 @@ if __name__ == "__main__": acc.append((out.argmax(-1) == Y).sum() / eval_batchsize) return Tensor.stack(*loss).mean() / (batchsize*loss_batchsize_scaler), Tensor.stack(*acc).mean() - np.random.seed(1337) + Tensor.manual_seed(1337) + num_train_samples = X_train.shape[0] + for epoch in range(math.ceil(hyp['misc']['train_epochs'])): - # TODO: move to tinygrad gst = time.perf_counter() - idxs = np.arange(X_train.shape[0]) - np.random.shuffle(idxs) - tidxs = Tensor(idxs, dtype='int')[:num_steps_per_epoch*batchsize].reshape(num_steps_per_epoch, batchsize) # NOTE: long doesn't fold + tidxs = Tensor.randperm(num_train_samples, dtype='int')[:num_steps_per_epoch*batchsize].reshape(num_steps_per_epoch, batchsize) train_loss:float = 0 for epoch_step in (t:=trange(num_steps_per_epoch)): st = time.perf_counter() diff --git a/tinygrad_repo/examples/gpt2.py b/tinygrad_repo/examples/gpt2.py index 6a233327..6670b4e2 100644 --- a/tinygrad_repo/examples/gpt2.py +++ b/tinygrad_repo/examples/gpt2.py @@ -181,6 +181,7 @@ class GPT2: self.tokenizer = tokenizer def generate(self, prompt:str, max_length:int, temperature:float, timing:bool=False, batch_size:int=1): + step_times = [] prompt_tokens = self.tokenizer.encode(prompt, allowed_special={"<|endoftext|>"}) toks = [prompt_tokens[:] for _ in range(batch_size)] start_pos = 0 @@ -188,7 +189,7 @@ class GPT2: GlobalCounters.reset() if timing: print("") st = GlobalCounters.time_sum_s - with Timing("ran model in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "")+ + with Timing("ran model in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on {Device.DEFAULT}" if DEBUG>=2 else "")+ f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+ (f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=timing): with WallTimeEvent(BenchEvent.STEP): @@ -197,8 +198,13 @@ class GPT2: else: tokens = Tensor([x[start_pos:] for x in toks]) tok = self.model(tokens, Variable("start_pos", 1 if start_pos else 0, MAX_CONTEXT-1).bind(start_pos), temperature).tolist() + step_times.append((GlobalCounters.time_sum_s-st)*1e3) start_pos = len(toks[0]) for i,t in enumerate(tok): toks[i].append(t) + + if (assert_time:=getenv("ASSERT_MIN_STEP_TIME")): + min_time = min(step_times) + assert min_time < assert_time, f"Speed regression, expected min step time of < {assert_time} ms but took: {min_time} ms" return [self.tokenizer.decode(x) for x in toks] # **** main code **** diff --git a/tinygrad_repo/examples/hlb_cifar10.py b/tinygrad_repo/examples/hlb_cifar10.py index 8f19c8f7..bde89c4c 100644 --- a/tinygrad_repo/examples/hlb_cifar10.py +++ b/tinygrad_repo/examples/hlb_cifar10.py @@ -355,7 +355,7 @@ def train_cifar(): # https://www.anandtech.com/show/16727/nvidia-announces-geforce-rtx-3080-ti-3070-ti-upgraded-cards-coming-in-june # 136 TFLOPS is the theoretical max w float16 on 3080 Ti - + step_times = [] model_ema: Optional[modelEMA] = None projected_ema_decay_val = hyp['ema']['decay_base'] ** hyp['ema']['every_n_steps'] i = 0 @@ -413,12 +413,17 @@ def train_cifar(): model_ema.update(model, Tensor([projected_ema_decay_val*(i/STEPS)**hyp['ema']['decay_pow']])) cl = time.monotonic() + step_times.append((cl-st)*1000.0) device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}" # 53 221.74 ms run, 2.22 ms python, 219.52 ms CL, 803.39 loss, 0.000807 LR, 4.66 GB used, 3042.49 GFLOPS, 674.65 GOPS print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms {device_str}, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS, {GlobalCounters.global_ops*1e-9:9.2f} GOPS") st = cl i += 1 + if (assert_time:=getenv("ASSERT_MIN_STEP_TIME")): + min_time = min(step_times) + assert min_time < assert_time, f"Speed regression, expected min step time of < {assert_time} ms but took: {min_time} ms" + # verify eval acc if target := getenv("TARGET_EVAL_ACC_PCT", 0.0): if eval_acc_pct >= target: diff --git a/tinygrad_repo/examples/llama.py b/tinygrad_repo/examples/llama.py index 42f9b6e5..6739ca4c 100755 --- a/tinygrad_repo/examples/llama.py +++ b/tinygrad_repo/examples/llama.py @@ -478,7 +478,7 @@ After you are done speaking, output [EOS]. You are not Chad. with Profiling(enabled=args.profile): with Timing("total ", enabled=args.timing, on_exit=lambda x: f", {1e9/x:.2f} tok/s, {GlobalCounters.global_mem/x:.2f} GB/s, param {param_bytes/x:.2f} GB/s"): with WallTimeEvent(BenchEvent.STEP): - with Timing("enqueue in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "")+ + with Timing("enqueue in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on {Device.DEFAULT}" if DEBUG>=2 else "")+ f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+ (f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s, param {param_bytes*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=args.timing): tok_tensor = llama.model(next_tok, start_pos, args.temperature) diff --git a/tinygrad_repo/examples/llama3.py b/tinygrad_repo/examples/llama3.py index 9664f491..d7c7f2c9 100644 --- a/tinygrad_repo/examples/llama3.py +++ b/tinygrad_repo/examples/llama3.py @@ -441,7 +441,7 @@ if __name__ == "__main__": with Profiling(enabled=args.profile): with Timing("total ", on_exit=lambda x: f", {1e9/x:.2f} tok/s, {GlobalCounters.global_mem/x:.2f} GB/s, param {param_bytes/x:.2f} GB/s"): with WallTimeEvent(BenchEvent.STEP): - with Timing("enqueue in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "")+ + with Timing("enqueue in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on {Device.DEFAULT}" if DEBUG>=2 else "")+ f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+ (f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s, param {param_bytes*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None): tok = model(Tensor([[last_tok]], device=device), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P) @@ -479,7 +479,7 @@ if __name__ == "__main__": st = GlobalCounters.time_sum_s with Profiling(enabled=args.profile): with Timing("total ", enabled=args.timing, on_exit=lambda x: f", {1e9/x:.2f} tok/s, {GlobalCounters.global_mem/x:.2f} GB/s, param {param_bytes/x:.2f} GB/s"): - with Timing("enqueue in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "")+ + with Timing("enqueue in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on {Device.DEFAULT}" if DEBUG>=2 else "")+ f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+ (f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s, param {param_bytes*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=args.timing): diff --git a/tinygrad_repo/examples/mamba.py b/tinygrad_repo/examples/mamba.py index d6093eab..d3098074 100644 --- a/tinygrad_repo/examples/mamba.py +++ b/tinygrad_repo/examples/mamba.py @@ -279,9 +279,15 @@ def generate(model, tokenizer, prompt: str, n_tokens_to_gen: int = 10, temp: boo # Loading in the prompt tokens logits = model.forward(Tensor([tks]))[:, -1, :] for _ in tqdm(range(n_tokens_to_gen), desc="Speed Gen"): - # TODO: topk if sample: - tok_Tens = (logits/temp).softmax().multinomial() + scaled_logits = logits / temp + if top_k is not None: + topk_values, topk_indices = scaled_logits.topk(top_k) + filtered_logits = Tensor.full_like(scaled_logits, -float("inf")) + filtered_logits = filtered_logits.scatter(dim=-1, index=topk_indices, src=topk_values) + tok_Tens = filtered_logits.softmax().multinomial() + else: + tok_Tens = scaled_logits.softmax().multinomial() else: tok_Tens = logits.argmax(axis=-1).unsqueeze(0) tok = tok_Tens.item() @@ -298,6 +304,7 @@ if __name__ == "__main__": parser.add_argument("--size", type=str, default="370m", help=f"Size of model to use [{', '.join([k for k in MODELS.keys()])}]") parser.add_argument("--n_tokens", type=int, default=10, help="Number of tokens to generate") + parser.add_argument("--top_k", type=int, help="Limit sampling to the top k most likely tokens") parser.add_argument("--sample", dest="sample", action="store_true", help="Sample flag") parser.add_argument("--temp", type=float, default=1.0, help="Sampling temp has to be <=1.0") args = parser.parse_args() @@ -308,8 +315,9 @@ if __name__ == "__main__": num_toks = args.n_tokens sample = args.sample temp = args.temp + top_k = args.top_k s = time.time() - tinyoutput = generate(model, tokenizer, prompt, n_tokens_to_gen=num_toks, sample=sample, temp=temp) + tinyoutput = generate(model, tokenizer, prompt, n_tokens_to_gen=num_toks, sample=sample, temp=temp, top_k=top_k) print(tinyoutput) print('TIME: ', time.time() - s) TORCHOUTPUT = "Why is gravity \nso important?\nBecause it's the only" diff --git a/tinygrad_repo/examples/mlperf/dataloader.py b/tinygrad_repo/examples/mlperf/dataloader.py index c82c4241..09fb1915 100644 --- a/tinygrad_repo/examples/mlperf/dataloader.py +++ b/tinygrad_repo/examples/mlperf/dataloader.py @@ -758,6 +758,27 @@ def batch_load_llama3(bs:int, samples:int, seqlen:int, base_dir:Path, seed:int=0 batch.append(tokens) yield Tensor.stack(batch, dim=0) +def batch_load_llama3_small(bs:int, samples:int, seqlen:int, base_dir:Path, seed:int=0, val:bool=True): + if val: + dataset = BlendedGPTDataset([ + base_dir / "c4-validation-91205-samples.en_text_document", + ], [ + 1.0 + ], samples, seqlen, seed, False) + else: + dataset = BlendedGPTDataset([ + base_dir / "c4-train.en_6_text_document", + ], [ + 1.0 + ], samples, seqlen, seed, True) + + for b in range(math.ceil(samples / bs)): + batch = [] + for i in range(bs): + tokens = dataset.get(b * bs + i) + batch.append(tokens) + yield Tensor.stack(batch, dim=0) + if __name__ == "__main__": def load_unet3d(val): assert not val, "validation set is not supported due to different sizes on inputs" diff --git a/tinygrad_repo/examples/mlperf/model_eval.py b/tinygrad_repo/examples/mlperf/model_eval.py index 091f9456..b71c290a 100644 --- a/tinygrad_repo/examples/mlperf/model_eval.py +++ b/tinygrad_repo/examples/mlperf/model_eval.py @@ -243,31 +243,49 @@ def eval_mrcnn(): def eval_llama3(): from extra.models.llama import Transformer - from examples.llama3 import MODEL_PARAMS + from examples.llama3 import MODEL_PARAMS, load, convert_from_huggingface from tinygrad.helpers import tqdm - bs = 4 - sequence_length = 512 + BASEDIR = Path(getenv("BASEDIR", "/raid/datasets/c4/")) + BS = getenv("BS", 4) + SMALL = getenv("SMALL", 0) + SEQLEN = getenv("SEQLEN", 8192) + MODEL_PATH = Path(getenv("MODEL_PATH", "/raid/weights/llama31_8b/")) - model = Transformer(**(MODEL_PARAMS[getenv("LLAMA3_SIZE", "8B")]["args"]|{"vocab_size": 32000}), max_context=sequence_length, jit=False, disable_kv_cache=True) + params = MODEL_PARAMS[getenv("LLAMA3_SIZE", "8B")]["args"] + params = params | {"vocab_size": 32000} if not SMALL else params + if (llama_layers:=getenv("LLAMA_LAYERS")) != 0: params['n_layers'] = llama_layers + model = Transformer(**params, max_context=SEQLEN, jit=False, disable_kv_cache=True) + + # load weights + weights = load(str(MODEL_PATH / "model.safetensors.index.json")) + if "model.embed_tokens.weight" in weights: + print("converting from huggingface format") + weights = convert_from_huggingface(weights, params["n_layers"], params["n_heads"], params["n_kv_heads"]) + + load_state_dict(model, weights, strict=False, consume=True) @TinyJit def eval_step(model, tokens): logits:Tensor = model(tokens[:, :-1], start_pos=0, temperature=math.nan) loss = logits.sparse_categorical_crossentropy(tokens[:, 1:]) - return loss.flatten() + return loss.flatten().float() - from examples.mlperf.dataloader import batch_load_llama3 - iter = batch_load_llama3(bs, 5760, sequence_length, Path(getenv("BASEDIR", "/raid/datasets/c4/")), True) + if SMALL: + from examples.mlperf.dataloader import batch_load_llama3_small + iter = batch_load_llama3_small(BS, 5760, SEQLEN, BASEDIR, val=True) + else: + from examples.mlperf.dataloader import batch_load_llama3 + iter = batch_load_llama3(BS, 5760, SEQLEN, BASEDIR, val=True) losses = [] - for tokens in tqdm(iter, total=5760//bs): + for tokens in tqdm(iter, total=5760//BS): GlobalCounters.reset() losses += eval_step(model, tokens).tolist() tqdm.write(f"loss: {np.mean(losses)}") - log_perplexity = Tensor(losses).mean() - print(f"Log Perplexity: {log_perplexity.item()}") + log_perplexity = np.mean(losses) + print(f"Log Perplexity: {log_perplexity}") if __name__ == "__main__": # inference only diff --git a/tinygrad_repo/examples/mlperf/model_train.py b/tinygrad_repo/examples/mlperf/model_train.py index f3bab81f..c2e961e9 100644 --- a/tinygrad_repo/examples/mlperf/model_train.py +++ b/tinygrad_repo/examples/mlperf/model_train.py @@ -4,7 +4,7 @@ import multiprocessing from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear, FUSE_CONV_BW, Profiling -from tinygrad.nn.state import get_parameters, get_state_dict, safe_load, safe_save +from tinygrad.nn.state import get_parameters, get_state_dict, load_state_dict, safe_load, safe_save from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup, Adam, AdamW from extra.lr_scheduler import LRSchedulerGroup @@ -252,6 +252,10 @@ def train_resnet(): print(f"epoch global_ops: {steps_in_train_epoch * GlobalCounters.global_ops:_}, " f"epoch global_mem: {steps_in_train_epoch * GlobalCounters.global_mem:_}") # if we are doing beam search, run the first eval too + if (assert_time:=getenv("ASSERT_MIN_STEP_TIME")): + min_time = min(step_times) + assert min_time < assert_time, f"Speed regression, expected min step time of < {assert_time} ms but took: {min_time} ms" + if (TRAIN_BEAM or EVAL_BEAM) and e == start_epoch: break return if MLLOGGER and RUNMLPERF: @@ -344,6 +348,8 @@ def train_resnet(): print(f"saving ckpt to {fn}") safe_save(get_training_state(model, optimizer_group, scheduler_group), fn) + + def train_retinanet(): from contextlib import redirect_stdout from examples.mlperf.dataloader import batch_load_retinanet @@ -1290,12 +1296,14 @@ def train_llama3(): from examples.mlperf.lr_schedulers import CosineAnnealingLRWithWarmup config = {} + BASEDIR = config["BASEDIR"] = Path(getenv("BASEDIR", "/raid/datasets/c4/")) BS = config["BS"] = getenv("BS", 16) grad_acc = config["GRADIENT_ACC_STEPS"] = getenv("GRADIENT_ACC_STEPS", 1) GBS = config["GLOBAL_BATCH_SIZE"] = BS * grad_acc SEED = config["SEED"] = getenv("SEED", 5760) SEQLEN = config["SEQLEN"] = getenv("SEQLEN", 8192) TRAIN_ON_VAL = config["TRAIN_ON_VAL"] = getenv("TRAIN_ON_VAL", 0) + SMALL = config["SMALL"] = getenv("SMALL", 0) SAMPLES = config["SAMPLES"] = getenv("SAMPLES", 5_760 if TRAIN_ON_VAL else 1_200_000 * 1152) EVAL_FREQ = config["EVAL_FREQ"] = getenv("EVAL_FREQ", 46080) EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 16) @@ -1311,13 +1319,14 @@ def train_llama3(): opt_gradient_clip_norm = 1.0 opt_learning_rate_warmup_steps = getenv("WARMUP_STEPS", math.ceil(8000 * 1152 / GBS)) - opt_learning_rate_decay_steps = getenv("DECAY_STEPS", math.ceil(1_200_000 * 1152 / GBS) - opt_learning_rate_warmup_steps) + opt_learning_rate_decay_steps = getenv("MAX_STEPS", math.ceil(1_200_000 * 1152 / GBS)) - opt_learning_rate_warmup_steps opt_base_learning_rate = getenv("LR", 8e-5 * GBS / 1152) # NOTE: cannot change for benchmark - opt_end_learning_rate = 8e-7 + opt_end_learning_rate = getenv("END_LR", 8e-7) # TODO: confirm weights are in bf16 # vocab_size from the mixtral tokenizer - params = MODEL_PARAMS[getenv("LLAMA3_SIZE", "8B")]["args"]|{"vocab_size": 32000} + params = MODEL_PARAMS[getenv("LLAMA3_SIZE", "8B")]["args"] + params = params | {"vocab_size": 32000} if not SMALL else params if (llama_layers:=getenv("LLAMA_LAYERS")) != 0: params['n_layers'] = llama_layers model = Transformer(**params, max_context=SEQLEN, jit=False, disable_kv_cache=True) @@ -1353,6 +1362,15 @@ def train_llama3(): b1=opt_adamw_beta_1, b2=opt_adamw_beta_2, eps=opt_adamw_epsilon, weight_decay=opt_adamw_weight_decay) scheduler = CosineAnnealingLRWithWarmup(optim, opt_base_learning_rate, opt_end_learning_rate, opt_learning_rate_warmup_steps, opt_learning_rate_decay_steps) + if resume_ckpt := getenv("RESUME_CKPT"): + fn = f"./ckpts/llama3_{resume_ckpt}.safe" + print(f"loading initial checkpoint from {fn}") + load_state_dict(model, safe_load(fn), realize=False) + + fn = f"./ckpts/llama3_{resume_ckpt}_optim.safe" + print(f"loading optim checkpoint from {fn}") + load_state_dict(scheduler, safe_load(fn), realize=False) + @TinyJit @Tensor.train() def train_step(model, tokens:Tensor, grad_acc:int): @@ -1403,43 +1421,55 @@ def train_llama3(): # ** data iters ** def fake_data(bs, samples): for _ in range(samples // bs): - yield Tensor.randint(bs, SEQLEN + 1, low=0, high=32000, dtype=dtypes.int32, device=Device.DEFAULT) + yield Tensor.randint(bs, SEQLEN + 1, low=0, high=params["vocab_size"], dtype=dtypes.int32, device=Device.DEFAULT) def get_train_iter(): if getenv("FAKEDATA", 0): return fake_data(GBS, SAMPLES) else: - from examples.mlperf.dataloader import batch_load_llama3 - return batch_load_llama3(GBS, SAMPLES, SEQLEN, Path(getenv("BASEDIR", "/raid/datasets/c4/")), seed=SEED, val=bool(TRAIN_ON_VAL)) + if SMALL: + from examples.mlperf.dataloader import batch_load_llama3_small + return batch_load_llama3_small(GBS, SAMPLES, SEQLEN, BASEDIR, seed=SEED, val=bool(TRAIN_ON_VAL)) + else: + from examples.mlperf.dataloader import batch_load_llama3 + return batch_load_llama3(GBS, SAMPLES, SEQLEN, BASEDIR, seed=SEED, val=bool(TRAIN_ON_VAL)) def get_eval_iter(): if getenv("FAKEDATA", 0): return fake_data(EVAL_BS, 5760) else: - from examples.mlperf.dataloader import batch_load_llama3 - return batch_load_llama3(EVAL_BS, 5760, SEQLEN, Path(getenv("BASEDIR", "/raid/datasets/c4/")), seed=SEED, val=True) + if SMALL: + from examples.mlperf.dataloader import batch_load_llama3_small + return batch_load_llama3_small(EVAL_BS, 5760, SEQLEN, BASEDIR, val=True) + else: + from examples.mlperf.dataloader import batch_load_llama3 + return batch_load_llama3(EVAL_BS, 5760, SEQLEN, BASEDIR, val=True) iter = get_train_iter() - i, sequences_seen = 0, 0 + i, sequences_seen = resume_ckpt, 0 for tokens in tqdm(iter, total=SAMPLES//GBS): t = time.perf_counter() GlobalCounters.reset() loss, lr = train_step(model, tokens, grad_acc) loss = loss.float().item() - # above as tqdm.write f-string + + i += 1 + sequences_seen += tokens.shape[0] + tqdm.write(f"{loss:.4f} loss, {lr.item():.12f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, {time.perf_counter()-t:.2f} s") if (fname:=getenv("LOSS_FILE", "")): with open(fname, "a") as f: f.write(f"{i} {loss:.4f} {lr.item():.12f} {GlobalCounters.mem_used / 1e9:.2f}\n") - if getenv("CKPT") and (i % 200 == 0 or i == 10): + if (ckpt_freq := getenv("CKPT")) and (i % ckpt_freq == 0 and (i != 1 or ckpt_freq == 1)): tqdm.write("saving checkpoint") if not os.path.exists(ckpt_dir := "./ckpts"): os.mkdir(ckpt_dir) fn = f"{ckpt_dir}/llama3_{i}.safe" safe_save(get_state_dict(model), fn) - i += 1 - sequences_seen += tokens.shape[0] + tqdm.write("saving optim checkpoint") + fn = f"{ckpt_dir}/llama3_{i}_optim.safe" + safe_save(get_state_dict(scheduler), fn) if sequences_seen % EVAL_FREQ == 0 and (i != 1 or EVAL_FREQ == 1): tqdm.write(f"evaluating after {sequences_seen} sequences") diff --git a/tinygrad_repo/examples/mlperf/scripts/stable_diffusion_downloads.sh b/tinygrad_repo/examples/mlperf/scripts/stable_diffusion_downloads.sh new file mode 100644 index 00000000..5f6798c1 --- /dev/null +++ b/tinygrad_repo/examples/mlperf/scripts/stable_diffusion_downloads.sh @@ -0,0 +1,57 @@ +#!/usr/bin/env bash +# adapted from https://github.com/mlcommons/training/blob/4bdf5c8ed218ad76565a2ba1ac27c919ccc6d689/stable_diffusion/README.md + +# setup dirs + +DATA=/raid/datasets/stable_diffusion + +LAION=$DATA/laion-400m/webdataset-moments-filtered +COCO=$DATA/coco2014 +mkdir -p $LAION $COCO + +CKPT=/raid/weights/stable_diffusion +mkdir -p $CKPT/clip $CKPT/sd $CKPT/inception + +# download data + +# if rclone isn't installed system-wide / in your PATH, put the executable path in quotes below +#RCLONE="" +RCLONE="rclone" + +## VAE-encoded image latents, from 6.1M image subset of laion-400m +## about 1 TB for whole download +$RCLONE config create mlc-training s3 provider=Cloudflare access_key_id=76ea42eadb867e854061a1806220ee1e secret_access_key=a53625c4d45e3ca8ac0df8a353ea3a41ffc3292aa25259addd8b7dc5a6ce2936 endpoint=c2686074cb2caf5cbaf6d134bdba8b47.r2.cloudflarestorage.com +$RCLONE copy mlc-training:mlcommons-training-wg-public/stable_diffusion/datasets/laion-400m/moments-webdataset-filtered/ ${LAION} --include="*.tar" -P +$RCLONE copy mlc-training:mlcommons-training-wg-public/stable_diffusion/datasets/laion-400m/moments-webdataset-filtered/sha512sums.txt ${LAION} -P +cd $LAION && grep -E '\.tar$' sha512sums.txt | sha512sum -c --quiet - && \ + echo "All .tar files verified" || { echo "Checksum failure when validating downloaded Laion moments"; exit 1; } + +## prompts and FID statistics from 30k image subset of coco2014 +## 33 MB +$RCLONE config create mlc-training s3 provider=Cloudflare access_key_id=76ea42eadb867e854061a1806220ee1e secret_access_key=a53625c4d45e3ca8ac0df8a353ea3a41ffc3292aa25259addd8b7dc5a6ce2936 endpoint=c2686074cb2caf5cbaf6d134bdba8b47.r2.cloudflarestorage.com +$RCLONE copy mlc-training:mlcommons-training-wg-public/stable_diffusion/datasets/coco2014/val2014_30k.tsv ${COCO} -P + +$RCLONE config create mlc-training s3 provider=Cloudflare access_key_id=76ea42eadb867e854061a1806220ee1e secret_access_key=a53625c4d45e3ca8ac0df8a353ea3a41ffc3292aa25259addd8b7dc5a6ce2936 endpoint=c2686074cb2caf5cbaf6d134bdba8b47.r2.cloudflarestorage.com +$RCLONE copy mlc-training:mlcommons-training-wg-public/stable_diffusion/datasets/coco2014/val2014_30k_stats.npz ${COCO} -P + +# download checkpoints + +## clip (needed for text and vision encoders for validation) +CLIP_WEIGHTS_URL="https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin" +CLIP_WEIGHTS_SHA256="9a78ef8e8c73fd0df621682e7a8e8eb36c6916cb3c16b291a082ecd52ab79cc4" +CLIP_CONFIG_URL="https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/raw/main/open_clip_config.json" +wget -N -P ${CKPT}/clip ${CLIP_WEIGHTS_URL} +wget -N -P ${CKPT}/clip ${CLIP_CONFIG_URL} +echo "${CLIP_WEIGHTS_SHA256} ${CKPT}/clip/open_clip_pytorch_model.bin" | sha256sum -c + +## sd (needed for latent->image decoder for validation, also has clip text encoder for training) +SD_WEIGHTS_URL='https://huggingface.co/stabilityai/stable-diffusion-2-base/resolve/main/512-base-ema.ckpt' +SD_WEIGHTS_SHA256="d635794c1fedfdfa261e065370bea59c651fc9bfa65dc6d67ad29e11869a1824" +wget -N -P ${CKPT}/sd ${SD_WEIGHTS_URL} +echo "${SD_WEIGHTS_SHA256} ${CKPT}/sd/512-base-ema.ckpt" | sha256sum -c + +## inception (needed for validation) +FID_WEIGHTS_URL='https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' +FID_WEIGHTS_SHA1="bd836944fd6db519dfd8d924aa457f5b3c8357ff" +wget -N -P ${CKPT}/inception ${FID_WEIGHTS_URL} +echo "${FID_WEIGHTS_SHA1} ${CKPT}/inception/pt_inception-2015-12-05-6726825d.pth" | sha1sum -c \ No newline at end of file diff --git a/tinygrad_repo/examples/openpilot/compile4.py b/tinygrad_repo/examples/openpilot/compile4.py index c57bd3eb..55fcccbf 100644 --- a/tinygrad_repo/examples/openpilot/compile4.py +++ b/tinygrad_repo/examples/openpilot/compile4.py @@ -6,7 +6,7 @@ from tinygrad.schedule.kernelize import get_kernelize_map from tinygrad.engine.schedule import create_schedule_with_vars from tinygrad.engine.realize import run_schedule -# NOLOCALS=1 GPU=1 IMAGE=2 FLOAT16=1 VIZ=1 DEBUG=2 python3 examples/openpilot/compile4.py +# NOLOCALS=1 CL=1 IMAGE=2 FLOAT16=1 VIZ=1 DEBUG=2 python3 examples/openpilot/compile4.py OPENPILOT_MODEL = sys.argv[1] if len(sys.argv) > 1 else "https://github.com/commaai/openpilot/raw/v0.9.7/selfdrive/modeld/models/supercombo.onnx" OUTPUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/openpilot.pkl" diff --git a/tinygrad_repo/examples/qwq.py b/tinygrad_repo/examples/qwq.py index fad87695..b3b03065 100644 --- a/tinygrad_repo/examples/qwq.py +++ b/tinygrad_repo/examples/qwq.py @@ -8,7 +8,7 @@ from typing import Dict, Union from extra.models.llama import Transformer, convert_from_huggingface, fix_bf16 from examples.llama3 import load -from tinygrad import nn, Tensor +from tinygrad import nn, Tensor, Device from tinygrad.helpers import fetch, colored, GlobalCounters, Timing, DEBUG from tinygrad.nn.state import load_state_dict, get_parameters @@ -80,7 +80,7 @@ if __name__ == "__main__": st = GlobalCounters.time_sum_s next_tok = Tensor([toks[start_pos:]]) if tok_tensor is None or (len(toks)-start_pos) > 1 else tok_tensor.reshape(1, 1) with Timing("total ", enabled=args.timing, on_exit=lambda x: f", {1e9/x:.2f} tok/s, {GlobalCounters.global_mem/x:.2f} GB/s, param {param_bytes/x:.2f} GB/s"): - with Timing("enqueue in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "") + + with Timing("enqueue in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on {Device.DEFAULT}" if DEBUG>=2 else "") + f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB" + (f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s, param {param_bytes*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=args.timing): tok_tensor = transformer(next_tok, start_pos, args.temperature) diff --git a/tinygrad_repo/examples/sdxl.py b/tinygrad_repo/examples/sdxl.py index 24a5a6bb..92d85ce0 100644 --- a/tinygrad_repo/examples/sdxl.py +++ b/tinygrad_repo/examples/sdxl.py @@ -6,7 +6,7 @@ from tinygrad import Tensor, TinyJit, dtypes, GlobalCounters from tinygrad.nn import Conv2d, GroupNorm from tinygrad.nn.state import safe_load, load_state_dict -from tinygrad.helpers import fetch, trange, colored, Timing +from tinygrad.helpers import fetch, trange, colored, Timing, getenv from extra.models.clip import Embedder, FrozenClosedClipEmbedder, FrozenOpenClipEmbedder from extra.models.unet import UNetModel, Upsample, Downsample, timestep_embedding from extra.bench_log import BenchEvent, WallTimeEvent @@ -14,7 +14,7 @@ from examples.stable_diffusion import ResnetBlock, Mid import numpy as np from typing import Dict, List, Callable, Optional, Any, Set, Tuple, Union, Type -import argparse, tempfile +import argparse, tempfile, time from abc import ABC, abstractmethod from pathlib import Path from PIL import Image @@ -342,11 +342,13 @@ class DPMPP2MSampler: sigmas = self.discretization(num_steps).to(x.device) x *= Tensor.sqrt(1.0 + sigmas[0] ** 2.0) num_sigmas = len(sigmas) + step_times = [] old_denoised = None for i in trange(num_sigmas - 1): with Timing("step in ", enabled=timing, on_exit=lambda _: f", using {GlobalCounters.mem_used/1e9:.2f} GB"): GlobalCounters.reset() + st = time.perf_counter_ns() with WallTimeEvent(BenchEvent.STEP): x, old_denoised = self.sampler_step( old_denoised=old_denoised, @@ -358,8 +360,13 @@ class DPMPP2MSampler: c=c, uc=uc, ) + step_times.append(t:=(time.perf_counter_ns() - st)*1e-6) x.realize(old_denoised) + if (assert_time:=getenv("ASSERT_MIN_STEP_TIME")): + min_time = min(step_times) + assert min_time < assert_time, f"Speed regression, expected min step time of < {assert_time} ms but took: {min_time} ms" + return x @@ -430,8 +437,8 @@ if __name__ == "__main__": im.show() # validation! - if args.prompt == default_prompt and args.steps == 10 and args.seed == 0 and args.guidance == 6.0 and args.width == args.height == 1024 \ - and not args.weights: + is_default = args.prompt == default_prompt and args.steps == 10 and args.seed == 0 and args.guidance == 6.0 and args.width == args.height == 1024 + if is_default and not args.weights and not args.fakeweights: ref_image = Tensor(np.array(Image.open(Path(__file__).parent / "sdxl_seed0.png"))) distance = (((x.cast(dtypes.float) - ref_image.cast(dtypes.float)) / ref_image.max())**2).mean().item() assert distance < 4e-3, colored(f"validation failed with {distance=}", "red") diff --git a/tinygrad_repo/examples/stable_diffusion.py b/tinygrad_repo/examples/stable_diffusion.py index 44dca39e..1dbd74e7 100644 --- a/tinygrad_repo/examples/stable_diffusion.py +++ b/tinygrad_repo/examples/stable_diffusion.py @@ -2,7 +2,7 @@ # https://github.com/ekagra-ranjan/huggingface-blog/blob/main/stable_diffusion.md import tempfile from pathlib import Path -import argparse +import argparse, time from collections import namedtuple from typing import Dict, Any @@ -266,17 +266,23 @@ if __name__ == "__main__": def run(model, *x): return model(*x).realize() # this is diffusion + step_times = [] with Context(BEAM=getenv("LATEBEAM")): for index, timestep in (t:=tqdm(list(enumerate(timesteps))[::-1])): GlobalCounters.reset() + st = time.perf_counter_ns() t.set_description("%3d %3d" % (index, timestep)) with Timing("step in ", enabled=args.timing, on_exit=lambda _: f", using {GlobalCounters.mem_used/1e9:.2f} GB"): with WallTimeEvent(BenchEvent.STEP): tid = Tensor([index]) latent = run(model, unconditional_context, context, latent, Tensor([timestep]), alphas[tid], alphas_prev[tid], Tensor([args.guidance])) if args.timing: Device[Device.DEFAULT].synchronize() + step_times.append((time.perf_counter_ns() - st)*1e-6) del run + if (assert_time:=getenv("ASSERT_MIN_STEP_TIME")): + min_time = min(step_times) + assert min_time < assert_time, f"Speed regression, expected min step time of < {assert_time} ms but took: {min_time} ms" # upsample latent space to image with autoencoder x = model.decode(latent) print(x.shape) diff --git a/tinygrad_repo/examples/whisper.py b/tinygrad_repo/examples/whisper.py index 5cce861a..2df31226 100644 --- a/tinygrad_repo/examples/whisper.py +++ b/tinygrad_repo/examples/whisper.py @@ -109,7 +109,7 @@ class TextDecoder: def forward(self, x:Tensor, pos:Union[Variable, Literal[0]], encoded_audio:Tensor): seqlen = x.shape[-1] - x = self.token_embedding(x) + self.positional_embedding.shrink(((pos, pos+seqlen), None, None)) + x = self.token_embedding(x) + self.positional_embedding.shrink(((pos, pos+seqlen), None)) for block in self.blocks: x = block(x, xa=encoded_audio, mask=self.mask, len=pos) return self.output_tok(x) diff --git a/tinygrad_repo/extra/amdpci/proclogs.py b/tinygrad_repo/extra/amdpci/proclogs.py index 86b8a743..18c616d2 100644 --- a/tinygrad_repo/extra/amdpci/proclogs.py +++ b/tinygrad_repo/extra/amdpci/proclogs.py @@ -37,7 +37,7 @@ def main(): dev = PCIIface(None, 0) for x, y in dev.dev_impl.__dict__.items(): if isinstance(y, AMRegister): - for inst, addr in y.addr.keys(): reg_names[addr] = f"{x}, xcc={inst}" + for inst, addr in y.addr.items(): reg_names[addr] = f"{x}, xcc={inst}" with open(sys.argv[1], 'r') as f: log_content = log_content_them = f.read() diff --git a/tinygrad_repo/extra/archprobe.py b/tinygrad_repo/extra/archprobe.py index 73eb5037..7ba20b2a 100644 --- a/tinygrad_repo/extra/archprobe.py +++ b/tinygrad_repo/extra/archprobe.py @@ -1,7 +1,7 @@ # copying the kernels from https://github.com/microsoft/ArchProbe into Python import numpy as np import pickle -from tinygrad.runtime.ops_gpu import CLProgram, CLBuffer +from tinygrad.runtime.ops_cl import CLProgram, CLBuffer from tinygrad import dtypes from tqdm import trange, tqdm from matplotlib import pyplot as plt diff --git a/tinygrad_repo/extra/assembly/assembly_rdna.py b/tinygrad_repo/extra/assembly/assembly_rdna.py index 0f5ab01e..297639d6 100644 --- a/tinygrad_repo/extra/assembly/assembly_rdna.py +++ b/tinygrad_repo/extra/assembly/assembly_rdna.py @@ -4,7 +4,7 @@ from tinygrad import dtypes from tinygrad.codegen.assembly import AssemblyCodegen, Register from tinygrad.codegen.opt.kernel import Ops from tinygrad.uop.ops import BinaryOps, UnaryOps, TernaryOps -from tinygrad.runtime.ops_gpu import ROCM_LLVM_PATH +from tinygrad.runtime.ops_cl import ROCM_LLVM_PATH # ugh, is this really needed? from extra.helpers import enable_early_exec diff --git a/tinygrad_repo/extra/assembly/rocm/rdna3/asm.py b/tinygrad_repo/extra/assembly/rocm/rdna3/asm.py index 2f6ad132..9c65fa73 100644 --- a/tinygrad_repo/extra/assembly/rocm/rdna3/asm.py +++ b/tinygrad_repo/extra/assembly/rocm/rdna3/asm.py @@ -5,7 +5,7 @@ from tinygrad.helpers import colored from extra.helpers import enable_early_exec early_exec = enable_early_exec() -from tinygrad.runtime.ops_gpu import CLProgram, CLBuffer, ROCM_LLVM_PATH +from tinygrad.runtime.ops_cl import CLProgram, CLBuffer, ROCM_LLVM_PATH ENABLE_NON_ASM = False diff --git a/tinygrad_repo/extra/backends/clang_graph.py b/tinygrad_repo/extra/backends/clang_graph.py index 9c0cf3b3..2e946d54 100644 --- a/tinygrad_repo/extra/backends/clang_graph.py +++ b/tinygrad_repo/extra/backends/clang_graph.py @@ -10,13 +10,13 @@ from tinygrad.renderer.cstyle import ClangRenderer render_dtype = ClangRenderer().render_dtype class ClangGraph(GraphRunner): - def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]): + def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[str, int]): super().__init__(jit_cache, input_rawbuffers, var_vals) if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException prgs = '\n'.join(dedup([cast(CompiledRunner, ji.prg).p.src for ji in jit_cache])) args = [f"{render_dtype(x.dtype)}* arg{i}" for i,x in enumerate(input_rawbuffers)] - args += sorted([f"int {v.expr}" for v in var_vals]) + args += sorted([f"int {v}" for v in var_vals]) code = ["void batched("+','.join(args)+") {"] for ji in jit_cache: args = [] @@ -34,6 +34,6 @@ class ClangGraph(GraphRunner): assert compiler is not None self._prg = ClangProgram("batched", compiler.compile(prgs+"\n"+"\n".join(code))) # no point in caching the pointers - def __call__(self, rawbufs: List[Buffer], var_vals: Dict[Variable, int], wait=False): + def __call__(self, rawbufs: List[Buffer], var_vals: Dict[str, int], wait=False): return cpu_time_execution( - lambda: self._prg(*[x._buf for x in rawbufs], *[x[1] for x in sorted(var_vals.items(), key=lambda x: x[0].expr)]), enable=wait) \ No newline at end of file + lambda: self._prg(*[x._buf for x in rawbufs], *[x[1] for x in sorted(var_vals.items(), key=lambda x: x[0])]), enable=wait) diff --git a/tinygrad_repo/extra/backends/hsa_graph.py b/tinygrad_repo/extra/backends/hsa_graph.py index aced2c02..b8df5885 100644 --- a/tinygrad_repo/extra/backends/hsa_graph.py +++ b/tinygrad_repo/extra/backends/hsa_graph.py @@ -26,7 +26,7 @@ class VirtAQLQueue(AQLQueue): self.available_packet_slots -= 1 class HSAGraph(MultiGraphRunner): - def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]): + def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[str, int]): super().__init__(jit_cache, input_rawbuffers, var_vals) # Check all jit items are compatible. @@ -53,7 +53,7 @@ class HSAGraph(MultiGraphRunner): self.ji_kargs_structs[j] = ji.prg._prg.args_struct_t.from_address(kernargs_ptrs[ji.prg.dev]) kernargs_ptrs[ji.prg.dev] += round_up(ctypes.sizeof(ji.prg._prg.args_struct_t), 16) for i in range(len(ji.bufs)): self.ji_kargs_structs[j].__setattr__(f'f{i}', cast(Buffer, ji.bufs[i])._buf) - for i in range(len(ji.prg.p.vars)): self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[ji.prg.p.vars[i]]) + for i in range(len(ji.prg.p.vars)): self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[ji.prg.p.vars[i].expr]) # Build queues. self.virt_aql_queues: Dict[Compiled, VirtAQLQueue] = {dev:VirtAQLQueue(dev, 2*len(self.jit_cache)+16) for dev in self.devices} @@ -106,7 +106,7 @@ class HSAGraph(MultiGraphRunner): for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 0) hsa.hsa_signal_silent_store_relaxed(self.finish_signal, 0) - def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]: + def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[str, int], wait=False) -> Optional[float]: # Wait and restore signals hsa.hsa_signal_wait_scacquire(self.finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE) for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 1) @@ -123,7 +123,7 @@ class HSAGraph(MultiGraphRunner): # Update var_vals for j in self.jc_idx_with_updatable_var_vals: for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).p.vars): - self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[v]) + self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[v.expr]) # Update launch dims for j in self.jc_idx_with_updatable_launch_dims: diff --git a/tinygrad_repo/extra/backends/rdna.py b/tinygrad_repo/extra/backends/rdna.py index 32e9b8d9..a5b775b7 100644 --- a/tinygrad_repo/extra/backends/rdna.py +++ b/tinygrad_repo/extra/backends/rdna.py @@ -29,10 +29,10 @@ def uops_to_rdna(function_name:str, uops:UOpGraph) -> str: r: Dict[UOp, str] = {} for u in uops: if u.uop == UOps.SPECIAL: - if u.arg[1].startswith("lidx"): - r[u] = f'v{u.arg[0]}' - elif u.arg[1].startswith("gidx"): - r[u] = f's{2+u.arg[0]}' + if u.arg.startswith("lidx"): + r[u] = f'v{u.src[0].arg}' + elif u.arg.startswith("gidx"): + r[u] = f's{2+u.src[0].arg}' else: raise NotImplementedError elif u.uop == UOps.CONST: diff --git a/tinygrad_repo/extra/export_model.py b/tinygrad_repo/extra/export_model.py index 2d3d3426..e29f8a8d 100644 --- a/tinygrad_repo/extra/export_model.py +++ b/tinygrad_repo/extra/export_model.py @@ -10,7 +10,7 @@ from tinygrad.uop.ops import Ops import json from collections import OrderedDict -EXPORT_SUPPORTED_DEVICE = ["WEBGPU", "CPU", "CUDA", "GPU"] +EXPORT_SUPPORTED_DEVICE = ["WEBGPU", "CPU", "CUDA", "CL"] def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str],List[Tuple[str,List[str],List[int]]],Dict[str,Tuple[int,DType,int]],Dict[str,Tensor]]: functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0 @@ -67,11 +67,12 @@ def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,in forward_args = ",".join(f"{dtype}{'*' if name not in symbolic_vars.values() else ''} {name}" for name,dtype,_ in (outputs+inputs if wasm else inputs+outputs)) if not wasm: + thread_id = 0 # NOTE: export does not support threading, thread_id is always 0 for name,cl in bufs_to_save.items(): weight = ''.join(["\\x%02X"%x for x in bytes(to_mv(cl._buf.va_addr, cl._buf.size))]) cprog.append(f"unsigned char {name}_data[] = \"{weight}\";") cprog += [f"{dtype_map[dtype]} {name}[{len}];" if name not in bufs_to_save else f"{dtype_map[dtype]} *{name} = ({dtype_map[dtype]} *){name}_data;" for name,(len,dtype,_key) in bufs.items() if name not in input_names+output_names] - cprog += [f"void net({forward_args}) {{"] + [f"{name}({', '.join(args)});" for (name, args, _global_size, _local_size) in statements] + ["}"] + cprog += [f"void net({forward_args}) {{"] + [f"{name}({', '.join(args)}, {thread_id});" for (name, args, _global_size, _local_size) in statements] + ["}"] return '\n'.join(headers + cprog) else: if bufs_to_save: @@ -239,7 +240,9 @@ export default {model_name}; def export_model(model, target:str, *inputs, model_name: Optional[str] = "model", stream_weights=False): assert Device.DEFAULT in EXPORT_SUPPORTED_DEVICE, f"only {', '.join(EXPORT_SUPPORTED_DEVICE)} are supported" - with Context(JIT=2): run,special_names = jit_model(model, *inputs) + + # NOTE: CPU_COUNT=1, since export does not support threading + with Context(JIT=2, CPU_COUNT=1): run,special_names = jit_model(model, *inputs) functions, statements, bufs, bufs_to_save = compile_net(run, special_names) state = get_state_dict(model) weight_names = {id(x.uop.base.realized): name for name, x in state.items()} diff --git a/tinygrad_repo/extra/gemm/amd_uop_matmul.py b/tinygrad_repo/extra/gemm/amd_uop_matmul.py index 9ad5351b..78dbf81a 100644 --- a/tinygrad_repo/extra/gemm/amd_uop_matmul.py +++ b/tinygrad_repo/extra/gemm/amd_uop_matmul.py @@ -65,7 +65,7 @@ def top_spec_kernel3(): c = a@b sink = c.schedule()[-1].ast L = 16 - sink = sink.reshape((N//L, L, N//L, L)) #.lift({0:UOp.range(dtypes.int, N//BM, 0), 2:UOp.range(dtypes.int, N//BN, 1)}) + sink = sink.reshape((N//L, L, N//L, L)) #.lift({0:UOp.range(N//BM, 0), 2:UOp.range(N//BN, 1)}) sink = graph_rewrite(sink, view_left+pm) axis_types = (AxisType.GLOBAL, AxisType.LOCAL, AxisType.GLOBAL, AxisType.LOCAL, AxisType.REDUCE) return sink.replace(arg=KernelInfo(name="top_"+to_colored(sink.full_shape, axis_types), axis_types=axis_types)) @@ -186,7 +186,7 @@ def hand_spec_kernel3(kernel4=getenv("K4", 0), kernel5=getenv("K5", 0)): c_regs = UOp(Ops.DEFINE_REG, dtypes.float.ptr(TM * nbIterWaveM * TN * nbIterWaveN), arg=2) - i = UOp.range(dtypes.int, c_regs.dtype.size, 16) + i = UOp.range(c_regs.dtype.size, 16) init_store = c_regs[i].store(UOp.const(dtypes.float, 0.0), i) if kernel4: @@ -197,53 +197,53 @@ def hand_spec_kernel3(kernel4=getenv("K4", 0), kernel5=getenv("K5", 0)): kId = 0 # load from globals into locals - i = UOp.range(dtypes.int, nbReadsB, 0) + i = UOp.range(nbReadsB, 0) index_x = BN * blockIdx_x + rBIdx index_y = rBIdy + i * strideReadB + kId Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(b[N * index_y + index_x].load(), i) - i = UOp.range(dtypes.int, nbReadsA, 1) + i = UOp.range(nbReadsA, 1) index_x = rAIdx + kId index_y = BM * blockIdx_y + rAIdy + i * strideReadA As_store = As[(index_x % BK) * BM_As_stride + index_y % BM].store(a[N * index_y + index_x].load(), i) # iterate over the middle chunk - kId_range = UOp.range(dtypes.int, N//BK-1, 2) + kId_range = UOp.range(N//BK-1, 2) kId = kId_range*BK barrier = UOp.barrier(As_store, Bs_store) # load from globals into registers (next round) - i = UOp.range(dtypes.int, nbReadsB, 3) + i = UOp.range(nbReadsB, 3) index_x = BN * blockIdx_x + rBIdx index_y = rBIdy + i * strideReadB + kId + BK regB_store = regB[i].store(b[N * index_y + index_x].load(), i) - i = UOp.range(dtypes.int, nbReadsA, 4) + i = UOp.range(nbReadsA, 4) index_x = rAIdx + kId + BK index_y = BM * blockIdx_y + rAIdy + i * strideReadA regA_store = regA[i].store(a[N * index_y + index_x].load(), i) def inner_loop(first_range, inp_dep=()): # inner unroll - k = UOp.range(dtypes.int, BK, first_range+0) + k = UOp.range(BK, first_range+0) # load from locals into registers - iterWave = UOp.range(dtypes.int, nbIterWaveN, first_range+1) - i = UOp.range(dtypes.int, TN, first_range+2) + iterWave = UOp.range(nbIterWaveN, first_range+1) + i = UOp.range(TN, first_range+2) index = waveIdx * WN + iterWave * SUBWN + TN * idxInWave + i B_row_store = B_row[iterWave*TN + i].store(Bs[k*BN + index].load(*inp_dep), iterWave, i) - iterWave = UOp.range(dtypes.int, nbIterWaveM, first_range+3) - i = UOp.range(dtypes.int, TM, first_range+4) + iterWave = UOp.range(nbIterWaveM, first_range+3) + i = UOp.range(TM, first_range+4) index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i A_col_store = A_col[iterWave*TM + i].store(As[k*BM_As_stride + index].load(*inp_dep), iterWave, i) # do the GEMM math - iterWaveM = UOp.range(dtypes.int, nbIterWaveM, first_range+5) - yt = UOp.range(dtypes.int, TM, first_range+6) - iterWaveN = UOp.range(dtypes.int, nbIterWaveN, first_range+7) - xt = UOp.range(dtypes.int, TN, first_range+8) + iterWaveM = UOp.range(nbIterWaveM, first_range+5) + yt = UOp.range(TM, first_range+6) + iterWaveN = UOp.range(nbIterWaveN, first_range+7) + xt = UOp.range(TN, first_range+8) x = iterWaveN * TN + xt y = iterWaveM * TM + yt c_regs_idx = c_regs[y * TN * nbIterWaveN + x] @@ -256,12 +256,12 @@ def hand_spec_kernel3(kernel4=getenv("K4", 0), kernel5=getenv("K5", 0)): sink = inner_loop(5, (barrier, regB_store, regA_store)).barrier() # load from registers into locals - i = UOp.range(dtypes.int, nbReadsB, 14) + i = UOp.range(nbReadsB, 14) index_x = BN * blockIdx_x + rBIdx index_y = rBIdy + i * strideReadB + kId + BK Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(regB[i].load(sink), i, kId_range) - i = UOp.range(dtypes.int, nbReadsA, 15) + i = UOp.range(nbReadsA, 15) index_x = rAIdx + kId + BK index_y = BM * blockIdx_y + rAIdy + i * strideReadA As_store = As[(index_x % BK) * BM_As_stride + index_y % BM].store(regA[i].load(sink), i, kId_range) @@ -269,40 +269,40 @@ def hand_spec_kernel3(kernel4=getenv("K4", 0), kernel5=getenv("K5", 0)): # final iteration without the copy sink = inner_loop(16, (UOp.barrier(Bs_store, As_store),)) else: - kId_range = UOp.range(dtypes.int, N//BK, 0) + kId_range = UOp.range(N//BK, 0) kId = kId_range*BK # load from globals into locals - i = UOp.range(dtypes.int, nbReadsB, 1) + i = UOp.range(nbReadsB, 1) index_x = BN * blockIdx_x + rBIdx index_y = rBIdy + i * strideReadB + kId Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(b[N * index_y + index_x].load(), i) - i = UOp.range(dtypes.int, nbReadsA, 2) + i = UOp.range(nbReadsA, 2) index_x = rAIdx + kId index_y = BM * blockIdx_y + rAIdy + i * strideReadA As_store = As[(index_x % BK) * BM_As_stride + index_y % BM].store(a[N * index_y + index_x].load(), i) barrier = UOp.barrier(As_store, Bs_store) - k = UOp.range(dtypes.int, BK, 3) + k = UOp.range(BK, 3) # load from locals into registers - iterWave = UOp.range(dtypes.int, nbIterWaveN, 4) - i = UOp.range(dtypes.int, TN, 5) + iterWave = UOp.range(nbIterWaveN, 4) + i = UOp.range(TN, 5) index = waveIdx * WN + iterWave * SUBWN + TN * idxInWave + i B_row_store = B_row[iterWave*TN + i].store(Bs[k*BN + index].load(barrier), iterWave, i) - iterWave = UOp.range(dtypes.int, nbIterWaveM, 6) - i = UOp.range(dtypes.int, TM, 7) + iterWave = UOp.range(nbIterWaveM, 6) + i = UOp.range(TM, 7) index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i A_col_store = A_col[iterWave*TM + i].store(As[k*BM_As_stride + index].load(barrier), iterWave, i) # do the GEMM math - iterWaveM = UOp.range(dtypes.int, nbIterWaveM, 8) - yt = UOp.range(dtypes.int, TM, 9) - iterWaveN = UOp.range(dtypes.int, nbIterWaveN, 10) - xt = UOp.range(dtypes.int, TN, 12) + iterWaveM = UOp.range(nbIterWaveM, 8) + yt = UOp.range(TM, 9) + iterWaveN = UOp.range(nbIterWaveN, 10) + xt = UOp.range(TN, 12) x = iterWaveN * TN + xt y = iterWaveM * TM + yt c_regs_idx = c_regs[y * TN * nbIterWaveN + x] @@ -310,10 +310,10 @@ def hand_spec_kernel3(kernel4=getenv("K4", 0), kernel5=getenv("K5", 0)): iterWaveM, iterWaveN, yt, xt, k, kId_range) # store c_regs into c - iterWaveM = UOp.range(dtypes.int, nbIterWaveM, 1000) - yt = UOp.range(dtypes.int, TM, 1001) - iterWaveN = UOp.range(dtypes.int, nbIterWaveN, 1002) - xt = UOp.range(dtypes.int, TN, 1003) + iterWaveM = UOp.range(nbIterWaveM, 1000) + yt = UOp.range(TM, 1001) + iterWaveN = UOp.range(nbIterWaveN, 1002) + xt = UOp.range(TN, 1003) xOut = blockIdx_x * BN + waveIdx * WN + iterWaveN * SUBWN + TN * idxInWave yOut = blockIdx_y * BM + waveIdy * WM + iterWaveM * SUBWM + TM * idyInWave indexC = N * (yOut + yt) + xOut + xt diff --git a/tinygrad_repo/extra/gemm/intel_xmx.py b/tinygrad_repo/extra/gemm/intel_xmx.py index 8ec478e5..71983047 100644 --- a/tinygrad_repo/extra/gemm/intel_xmx.py +++ b/tinygrad_repo/extra/gemm/intel_xmx.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 import numpy as np -from tinygrad.runtime.ops_gpu import CLProgram, CLCompiler +from tinygrad.runtime.ops_cl import CLProgram, CLCompiler from tinygrad import Device, dtypes from tinygrad.device import Buffer from hexdump import hexdump @@ -11,7 +11,7 @@ from hexdump import hexdump # https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_split_matrix_multiply_accumulate.html # https://hc34.hotchips.org/assets/program/conference/day1/GPU%20HPC/Intel_s%20Ponte%20Vecchio%20GPU%20-%20Architecture%20Systems%20and%20Software%20FINAL.pdf -device = Device["GPU"] +device = Device["CL"] # NOTE: only the subgroup type 8 ones work prog = CLProgram(device, "test", CLCompiler(device, "test").compile(f""" @@ -26,9 +26,9 @@ __kernel void test(__global float* data0, const __global int* data1, const __glo """)) #with open("/tmp/test.elf", "wb") as f: f.write(prog.lib) -a = Buffer("GPU", 8, dtypes.float32).allocate() -b = Buffer("GPU", 0x10, dtypes.float16).allocate() -c = Buffer("GPU", 8*0x10, dtypes.float16).allocate() +a = Buffer("CL", 8, dtypes.float32).allocate() +b = Buffer("CL", 0x10, dtypes.float16).allocate() +c = Buffer("CL", 8*0x10, dtypes.float16).allocate() row = np.array([1,2,3,4,5,6,7,8,1,2,3,4,5,6,7,8], np.float16) mat = np.random.random((8, 0x10)).astype(np.float16) diff --git a/tinygrad_repo/extra/gemm/max_matmul.py b/tinygrad_repo/extra/gemm/max_matmul.py index c1c376ea..50414978 100644 --- a/tinygrad_repo/extra/gemm/max_matmul.py +++ b/tinygrad_repo/extra/gemm/max_matmul.py @@ -56,7 +56,7 @@ def randoms(): def ast_to_cuda_prog(compiler, ast, opts): k = Kernel(ast) k.apply_opts(opts) - p = get_program(k.get_optimized_ast(), k.opts) + p = get_program(k.ast, k.opts, k.applied_opts) return CUDAProgram(device, p.function_name, compiler.compile(p.src)) if __name__ == "__main__": @@ -75,7 +75,7 @@ if __name__ == "__main__": if GEMM_VARIATION == "max" and (M%64)==0 and (N%128)==0 and (K%64)==0 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.float and DTYPE_ACC == dtypes.float: print("Using CUDA and triton-generated kernel") - # See nv_triton_gemm.annotated.ptx for PTX code which was generated from `PYTHONPATH=. DEBUG=6 CUDA=1 PTX=1 python3 extra/gemm/triton_nv_matmul.py` + # See nv_triton_gemm.annotated.ptx for PTX code which was generated from `PYTHONPATH=. DEBUG=6 CUDA=1 CUDA_PTX=1 python3 extra/gemm/triton_nv_matmul.py` # this kernel with M=N=K=4096 does 162TFLOPS, vs torch at 144TFLOPS and BEAM=8 tinygrad at 138TFLOPS. theo max is 165TFLOPS. # WMMA element size is (M, N, K) = (16, 8, 16) diff --git a/tinygrad_repo/extra/gemm/simple_matmul.py b/tinygrad_repo/extra/gemm/simple_matmul.py index 7b9c0727..0c91005a 100644 --- a/tinygrad_repo/extra/gemm/simple_matmul.py +++ b/tinygrad_repo/extra/gemm/simple_matmul.py @@ -2,7 +2,7 @@ import numpy as np from tinygrad import dtypes, Tensor from tinygrad.helpers import getenv, get_single_element from tinygrad.dtype import _to_np_dtype -from tinygrad.codegen.opt.kernel import OptOps +from tinygrad.codegen.opt import OptOps from tinygrad.engine.realize import lower_schedule dtype_in = dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else dtypes.float diff --git a/tinygrad_repo/extra/gemm/tinygrad_nv_matmul.py b/tinygrad_repo/extra/gemm/tinygrad_nv_matmul.py index 8fff78ea..1ee3e72e 100644 --- a/tinygrad_repo/extra/gemm/tinygrad_nv_matmul.py +++ b/tinygrad_repo/extra/gemm/tinygrad_nv_matmul.py @@ -29,7 +29,7 @@ if __name__ == "__main__": Opt(op=OptOps.LOCAL, axis=0, amt=2), ] k.apply_opts(opts) - prg = get_program(k.get_optimized_ast(), k.opts) + prg = get_program(k.ast, k.opts, k.applied_opts) new_src = prg.src # can mod source here prg = replace(prg, src=new_src) diff --git a/tinygrad_repo/extra/gemm/triton_nv_matmul.py b/tinygrad_repo/extra/gemm/triton_nv_matmul.py index 5f04a340..89e7838b 100644 --- a/tinygrad_repo/extra/gemm/triton_nv_matmul.py +++ b/tinygrad_repo/extra/gemm/triton_nv_matmul.py @@ -43,7 +43,7 @@ def matmul_kernel(c_ptr, a_ptr, b_ptr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] tl.store(c_ptrs, c) -# CUDA=1 PTX=1 python3 extra/gemm/triton_nv_matmul.py +# CUDA=1 CUDA_PTX=1 python3 extra/gemm/triton_nv_matmul.py if __name__ == "__main__": BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K = 64, 128, 64 M, N, K = 4096, 4096, 4096 diff --git a/tinygrad_repo/extra/mcts_search.py b/tinygrad_repo/extra/mcts_search.py index 731fab11..825d97c3 100644 --- a/tinygrad_repo/extra/mcts_search.py +++ b/tinygrad_repo/extra/mcts_search.py @@ -88,7 +88,7 @@ def mcts_search(lin:Kernel, rawbufs:List[Buffer], amt:int) -> Kernel: return ret rawbufs = _ensure_buffer_alloc(rawbufs) - var_vals = {k:(k.vmax+k.vmin)//2 for k in lin.ast.variables()} + var_vals = {k.expr:(k.vmax+k.vmin)//2 for k in lin.ast.variables()} dev = Device[lin.opts.device] root = MCTSNode(lin) diff --git a/tinygrad_repo/extra/models/inception.py b/tinygrad_repo/extra/models/inception.py index cf77a63e..15d0b58b 100644 --- a/tinygrad_repo/extra/models/inception.py +++ b/tinygrad_repo/extra/models/inception.py @@ -270,8 +270,10 @@ class FidInceptionV3: self.Mixed_7b = inception.Mixed_7b self.Mixed_7c = inception.Mixed_7c - def load_from_pretrained(self): - state_dict = torch_load(str(fetch("https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth", "pt_inception-2015-12-05-6726825d.pth"))) + def load_from_pretrained(self, path=None): + if path is None: + path = fetch("https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth", "pt_inception-2015-12-05-6726825d.pth") + state_dict = torch_load(str(path)) for k,v in state_dict.items(): if k.endswith(".num_batches_tracked"): state_dict[k] = v.reshape(1) diff --git a/tinygrad_repo/extra/models/llama.py b/tinygrad_repo/extra/models/llama.py index 6d2d584d..d8de35af 100644 --- a/tinygrad_repo/extra/models/llama.py +++ b/tinygrad_repo/extra/models/llama.py @@ -249,8 +249,5 @@ def convert_from_gguf(weights:dict[str, Tensor], n_layers:int): return sd def fix_bf16(weights:dict[Any, Tensor]): - if getenv("SUPPORT_BF16", 1): - # TODO: without casting to float16, 70B llama OOM on tinybox. - return {k:v.cast(dtypes.float32).cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()} - # TODO: check if device supports bf16 - return {k:v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()} + # TODO: without casting to float16, 70B llama OOM on tinybox. + return {k:v.cast(dtypes.float32).cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()} diff --git a/tinygrad_repo/extra/nv_gpu_driver/nv_ioctl.py b/tinygrad_repo/extra/nv_gpu_driver/nv_ioctl.py index 30a6c183..44a2a11f 100644 --- a/tinygrad_repo/extra/nv_gpu_driver/nv_ioctl.py +++ b/tinygrad_repo/extra/nv_gpu_driver/nv_ioctl.py @@ -272,4 +272,4 @@ def compare_launch_state(states, good_states): return True, "PASS" -# IOCTL=1 PTX=1 CUDA=1 python3 test/test_ops.py TestOps.test_tiny_add \ No newline at end of file +# IOCTL=1 CUDA=1 CUDA_PTX=1 python3 test/test_ops.py TestOps.test_tiny_add \ No newline at end of file diff --git a/tinygrad_repo/extra/optimization/generate_dataset.sh b/tinygrad_repo/extra/optimization/generate_dataset.sh index 6f709169..b843dac7 100755 --- a/tinygrad_repo/extra/optimization/generate_dataset.sh +++ b/tinygrad_repo/extra/optimization/generate_dataset.sh @@ -7,7 +7,7 @@ rm $LOGOPS test/external/process_replay/reset.py CI=1 python3 -m pytest -n=auto test/test_ops.py test/test_nn.py test/test_winograd.py test/models/test_real_world.py --durations=20 -GPU=1 python3 -m pytest test/test_tiny.py +CL=1 python3 -m pytest test/test_tiny.py # extract, sort and uniq extra/optimization/extract_dataset.py diff --git a/tinygrad_repo/extra/optimization/helpers.py b/tinygrad_repo/extra/optimization/helpers.py index 94eacaa8..d8ab1279 100644 --- a/tinygrad_repo/extra/optimization/helpers.py +++ b/tinygrad_repo/extra/optimization/helpers.py @@ -1,6 +1,6 @@ # stuff needed to unpack a kernel from tinygrad import Variable -from tinygrad.codegen.opt.kernel import Opt, OptOps +from tinygrad.codegen.opt import Opt, OptOps from tinygrad.uop.ops import UOp, Ops, KernelInfo from tinygrad.dtype import dtypes, PtrDType from tinygrad.shape.shapetracker import ShapeTracker @@ -115,7 +115,7 @@ def time_linearizer(lin:Kernel, rawbufs:list[Buffer], allow_test_size=True, max_ assert dev.compiler is not None rawbufs = _ensure_buffer_alloc(rawbufs) - var_vals: dict[Variable, int] = {k:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()} + var_vals: dict[str, int] = {k.expr:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()} p = get_program(lin.get_optimized_ast(), lin.opts) tms = _time_program(p, dev.compiler.compile(p.src), var_vals, rawbufs, max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name)) diff --git a/tinygrad_repo/extra/optimization/test_beam_search.py b/tinygrad_repo/extra/optimization/test_beam_search.py index 24c3f943..f493ec48 100644 --- a/tinygrad_repo/extra/optimization/test_beam_search.py +++ b/tinygrad_repo/extra/optimization/test_beam_search.py @@ -16,9 +16,9 @@ class TestBeamSearch(unittest.TestCase): BEAM.value = self.old_beam def test_variable_ast_beam(self): - with Context(IGNORE_OOB=1): - a = rand(3, 3).reshape((Variable("a", 1, 10).bind(3), 3)) - a = (a+1).realize() + vi = Variable("a", 1, 10).bind(3) + a = rand(10, 3)[:vi] + a = (a+1).realize() def test_big_prime_number(self): a = rand(367, 367) @@ -42,18 +42,16 @@ class TestBeamSearch(unittest.TestCase): def test_variable_big_prime_number(self): v = Variable("v", 1, 400).bind(367) - a = rand(367, 367) - b = rand(367, 367) - with Context(IGNORE_OOB=1): - c = (a.reshape(367, v) @ b.reshape(v, 367)).realize() - np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), atol=1e-4, rtol=1e-4) + a = rand(367, 400) + b = rand(400, 367) + c = (a[:, :v] @ b[:v, :]).realize() + np.testing.assert_allclose(c.numpy(), a[:, :367].numpy() @ b[:367, :].numpy(), atol=1e-4, rtol=1e-4) def test_variable_shrink_prime_number(self): v = Variable("v", 1, 400).bind(367) a = rand(400, 367) - with Context(IGNORE_OOB=1): - b = (a.shrink(((0,v), None))+1).reshape(367,367).realize() - np.testing.assert_allclose(b.numpy(), a.numpy()[:367]+1, atol=1e-4, rtol=1e-4) + b = (a.shrink(((0,v), None))+1)[:367,:367].realize() + np.testing.assert_allclose(b.numpy(), a.numpy()[:367]+1, atol=1e-4, rtol=1e-4) def test_no_mutate_rawbuffers(self): a = rand(3, 3).realize() diff --git a/tinygrad_repo/extra/qcom_gpu_driver/qcom_opencl_interop.py b/tinygrad_repo/extra/qcom_gpu_driver/qcom_opencl_interop.py index d595ba34..c2e0741c 100644 --- a/tinygrad_repo/extra/qcom_gpu_driver/qcom_opencl_interop.py +++ b/tinygrad_repo/extra/qcom_gpu_driver/qcom_opencl_interop.py @@ -1,6 +1,6 @@ import ctypes, array from hexdump import hexdump -from tinygrad.runtime.ops_gpu import GPUDevice +from tinygrad.runtime.ops_cl import CLDevice from tinygrad.helpers import getenv, to_mv, mv_address from tinygrad.dtype import dtypes from tinygrad import Tensor, TinyJit @@ -8,7 +8,7 @@ from tinygrad.runtime.autogen import opencl as cl if getenv("IOCTL"): import extra.qcom_gpu_driver.opencl_ioctl # noqa: F401 # pylint: disable=unused-import # create raw opencl buffer. -gdev = GPUDevice() +gdev = CLDevice() cl_buf = cl.clCreateBuffer(gdev.context, cl.CL_MEM_READ_WRITE, 0x100, None, status := ctypes.c_int32()) assert status.value == 0 diff --git a/tinygrad_repo/extra/remu/src/thread.rs b/tinygrad_repo/extra/remu/src/thread.rs index 4cb84a56..4662d3fa 100644 --- a/tinygrad_repo/extra/remu/src/thread.rs +++ b/tinygrad_repo/extra/remu/src/thread.rs @@ -673,6 +673,7 @@ impl<'a> Thread<'a> { 39 => f32::log2(s0), 42 => 1.0 / s0, 43 => 1.0 / s0, + 46 => 1.0 / f32::sqrt(s0), 51 => f32::sqrt(s0), _ => todo_instr!(instruction)?, } @@ -1246,7 +1247,7 @@ impl<'a> Thread<'a> { } let ret = match op { - 257 | 259 | 299 | 260 | 261 | 264 | 272 | 392 | 426 | 531 | 537 | 540 | 551 | 567 | 796 => { + 257 | 259 | 299 | 260 | 261 | 264 | 272 | 392 | 426 | 430 | 531 | 537 | 540 | 551 | 567 | 796 => { let s0 = f32::from_bits(s0).negate(0, neg).absolute(0, abs); let s1 = f32::from_bits(s1).negate(1, neg).absolute(1, abs); let s2 = f32::from_bits(s2).negate(2, neg).absolute(2, abs); @@ -1258,6 +1259,7 @@ impl<'a> Thread<'a> { 272 => f32::max(s0, s1), 299 => f32::mul_add(s0, s1, f32::from_bits(self.vec_reg[vdst])), 426 => s0.recip(), + 430 => 1.0 / f32::sqrt(s0), 531 => f32::mul_add(s0, s1, s2), 537 => f32::min(f32::min(s0, s1), s2), 540 => f32::max(f32::max(s0, s1), s2), @@ -2625,6 +2627,14 @@ mod test_vop1 { assert_eq!(thread.vec_reg[3], 1071644672); } + #[test] + fn test_v_rsq_f32() { + let mut thread = _helper_test_thread(); + thread.vec_reg[0] = f32::to_bits(4.0); + r(&vec![0x7E005D00, END_PRG], &mut thread); + assert_eq!(f32::from_bits(thread.vec_reg[0]), 0.5); + } + #[test] fn test_v_frexp_exp_i32_f64() { [(3573412790272.0, 42), (69.0, 7), (2.0, 2), (f64::NEG_INFINITY, 0)] diff --git a/tinygrad_repo/extra/replay_pkl.py b/tinygrad_repo/extra/replay_pkl.py index a1456e71..e4cb5ed5 100644 --- a/tinygrad_repo/extra/replay_pkl.py +++ b/tinygrad_repo/extra/replay_pkl.py @@ -58,7 +58,7 @@ if __name__ == "__main__": GlobalCounters.kernel_count -= 1 if not getenv("NOOPT"): k.apply_opts(hand_coded_optimizations(k)) - p2 = get_program(k.get_optimized_ast(), k.opts) + p2 = get_program(k.ast, k.opts, k.applied_opts) new_ei = replace(ei, prg=CompiledRunner(p2)) new_ei.run() new_jit.append(new_ei) diff --git a/tinygrad_repo/extra/sqtt/README.md b/tinygrad_repo/extra/sqtt/README.md index 6d739ceb..1d19ae8f 100644 --- a/tinygrad_repo/extra/sqtt/README.md +++ b/tinygrad_repo/extra/sqtt/README.md @@ -4,7 +4,7 @@ Only supported on 7900XTX, requires either AM (`rmmod amdgpu`) or disabling power gating on AMD (`ppfeaturemask=0xffff3fff`, don't forget to rebuild initramfs) -SQTT is implemented on top of normal tinygrad PROFILE=1, `PROFILE=1 SQTT=1` to get profile pickle with sqtt data embedded in it. +SQTT is implemented on top of normal tinygrad profiling, `VIZ=1 SQTT=1` to get profile pickle with sqtt data embedded in it. `SQTT_BUFFER_SIZE=X` to change size of SQTT buffer (per shader engine, 6 SEs on 7900xtx) in megabytes, default 256. diff --git a/tinygrad_repo/extra/test_hcopt.py b/tinygrad_repo/extra/test_hcopt.py new file mode 100644 index 00000000..36978bf8 --- /dev/null +++ b/tinygrad_repo/extra/test_hcopt.py @@ -0,0 +1,40 @@ +import time +from extra.optimization.helpers import load_worlds, ast_str_to_ast +from tinygrad import Device +from tinygrad.codegen.lowerer import pm_lowerer, get_index +from tinygrad.uop.ops import graph_rewrite +from tinygrad.codegen.opt.kernel import Kernel +from tinygrad.codegen.opt.postrange import Scheduler +from tinygrad.codegen.opt.heuristic import hand_coded_optimizations +from tinygrad.helpers import getenv + +if __name__ == "__main__": + renderer = Device.default.renderer + ast_strs = load_worlds() + if (n:=getenv("N", -1)) != -1: ast_strs = ast_strs[n:n+1] + good = 0 + for i, ast_str in enumerate(ast_strs): + ast = ast_str_to_ast(ast_str) + + st = time.perf_counter() + lin = Kernel(ast, renderer) + opt1 = hand_coded_optimizations(lin) + et_lin = time.perf_counter() - st + + lowered = graph_rewrite(ast, pm_lowerer, ctx=get_index(ast), bottom_up=True) + st = time.perf_counter() + sch = Scheduler(lowered, renderer) + sch.convert_loop_to_global() + sch.simplify_merge_adjacent() + opt2 = hand_coded_optimizations(sch) + et_sch = time.perf_counter() - st + + if opt1 != opt2: + print(f"******* {i:6d}") + print("Kernel: ", lin.colored_shape(), "->", lin.apply_opts(opt1).colored_shape()) + print("Scheduler: ", sch.colored_shape(), "->", sch.apply_opts(opt2).colored_shape()) + print(opt1) + print(opt2) + else: + good += 1 + print(f"******* {i:6d} MATCH {good/(i+1)*100:.2f}% -- {et_lin/et_sch:4.2f}x speedup") diff --git a/tinygrad_repo/extra/test_pyrender.py b/tinygrad_repo/extra/test_pyrender.py new file mode 100644 index 00000000..8954dde4 --- /dev/null +++ b/tinygrad_repo/extra/test_pyrender.py @@ -0,0 +1,20 @@ +from extra.optimization.helpers import load_worlds, ast_str_to_ast +from tinygrad.helpers import tqdm +from tinygrad.uop.ops import pyrender, UOp, Ops +from tinygrad import dtypes +from tinygrad.shape.shapetracker import ShapeTracker, View +inf, nan = float('inf'), float('nan') + +if __name__ == "__main__": + ast_strs = load_worlds() + for i, ast_str in enumerate(tqdm(ast_strs)): + good_ast = ast_str_to_ast(ast_str) + code = '\n'.join(pyrender(good_ast)) + print("\n***************\n\n"+code) + exec(code) + if str(good_ast) != str(ast): + print(code) + print("MISMATCH") + print(good_ast) + print(ast) + break \ No newline at end of file diff --git a/tinygrad_repo/extra/thneed.py b/tinygrad_repo/extra/thneed.py index c59f6368..ca89bfa6 100644 --- a/tinygrad_repo/extra/thneed.py +++ b/tinygrad_repo/extra/thneed.py @@ -4,13 +4,13 @@ import struct import json import traceback import numpy as np -from tinygrad.runtime.ops_gpu import CLProgram, compile_gpu +from tinygrad.runtime.ops_cl import CLProgram, compile_gpu from tinygrad.device import Device from tinygrad.helpers import DEBUG, getenv from collections import defaultdict import pyopencl as cl -from tinygrad.runtime.ops_gpu import OSX_TIMING_RATIO -CL = Device["GPU"] +from tinygrad.runtime.ops_cl import OSX_TIMING_RATIO +CL = Device["CL"] DEBUGCL = getenv("DEBUGCL", 0) FLOAT16 = getenv("FLOAT16", 0) @@ -110,7 +110,7 @@ class Thneed: prgs = {} for o in jdat['binaries']: nptr = ptr + o['length'] - prgs[o['name']] = CLProgram(Device["GPU"], o['name'], weights[ptr:nptr]) + prgs[o['name']] = CLProgram(Device["CL"], o['name'], weights[ptr:nptr]) ptr = nptr # populate the cl_cache @@ -267,7 +267,7 @@ class Thneed: for prg, args in self.cl_cache: events.append(prg.clprg(CL.queue, *args)) mt = time.monotonic() - Device["GPU"].synchronize() + Device["CL"].synchronize() et = time.monotonic() - st print(f"submit in {(mt-st)*1000.0:.2f} ms, total runtime is {et*1000.0:.2f} ms") diff --git a/tinygrad_repo/extra/to_movement_ops.py b/tinygrad_repo/extra/to_movement_ops.py index 68d4ef3d..3170cd8c 100644 --- a/tinygrad_repo/extra/to_movement_ops.py +++ b/tinygrad_repo/extra/to_movement_ops.py @@ -2,7 +2,6 @@ import itertools from enum import Enum, auto from collections import defaultdict from typing import List, Tuple, DefaultDict -from extra.optimization.helpers import load_worlds, ast_str_to_ast from tinygrad.helpers import prod, tqdm from tinygrad.uop.ops import UOp, Ops from tinygrad.shape.shapetracker import ShapeTracker @@ -36,7 +35,7 @@ def to_movement_ops(st: ShapeTracker) -> List[Tuple[MovementOps, Tuple]]: to_apply:List[Tuple[MovementOps, Tuple]] = [] for i, v in enumerate(st.views): real_shape = tuple(y-x for x,y in v.mask) if v.mask else v.shape - offset = v.offset + sum(st*(s-1) for s,st in zip(real_shape, v.strides) if st<0) + offset = (v.offset or 0) + sum(st*(s-1) for s,st in zip(real_shape, v.strides) if st<0) real_offset = offset + (sum(x*st for (x,_),st in zip(v.mask, v.strides)) if v.mask else 0) real_real_shape = [s for s,st in zip(real_shape, v.strides) if st] strides: List[int] = [abs(st) if isinstance(st,int) else st for st in v.strides if st] @@ -121,7 +120,7 @@ def st_equivalent(st1: ShapeTracker, st2: ShapeTracker): if i > 1000: print("WARNING: did not search all possible combinations") break - var_vals = {k:v for k,v in zip(vs, ranges)} + var_vals = {k.expr:v for k,v in zip(vs, ranges)} r1 = sym_infer(idx1, var_vals) if sym_infer(valid1, var_vals) else 0 r2 = sym_infer(idx2, var_vals) if sym_infer(valid2, var_vals) else 0 if r1 != r2: return False @@ -147,6 +146,7 @@ def test_rebuild_bufferop_st(ast:UOp): for src in ast.src: test_rebuild_bufferop_st(src) if __name__ == "__main__": + from extra.optimization.helpers import load_worlds, ast_str_to_ast ast_strs = load_worlds(False, False, True)[:2000] for ast_str in tqdm(ast_strs): test_rebuild_bufferop_st(ast_str_to_ast(ast_str)) diff --git a/tinygrad_repo/extra/torch_backend/backend.py b/tinygrad_repo/extra/torch_backend/backend.py index c312badf..d5993f64 100644 --- a/tinygrad_repo/extra/torch_backend/backend.py +++ b/tinygrad_repo/extra/torch_backend/backend.py @@ -177,22 +177,28 @@ def cached_to_movement_ops(shape, st) -> list: from tinygrad.shape.shapetracker import ShapeTracker, View from extra.to_movement_ops import to_movement_ops, apply_mop, MovementOps + +@wrap_view_op +def _as_strided(tensor:Tensor, size, stride, storage_offset=None): + # multiple as_strided do not compound + base = canonical_base(tensor) + # TODO: this is heavyweight + st = ShapeTracker(base.uop.st.views + (View.create(tuple(size), tuple(stride), storage_offset),)) + ret = base + if TORCH_DEBUG >= 1: print("**** as_strided", tensor.shape, size, stride, st) + if prod(size) == 1: return ret.flatten()[storage_offset].reshape(size) + for mo in cached_to_movement_ops(tuple(base.shape), st): ret = apply_mop(ret, mo) + return ret + @torch.library.impl("aten::as_strided", "privateuseone") def as_strided(tensor:torch.Tensor, size, stride, storage_offset=None): storage_offset = storage_offset or tensor.storage_offset() - @wrap_view_op - def _as_strided(tensor:Tensor, size, stride, storage_offset=None): - # multiple as_strided do not compound - base = canonical_base(tensor) - # TODO: this is heavyweight - st = ShapeTracker(base.uop.st.views + (View.create(tuple(size), tuple(stride), storage_offset),)) - ret = base - if TORCH_DEBUG >= 1: print("**** as_strided", tensor.shape, size, stride, st) - if prod(size) == 1: return ret.flatten()[storage_offset].reshape(size) - for mo in cached_to_movement_ops(tuple(base.shape), st): ret = apply_mop(ret, mo) - return ret return _as_strided(tensor, size, stride, storage_offset) +@torch.library.impl("aten::_reshape_alias", "privateuseone") +def _reshape_alias(tensor:torch.Tensor, size, stride): + return _as_strided(tensor, size, stride) + @torch.library.impl("aten::empty_strided", "privateuseone") def empty_strided(size, stride, dtype, layout=None, device=None, pin_memory=False): if TORCH_DEBUG: print(f"empty_strided {size=} {stride=} {dtype=} {layout=} {device=} {pin_memory=}") diff --git a/tinygrad_repo/pytest.ini b/tinygrad_repo/pytest.ini index cccc62e4..b9c3f606 100644 --- a/tinygrad_repo/pytest.ini +++ b/tinygrad_repo/pytest.ini @@ -1,2 +1,6 @@ [pytest] norecursedirs = extra +timeout = 180 +timeout_method = thread +timeout_func_only = true +testpaths = test diff --git a/tinygrad_repo/setup.py b/tinygrad_repo/setup.py index 2b510424..f90a52b5 100644 --- a/tinygrad_repo/setup.py +++ b/tinygrad_repo/setup.py @@ -9,12 +9,12 @@ with open(directory / 'README.md', encoding='utf-8') as f: testing_minimal = [ "numpy", - "torch==2.7.1", + "torch==2.8.0", "pytest", "pytest-xdist", + "pytest-timeout", "hypothesis", "z3-solver", - "ml_dtypes" ] setup(name='tinygrad', @@ -59,11 +59,12 @@ setup(name='tinygrad', 'triton': ["triton-nightly>=2.1.0.dev20231014192330"], 'linting': [ "pylint", - "mypy==1.13.0", + "mypy==1.18.1", "typing-extensions", "pre-commit", "ruff", "numpy", + "typeguard", ], #'mlperf': ["mlperf-logging @ git+https://github.com/mlperf/logging.git@5.0.0-rc3"], 'testing_minimal': testing_minimal, @@ -86,6 +87,7 @@ setup(name='tinygrad', "tiktoken", "blobfile", "librosa", + "numba>=0.55", # librosa needs numba but uv ignores python upper bounds and some numba versions require tuple[str, Any]: # (error msg, run state) if rawbufs is None: rawbufs = bufs_from_lin(lin) - if var_vals is None: var_vals = {v: v.min for v in lin.vars} + if var_vals is None: var_vals = {v.expr: v.min for v in lin.vars} # TODO: images needs required_optimization try: @@ -129,7 +129,7 @@ def compare_linearizer(lin: Kernel, rawbufs=None, var_vals=None, ground_truth=No if var_vals is None: # TODO: handle symbolic max case - var_vals = {v: random.randint(v.vmin, v.vmax) for v in lin.ast.variables()} + var_vals = {v.expr: random.randint(v.vmin, v.vmax) for v in lin.ast.variables()} if ground_truth is None and not has_bf16: unoptimized = Kernel(lin.ast) @@ -302,7 +302,7 @@ if __name__ == "__main__": for i, ast in enumerate(ast_strs[:getenv("FUZZ_N", len(ast_strs))]): if (nth := getenv("FUZZ_NTH", -1)) != -1 and i != nth: continue if getenv("FUZZ_IMAGEONLY") and "dtypes.image" not in ast: continue - if "dtypes.image" in ast and Device.DEFAULT not in {"GPU", "QCOM"}: continue # IMAGE is only for GPU + if "dtypes.image" in ast and Device.DEFAULT not in {"CL", "QCOM"}: continue # IMAGE is only for CL if ast in seen_ast_strs: continue seen_ast_strs.add(ast) diff --git a/tinygrad_repo/test/external/fuzz_symbolic.py b/tinygrad_repo/test/external/fuzz_symbolic.py index 41970a93..0e87883d 100644 --- a/tinygrad_repo/test/external/fuzz_symbolic.py +++ b/tinygrad_repo/test/external/fuzz_symbolic.py @@ -1,8 +1,8 @@ import random, operator import z3 from tinygrad import Variable, dtypes -from tinygrad.uop.ops import UOp, graph_rewrite -from tinygrad.uop.spec import z3_renderer +from tinygrad.uop.ops import UOp +from tinygrad.uop.spec import uops_to_z3 from tinygrad.helpers import DEBUG, Context seed = random.randint(0, 100) @@ -57,8 +57,7 @@ if __name__ == "__main__": solver = z3.Solver() solver.set(timeout=5000) # some expressions take very long verify, but its very unlikely they actually return sat - z3_sink = graph_rewrite(expr.sink(simplified_expr, u1, u2, u3), z3_renderer, ctx=(solver, {})) - z3_expr, z3_simplified_expr = z3_sink.src[0].arg, z3_sink.src[1].arg + z3_expr, z3_simplified_expr, v1, v2, v3 = uops_to_z3(solver, expr, simplified_expr, u1, u2, u3) check = solver.check(z3_simplified_expr != z3_expr) if check == z3.unknown and DEBUG>=1: skipped += 1 @@ -69,7 +68,6 @@ if __name__ == "__main__": f"expr = {expr.render(simplify=False)}\n") elif check == z3.sat: m = solver.model() - v1, v2, v3 = z3_sink.src[2].arg, z3_sink.src[3].arg, z3_sink.src[4].arg n1, n2, n3 = m[v1], m[v2], m[v3] u1_val, u2_val, u3_val = u1.const_like(n1.as_long()), u2.const_like(n2.as_long()), u3.const_like(n3.as_long()) with Context(CORRECT_DIVMOD_FOLDING=1): diff --git a/tinygrad_repo/test/external/process_replay/process_replay.py b/tinygrad_repo/test/external/process_replay/process_replay.py index 0a82ac31..5baf9f20 100755 --- a/tinygrad_repo/test/external/process_replay/process_replay.py +++ b/tinygrad_repo/test/external/process_replay/process_replay.py @@ -12,7 +12,7 @@ try: from tinygrad.renderer import Renderer, ProgramSpec from tinygrad.engine.realize import get_program from tinygrad.uop.ops import UOp, Ops, KernelInfo - from tinygrad.codegen.opt.kernel import Opt + from tinygrad.codegen.opt import Opt from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm from tinygrad.device import Device except ImportError as e: diff --git a/tinygrad_repo/test/external/speed_beam_v_hcopt.py b/tinygrad_repo/test/external/speed_beam_v_hcopt.py deleted file mode 100644 index b241eb1b..00000000 --- a/tinygrad_repo/test/external/speed_beam_v_hcopt.py +++ /dev/null @@ -1,41 +0,0 @@ -from tinygrad import Device -from tinygrad.helpers import getenv, DEBUG, BEAM -from tinygrad.codegen.opt.search import beam_search, bufs_from_lin -from tinygrad.codegen.opt.heuristic import hand_coded_optimizations -from extra.optimization.helpers import load_worlds, ast_str_to_lin, time_linearizer - -if __name__ == "__main__": - filter_reduce = bool(getenv("FILTER_REDUCE")) - ast_strs = load_worlds(filter_reduce=filter_reduce, filter_novariable=True) - dev = Device[Device.DEFAULT] - - test_n = getenv("TEST_N", 10) - single = getenv("NUM", -1) - if single != -1: ast_strs = ast_strs[single:single+1] - - beam_won, tested = 0, 0 - - for num, ast in enumerate(ast_strs[:test_n]): - def new_lin(): return ast_str_to_lin(ast, opts=dev.renderer) - - k = new_lin() - - if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.apply_opts(hand_coded_optimizations(k)) - - assert BEAM > 0 - - lins = [(("tc" if used_tensor_cores else "hc"), k)] - if used_tensor_cores: - lins.append(("hc", new_lin())) - lins[-1][1].apply_opts(hand_coded_optimizations(lins[-1][1])) - kb = new_lin() - test_rawbuffers = bufs_from_lin(kb) # allocate scratch buffers for optimization - lins.append((f"beam{BEAM.value}", beam_search(kb, test_rawbuffers, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1))))) - timed = sorted([(nm, tk, time_linearizer(tk, test_rawbuffers, allow_test_size=False, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2]) - if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed)) - - tested += 1 - if timed[0][0].startswith("beam"): - beam_won += 1 - - print(f"{beam_won=} / {tested=} = {beam_won/tested:.3f}") \ No newline at end of file diff --git a/tinygrad_repo/test/helpers.py b/tinygrad_repo/test/helpers.py index 4833f425..cee64595 100644 --- a/tinygrad_repo/test/helpers.py +++ b/tinygrad_repo/test/helpers.py @@ -57,8 +57,8 @@ def eval_uop(uop:UOp, inputs:list[tuple[DType, list[Any]]]|None=None): return out_buf.cast(uop.dtype.fmt).tolist()[0] def not_support_multi_device(): - # GPU and CUDA don't support multi device if in CI - return CI and REAL_DEV in ("GPU", "CUDA") + # CL and CUDA don't support multi device if in CI + return CI and REAL_DEV in ("CL", "CUDA") # NOTE: This will open REMOTE if it's the default device REAL_DEV = (Device.DEFAULT if Device.DEFAULT != "REMOTE" else Device['REMOTE'].properties.real_device) diff --git a/tinygrad_repo/test/mockgpu/amd/amdgpu.py b/tinygrad_repo/test/mockgpu/amd/amdgpu.py index e1f2245a..152ee2e9 100644 --- a/tinygrad_repo/test/mockgpu/amd/amdgpu.py +++ b/tinygrad_repo/test/mockgpu/amd/amdgpu.py @@ -1,5 +1,6 @@ import ctypes, time from test.mockgpu.gpu import VirtGPU +from test.mockgpu.helpers import _try_dlopen_remu from tinygrad.helpers import getbits, to_mv, init_c_struct_t import tinygrad.runtime.autogen.amd_gpu as amd_gpu, tinygrad.runtime.autogen.am.pm4_nv as pm4 @@ -24,19 +25,6 @@ WAIT_REG_MEM_FUNCTION_EQ = 3 # == WAIT_REG_MEM_FUNCTION_NEQ = 4 # != WAIT_REG_MEM_FUNCTION_GEQ = 5 # >= -REMU_PATHS = ["extra/remu/target/release/libremu.so", "libremu.so", "/usr/local/lib/libremu.so", - "extra/remu/target/release/libremu.dylib", "libremu.dylib", "/usr/local/lib/libremu.dylib", "/opt/homebrew/lib/libremu.dylib"] -def _try_dlopen_remu(): - for path in REMU_PATHS: - try: - remu = ctypes.CDLL(path) - remu.run_asm.restype = ctypes.c_int32 - remu.run_asm.argtypes = [ctypes.c_void_p, ctypes.c_uint32, ctypes.c_uint32, ctypes.c_uint32, ctypes.c_uint32, - ctypes.c_uint32, ctypes.c_uint32, ctypes.c_uint32, ctypes.c_void_p] - except OSError: pass - else: return remu - print("Could not find libremu.so") - return None remu = _try_dlopen_remu() def create_sdma_packets(): diff --git a/tinygrad_repo/test/mockgpu/cuda/cuda.py b/tinygrad_repo/test/mockgpu/cuda/cuda.py index a5dd5fe0..daf8db6f 100644 --- a/tinygrad_repo/test/mockgpu/cuda/cuda.py +++ b/tinygrad_repo/test/mockgpu/cuda/cuda.py @@ -2,16 +2,14 @@ from __future__ import annotations from typing import Any import ctypes, time from tinygrad.runtime.autogen import cuda as orig_cuda +from test.mockgpu.helpers import _try_dlopen_gpuocelot from tinygrad.helpers import mv_address for attr in dir(orig_cuda): if not attr.startswith('__'): globals()[attr] = getattr(orig_cuda, attr) -try: - gpuocelot_lib = ctypes.CDLL(ctypes.util.find_library("gpuocelot")) - gpuocelot_lib.ptx_run.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int] # noqa: E501 -except Exception: pass +gpuocelot_lib = _try_dlopen_gpuocelot() # Global state class CUDAState: @@ -130,7 +128,10 @@ def cuModuleUnload(hmod) -> int: def cuLaunchKernel(f, gx: int, gy: int, gz: int, lx: int, ly: int, lz: int, sharedMemBytes: int, hStream: Any, kernelParams: Any, extra: Any) -> int: cargs = [ctypes.cast(getattr(extra, field[0]), ctypes.c_void_p) for field in extra._fields_] - gpuocelot_lib.ptx_run(ctypes.cast(f.value, ctypes.c_char_p), len(cargs), (ctypes.c_void_p*len(cargs))(*cargs), lx, ly, lz, gx, gy, gz, 0) + try: gpuocelot_lib.ptx_run(ctypes.cast(f.value, ctypes.c_char_p), len(cargs), (ctypes.c_void_p*len(cargs))(*cargs), lx, ly, lz, gx, gy, gz, 0) + except Exception as e: + print("Error in cuLaunchKernel:", e) + return orig_cuda.CUDA_ERROR_LAUNCH_FAILED return orig_cuda.CUDA_SUCCESS def cuDeviceComputeCapability(major, minor, dev: int) -> int: diff --git a/tinygrad_repo/test/mockgpu/helpers.py b/tinygrad_repo/test/mockgpu/helpers.py new file mode 100644 index 00000000..c91672fe --- /dev/null +++ b/tinygrad_repo/test/mockgpu/helpers.py @@ -0,0 +1,29 @@ +import ctypes, ctypes.util + +def _try_dlopen_gpuocelot(): + GPUOCELOT_PATHS = [ctypes.util.find_library("gpuocelot")] if ctypes.util.find_library("gpuocelot") is not None else [] + GPUOCELOT_PATHS += ["libgpuocelot.so", "/usr/local/lib/libgpuocelot.so", + "libgpuocelot.dylib", "/usr/local/lib/libgpuocelot.dylib", "/opt/homebrew/lib/libgpuocelot.dylib"] + for path in GPUOCELOT_PATHS: + try: + gpuocelot_lib = ctypes.CDLL(path) + gpuocelot_lib.ptx_run.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, ctypes.c_int, + ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int] + except OSError: pass + else: return gpuocelot_lib + print("Could not find libgpuocelot.so") + return None + +def _try_dlopen_remu(): + REMU_PATHS = ["extra/remu/target/release/libremu.so", "libremu.so", "/usr/local/lib/libremu.so", + "extra/remu/target/release/libremu.dylib", "libremu.dylib", "/usr/local/lib/libremu.dylib", "/opt/homebrew/lib/libremu.dylib"] + for path in REMU_PATHS: + try: + remu = ctypes.CDLL(path) + remu.run_asm.restype = ctypes.c_int32 + remu.run_asm.argtypes = [ctypes.c_void_p, ctypes.c_uint32, ctypes.c_uint32, ctypes.c_uint32, ctypes.c_uint32, + ctypes.c_uint32, ctypes.c_uint32, ctypes.c_uint32, ctypes.c_void_p] + except OSError: pass + else: return remu + print("Could not find libremu.so") + return None diff --git a/tinygrad_repo/test/mockgpu/nv/nvgpu.py b/tinygrad_repo/test/mockgpu/nv/nvgpu.py index 76239c15..deff54bd 100644 --- a/tinygrad_repo/test/mockgpu/nv/nvgpu.py +++ b/tinygrad_repo/test/mockgpu/nv/nvgpu.py @@ -2,6 +2,7 @@ import ctypes, ctypes.util, time import tinygrad.runtime.autogen.nv_gpu as nv_gpu from enum import Enum, auto from test.mockgpu.gpu import VirtGPU +from test.mockgpu.helpers import _try_dlopen_gpuocelot from tinygrad.helpers import to_mv, init_c_struct_t def make_qmd_struct_type(): @@ -16,10 +17,7 @@ def make_qmd_struct_type(): qmd_struct_t = make_qmd_struct_type() assert ctypes.sizeof(qmd_struct_t) == 0x40 * 4 -try: - gpuocelot_lib = ctypes.CDLL(ctypes.util.find_library("gpuocelot")) - gpuocelot_lib.ptx_run.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int] # noqa: E501 -except Exception: pass +gpuocelot_lib = _try_dlopen_gpuocelot() class SchedResult(Enum): CONT = auto(); YIELD = auto() # noqa: E702 @@ -99,7 +97,10 @@ class GPFIFO: cargs = [ctypes.cast(args[i], ctypes.c_void_p) for i in range(args_cnt)] + [ctypes.cast(vals[i], ctypes.c_void_p) for i in range(vals_cnt)] gx, gy, gz = qmd.cta_raster_width, qmd.cta_raster_height, qmd.cta_raster_depth lx, ly, lz = qmd.cta_thread_dimension0, qmd.cta_thread_dimension1, qmd.cta_thread_dimension2 - gpuocelot_lib.ptx_run(ctypes.cast(prg_addr, ctypes.c_char_p), args_cnt+vals_cnt, (ctypes.c_void_p*len(cargs))(*cargs), lx, ly, lz, gx, gy, gz, 0) + try: + gpuocelot_lib.ptx_run(ctypes.cast(prg_addr, ctypes.c_char_p), args_cnt+vals_cnt, + (ctypes.c_void_p*len(cargs))(*cargs), lx, ly, lz, gx, gy, gz, 0) + except Exception as e: print("failed to execute:", e) if qmd.release0_enable: rel0 = to_mv(qmd.release0_address_lower + (qmd.release0_address_upper << 32), 0x10).cast('Q') rel0[0] = qmd.release0_payload_lower + (qmd.release0_payload_upper << 32) diff --git a/tinygrad_repo/test/models/test_bert.py b/tinygrad_repo/test/models/test_bert.py index 0e42ff9d..78e33673 100644 --- a/tinygrad_repo/test/models/test_bert.py +++ b/tinygrad_repo/test/models/test_bert.py @@ -1,14 +1,14 @@ #!/usr/bin/env python import unittest +from tinygrad import Tensor import numpy as np -from tinygrad.tensor import Tensor import torch def get_question_samp(bsz, seq_len, vocab_size, seed): np.random.seed(seed) in_ids= np.random.randint(vocab_size, size=(bsz, seq_len)) mask = np.random.choice([True, False], size=(bsz, seq_len)) - seg_ids = np.random.randint(1, size=(bsz, seq_len)) + seg_ids = np.random.randint(2, size=(bsz, seq_len)) # type_vocab_size return in_ids, mask, seg_ids def set_equal_weights(mdl, torch_mdl): @@ -45,7 +45,7 @@ class TestBert(unittest.TestCase): seeds = (1337, 3141) bsz, seq_len = 1, 16 - for _, seed in enumerate(seeds): + for seed in seeds: in_ids, mask, seg_ids = get_question_samp(bsz, seq_len, config['vocab_size'], seed) out = mdl(Tensor(in_ids), Tensor(mask), Tensor(seg_ids)) torch_out = torch_mdl.forward(torch.from_numpy(in_ids).long(), torch.from_numpy(mask), torch.from_numpy(seg_ids).long())[:2] diff --git a/tinygrad_repo/test/models/test_efficientnet.py b/tinygrad_repo/test/models/test_efficientnet.py index 28921965..8e434ba8 100644 --- a/tinygrad_repo/test/models/test_efficientnet.py +++ b/tinygrad_repo/test/models/test_efficientnet.py @@ -1,12 +1,10 @@ -import ast -import pathlib -import unittest +import ast, pathlib, unittest import numpy as np from PIL import Image -from tinygrad.helpers import getenv -from tinygrad.tensor import Tensor +from tinygrad import Tensor +from tinygrad.helpers import getenv, CI from extra.models.efficientnet import EfficientNet from extra.models.vit import ViT from extra.models.resnet import ResNet50 @@ -40,19 +38,13 @@ def preprocess(img, new=False): img /= np.array([0.229, 0.224, 0.225]).reshape((1, -1, 1, 1)) return img +def _infer(model: EfficientNet, img): + with Tensor.train(False): + out = model.forward(Tensor(img)).argmax(axis=-1) + return out.tolist() -def _infer(model: EfficientNet, img, bs=1): - old_training = Tensor.training - Tensor.training = False - img = preprocess(img) - # run the net - if bs > 1: img = img.repeat(bs, axis=0) - out = model.forward(Tensor(img)) - Tensor.training = old_training - return _LABELS[np.argmax(out.numpy()[0])] - -chicken_img = Image.open(pathlib.Path(__file__).parent / 'efficientnet/Chicken.jpg') -car_img = Image.open(pathlib.Path(__file__).parent / 'efficientnet/car.jpg') +chicken_img = preprocess(Image.open(pathlib.Path(__file__).parent / 'efficientnet/Chicken.jpg')) +car_img = preprocess(Image.open(pathlib.Path(__file__).parent / 'efficientnet/car.jpg')) class TestEfficientNet(unittest.TestCase): @classmethod @@ -64,17 +56,20 @@ class TestEfficientNet(unittest.TestCase): def tearDownClass(cls): del cls.model + @unittest.skipIf(CI, "covered by test_chicken_car") def test_chicken(self): - label = _infer(self.model, chicken_img) - self.assertEqual(label, "hen") - - def test_chicken_bigbatch(self): - label = _infer(self.model, chicken_img, 2) - self.assertEqual(label, "hen") + labels = _infer(self.model, chicken_img) + self.assertEqual(_LABELS[labels[0]], "hen") + @unittest.skipIf(CI, "covered by test_chicken_car") def test_car(self): - label = _infer(self.model, car_img) - self.assertEqual(label, "sports car, sport car") + labels = _infer(self.model, car_img) + self.assertEqual(_LABELS[labels[0]], "sports car, sport car") + + def test_chicken_car(self): + labels = _infer(self.model, np.concat([chicken_img, car_img], axis=0)) + self.assertEqual(_LABELS[labels[0]], "hen") + self.assertEqual(_LABELS[labels[1]], "sports car, sport car") class TestViT(unittest.TestCase): @classmethod @@ -87,12 +82,12 @@ class TestViT(unittest.TestCase): del cls.model def test_chicken(self): - label = _infer(self.model, chicken_img) - self.assertEqual(label, "cock") + labels = _infer(self.model, chicken_img) + self.assertEqual(_LABELS[labels[0]], "cock") def test_car(self): - label = _infer(self.model, car_img) - self.assertEqual(label, "racer, race car, racing car") + labels = _infer(self.model, car_img) + self.assertEqual(_LABELS[labels[0]], "racer, race car, racing car") class TestResNet(unittest.TestCase): @classmethod @@ -105,12 +100,12 @@ class TestResNet(unittest.TestCase): del cls.model def test_chicken(self): - label = _infer(self.model, chicken_img) - self.assertEqual(label, "hen") + labels = _infer(self.model, chicken_img) + self.assertEqual(_LABELS[labels[0]], "hen") def test_car(self): - label = _infer(self.model, car_img) - self.assertEqual(label, "sports car, sport car") + labels = _infer(self.model, car_img) + self.assertEqual(_LABELS[labels[0]], "sports car, sport car") if __name__ == '__main__': unittest.main() diff --git a/tinygrad_repo/test/models/test_onnx.py b/tinygrad_repo/test/models/test_onnx.py index 79ce2046..34e5a132 100644 --- a/tinygrad_repo/test/models/test_onnx.py +++ b/tinygrad_repo/test/models/test_onnx.py @@ -1,23 +1,12 @@ #!/usr/bin/env python -import os -import time import unittest import numpy as np -try: - import onnx -except ModuleNotFoundError: - raise unittest.SkipTest("onnx not installed, skipping onnx test") from tinygrad.frontend.onnx import OnnxRunner -from tinygrad.tensor import Tensor from tinygrad.device import Device -from tinygrad.helpers import CI, fetch, temp, Context +from tinygrad.helpers import fetch, Context -try: - from extra.onnx_helpers import validate - from extra.huggingface_onnx.huggingface_manager import DOWNLOADS_DIR, snapshot_download_with_retry - HUGGINGFACE_AVAILABLE = True -except ModuleNotFoundError: - HUGGINGFACE_AVAILABLE = False +from extra.onnx_helpers import validate +from extra.huggingface_onnx.huggingface_manager import DOWNLOADS_DIR, snapshot_download_with_retry def run_onnx_torch(onnx_model, inputs): import torch @@ -27,86 +16,9 @@ def run_onnx_torch(onnx_model, inputs): torch_out = torch_model(*[torch.tensor(x) for x in inputs.values()]) return torch_out -OPENPILOT_MODEL = "https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx" - np.random.seed(1337) class TestOnnxModel(unittest.TestCase): - @unittest.skip("this isn't a test, it can't fail") - def test_benchmark_openpilot_model(self): - onnx_model = fetch(OPENPILOT_MODEL) - run_onnx = OnnxRunner(onnx_model) - def get_inputs(): - np_inputs = { - "input_imgs": np.random.randn(*(1, 12, 128, 256)), - "big_input_imgs": np.random.randn(*(1, 12, 128, 256)), - "desire": np.zeros((1, 100, 8)), - "traffic_convention": np.array([[1., 0.]]), - "nav_features": np.zeros((1, 256)), - "features_buffer": np.zeros((1, 99, 128)), - } - inputs = {k:Tensor(v.astype(np.float32), requires_grad=False) for k,v in np_inputs.items()} - return inputs - - for _ in range(7): - inputs = get_inputs() - st = time.monotonic() - tinygrad_out = run_onnx(inputs)['outputs'] - mt = time.monotonic() - tinygrad_out.realize() - mt2 = time.monotonic() - tinygrad_out = tinygrad_out.numpy() - et = time.monotonic() - if not CI: - print(f"ran openpilot model in {(et-st)*1000.0:.2f} ms, waited {(mt2-mt)*1000.0:.2f} ms for realize, {(et-mt2)*1000.0:.2f} ms for GPU queue") - - if not CI: - import cProfile - import pstats - inputs = get_inputs() - pr = cProfile.Profile(timer=time.perf_counter_ns, timeunit=1e-6) - pr.enable() - tinygrad_out = run_onnx(inputs)['outputs'] - tinygrad_out.realize() - tinygrad_out = tinygrad_out.numpy() - if not CI: - pr.disable() - stats = pstats.Stats(pr) - stats.dump_stats(temp("net.prof")) - os.system(f"flameprof {temp('net.prof')} > {temp('prof.svg')}") - ps = stats.sort_stats(pstats.SortKey.TIME) - ps.print_stats(30) - - def test_openpilot_model(self): - onnx_model = fetch(OPENPILOT_MODEL) - run_onnx = OnnxRunner(onnx_model) - print("got run_onnx") - inputs = { - "input_imgs": np.random.randn(*(1, 12, 128, 256)), - "big_input_imgs": np.random.randn(*(1, 12, 128, 256)), - "desire": np.zeros((1, 100, 8)), - "traffic_convention": np.array([[1., 0.]]), - "nav_features": np.zeros((1, 256)), - "features_buffer": np.zeros((1, 99, 128)), - } - inputs = {k:v.astype(np.float32) for k,v in inputs.items()} - - st = time.monotonic() - print("****** run onnx ******") - tinygrad_out = run_onnx(inputs)['outputs'] - mt = time.monotonic() - print("****** realize ******") - tinygrad_out.realize() - mt2 = time.monotonic() - tinygrad_out = tinygrad_out.numpy() - et = time.monotonic() - print(f"ran openpilot model in {(et-st)*1000.0:.2f} ms, waited {(mt2-mt)*1000.0:.2f} ms for realize, {(et-mt2)*1000.0:.2f} ms for GPU queue") - - onnx_model = onnx.load(fetch(OPENPILOT_MODEL)) - torch_out = run_onnx_torch(onnx_model, inputs).numpy() - print(tinygrad_out, torch_out) - np.testing.assert_allclose(tinygrad_out, torch_out, atol=1e-4, rtol=1e-2) - @unittest.skip("slow") def test_efficientnet(self): input_name, input_new = "images:0", True @@ -146,7 +58,7 @@ class TestOnnxModel(unittest.TestCase): print(cls, _LABELS[cls]) assert "car" in _LABELS[cls] or _LABELS[cls] == "convertible" -@unittest.skipUnless(HUGGINGFACE_AVAILABLE and Device.DEFAULT == "METAL", "only run on METAL") +@unittest.skipUnless(Device.DEFAULT == "METAL", "only run on METAL") class TestHuggingFaceOnnxModels(unittest.TestCase): @classmethod def setUpClass(cls): diff --git a/tinygrad_repo/test/models/test_real_world.py b/tinygrad_repo/test/models/test_real_world.py index 66f41663..aa734582 100644 --- a/tinygrad_repo/test/models/test_real_world.py +++ b/tinygrad_repo/test/models/test_real_world.py @@ -53,8 +53,8 @@ class TestRealWorld(unittest.TestCase): @unittest.skipUnless(is_dtype_supported(dtypes.float16), "need dtypes.float16") def test_stable_diffusion(self): params = unet_params - params["model_ch"] = 16 - params["ctx_dim"] = 16 + params["model_ch"] = 8 + params["ctx_dim"] = 8 params["num_res_blocks"] = 1 params["n_heads"] = 2 model = UNetModel(**params) @@ -114,7 +114,7 @@ class TestRealWorld(unittest.TestCase): helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 93) - @unittest.skipIf(CI and Device.DEFAULT in {"CPU", "GPU", "LLVM"}, "slow") + @unittest.skipIf(CI and Device.DEFAULT in {"CPU", "CL"}, "slow") def test_train_cifar(self): with Tensor.train(): model = SpeedyResNet(Tensor.ones((12,3,2,2))) @@ -144,6 +144,7 @@ class TestRealWorld(unittest.TestCase): final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=4) assert not np.isnan(lr_scheduler.min_lr), "lr too small or initial_div_facotr too big for half" + @unittest.skipIf(CI and Device.DEFAULT == "CPU", "slow") def test_bert(self): with Tensor.train(): args_tiny = {"attention_probs_dropout_prob": 0.0, "hidden_dropout_prob": 0.0, "vocab_size": 30522, "type_vocab_size": 2, @@ -167,9 +168,5 @@ class TestRealWorld(unittest.TestCase): helper_test("train_bert", lambda: (data["input_ids"], data["segment_ids"], data["input_mask"], data["masked_lm_positions"], \ data["masked_lm_ids"], data["masked_lm_weights"], data["next_sentence_labels"]), train, 0.25, 347) - def test_bert_fuse_arange(self): - with Context(FUSE_ARANGE=1): - self.test_bert() - if __name__ == '__main__': unittest.main() diff --git a/tinygrad_repo/test/models/test_rnnt.py b/tinygrad_repo/test/models/test_rnnt.py index f9d5e2c9..c321b16d 100644 --- a/tinygrad_repo/test/models/test_rnnt.py +++ b/tinygrad_repo/test/models/test_rnnt.py @@ -1,8 +1,8 @@ #!/usr/bin/env python import unittest -import numpy as np -from tinygrad.tensor import Tensor +from tinygrad import Tensor from extra.models.rnnt import LSTM +import numpy as np import torch class TestRNNT(unittest.TestCase): diff --git a/tinygrad_repo/test/models/test_train.py b/tinygrad_repo/test/models/test_train.py index 605e6f6d..fe511474 100644 --- a/tinygrad_repo/test/models/test_train.py +++ b/tinygrad_repo/test/models/test_train.py @@ -1,9 +1,8 @@ -import unittest -import time +import unittest, time import numpy as np +from tinygrad import Device from tinygrad.nn.state import get_parameters from tinygrad.nn import optim -from tinygrad.tensor import Device from tinygrad.helpers import getenv, CI from extra.training import train from extra.models.convnext import ConvNeXt @@ -27,7 +26,7 @@ def train_one_step(model,X,Y): print("done in %.2f ms" % (et*1000.)) def check_gc(): - if Device.DEFAULT == "GPU": + if Device.DEFAULT == "CL": from extra.introspection import print_objects assert print_objects() == 0 @@ -40,7 +39,6 @@ class TestTrain(unittest.TestCase): check_gc() @unittest.skipIf(CI, "slow") - @unittest.skipIf(Device.DEFAULT in ["METAL", "WEBGPU"], "too many buffers for webgpu and metal") def test_efficientnet(self): model = EfficientNet(0) X = np.zeros((BS,3,224,224), dtype=np.float32) @@ -49,7 +47,6 @@ class TestTrain(unittest.TestCase): check_gc() @unittest.skipIf(CI, "slow") - @unittest.skipIf(Device.DEFAULT in ["METAL", "WEBGPU"], "too many buffers for webgpu and metal") def test_vit(self): model = ViT() X = np.zeros((BS,3,224,224), dtype=np.float32) @@ -57,7 +54,7 @@ class TestTrain(unittest.TestCase): train_one_step(model,X,Y) check_gc() - @unittest.skipIf(Device.DEFAULT in ["METAL", "WEBGPU"], "too many buffers for webgpu and metal") + @unittest.skipIf(CI, "slow") def test_transformer(self): # this should be small GPT-2, but the param count is wrong # (real ff_dim is 768*4) diff --git a/tinygrad_repo/test/models/test_whisper.py b/tinygrad_repo/test/models/test_whisper.py index 056615b2..c18fe0cb 100644 --- a/tinygrad_repo/test/models/test_whisper.py +++ b/tinygrad_repo/test/models/test_whisper.py @@ -1,7 +1,7 @@ import unittest import pathlib from examples.whisper import init_whisper, load_file_waveform, transcribe_file, transcribe_waveform -from tinygrad.helpers import CI, fetch +from tinygrad.helpers import CI, fetch, CPU_LLVM from tinygrad import Device, dtypes from tinygrad.device import is_dtype_supported @@ -16,7 +16,7 @@ TRANSCRIPTION_2 = "a slightly longer audio file so that we can test batch transc TEST_FILE_3_URL = 'https://homepage.ntu.edu.tw/~karchung/miniconversations/mc45.mp3' TRANSCRIPTION_3 = "Just lie back and relax. Is the level of pressure about right? Yes, it's fine, and I'd like conditioner please. Sure. I'm going to start the second lathering now. Would you like some Q-tips? How'd you like it cut? I'd like my bangs and the back trimmed, and I'd like the rest thinned out a bit and layered. Where would you like the part? On the left, right about here. Here, have a look. What do you think? It's fine. Here's a thousand anti-dollars. It's 30-ant extra for the rants. Here's your change and receipt. Thank you, and please come again. So how do you like it? It could have been worse, but you'll notice that I didn't ask her for her card. Hmm, yeah. Maybe you can try that place over there next time." # noqa: E501 -@unittest.skipIf(Device.DEFAULT in ["CPU", "LLVM"], "slow") +@unittest.skipIf(Device.DEFAULT in ["CPU"], "slow") @unittest.skipUnless(is_dtype_supported(dtypes.float16), "need float16 support") class TestWhisper(unittest.TestCase): @classmethod @@ -33,11 +33,11 @@ class TestWhisper(unittest.TestCase): def test_transcribe_file1(self): self.assertEqual(transcribe_file(self.model, self.enc, TEST_FILE_1), TRANSCRIPTION_1) - @unittest.skipIf(CI or Device.DEFAULT == "LLVM", "too many tests for CI") + @unittest.skipIf(CI or (Device.DEFAULT == "CPU" and CPU_LLVM), "too many tests for CI") def test_transcribe_file2(self): self.assertEqual(transcribe_file(self.model, self.enc, TEST_FILE_2), TRANSCRIPTION_2) - @unittest.skipIf(CI or Device.DEFAULT == "LLVM", "too many tests for CI") + @unittest.skipIf(CI or (Device.DEFAULT == "CPU" and CPU_LLVM), "too many tests for CI") def test_transcribe_batch12(self): waveforms = [load_file_waveform(TEST_FILE_1), load_file_waveform(TEST_FILE_2)] transcriptions = transcribe_waveform(self.model, self.enc, waveforms) @@ -52,13 +52,13 @@ class TestWhisper(unittest.TestCase): self.assertEqual(TRANSCRIPTION_2, transcriptions[0]) self.assertEqual(TRANSCRIPTION_1, transcriptions[1]) - @unittest.skipIf(CI or Device.DEFAULT == "LLVM", "too long for CI") + @unittest.skipIf(CI or (Device.DEFAULT == "CPU" and CPU_LLVM), "too long for CI") def test_transcribe_long(self): waveform = [load_file_waveform(fetch(TEST_FILE_3_URL))] transcription = transcribe_waveform(self.model, self.enc, waveform) self.assertEqual(TRANSCRIPTION_3, transcription) - @unittest.skipIf(CI or Device.DEFAULT == "LLVM", "too long for CI") + @unittest.skipIf(CI or (Device.DEFAULT == "CPU" and CPU_LLVM), "too long for CI") def test_transcribe_long_no_batch(self): waveforms = [load_file_waveform(fetch(TEST_FILE_3_URL)), load_file_waveform(TEST_FILE_1)] diff --git a/tinygrad_repo/test/opt/test_gen_float4.py b/tinygrad_repo/test/opt/test_gen_float4.py new file mode 100644 index 00000000..1b72514b --- /dev/null +++ b/tinygrad_repo/test/opt/test_gen_float4.py @@ -0,0 +1,180 @@ +import unittest +from tinygrad import Device, Tensor, dtypes +from tinygrad.uop.ops import UOp, Ops +from tinygrad.codegen.opt import Opt, OptOps +from tinygrad.shape.shapetracker import ShapeTracker, View +from tinygrad.engine.realize import get_program +from tinygrad.helpers import AMX + +@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "need backends that support float4") +class TestFloat4(unittest.TestCase): + @staticmethod + def count_float4(uops: list[UOp], n=4): + return (len([uop for uop in uops if uop.op is Ops.LOAD and uop.dtype == dtypes.float.vec(n)]), + len([uop for uop in uops if uop.op is Ops.STORE and uop.src[1].dtype == dtypes.float.vec(n)])) + @staticmethod + def count_half4(uops: list[UOp]): + return (len([uop for uop in uops if uop.op is Ops.LOAD and uop.dtype == dtypes.half.vec(4)]), + len([uop for uop in uops if uop.op is Ops.STORE and uop.src[1].dtype == dtypes.half.vec(4)])) + + def test_float4_basic(self): + a = Tensor.empty(2, 8).realize() + b = Tensor.empty(2, 8).realize() + c = a + b + + s = c.schedule()[0] + realized_ast = s.ast + opts_to_apply = [Opt(op=OptOps.UPCAST, axis=0, arg=4)] + program = get_program(realized_ast, Device[Device.DEFAULT].renderer, opts=opts_to_apply) + + assert TestFloat4.count_float4(program.uops) == (2, 1) + + @unittest.skipIf(Device.DEFAULT in {"CPU"} and AMX, "CPU with AMX upcasts float up to size 16") + def test_float4_multidim(self): + a = Tensor.empty(2, 8).realize() + b = Tensor.empty(2, 8).realize() + c = a + b + + s = c.schedule()[0] + uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=2)]).uops + assert TestFloat4.count_float4(uops) == (4, 2) + + @unittest.skipUnless(Device.DEFAULT in {"CPU"} and AMX, "Only CPU with AMX upcasts float up to size 16") + def test_float4_multidim_amx(self): + def kernel_for_shape(size, shift): + a = Tensor.empty(2, size).realize() + b = Tensor.empty(2, size).realize() + c = a + b + + s = c.schedule()[0] + return get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=shift)]).uops + + sizes = [12, 8, 16] + shifts = [3, 2, 4] + expected_upcast_size = [4, 8, 16] + expected_output = [(6,3), (2,1), (2,1)] + + for i in range(len(sizes)): + assert TestFloat4.count_float4(kernel_for_shape(sizes[i], shifts[i]), expected_upcast_size[i]) == expected_output[i] + + def test_float4_unaligned_load(self): + a = Tensor.empty(9).realize().shrink(((1, 9),)) + b = Tensor.empty(9).realize().shrink(((1, 9),)) + c = a + b + + s = c.schedule()[0] + realized_ast = s.ast + opts_to_apply = [Opt(op=OptOps.UPCAST, axis=0, arg=4)] + program = get_program(realized_ast, Device[Device.DEFAULT].renderer, opts=opts_to_apply) + + assert TestFloat4.count_float4(program.uops) == (0, 1) + + @unittest.skipIf(Device.DEFAULT in {"CPU"} and AMX, "CPU with AMX upcasts float up to size 16") + def test_float4_multidim_unaligned_load(self): + a = Tensor.empty(2, 9).realize().shrink(((0, 2), (1, 9),)) + b = Tensor.empty(2, 9).realize().shrink(((0, 2), (1, 9),)) + c = a + b + + s = c.schedule()[0] + uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=2)]).uops + + assert TestFloat4.count_float4(uops) == (0, 2) + + @unittest.skipUnless(Device.DEFAULT in {"CPU"} and AMX, "Only CPU with AMX upcasts float up to size 16") + def test_float4_multidim_unaligned_load_amx(self): + def kernel_for_shape(size, shift): + a = Tensor.empty(2, size).realize().shrink(((0, 2), (1, size),)) + b = Tensor.empty(2, size).realize().shrink(((0, 2), (1, size),)) + c = a + b + + s = c.schedule()[0] + return get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=shift)]).uops + + sizes = [13, 9, 17] + shifts = [3, 2, 4] + expected_upcast_size = [4, 8, 16] + expected_output = [(0,3), (0,1), (0,1)] + + for i in range(len(sizes)): + assert TestFloat4.count_float4(kernel_for_shape(sizes[i], shifts[i]), expected_upcast_size[i]) == expected_output[i] + + def test_float4_sometimes_unaligned(self): + a = Tensor.empty(1, 1, 8).realize() + b = Tensor.empty(1, 1, 5).realize().shrink(((0, 1), (0, 1), (1, 5))) + c = a.conv2d(b) + # only the first and last conv dot products are aligned in a, and b is never aligned, so no + # float4 should be emitted (the reduce axis of size 4 is the float4 axis here) + + s = c.schedule()[0] + uops = get_program(s.ast, opts=[Opt(op=OptOps.UNROLL, axis=0, arg=4)]).uops + + assert TestFloat4.count_float4(uops) == (0, 0) + + def test_float4_multidim_sometimes_unaligned(self): + a = Tensor.empty(1, 1, 7).realize() + b = Tensor.empty(1, 1, 5).realize().shrink(((0, 1), (0, 1), (1, 5))) + c = a.conv2d(b) + # the first conv dot product is aligned in a. If we upcast the output and reduce + # dimension, then we could do float4 for only that one set of loads, but we currently + # don't. + # UPDATE: now we do this fusion + + s = c.schedule()[0] + uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=0)]).uops + + assert TestFloat4.count_float4(uops) in {(0,1), (1,1)} + + def test_float4_expand(self): + a = Tensor.empty(9).realize().shrink(((1, 9),)) + b = Tensor.empty(2).realize().reshape((2, 1)).expand((2,4)).reshape((8,)) + c = a + b + + # we will upcast the top axis of sz 4. they should not be coalesced into float4, + # since the top axis is not contiguous. + + s = c.schedule()[0] + uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4)]).uops + + assert TestFloat4.count_float4(uops) == (0, 1) + + def test_float4_heterogeneous(self): + a = Tensor.empty(8).realize() + b = Tensor.empty(9).realize().shrink(((1, 9),)) + c = a + b + + # should float4 b but not a + + s = c.schedule()[0] + uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4)]).uops + + assert TestFloat4.count_float4(uops) == (1, 1) + + def test_half4_load_unrolled(self): + # from llama 7B shard 4 gpus + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.VIEW, dtypes.float.ptr(96000), arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1), strides=(0, 32000, 1, 0), offset=0, mask=None, contiguous=True),)), src=( # noqa: E501 + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(96000), arg=0, src=()),)), + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=( + UOp(Ops.CAST, dtypes.float, arg=None, src=( + UOp(Ops.MUL, dtypes.half, arg=None, src=( + UOp(Ops.LOAD, dtypes.half, arg=None, src=( + UOp(Ops.VIEW, dtypes.half.ptr(9216), arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 4096, 0, 1), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501 + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(9216), arg=1, src=()),)),)), + UOp(Ops.LOAD, dtypes.half, arg=None, src=( + UOp(Ops.VIEW, dtypes.half.ptr(32768000), arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 0, 1024, 1), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501 + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(32768000), arg=2, src=()),)),)),)),)),)),)),)) + + # TODO: fix this, expected might change but should be positive + for expected, opts in [ + ((7, 0), [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=3), Opt(op=OptOps.UNROLL, axis=0, arg=4)]), + ((5, 0), [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UNROLL, axis=0, arg=4)]), + ((2, 0), [Opt(op=OptOps.UNROLL, axis=0, arg=4)]), + ]: + program = get_program(ast, Device[Device.DEFAULT].renderer, opts=opts) + + count = TestFloat4.count_half4(program.uops) + assert count == expected, f"{count=}, {expected=}" + +if __name__ == '__main__': + unittest.main() diff --git a/tinygrad_repo/test/opt/test_kernel_opts.py b/tinygrad_repo/test/opt/test_kernel_opts.py new file mode 100644 index 00000000..c0c88651 --- /dev/null +++ b/tinygrad_repo/test/opt/test_kernel_opts.py @@ -0,0 +1,355 @@ +import unittest +from tinygrad import Device, Tensor, dtypes +from tinygrad.helpers import CI +from tinygrad.codegen.opt import Opt, OptOps, KernelOptError + +# TODO: write a clean version of this +from test.test_linearizer import helper_linearizer_opt + +class TestKernelOpts(unittest.TestCase): + @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") + @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") + def test_local_and_grouped_reduce(self): + N = 128 + Tensor.manual_seed(1882) + a = Tensor.rand(4, 4, N, N) + b = Tensor.rand(4, 4, N) + r = (b.sqrt() + ((a+1).sum(axis=3).exp())) + helper_linearizer_opt(r, [ + [Opt(OptOps.LOCAL, 0, 2)], + [Opt(OptOps.LOCAL, 0, 8)], + [Opt(OptOps.LOCAL, 0, 16)], # Checking how it works with locals + [Opt(OptOps.GROUPTOP, 0, 2)], + [Opt(OptOps.GROUPTOP, 0, 32)], + [Opt(OptOps.GROUPTOP, 0, 64)], # Checking how it works with grouped reduce + [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 2)], + [Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.GROUPTOP, 0, 16)], + [Opt(OptOps.LOCAL, 0, 32), Opt(OptOps.GROUPTOP, 0, 2)], + # Checking how it works with locals + grouped reduce + [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 64)], + # Checking how it works with locals + grouped reduce + upcasts + [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.UPCAST, 0, 8), Opt(OptOps.UNROLL, 1, 4)], + # many local + many group + [Opt(OptOps.GROUP, 0, 2)] * 4, + [Opt(OptOps.LOCAL, 0, 2)] * 4, + [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUP, 0, 2)] * 4, + ]) + + def test_upcasts(self): + N = 16 + Tensor.manual_seed(1772) + a = Tensor.rand(N, N) + b = Tensor.rand(N, N) + r = (a+b).sqrt() * ((a+1).exp()) + helper_linearizer_opt(r, [ + [Opt(OptOps.UPCAST, 0, 2)], + [Opt(OptOps.UPCAST, 0, 4)], + [Opt(OptOps.UPCAST, 0, 8)], # Checking how it works with upcasts + ]) + + def test_full_upcast(self): + Tensor.manual_seed(1772) + a = Tensor.rand(4) + b = Tensor.rand(4) + r = (a+b).sqrt() * ((a+1).exp()) + helper_linearizer_opt(r, [ + [Opt(OptOps.UPCAST, 0, 4)], # Checking how it works with upcasts + ]) + + @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") + @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") + def test_matmul(self): + N = 128 + Tensor.manual_seed(1552) + a = Tensor.rand(N, N) + b = Tensor.rand(N, N) + r = a@b + helper_linearizer_opt(r, [ + [Opt(OptOps.UPCAST, 0, 2)], + [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4)], # Checking how it works with upcasts + [Opt(OptOps.LOCAL, 0, 2)], + [Opt(OptOps.LOCAL, 1, 32)], + [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4)], + [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 32)], + [Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.LOCAL, 1, 8)], # Checking how it works with locals + [Opt(OptOps.GROUPTOP, 0, 2)], + [Opt(OptOps.GROUPTOP, 0, 32)], + [Opt(OptOps.GROUPTOP, 0, 32), Opt(OptOps.UNROLL, 0, 4)], # Checking how it works with grouped_reduce + [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 1, 2), Opt(OptOps.GROUPTOP, 0, 32)], + [Opt(OptOps.LOCAL, 0, 8), Opt(OptOps.GROUPTOP, 0, 32)], + [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 0, 8), Opt(OptOps.GROUPTOP, 0, 4)], # Checking how it works with local+grouped_reduce + # Checking all together + [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4), + Opt(OptOps.UPCAST, 1, 2)], + # Full global upcast + local + [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 8)], + ]) + + @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") + @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") + def test_double_reduce(self): + N = 128 + Tensor.manual_seed(1552) + a = Tensor.rand(8, N, 8, N) + r = a.sum(axis=(1,3)) + helper_linearizer_opt(r, [ + # openCL / CL=1 is 256 max threads + [Opt(OptOps.GROUPTOP, 0, 2)], [Opt(OptOps.GROUPTOP, 0, 32)], + [Opt(OptOps.GROUPTOP, 1, 2)], [Opt(OptOps.GROUPTOP, 1, 32)], # Checking how it works with 1 grouped_reduce. + [Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2)], + [Opt(OptOps.GROUPTOP, 0, 16), Opt(OptOps.GROUPTOP, 1, 2)], + [Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 64)], # Checking how it works with 2 grouped_reduces. + [Opt(OptOps.GROUPTOP, 0, 16), Opt(OptOps.GROUPTOP, 1, 2), Opt(OptOps.UNROLL, 0, 4)], + [Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 32), Opt(OptOps.UNROLL, 2, 4)], # Checking how it works with 2 grouped_reduces + upcasts. + [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4), Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 4)], + # Checking how it works with 2 grouped_reduces + upcasts + locals. + [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4), Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 32), Opt(OptOps.UNROLL, 1, 4)], + [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 1, 2), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UPCAST, 0, 2)], + [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 1, 2), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UPCAST, 0, 2), + Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UNROLL, 1, 4)], # Checking how it works with 2 grouped_reduces + upcasts + locals. + [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4), Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UPCAST, 0, 2), + Opt(OptOps.UPCAST, 0, 2)], # No globals + ]) + + @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") + @unittest.skipUnless(any(tc.dtype_in == tc.dtype_out == dtypes.half for tc in Device[Device.DEFAULT].renderer.tensor_cores), + "test requires tensor cores with accumulation in half") # testing with half suffices. + def test_tensor_core_opts(self): + N = 128 + Tensor.manual_seed(1552) + a, b = Tensor.rand(N, N, dtype=dtypes.half), Tensor.rand(N, N, dtype=dtypes.half) + r = a.matmul(b, dtype=dtypes.half) + atol, rtol = 0.25, 0.01 + helper_linearizer_opt(r, [ + [], + [Opt(OptOps.UPCAST, 0, 4)], + [Opt(OptOps.UPCAST, 1, 4)], + [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4)], # check upcasts + [Opt(OptOps.UNROLL, 0, 2)], # check unroll + [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 2)], # check combo of unroll and local + [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 2)], + [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4)], + [Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UPCAST, 0, 4)], # check permutations + [Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 0, 4)], + [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 1, 4)], + [Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], + ], apply_tc=True, atol=atol, rtol=rtol) + + @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") + @unittest.skipUnless(any(tc.dtype_in == tc.dtype_out == dtypes.half for tc in Device[Device.DEFAULT].renderer.tensor_cores), + "test requires tensor cores with accumulation in half") # testing with half suffices. + @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") + def test_tensor_core_opts_locals(self): + N = 128 + Tensor.manual_seed(1552) + a, b = Tensor.rand(N, N, dtype=dtypes.half), Tensor.rand(N, N, dtype=dtypes.half) + r = a.matmul(b, dtype=dtypes.half) + atol, rtol = 0.25, 0.01 + helper_linearizer_opt(r, [ + [Opt(OptOps.UNROLL, 0, 0)], # check full unroll of reduce with locals + [Opt(OptOps.LOCAL, 0, 4)], # check local + [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.LOCAL, 0, 2)], + [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 0, 4)], + ], apply_tc=True, atol=atol, rtol=rtol) + + @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") + @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") + @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory") + @unittest.skipUnless(any(tc.dtype_in == tc.dtype_out == dtypes.half for tc in Device[Device.DEFAULT].renderer.tensor_cores), + "test requires tensor cores with accumulation in half") # testing with half suffices. + # NOTE: the METAL test is broken, likely due to a compiler bug. passes on CI with -O0 and with default opt level locally on M3 + @unittest.skipIf(Device.DEFAULT == "METAL", "broken for METAL") + @unittest.skip("feature was removed") + def test_tensor_core_opts_group(self): + N = 128 + Tensor.manual_seed(1552) + a, b = Tensor.rand(N, N, dtype=dtypes.half), Tensor.rand(N, N, dtype=dtypes.half) + r = a.matmul(b, dtype=dtypes.half) + atol, rtol = 0.25, 0.01 + helper_linearizer_opt(r, [ + [Opt(OptOps.GROUP, 0, 2)], + [Opt(OptOps.GROUPTOP, 0, 4)], + [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.GROUP, 0, 2)], + [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUP, 0, 2)], + [Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.GROUP, 0, 2)], + [Opt(OptOps.UPCAST, 0, 2), Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUP, 0, 2)], + [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 1, 2)], + ], apply_tc=True, atol=atol, rtol=rtol) + + def test_padto_matmul(self): + if (CI and Device.DEFAULT in ["AMD", "NV", "CUDA"]): + self.skipTest("super slow on CUDA and AMD because of the big grid dims") + N = 17 * 17 + Tensor.manual_seed(289) + a = Tensor.rand(N, N) + b = Tensor.rand(N, N) + helper_linearizer_opt(a@b, [ + [Opt(OptOps.PADTO, 0, 32)], + [Opt(OptOps.PADTO, 1, 32)], + [Opt(OptOps.PADTO, 2, 32)], + [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32)], + [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32), Opt(OptOps.PADTO, 2, 32)], + # can optimize further post PADTO + [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32), Opt(OptOps.UPCAST, 0, 2), Opt(OptOps.UPCAST, 1, 2),], + ]) + + def test_padto_upcasted_not_ok(self): + N = 4 + a = Tensor.rand(N, N) + b = Tensor.rand(N, N) + helper_linearizer_opt(a@b, [ + [Opt(OptOps.UPCAST, 0, 0)], + [Opt(OptOps.UPCAST, 1, 0)], + [Opt(OptOps.UNROLL, 0, 0)], + [Opt(OptOps.PADTO, 0, 8)], + [Opt(OptOps.PADTO, 1, 8)], + [Opt(OptOps.PADTO, 2, 8)], + ]) + with self.assertRaises(KernelOptError): + helper_linearizer_opt(a@b, [[Opt(OptOps.UPCAST, 0, 0), Opt(OptOps.PADTO, 1, 8)]]) + with self.assertRaises(KernelOptError): + helper_linearizer_opt(a@b, [[Opt(OptOps.UPCAST, 1, 0), Opt(OptOps.PADTO, 1, 8)]]) + with self.assertRaises(KernelOptError): + helper_linearizer_opt(a@b, [[Opt(OptOps.UNROLL, 0, 0), Opt(OptOps.PADTO, 2, 8)]]) + + def test_padto_sum_ok(self): + N = 18 * 18 + # NOTE: this setup prevents 17 * 17 contiguous merged into one dimension + a = Tensor.rand(N, N).realize().shrink(((0, 17), (0, 17))) * 100 + b = (Tensor.rand(N, N) < 0.5).realize().shrink(((0, 17), (0, 17))) + + helper_linearizer_opt(a.sum(0), [ + [Opt(OptOps.PADTO, 0, 32)], + [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),], + ]) + helper_linearizer_opt(a.sum(1), [ + [Opt(OptOps.PADTO, 0, 32)], + [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),], + ]) + + # can pad sum reduce axis if there's no unsafe ops prior to sum + for axis in (0, 1): + helper_linearizer_opt(a.sum(), [[Opt(OptOps.PADTO, axis, 32)],]) + helper_linearizer_opt(a.sum(0), [[Opt(OptOps.PADTO, axis, 32)],]) + helper_linearizer_opt(b.sum(), [[Opt(OptOps.PADTO, axis, 32)],]) + helper_linearizer_opt(b.sum(0), [[Opt(OptOps.PADTO, axis, 32)],]) + helper_linearizer_opt(b.sum(dtype=dtypes.bool), [[Opt(OptOps.PADTO, axis, 32)],]) + # TODO: why? + if Device.DEFAULT != "WEBGPU": + helper_linearizer_opt(b.sum(0, dtype=dtypes.bool), [[Opt(OptOps.PADTO, axis, 32)],]) + helper_linearizer_opt(b.sum(1, dtype=dtypes.bool), [[Opt(OptOps.PADTO, axis, 32)],]) + + # having unsafe ops after sum is fine + helper_linearizer_opt(a.sum().exp(), [[Opt(OptOps.PADTO, 0, 32)],]) + helper_linearizer_opt(a.sum(0).exp(), [[Opt(OptOps.PADTO, 1, 32)],]) + + def test_padto_sum_not_ok(self): + N = 18 * 18 + # NOTE: this setup prevents 17 * 17 contiguous merged into one dimension + a = Tensor.rand(N, N).shrink(((0, 17), (0, 17))).exp() + # exp is not safe to pad + with self.assertRaises(KernelOptError): + helper_linearizer_opt(a.exp().sum(), [[Opt(OptOps.PADTO, 0, 32)],]) + with self.assertRaises(KernelOptError): + helper_linearizer_opt(a.exp().sum(0), [[Opt(OptOps.PADTO, 1, 32)],]) + + b = a < 1 + # lt is not safe to pad + with self.assertRaises(KernelOptError): + helper_linearizer_opt(b.sum(), [[Opt(OptOps.PADTO, 0, 32)],]) + with self.assertRaises(KernelOptError): + helper_linearizer_opt(b.sum(0), [[Opt(OptOps.PADTO, 1, 32)],]) + + def test_padto_max(self): + N = 18 * 18 + # NOTE: this setup prevents 17 * 17 contiguous merged into one axis + a = -Tensor.rand(N, N).shrink(((0, 17), (0, 17))) * 100 + + helper_linearizer_opt(a.max(0), [ + [Opt(OptOps.PADTO, 0, 32)], + [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),], + ]) + helper_linearizer_opt(a.max(1), [ + [Opt(OptOps.PADTO, 0, 32)], + [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),], + ]) + + # cannot pad max kernel on reduce + with self.assertRaises(KernelOptError): + helper_linearizer_opt(a.max(), [[Opt(OptOps.PADTO, 0, 32)],]) + with self.assertRaises(KernelOptError): + helper_linearizer_opt(a.max(0), [[Opt(OptOps.PADTO, 1, 32)],]) + + def test_padto_where(self): + Tensor.manual_seed(0) + N = 17 * 17 + a = (Tensor.randn(N, N).realize().max(axis=0, keepdim=True) > 1).where(1, 0) + helper_linearizer_opt(a.max(0), [ + [Opt(OptOps.PADTO, 0, 32)], + [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),], + ]) + + def test_padto_where_multioutput(self): + Tensor.manual_seed(0) + N = 17 * 17 + r = Tensor.randn(N, N).realize().max(axis=0, keepdim=True) > 1 + a0 = r.where(1, 0) + a1 = r.where(2, 0) + helper_linearizer_opt([a0.max(0), a1.max(0)], [ + [Opt(OptOps.PADTO, 0, 32)], + [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),], + ]) + + @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") + @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") + def test_color_shapes_with_local(self): + N = 32 + Tensor.manual_seed(1552) + a = Tensor.rand(N, N) + b = Tensor.rand(N, N) + r = a@b + opts_shapes = [ + ([Opt(OptOps.LOCAL, 0, 2)], [("blue",16),("blue",32),("cyan",2),("red",32)]), + ([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.GROUP, 0, 2)], [("blue",16),("blue",32),("cyan",2),("green",2),("red",16)]), + # check to ensure local_dims are stable for full UNROLL of the first reduce + ([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]), + ([Opt(OptOps.UNROLL, 0, 0),Opt(OptOps.LOCAL, 0, 2)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]), + # check behavior for full UNROLL on an existing GROUP + ([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.GROUP, 0, 0),Opt(OptOps.UNROLL, 0, 2)], [("blue",16),("blue",32),("cyan",2),("green",16),("magenta",2)]), + ([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.GROUP, 0, 0),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]), + ([Opt(OptOps.GROUP, 0, 0),Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]), + ([Opt(OptOps.GROUP, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",32),("blue",32),("red",16),("magenta",2)]), + ] + helper_linearizer_opt(r, [x[0] for x in opts_shapes], color_sizes=[x[1] for x in opts_shapes]) + + @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") + @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") + @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") + def test_arange_opts(self): + a = Tensor.arange(128) + # NOTE: arange no longer has reduce ops available for opt + helper_linearizer_opt(a, [ + #[Opt(OptOps.GROUP, 0, 32)], + #[Opt(OptOps.GROUPTOP, 0, 32)], + [Opt(op=OptOps.LOCAL, axis=0, arg=8)], + [Opt(op=OptOps.LOCAL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=0, arg=0)], + #[Opt(op=OptOps.LOCAL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.GROUP, axis=0, arg=8)], + #[Opt(op=OptOps.LOCAL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.GROUP, axis=0, arg=8), Opt(op=OptOps.UNROLL, axis=1, arg=4)], # noqa: E501 + ]) + + @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_threads, "test requires threads") + @unittest.skipUnless(Device[Device.DEFAULT].renderer.global_max is not None and + Device[Device.DEFAULT].renderer.global_max[0] > 1, "test requires multicore") + def test_thread_opts(self): + a = Tensor.rand(4, 4, 4, 4) + b = Tensor.rand(4, 4, 4) + r = (b.sqrt() + ((a+1).sum(axis=3).exp())) + helper_linearizer_opt(r, [ + [Opt(OptOps.THREAD, 0, 2)], + [Opt(OptOps.UPCAST, 0, 2), Opt(OptOps.THREAD, 0, 2)], + [Opt(OptOps.UPCAST, 0, 2), Opt(OptOps.THREAD, 0, 2), Opt(OptOps.UNROLL, 0, 2)], + ] + [[Opt(OptOps.THREAD, 0, 4)] if Device[Device.DEFAULT].renderer.global_max[0] >= 4 else []] + + [[Opt(OptOps.THREAD, 0, 8)] if Device[Device.DEFAULT].renderer.global_max[0] >= 8 else []]) + +if __name__ == '__main__': + unittest.main() diff --git a/tinygrad_repo/test/opt/test_tensor_cores.py b/tinygrad_repo/test/opt/test_tensor_cores.py new file mode 100644 index 00000000..0639bf58 --- /dev/null +++ b/tinygrad_repo/test/opt/test_tensor_cores.py @@ -0,0 +1,191 @@ +import numpy as np +import unittest +from dataclasses import replace + +from tinygrad import Device, Tensor, dtypes +from tinygrad.tensor import _to_np_dtype +from tinygrad.uop.ops import Ops +from tinygrad.dtype import DType +from tinygrad.device import is_dtype_supported +from tinygrad.helpers import AMX, CI, AMD_LLVM, CPU_LLVM +from tinygrad.engine.realize import CompiledRunner, get_program +from tinygrad.codegen.opt import Opt, OptOps, KernelOptError + +# TODO: write a clean version of this +from test.test_linearizer import helper_realized_ast, helper_linearizer_opt + +def helper_tc_ensure_uops_and_opts_count(N: int, M:int, K:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_select:int=-1, tc_opt:int=0, + ensure_triggered:bool=True): + a, b = Tensor.rand(M, K, dtype=dtype_in), Tensor.rand(K, N, dtype=dtype_in) + r = a.matmul(b, dtype=dtype_out) + sched = r.schedule() + realized_ast = sched[-1].ast + opts_to_apply = [Opt(OptOps.TC, axis, (tc_select, tc_opt, 1))] + + if ensure_triggered: + program = get_program(realized_ast, Device[Device.DEFAULT].renderer, opts=opts_to_apply) + wmmas = len([uop for uop in program.uops if uop.op is Ops.WMMA]) + tcs = len([x for x in program.applied_opts if x.op is OptOps.TC]) + assert wmmas > 0, "tensor core not triggered" + assert tcs == 1, "tensor core opt not included" + else: + try: + program = get_program(realized_ast, Device[Device.DEFAULT].renderer, opts=opts_to_apply) + assert False, "OptOps.TC triggered, expected KernelOptError" + except KernelOptError: pass + +def helper_tc_allclose(N:int, M:int, K:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_select:int=-1, tc_opt:int=0, use_tensor_cores:int=1): + a, b = Tensor.rand(M, K, dtype=dtype_in), Tensor.rand(K, N, dtype=dtype_in) + np_a, np_b = a.numpy(), b.numpy() + r = a.matmul(b, dtype=dtype_out) + if dtype_in == dtypes.bfloat16: r = r.float() + realized_ast, bufs = helper_realized_ast(r) + opts = [Opt(op=OptOps.TC, axis=axis, arg=(tc_select, tc_opt, use_tensor_cores))] + prg = CompiledRunner(replace(get_program(realized_ast, opts=opts), device=Device.DEFAULT)) + if use_tensor_cores == 1: assert len([uop for uop in prg.p.uops if uop.op is Ops.WMMA]) > 0, "wmma not triggered" + assert len([x for x in prg.p.uops[-1].arg.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included" + prg.exec(bufs) + if dtype_in == dtypes.half: tc_atol, tc_rtol = 1e-2, 1e-3 + elif dtype_in == dtypes.bfloat16: tc_atol, tc_rtol = 1e-2, 1e-2 + else: tc_atol, tc_rtol = 5e-3, 1e-4 + c = bufs[0].numpy().reshape((M,N)) + np.testing.assert_allclose(c, np_a @ np_b, atol=tc_atol, rtol=tc_rtol) + +class TestTensorCores(unittest.TestCase): + # TODO: don't skip bf16 for real device (METAL, AMD) + @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") + def test_tensor_cores(self): + for tc in Device[Device.DEFAULT].renderer.tensor_cores: + if not is_dtype_supported(tc.dtype_in) or not is_dtype_supported(tc.dtype_out): continue + # for AMX, tc.dims[2] == 1 so reduceop is None thus tensor_cores are not triggered + helper_tc_allclose(tc.dims[0], tc.dims[1], 2 if AMX else tc.dims[2], tc.dtype_in, tc.dtype_out, axis=0, tc_opt=0) + + @unittest.skipIf(Device.DEFAULT == "PYTHON", "not generated on EMULATED device") + @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") + def test_tensor_cores_codegen(self): + for tc in Device[Device.DEFAULT].renderer.tensor_cores: + if not is_dtype_supported(tc.dtype_in) or not is_dtype_supported(tc.dtype_out): continue + n, m, k = tc.dims[0], tc.dims[1], 2 if AMX else tc.dims[2] + a, b = Tensor.rand(m, k, dtype=tc.dtype_in), Tensor.rand(k, n, dtype=tc.dtype_in) + r = a.matmul(b, dtype=tc.dtype_out) + prg = get_program(r.schedule()[-1].ast, opts=[Opt(op=OptOps.TC, axis=0, arg=(-1, 2, 1))]) + if Device.DEFAULT == "CPU" and CPU_LLVM: + assert "0x201000" in prg.src + elif Device.DEFAULT == "AMD" and AMD_LLVM: + assert "@llvm.amdgcn.wmma" in prg.src + elif Device[Device.DEFAULT].renderer.suffix == "PTX": + assert "mma.sync.aligned" in prg.src + else: + assert "__WMMA_" in prg.src + + @unittest.skipIf((Device.DEFAULT == "AMD") or (Device.DEFAULT == "PYTHON" and Device.default.renderer.device == "AMD"), "broken for AMD") + @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") + def test_tensor_cores_padded(self): + for tc in Device[Device.DEFAULT].renderer.tensor_cores: + if not is_dtype_supported(tc.dtype_in) or not is_dtype_supported(tc.dtype_out): continue + helper_tc_allclose(tc.dims[0]+(pad:=1), tc.dims[1]+pad, tc.dims[2]+pad, tc.dtype_in, tc.dtype_out, tc_opt=2) + + # AMD compiler bug: AMD miscompiles non-zero padded tc kernels with -O3, producing wrong results, nans or hang (see #9606) + # Internal bug: zero-stride dimensions combined with a mask may produce wrong index/valid for pad == 1 on AMD + @unittest.skipUnless((Device.DEFAULT == "AMD") or (Device.DEFAULT == "PYTHON" and Device.default.renderer.device == "AMD"), "test for AMD's tc") + @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") + @unittest.skip("warp elements not duplicated properly across lanes") + def test_tensor_cores_padded_amd(self): + for tc in Device[Device.DEFAULT].renderer.tensor_cores: + if not is_dtype_supported(tc.dtype_in) or not is_dtype_supported(tc.dtype_out): continue + helper_tc_allclose(tc.dims[0]+(pad:=1), tc.dims[1]+pad, tc.dims[2]+pad, tc.dtype_in, tc.dtype_out, tc_opt=2) + + @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") + def test_tensor_cores_padded_uops(self): + for tc in Device[Device.DEFAULT].renderer.tensor_cores: + pad = 1 + + # check that TC is triggered for TC_OPT=2 + helper_tc_ensure_uops_and_opts_count(tc.dims[0]+pad, tc.dims[1]+pad, tc.dims[2]+pad, + tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=True) + + # check that TC is not triggered for TC_OPT<2 + helper_tc_ensure_uops_and_opts_count(tc.dims[0]+pad, tc.dims[1]+pad, tc.dims[2]+pad, + tc.dtype_in, tc.dtype_out, tc_opt=1, ensure_triggered=False) + helper_tc_ensure_uops_and_opts_count(tc.dims[0]+pad, tc.dims[1]+pad, tc.dims[2]+pad, + tc.dtype_in, tc.dtype_out, tc_opt=0, ensure_triggered=False) + + # check excessive padding doesn't trigger padded TC in TC_OPT=2 + helper_tc_ensure_uops_and_opts_count(tc.dims[0]//4, tc.dims[1], tc.dims[2], tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=False) + helper_tc_ensure_uops_and_opts_count(tc.dims[0], tc.dims[1]//4, tc.dims[2], tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=False) + if not AMX: # AMX tc.dims[2] == 1 + helper_tc_ensure_uops_and_opts_count(tc.dims[0], tc.dims[1], tc.dims[2]//8, tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=False) + + @unittest.skipIf(Device.DEFAULT == "PYTHON", "not generated on EMULATED device") + @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI is really slow here") + @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") + def test_tensor_cores_multi_reduce(self): + for tc in Device[Device.DEFAULT].renderer.tensor_cores: + if not is_dtype_supported(tc.dtype_in) or not is_dtype_supported(tc.dtype_out): continue + if tc.dtype_in is dtypes.bfloat16: continue # <-- broken with numpy + # this will be a M=G16, N=G32, M=G16, M=G16, K=R16, K=R16, K=R16 with 9 choices of TC MNK axes + golden_result = None + for axis in range(9): + a = Tensor.rand(16, 16, 29, 29, dtype=tc.dtype_in).realize() + b = Tensor.rand(32, 16, 16, 16, dtype=tc.dtype_in).realize() + c = a.conv2d(b, padding=1, dtype=tc.dtype_out) + realized_ast, real_bufs = helper_realized_ast(c) + + program = get_program(realized_ast, Device[Device.DEFAULT].renderer, opts=[Opt(OptOps.TC, axis, (-1, 2, 1))]) + assert len([uop for uop in program.uops if uop.op is Ops.WMMA]) > 0, "tensor core not triggered" + assert len([x for x in program.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included" + + prg = CompiledRunner(program) + # TODO: support this even if numpy doesn't + if _to_np_dtype(real_bufs[0].dtype) is None: continue + real_bufs[0].copyin(np.zeros((real_bufs[0].size, ), dtype=_to_np_dtype(real_bufs[0].dtype)).data) # Zero to check that all values are filled + prg.exec(real_bufs) + result = np.frombuffer(real_bufs[0].as_buffer(), _to_np_dtype(real_bufs[0].dtype)) + + # ensure the results for each choice of axis matches + if golden_result is None: golden_result = np.frombuffer(real_bufs[0].as_buffer(), _to_np_dtype(real_bufs[0].dtype)) + np.testing.assert_allclose(result, golden_result, atol=0.1, rtol=0.2) + + @unittest.skipIf(Device.DEFAULT == "PYTHON", "slow on EMULATED device") + @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") + def test_tensor_cores_unroll_phi(self): + tc = Device[Device.DEFAULT].renderer.tensor_cores[0] + x, y = Tensor.rand(128, 128, dtype=tc.dtype_in), Tensor.rand(128, 128, dtype=tc.dtype_in) + r = x.matmul(y, dtype=tc.dtype_out) + opts = [Opt(OptOps.UNROLL, 0, 4)] + ast = helper_linearizer_opt(r, [opts], apply_tc=True, atol=3e-2, rtol=1e-3) + for u in get_program(ast, opts=opts).uops: + if u.op is Ops.WMMA: + assert u.src[-1].src[0].op != Ops.STORE + + @unittest.skipIf(Device.DEFAULT == "PYTHON", "slow on EMULATED device") + @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") + @unittest.skipIf(Device.DEFAULT in {"CPU"}, "CPU does not support using a different type for accumulation") + def test_tensor_cores_unroll_casted_phi(self): + tc = [tc for tc in Device[Device.DEFAULT].renderer.tensor_cores if tc.dtype_in != tc.dtype_out][0] + x, y = Tensor.rand(128, 128, dtype=tc.dtype_in), Tensor.rand(128, 128, dtype=tc.dtype_in) + r = x.matmul(y, dtype=tc.dtype_out) + opts = [Opt(OptOps.UNROLL, 0, 4)] + ast = helper_linearizer_opt(r, [opts], apply_tc=True, atol=3e-2, rtol=1e-3) + for u in get_program(ast, opts=opts).uops: + if u.op is Ops.WMMA: + #assert u.src[-1].dtype == dtypes.float.vec(prod(tc.thread_local_sizes[2])) + assert u.src[-1].src[0].op != Ops.STORE + + @unittest.skipIf(Device.DEFAULT == "PYTHON", "slow on EMULATED device") + @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") + @unittest.skipIf(Device.DEFAULT in {"CPU"}, "CPU does not support using a different type for accumulation") + def test_tensor_cores_unroll_casted_phi_with_children(self): + # all STORE children are outside the loop + tc = [tc for tc in Device[Device.DEFAULT].renderer.tensor_cores if tc.dtype_in != tc.dtype_out][0] + x, y = Tensor.rand(128, 128, dtype=tc.dtype_in), Tensor.rand(128, 128, dtype=tc.dtype_in) + r = x.matmul(y, dtype=tc.dtype_out).relu() + opts = [Opt(OptOps.UNROLL, 0, 4)] + ast = helper_linearizer_opt(r, [opts], apply_tc=True, atol=3e-2, rtol=1e-3) + for u in get_program(ast, opts=opts).uops: + if u.op is Ops.WMMA: + #assert u.src[-1].dtype == dtypes.float.vec(prod(tc.thread_local_sizes[2])) + assert u.src[-1].src[0].op != Ops.STORE + +if __name__ == '__main__': + unittest.main() diff --git a/tinygrad_repo/test/speed/external_test_copy_speed.py b/tinygrad_repo/test/speed/external_test_copy_speed.py index 391a4da0..359a2499 100644 --- a/tinygrad_repo/test/speed/external_test_copy_speed.py +++ b/tinygrad_repo/test/speed/external_test_copy_speed.py @@ -1,9 +1,9 @@ import unittest, numpy as np from tinygrad import Tensor, Device, TinyJit -from tinygrad.helpers import Timing, CI, OSX +from tinygrad.helpers import Timing, CI, OSX, getenv import multiprocessing.shared_memory as shared_memory -N = 256 +N = getenv("NSZ", 256) class TestCopySpeed(unittest.TestCase): @classmethod def setUpClass(cls): Device[Device.DEFAULT].synchronize() @@ -54,30 +54,32 @@ class TestCopySpeed(unittest.TestCase): @TinyJit def _do_copy(t): return t.to('CPU').realize() - t = Tensor.randn(N, N, 4).contiguous().realize() + t = Tensor.randn(N, N).contiguous().realize() + Device[Device.DEFAULT].synchronize() for _ in range(5): - with Timing("sync: ", on_exit=lambda ns: f" @ {t.nbytes()/ns:.2f} GB/s"): + with Timing(f"copy {Device.DEFAULT} -> CPU {t.nbytes()/(1024**2)}M: ", on_exit=lambda ns: f" @ {t.nbytes()/ns:.2f} GB/s"): x = _do_copy(t) Device[Device.DEFAULT].synchronize() np.testing.assert_equal(t.numpy(), x.numpy()) - def testCopytoCPUtoDefaultJit(self): + def testCopyCPUtoDefaultJit(self): if Device.DEFAULT == "CPU": return unittest.skip("CPU to CPU copy is a no-op") @TinyJit - def _do_copy(x): return t.to(Device.DEFAULT).realize() + def _do_copy(x): return x.to(Device.DEFAULT).realize() for _ in range(5): - t = Tensor.randn(N, N, 4, device="CPU").contiguous().realize() - with Timing("sync: ", on_exit=lambda ns: f" @ {t.nbytes()/ns:.2f} GB/s"): + t = Tensor.randn(N, N, device="CPU").contiguous().realize() + Device["CPU"].synchronize() + with Timing(f"copy CPU -> {Device.DEFAULT} {t.nbytes()/(1024**2)}M: ", on_exit=lambda ns: f" @ {t.nbytes()/ns:.2f} GB/s"): x = _do_copy(t) Device[Device.DEFAULT].synchronize() np.testing.assert_equal(t.numpy(), x.numpy()) @unittest.skipIf(CI, "CI doesn't have 6 GPUs") - @unittest.skipIf(Device.DEFAULT != "GPU", "only test this on GPU") + @unittest.skipIf(Device.DEFAULT != "CL", "only test this on CL") def testCopyCPUto6GPUs(self): - from tinygrad.runtime.ops_gpu import CLDevice + from tinygrad.runtime.ops_cl import CLDevice if len(CLDevice.device_ids) != 6: raise unittest.SkipTest("computer doesn't have 6 GPUs") t = Tensor.ones(N, N, device="CPU").contiguous().realize() print(f"buffer: {t.nbytes()*1e-9:.2f} GB") @@ -85,8 +87,8 @@ class TestCopySpeed(unittest.TestCase): with Timing("sync: ", on_exit=lambda ns: f" @ {t.nbytes()/ns:.2f} GB/s ({t.nbytes()*6/ns:.2f} GB/s total)"): with Timing("queue: "): for g in range(6): - t.to(f"gpu:{g}").realize() - Device["gpu"].synchronize() + t.to(f"CL:{g}").realize() + Device["CL"].synchronize() if __name__ == '__main__': unittest.main() diff --git a/tinygrad_repo/test/test_arange.py b/tinygrad_repo/test/test_arange.py index e680ffbd..a46b38a0 100644 --- a/tinygrad_repo/test/test_arange.py +++ b/tinygrad_repo/test/test_arange.py @@ -1,75 +1,29 @@ -import unittest, contextlib +import unittest import numpy as np from tinygrad import Tensor, GlobalCounters, dtypes, nn, Device, Variable -from tinygrad.helpers import CI, Context, getenv +from tinygrad.helpers import CI, Context, getenv, RANGEIFY from tinygrad.engine.realize import run_schedule -from tinygrad.codegen.opt.kernel import Opt, OptOps, Kernel, KernelOptError from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program -from tinygrad.codegen.opt.search import get_kernel_actions from tinygrad.uop.ops import Ops -from tinygrad.codegen import apply_rewrites, rewrites_for_views class TestArange(unittest.TestCase): - def _get_flops(self, N, opts=None): + def _get_flops(self, N): GlobalCounters.reset() tt = Tensor.arange(N) sched = tt.schedule() self.assertEqual(len(sched), 1) - p = get_program(sched[-1].ast, opts=opts) - print(p.name) - #print(p.src) + p = get_program(sched[-1].ast) ExecItem(CompiledRunner(p), [tt.uop.buffer]).run() np.testing.assert_equal(tt.numpy(), np.arange(N)) return p.estimates.ops - def test_complexity(self, opts=None, limit=None): - f1 = self._get_flops(256, opts) - f2 = self._get_flops(2560, opts) - print(f"{f1=}, {f2=}") - # add 1 to avoid divide by 0. arange is 0 flops now! - assert (f1 < 6000 and f2 < 6000) or ((f2+1) / (f1+1) < 16), f"bad complexity, flops {(f2+1) / (f1+1):.1f}X while inputs 10X" - if limit is not None and not getenv("PTX"): - # PTX counts index ALU in flops - assert f1 <= limit, f"{f1=}, {limit=}" + def test_complexity(self): + self.assertEqual(self._get_flops(256), 0) + self.assertEqual(self._get_flops(2560), 0) - def test_complexity_w_upcast(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4)], limit=0) - def test_complexity_w_unroll2(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 2)], limit=0) - def test_complexity_w_unroll4(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 4)], limit=0) - def test_complexity_w_unroll8(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 8)], limit=0) - def test_complexity_w_upcast_and_unroll(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], limit=0) - - if Device.default.renderer.has_local: - # TODO: fix limit - def test_complexity_w_group(self): return self.test_complexity([Opt(OptOps.GROUP, 0, 16)], limit=81920) - def test_complexity_w_group_top(self): return self.test_complexity([Opt(OptOps.GROUPTOP, 0, 16)], limit=106496) - - def test_complexity_w_local(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16)], limit=0) - @unittest.skip("doesn't work yet. TODO: this absolutely should work") - def test_complexity_w_local_unroll4(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.UNROLL, 0, 4)], limit=0) - @unittest.skip("doesn't work yet") - def test_complexity_w_local_and_padto(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.PADTO, axis=1, arg=32)]) - - def test_all_opts(self, opts=None, exclude=None): - k = Kernel(apply_rewrites(Tensor.arange(256).schedule()[-1].ast, rewrites_for_views)) - if opts is not None: - for o in opts: k.apply_opt(o) - all_opts_256 = [kk.applied_opts for kk in get_kernel_actions(k, include_0=False).values()] - k = Kernel(apply_rewrites(Tensor.arange(2560).schedule()[-1].ast, rewrites_for_views)) - if opts is not None: - for o in opts: k.apply_opt(o) - all_opts_2560 = [kk.applied_opts for kk in get_kernel_actions(k, include_0=False).values()] - all_opts = [x for x in all_opts_256 if x in all_opts_2560] - for opts in all_opts: - if exclude is not None and opts[-1] in exclude: continue - print(opts) - self.test_complexity(opts) - def test_all_opts_w_local(self): - with contextlib.suppress(KernelOptError): - return self.test_all_opts([Opt(OptOps.LOCAL, 0, 16)], [Opt(op=OptOps.PADTO, axis=1, arg=32)]) - def test_all_opts_w_upcast(self): return self.test_all_opts([Opt(OptOps.UPCAST, 0, 4)]) - def test_all_opts_w_unroll(self): return self.test_all_opts([Opt(OptOps.UNROLL, 0, 4)], [Opt(op=OptOps.GROUP, axis=0, arg=0)]) - def test_all_opts_w_upcast_and_unroll(self): - return self.test_all_opts([Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], [Opt(op=OptOps.GROUP, axis=0, arg=0)]) + def test_arange_cat(self): + t = Tensor.arange(2, dtype=dtypes.int)+Tensor([3]) + self.assertEqual(t.cat(t).tolist(), [3, 4, 3, 4]) class TestRand(unittest.TestCase): def test_fused_rand_less_ops(self, noopt=1): @@ -102,7 +56,6 @@ class TestIndexing(unittest.TestCase): run_schedule(sched) self.assertEqual(out.item(), 1337) - @unittest.skipIf(getenv("PTX"), "broken on ptx for some reason") def test_manual_index(self): dataset = Tensor.rand(DSET, DDIM).realize() idxs = Tensor([0,3,5,6]).realize() @@ -158,7 +111,7 @@ class TestIndexing(unittest.TestCase): X = dataset[idxs] assert X.shape == (4,DDIM) sched = X.schedule() - self.assertEqual(len(sched), 2) + self.assertEqual(len(sched), 1 if RANGEIFY else 2) run_schedule(sched) assert GlobalCounters.global_ops < 4*DSET, f"too many ops {GlobalCounters.global_ops} != {4*DSET}" np.testing.assert_allclose(real_index, X.numpy()) @@ -172,7 +125,6 @@ class TestIndexing(unittest.TestCase): X = dataset[idxs] np.testing.assert_equal(X.numpy(), 0) - @unittest.skipIf(getenv("PTX"), "broken on ptx for some reason") def test_index_mnist(self, noopt=1, op_limit=512*784*13, split_reduceop=0): # WEBGPU generates more ops due to bitpacking of < 4-byte dtypes if Device.DEFAULT == "WEBGPU": op_limit *= 15 @@ -191,7 +143,6 @@ class TestIndexing(unittest.TestCase): def test_index_mnist_split(self): self.test_index_mnist(1, split_reduceop=1) def test_index_mnist_opt_split(self): self.test_index_mnist(0, split_reduceop=1) - @unittest.skipIf(getenv("PTX"), "broken on ptx for some reason") def test_llama_embedding(self, noopt=1, op_limit=65536): # llama3 is 128256 vocab_size, embed_size = (10, 3) if CI else (32000, 4096) diff --git a/tinygrad_repo/test/test_compile_failures.py b/tinygrad_repo/test/test_compile_failures.py index 16559e90..cc25b90d 100644 --- a/tinygrad_repo/test/test_compile_failures.py +++ b/tinygrad_repo/test/test_compile_failures.py @@ -1,7 +1,7 @@ import unittest, io from contextlib import redirect_stdout from tinygrad import Tensor, dtypes, Device -from tinygrad.helpers import OSX +from tinygrad.helpers import OSX, CPU_LLVM from tinygrad.engine.realize import lower_schedule from tinygrad.device import is_dtype_supported from tinygrad.engine.realize import get_program @@ -19,7 +19,7 @@ class TestCompileFailures(unittest.TestCase): class TestDisassembly(unittest.TestCase): # TODO: fails on llvm. llvm.LLVMGetHostCPUName() returns "generic" - @unittest.skipUnless(Device.DEFAULT in ("CPU",) and OSX, "m series cpus support fp16 arithmetic") + @unittest.skipUnless(Device.DEFAULT in ("CPU",) and not CPU_LLVM and OSX, "m series cpus support fp16 arithmetic") def test_float16_alu(self): c = Tensor([1], dtype=dtypes.float16) + Tensor([1], dtype=dtypes.float16) s = c.schedule()[-1] diff --git a/tinygrad_repo/test/test_const_folding.py b/tinygrad_repo/test/test_const_folding.py index becc9904..d4a79fbb 100644 --- a/tinygrad_repo/test/test_const_folding.py +++ b/tinygrad_repo/test/test_const_folding.py @@ -1,11 +1,10 @@ import unittest, itertools, math -from typing import Any from tinygrad import Tensor, Device, dtypes -from tinygrad.dtype import DType +from tinygrad.dtype import DType, ConstType from tinygrad.uop.ops import Ops, UOp from tinygrad.codegen import full_rewrite_to_sink -import numpy as np from tinygrad.device import is_dtype_supported +import numpy as np from test.helpers import not_support_multi_device def _check_ast_count(desired_count:int, t:Tensor): @@ -25,7 +24,7 @@ class TestUnaryOpsConstFolding(unittest.TestCase): _check_ast_count(0, Tensor.ones(4).cast(dtypes.int16)) _check_ast_count(0, Tensor.full(4, fill_value=-1).cast(dtypes.uint16)) - @unittest.expectedFailure # no two level fold at lazybuffer + @unittest.expectedFailure # no two level fold def test_neg_folding(self): _check_ast_count(0, Tensor([1, 2, 3]).mul(-1).neg()) _check_ast_count(0, Tensor([1, 2, 3]).neg().mul(-1)) @@ -104,7 +103,7 @@ class TestBinaryOpsConstFolding(unittest.TestCase): class TestBitcastConstFolding(unittest.TestCase): def test_scalar_bitcast(self): - def t(cases: dict[DType, Any]): + def t(cases: dict[DType, ConstType]): for (from_dt, from_v), (to_dt, to_v) in itertools.product(cases.items(), cases.items()): if not math.isnan(from_v): r = full_rewrite_to_sink(UOp.const(from_dt, from_v).bitcast(to_dt).sink()).src[0] @@ -165,7 +164,6 @@ class TestMovedConstFolding(unittest.TestCase): _check_ast_count(1, Tensor([1.0, 2, 3, 4]) * Tensor.ones(2).pad(((1, 1),))) def test_cast_padded(self): - # NOTE: this is folded due to CAST_BEFORE_VIEW if is_dtype_supported(dtypes.int16): _check_ast_count(0, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16)) np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16).numpy(), [0, 1, 1, 1, 1, 0]) diff --git a/tinygrad_repo/test/test_dtype.py b/tinygrad_repo/test/test_dtype.py index 6a3d7d5a..6cb1aea1 100644 --- a/tinygrad_repo/test/test_dtype.py +++ b/tinygrad_repo/test/test_dtype.py @@ -4,13 +4,12 @@ import torch from typing import Any, List from tinygrad.device import is_dtype_supported from tinygrad.helpers import getenv, DEBUG, CI -from tinygrad.dtype import DType, DTYPES_DICT, least_upper_dtype, fp8_to_float, float_to_fp8 +from tinygrad.dtype import DType, DTYPES_DICT, least_upper_dtype, fp8_to_float, float_to_fp8, _to_np_dtype, _to_torch_dtype, truncate +from tinygrad.renderer.ptx import PTXRenderer from tinygrad import Device, Tensor, dtypes -from tinygrad.tensor import _to_np_dtype from hypothesis import assume, given, settings, strategies as strat from test.helpers import rand_for_dtype from test.unit.test_dtype_spec import _assert_eq, core_dtypes, dtype_ints, dtype_floats, FP8E4M3_MAX, FP8E5M2_MAX -import ml_dtypes import pytest pytestmark = pytest.mark.filterwarnings("ignore") @@ -24,6 +23,11 @@ def get_available_cast_dtypes(dtype: DType) -> List[DType]: # dont cast internal dtypes return [v for k, v in DTYPES_DICT.items() if v != dtype and is_dtype_supported(v) and not k.startswith("_")] +def _to_torch_storage_type(dtype:DType): + if dtype == dtypes.bfloat16: return torch.float32 + if dtype in dtypes.fp8s: return torch.float32 + return _to_torch_dtype(dtype) + def _test_to_np(a:Tensor, np_dtype, target): if DEBUG >= 2: print(a) na = a.numpy() @@ -44,12 +48,15 @@ def _test_cast(a:Tensor, target_dtype:DType): # TODO: struct.pack cannot pack value > 65504 (max of half) into e format a = (a > 65504).where(65504, a) - _test_op(lambda: a.cast(target_dtype), target_dtype, list(a.numpy().astype(_to_np_dtype(target_dtype)))) + expected = list(a.numpy().astype(_to_np_dtype(target_dtype))) + if target_dtype in dtypes.fp8s: expected = list(map(lambda x: truncate[target_dtype](x), expected)) + _test_op(lambda: a.cast(target_dtype), target_dtype, expected) def _test_bitcast(a:Tensor, target_dtype:DType, target=None): - if target_dtype == dtypes.bfloat16: raise unittest.SkipTest("no test for bf16 bitcast yet") - if getenv("PTX") and a.dtype == dtypes.int8 and target_dtype.itemsize != a.dtype.itemsize: + if isinstance(Device[Device.DEFAULT].renderer, PTXRenderer) and a.dtype == dtypes.int8 and target_dtype.itemsize != a.dtype.itemsize: raise unittest.SkipTest("shape changing bitcast of int8 broken on PTX") - _test_op(lambda: a.bitcast(target_dtype), target_dtype, target or a.numpy().view(_to_np_dtype(target_dtype)).tolist()) + expected = torch.tensor(a.tolist(), dtype=_to_torch_storage_type(a.dtype)).view(_to_torch_dtype(target_dtype)).tolist() + if target_dtype in dtypes.fp8s: expected = list(map(lambda x: fp8_to_float(x, target_dtype), expected)) + _test_op(lambda: a.bitcast(target_dtype), target_dtype, target or expected) class TestDType(unittest.TestCase): DTYPE: Any = None @@ -97,7 +104,7 @@ class TestDType(unittest.TestCase): )) @unittest.skipIf(Device.DEFAULT == "PYTHON", "skip for now") - @unittest.skipIf(getenv("PTX"), "skip for now") + @unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "skip for now") def test_uint_overflow(self): if not dtypes.is_unsigned(self.DTYPE): raise unittest.SkipTest("only for unsigned") v = dtypes.max(self.DTYPE) @@ -125,11 +132,10 @@ class TestDType(unittest.TestCase): np.testing.assert_allclose(tin, tor, atol=1e-6, rtol=1e-3) def test_finfo(self): - if self.DTYPE not in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]: return + if self.DTYPE not in [dtypes.float16, dtypes.float32, dtypes.float64]: return info = np.finfo(_to_np_dtype(self.DTYPE)) - assert info.bits == self.DTYPE.itemsize*8 - assert info.nexp == dtypes.finfo(self.DTYPE)[0] - assert info.nmant == dtypes.finfo(self.DTYPE)[1] + self.assertEqual(info.bits, self.DTYPE.itemsize*8) + self.assertEqual((info.nexp, info.nmant), dtypes.finfo(self.DTYPE)) def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None): target_dtype = target_dtype or least_upper_dtype(a_dtype, b_dtype) @@ -147,36 +153,40 @@ class TestFp8s(unittest.TestCase): class TestFp8sConversions(unittest.TestCase): @given(strat.floats(width=32, allow_subnormal=True, allow_nan=False, allow_infinity=False, min_value=-FP8E4M3_MAX, max_value=FP8E4M3_MAX)) - def test_float_to_fp8e4m3(self, x): np.testing.assert_equal(float_to_fp8(x, dtypes.fp8e4m3), ml_dtypes.float8_e4m3fn(x).tobytes()[0]) + def test_float_to_fp8e4m3(self, x): + np.testing.assert_equal(float_to_fp8(x, dtypes.fp8e4m3), torch.tensor(x, dtype=torch.float8_e4m3fn).view(torch.uint8).item()) def test_float_to_fp8e4m3_extreme_values(self): np.testing.assert_equal(float_to_fp8(FP8E4M3_MAX, dtypes.fp8e4m3), 126) np.testing.assert_equal(float_to_fp8(FP8E4M3_MAX*1.01, dtypes.fp8e4m3), 126) - np.testing.assert_equal(float_to_fp8(math.inf, dtypes.fp8e4m3), 126) + np.testing.assert_equal(float_to_fp8(math.inf, dtypes.fp8e4m3), 127) np.testing.assert_equal(float_to_fp8(-FP8E4M3_MAX, dtypes.fp8e4m3), 254) np.testing.assert_equal(float_to_fp8(-FP8E4M3_MAX*1.01, dtypes.fp8e4m3), 254) - np.testing.assert_equal(float_to_fp8(-math.inf, dtypes.fp8e4m3), 254) + np.testing.assert_equal(float_to_fp8(-math.inf, dtypes.fp8e4m3), 255) np.testing.assert_equal(float_to_fp8(math.nan, dtypes.fp8e4m3), 127) np.testing.assert_equal(float_to_fp8(-math.nan, dtypes.fp8e4m3), 255) @given(strat.floats(width=32, allow_subnormal=True, allow_nan=False, allow_infinity=False, min_value=-FP8E5M2_MAX, max_value=FP8E5M2_MAX)) - def test_float_to_fp8e5m2(self, x): np.testing.assert_equal(float_to_fp8(x, dtypes.fp8e5m2), ml_dtypes.float8_e5m2(x).tobytes()[0]) + def test_float_to_fp8e5m2(self, x): + np.testing.assert_equal(float_to_fp8(x, dtypes.fp8e5m2), torch.tensor(x, dtype=torch.float8_e5m2).view(torch.uint8).item()) def test_float_to_fp8e5m2_extreme_values(self): np.testing.assert_equal(float_to_fp8(FP8E5M2_MAX, dtypes.fp8e5m2), 123) np.testing.assert_equal(float_to_fp8(FP8E5M2_MAX*1.01, dtypes.fp8e5m2), 123) - np.testing.assert_equal(float_to_fp8(math.inf, dtypes.fp8e5m2), 123) + np.testing.assert_equal(float_to_fp8(math.inf, dtypes.fp8e5m2), 124) np.testing.assert_equal(float_to_fp8(-FP8E5M2_MAX, dtypes.fp8e5m2), 251) np.testing.assert_equal(float_to_fp8(-FP8E5M2_MAX*1.01, dtypes.fp8e5m2), 251) - np.testing.assert_equal(float_to_fp8(-math.inf, dtypes.fp8e5m2), 251) + np.testing.assert_equal(float_to_fp8(-math.inf, dtypes.fp8e5m2), 252) np.testing.assert_equal(float_to_fp8(math.nan, dtypes.fp8e5m2), 126) np.testing.assert_equal(float_to_fp8(-math.nan, dtypes.fp8e5m2), 254) @given(strat.integers(min_value=0, max_value=255)) - def test_fp8e4m3_to_float(self, x): np.testing.assert_equal(fp8_to_float(x, dtypes.fp8e4m3), np.uint8(x).view(ml_dtypes.float8_e4m3fn).item()) + def test_fp8e4m3_to_float(self, x): + np.testing.assert_equal(fp8_to_float(x, dtypes.fp8e4m3), torch.tensor(x, dtype=torch.uint8).view(torch.float8_e4m3fn).float().item()) @given(strat.integers(min_value=0, max_value=255)) - def test_fp8e5m2_to_float(self, x): np.testing.assert_equal(fp8_to_float(x, dtypes.fp8e5m2), np.uint8(x).view(ml_dtypes.float8_e5m2).item()) + def test_fp8e5m2_to_float(self, x): + np.testing.assert_equal(fp8_to_float(x, dtypes.fp8e5m2), torch.tensor(x, dtype=torch.uint8).view(torch.float8_e5m2).float().item()) @unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "bfloat16 not supported") class TestBFloat16(unittest.TestCase): @@ -252,7 +262,8 @@ class TestFloatDType(TestDType): class TestDoubleDType(TestDType): DTYPE = dtypes.double - @unittest.skipIf((CI and Device.DEFAULT in {"CUDA", "NV"}) or getenv("PTX"), "conversion not supported on CI CUDA and PTX") # TODO: why not? + @unittest.skipIf((CI and Device.DEFAULT in {"CUDA", "NV"}) or \ + isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "conversion not supported on CI CUDA and PTX") # TODO: why not? def test_float64_increased_precision(self): for func in [ lambda t: t.exp(), @@ -276,33 +287,34 @@ class TestDoubleDType(TestDType): class TestInt8DType(TestDType): DTYPE = dtypes.int8 - @unittest.skipIf(getenv("CUDA",0)==1 or getenv("PTX", 0)==1, "cuda saturation works differently") + @unittest.skipIf(getenv("CUDA",0)==1 or isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "cuda saturation works differently") def test_int8_to_uint8_negative(self): _test_op(lambda: Tensor([-1, -2, -3, -4], dtype=dtypes.int8).cast(dtypes.uint8), dtypes.uint8, [255, 254, 253, 252]) def test_int8_to_uint16_negative(self): _test_op(lambda: Tensor([-1, -2, -3, -4], dtype=dtypes.int8).cast(dtypes.uint16), dtypes.uint16, [2**16-1, 2**16-2, 2**16-3, 2**16-4]) - @unittest.skipIf(getenv("PTX"), "broken in ptx") + @unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "broken in ptx") def test_bitcast_alt(self): a = Tensor([72, -90, 27, 40, -53, 70, 96, 51], dtype=dtypes.int8).bitcast(dtypes.short) self.assertListEqual(a.tolist(), [-22968, 10267, 18123, 13152]) class TestUint8DType(TestDType): DTYPE = dtypes.uint8 - @unittest.skipIf(getenv("CUDA",0)==1 or getenv("PTX", 0)==1, "cuda saturation works differently") + @unittest.skipIf(getenv("CUDA",0)==1 or isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "cuda saturation works differently") def test_uint8_to_int8_overflow(self): _test_op(lambda: Tensor([255, 254, 253, 252], dtype=dtypes.uint8).cast(dtypes.int8), dtypes.int8, [-1, -2, -3, -4]) -@unittest.skipIf(Device.DEFAULT == "WEBGL", "No bitcast on WebGL") class TestBitCast(unittest.TestCase): @given(strat.sampled_from(dtype_ints + dtype_floats), strat.sampled_from(dtype_ints + dtype_floats)) def test_shape_change_bitcast(self, dt1, dt2): # NOTE: this has to be assume to prevent hypothesis from skipping all samples - assume(dt2 != dtypes.bfloat16 and dt1 != dtypes.bfloat16) # no test for bf16 bitcast yet - assume(not (getenv("PTX") and dt1 == dtypes.int8)) # TODO: bitcasting int8 fails in PTX + assume(not (isinstance(Device[Device.DEFAULT].renderer, PTXRenderer) and dt1 == dtypes.int8)) # TODO: bitcasting int8 fails in PTX data = rand_for_dtype(dt1, 32).reshape(2, 2, 8) - _test_op(lambda: Tensor(data, dtype=dt1).bitcast(dt2), dt2, data.view(_to_np_dtype(dt2)).tolist()) + expected = torch.tensor(data.tolist(), dtype=_to_torch_storage_type(dt1)).view(_to_torch_dtype(dt2)) + if dt2 in dtypes.fp8s: + expected = torch.tensor(list(map(lambda x: fp8_to_float(x, dt2), expected.view(-1).tolist()))).view_as(expected) + _test_op(lambda: Tensor(data, dtype=dt1).bitcast(dt2), dt2, expected.tolist()) def test_shape_change_bitcast_exceptions(self): with self.assertRaises(RuntimeError): @@ -342,6 +354,11 @@ class TestUint64DType(TestDType): class TestBoolDType(TestDType): DTYPE = dtypes.bool +class TestBFloat16Type(TestDType): DTYPE = dtypes.bfloat16 + +class TestFp8e4m3(TestDType): DTYPE = dtypes.fp8e4m3 +class TestFp8e5m2(TestDType): DTYPE = dtypes.fp8e5m2 + class TestPtrDType(unittest.TestCase): def test_vec_double(self): dt1 = dtypes.float.vec(4).ptr().vec(4) @@ -418,9 +435,14 @@ class TestDtypeUsage(unittest.TestCase): class TestOpsBFloat16(unittest.TestCase): def test_cast(self): # TODO: helper_test_op breaks in unrelated part - # TODO: wrong output with GPU=1 / PYTHON=1 on mac + # TODO: wrong output with CL=1 on mac data = [60000.0, 70000.0, 80000.0] np.testing.assert_allclose(Tensor(data).cast("bfloat16").numpy(), torch.tensor(data).type(torch.bfloat16).float().numpy()) + def test_no_approximation(self): + data = [326.0, 339.0, 10603200512.0] + expected = torch.tensor(data, dtype=torch.bfloat16).sqrt().float().numpy() + np.testing.assert_allclose(Tensor(data, dtype=dtypes.bfloat16).sqrt().numpy(), expected) + if __name__ == '__main__': unittest.main() diff --git a/tinygrad_repo/test/test_dtype_alu.py b/tinygrad_repo/test/test_dtype_alu.py index a255aec3..19debe43 100644 --- a/tinygrad_repo/test/test_dtype_alu.py +++ b/tinygrad_repo/test/test_dtype_alu.py @@ -1,12 +1,14 @@ import unittest, operator, math from tinygrad import Tensor, dtypes, Device -from tinygrad.dtype import DType +from tinygrad.dtype import DType, truncate from tinygrad.helpers import CI, getenv from tinygrad.tensor import _to_np_dtype from tinygrad.device import is_dtype_supported +from tinygrad.runtime.ops_python import from_storage_scalar +from tinygrad.renderer.ptx import PTXRenderer import numpy as np import pytest -from hypothesis import given, strategies as strat, settings, HealthCheck +from hypothesis import assume, given, strategies as strat, settings, HealthCheck pytestmark = pytest.mark.filterwarnings("ignore") @@ -19,29 +21,23 @@ dtypes_int = (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint dtypes_bool = (dtypes.bool,) binary_operations = [operator.add, operator.sub, operator.mul, operator.lt, operator.eq] -# TODO: LLVM comparing with nan is incorrect -if Device.DEFAULT == "LLVM" or getenv("AMD_LLVM", 0): - binary_operations.remove(operator.lt) - integer_binary_operations = binary_operations + [(Tensor.bitwise_xor, np.bitwise_xor), (Tensor.bitwise_and, np.bitwise_and), - (Tensor.bitwise_or, np.bitwise_or), operator.mod] + (Tensor.bitwise_or, np.bitwise_or), (Tensor.maximum, np.maximum), operator.mod] unary_operations = [(Tensor.exp, np.exp), (Tensor.log, np.log), (Tensor.sin, np.sin), - (Tensor.sqrt, np.sqrt), (Tensor.reciprocal, np.reciprocal)] + (Tensor.sqrt, np.sqrt), (Tensor.reciprocal, np.reciprocal), (Tensor.cos, np.cos)] # TODO: enable this (this is a dtype issue) #binary_operations.append(operator.truediv) -# TODO: (a+b)/2 in tensor.py's maximum can overflow. This requires a new implementation of maximum that can be backpropagated -#binary_operations += [(Tensor.maximum, np.maximum)] - # TODO: CI CUDA segfaults on sin, WEBGPU sin is not precise enough for large numbers -if (getenv("MOCKGPU") and Device.DEFAULT in {"NV", "CUDA"}) or Device.DEFAULT == "WEBGPU": unary_operations.remove((Tensor.sin, np.sin)) +if (getenv("MOCKGPU") and Device.DEFAULT in {"NV", "CUDA"}) or Device.DEFAULT == "WEBGPU": + unary_operations.remove((Tensor.sin, np.sin)) + unary_operations.remove((Tensor.cos, np.cos)) class ht: float64 = strat.floats(width=64, allow_subnormal=False) float32 = strat.floats(width=32, allow_subnormal=False) float16 = strat.floats(width=16, allow_subnormal=False) - bfloat16 = strat.floats(width=16, allow_subnormal=False) uint8 = strat.integers(0, 255) uint16 = strat.integers(0, 65535) uint32 = strat.integers(0, 2**32-1) @@ -51,28 +47,37 @@ class ht: int32 = strat.integers(-2147483648, 2147483647) int64 = strat.integers(-9223372036854775808, 9223372036854775807) bool = strat.booleans() +ht.bfloat16 = ht.uint16 +ht.fp8e4m3 = ht.uint8 +ht.fp8e5m2 = ht.uint8 def universal_test(a, b, dtype, op): - # The 'nan' cases only fail with Vulkan WebGPU backend (CI) - if (math.isnan(a) or math.isnan(b)) and Device.DEFAULT == "WEBGPU" and CI: return if not isinstance(op, tuple): op = (op, op) if op[0] == operator.mod and b == 0: return + # lt and max with nan is undefined in tinygrad + if op[0] in (operator.lt, Tensor.maximum) and (math.isnan(a) or math.isnan(b)): return ta, tb = Tensor([a], dtype=dtype), Tensor([b], dtype=dtype) tensor_value = (op[0](ta, tb)).numpy() numpy_value = op[1](ta.numpy(), tb.numpy()) + if dtype in dtypes.fp8s: numpy_value = truncate[dtype](numpy_value) if dtype in dtypes.floats: - atol, rtol = {dtypes.bfloat16:(1e-3, 1e-2)}.get(dtype, (1e-10, 1e-7)) + atol, rtol = {dtypes.bfloat16:(1e-3, 1e-2), dtypes.fp8e4m3:(1e-1, 1e-1), dtypes.fp8e5m2:(1.0, 5e-1)}.get(dtype, (1e-10, 1e-7)) np.testing.assert_allclose(tensor_value, numpy_value, atol=atol, rtol=rtol) else: np.testing.assert_equal(tensor_value, numpy_value) def universal_test_unary(a, dtype, op): if not isinstance(op, tuple): op = (op, op) ta = Tensor([a], dtype=dtype) + # TODO: cos does not match for large input + if op[0] == Tensor.cos and abs(a) > 30: return + if op[0] == Tensor.log and a <= 0: return out: Tensor = op[0](ta) tensor_value = out.numpy() numpy_value = op[1](ta.numpy()) + if dtype in dtypes.fp8s: numpy_value = truncate[dtype](numpy_value) if dtype in dtypes.floats: - atol, rtol = {dtypes.float16:(1e-3, 1e-2), dtypes.bfloat16:(1e-3, 1e-2)}.get(dtype, (1e-6, 1e-5)) + atol, rtol = { dtypes.float16:(1e-3, 1e-2), dtypes.bfloat16:(1e-3, 2e-2), + dtypes.fp8e4m3:(1e-1, 1e-1), dtypes.fp8e5m2: (1.0, 5e-1)}.get(dtype, (1e-6, 1e-5)) np.testing.assert_allclose(tensor_value, numpy_value, atol=atol, rtol=rtol) else: np.testing.assert_equal(tensor_value, numpy_value) @@ -85,11 +90,14 @@ def universal_test_cast(a, in_dtype, dtype): def universal_test_midcast(a, b, c, op1, op2, d1:DType, d2:DType): if not isinstance(op1, tuple): op1 = (op1, op1) if not isinstance(op2, tuple): op2 = (op2, op2) + # lt and max with nan is undefined in tinygrad + if op1[0] in (operator.lt, Tensor.maximum) and (math.isnan(a) or math.isnan(b)): return + if op2[0] in (operator.lt, Tensor.maximum) and math.isnan(c): return at, bt, ct = Tensor([a], dtype=d1), Tensor([b], dtype=d1), Tensor([c], dtype=d2) an, bn, cn = np.array([a]).astype(_to_np_dtype(d1)), np.array([b]).astype(_to_np_dtype(d1)), np.array([c]).astype(_to_np_dtype(d2)) tensor_value = op2[0](op1[0](at, bt).cast(d2), ct).numpy() numpy_value = op2[1](op1[1](an, bn).astype(_to_np_dtype(d2)), cn) - np.testing.assert_allclose(tensor_value, numpy_value, rtol=1e-6 if getenv("PTX") else 1e-7) + np.testing.assert_allclose(tensor_value, numpy_value, rtol=1e-6 if isinstance(Device[Device.DEFAULT].renderer, PTXRenderer) else 1e-7) class TestDTypeALU(unittest.TestCase): @unittest.skipUnless(is_dtype_supported(dtypes.float64), f"no float64 on {Device.DEFAULT}") @@ -105,7 +113,18 @@ class TestDTypeALU(unittest.TestCase): @unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), f"no bfloat16 on {Device.DEFAULT}") @given(ht.bfloat16, ht.bfloat16, strat.sampled_from(binary_operations)) - def test_bfloat16(self, a, b, op): universal_test(a, b, dtypes.bfloat16, op) + def test_bfloat16(self, a, b, op): + universal_test(from_storage_scalar(a, dtypes.bfloat16), from_storage_scalar(a, dtypes.bfloat16), dtypes.bfloat16, op) + + @unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3), f"no fp8e4m3 on {Device.DEFAULT}") + @given(ht.fp8e4m3, ht.fp8e4m3, strat.sampled_from(binary_operations)) + def test_fp8e4m3(self, a, b, op): + universal_test(from_storage_scalar(a, dtypes.fp8e4m3), from_storage_scalar(b, dtypes.fp8e4m3), dtypes.fp8e4m3, op) + + @unittest.skipUnless(is_dtype_supported(dtypes.fp8e5m2), f"no fp8e5m2 on {Device.DEFAULT}") + @given(ht.fp8e5m2, ht.fp8e5m2, strat.sampled_from(binary_operations)) + def test_fp8e5m2(self, a, b, op): + universal_test(from_storage_scalar(a, dtypes.fp8e5m2), from_storage_scalar(b, dtypes.fp8e5m2), dtypes.fp8e5m2, op) @given(ht.float32, strat.sampled_from(unary_operations)) def test_float32_unary(self, a, op): universal_test_unary(a, dtypes.float32, op) @@ -116,7 +135,19 @@ class TestDTypeALU(unittest.TestCase): @unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), f"no bfloat16 on {Device.DEFAULT}") @given(ht.bfloat16, strat.sampled_from(unary_operations)) - def test_bfloat16_unary(self, a, op): universal_test_unary(a, dtypes.bfloat16, op) + def test_bfloat16_unary(self, a, op): universal_test_unary(from_storage_scalar(a, dtypes.bfloat16), dtypes.bfloat16, op) + + @unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3), f"no fp8e4m3 on {Device.DEFAULT}") + @given(ht.fp8e4m3, strat.sampled_from(unary_operations)) + def test_fp8e4m3_unary(self, a, op): + if op[1] == np.reciprocal: assume(from_storage_scalar(a, dtype=dtypes.fp8e4m3) != 0.0) + universal_test_unary(from_storage_scalar(a, dtype=dtypes.fp8e4m3), dtypes.fp8e4m3, op) + + @unittest.skipUnless(is_dtype_supported(dtypes.fp8e5m2), f"no fp8e5m2 on {Device.DEFAULT}") + @given(ht.fp8e5m2, strat.sampled_from(unary_operations)) + def test_fp8e5m2_unary(self, a, op): + if op[1] == np.reciprocal: assume(from_storage_scalar(a, dtype=dtypes.fp8e5m2) != 0.0) + universal_test_unary(from_storage_scalar(a, dtype=dtypes.fp8e5m2), dtypes.fp8e5m2, op) @given(ht.uint8, ht.uint8, strat.sampled_from(integer_binary_operations)) def test_uint8(self, a, b, op): universal_test(a, b, dtypes.uint8, op) diff --git a/tinygrad_repo/test/test_edgecases.py b/tinygrad_repo/test/test_edgecases.py index a38b38f3..026ec2fb 100644 --- a/tinygrad_repo/test/test_edgecases.py +++ b/tinygrad_repo/test/test_edgecases.py @@ -26,6 +26,10 @@ import unittest import numpy as np import torch from tinygrad import Tensor, dtypes, nn +from tinygrad.device import is_dtype_supported +from tinygrad.helpers import getenv + +MOCKGPU = getenv("MOCKGPU") class TestNaNEdgeCases(unittest.TestCase): # we don't need more of these. it's unclear if torch's behavior is desired here @@ -167,34 +171,6 @@ class TestZeroFolding(unittest.TestCase): with self.assertRaises(RuntimeError): (x % x).numpy() -class TestArangeUOpValidationIssue(unittest.TestCase): - # these fail with UOp verification error. - # we don't need more of these involving arange - - @unittest.expectedFailure - def test_large_arange_sum(self): - # Summing a huge arange should either succeed or raise a MemoryError. - n = 2**31 + 3 - expected = (n - 1) * n // 2 - out = Tensor.arange(n).sum().item() - self.assertEqual(out, expected) - - @unittest.expectedFailure - def test_large_arange_index(self): - # Indexing a huge arange should return the correct value instead of failing - # with a UOp verification error. - n = 2**31 + 3 - out = Tensor.arange(n)[0].item() - self.assertEqual(out, 0) - - @unittest.expectedFailure - def test_large_arange_permute(self): - # Permuting a huge tensor should not trigger UOp verification failures. - n = 2**31 + 3 - out = Tensor.arange(n).reshape(n, 1).permute(1, 0) - self.assertEqual(out.shape, (1, n)) - out.realize() - class TestAssignIssues(unittest.TestCase): # these are good failures. i'm not sure we need more, but we need to fix these. @@ -230,10 +206,8 @@ class TestUOpValidationIssue(unittest.TestCase): # these fail with UOp verification error. # we want more of these with diverse errors! - @unittest.expectedFailure + @unittest.skipIf((not is_dtype_supported(dtypes.long)) or MOCKGPU, "hangs gpuocelot") def test_tensor_index_overflow(self): - # Advanced indexing on tensors expanded past int32 should not error, but - # tinygrad fails with a UOp verification error. val = Tensor([1]) big = val.expand(2**31 + 3) idx = Tensor([0, 2**31 + 2]) @@ -273,4 +247,4 @@ class TestEdgeCases(unittest.TestCase): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tinygrad_repo/test/test_image_dtype.py b/tinygrad_repo/test/test_image_dtype.py index 08d2c04c..6adab73e 100644 --- a/tinygrad_repo/test/test_image_dtype.py +++ b/tinygrad_repo/test/test_image_dtype.py @@ -7,7 +7,7 @@ from tinygrad.engine.realize import lower_schedule from tinygrad.helpers import prod, unwrap from test.helpers import REAL_DEV -IMAGE_SUPPORTED_DEVICES = ("QCOM", "GPU") +IMAGE_SUPPORTED_DEVICES = ("QCOM", "CL") @unittest.skipUnless(REAL_DEV in IMAGE_SUPPORTED_DEVICES, "Images not supported") class TestImageCopy(unittest.TestCase): diff --git a/tinygrad_repo/test/test_jit.py b/tinygrad_repo/test/test_jit.py index b6d0b0e2..11657e33 100644 --- a/tinygrad_repo/test/test_jit.py +++ b/tinygrad_repo/test/test_jit.py @@ -23,7 +23,7 @@ def _simple_test(add, extract=lambda x: x, N=10): class TestJit(unittest.TestCase): @settings(deadline=2e4) - @unittest.skipUnless(REAL_DEV in ["LLVM", "CPU"], f"no support on {REAL_DEV}") + @unittest.skipUnless(REAL_DEV in ["CPU"], f"no support on {REAL_DEV}") @given(strat.sampled_from([Tensor.exp2, Tensor.log2, Tensor.sin])) def test_approx_jit_timeout(self, op): with Context(TRANSCENDENTAL=2): @@ -609,21 +609,22 @@ class TestJitFree(unittest.TestCase): ext_tensor = Tensor([1,24,23,45,1]) @TinyJit def fxn(x:Tensor): - out = (x*2+ext_tensor).reshape(5,1).expand(5, 100).contiguous() - return out.sum() + t1 = (x * 2).contiguous().realize() + t2 = (t1 + ext_tensor).contiguous().realize() + out = (t2.sum()).contiguous().realize() + return out for i in range(5): - out = fxn(Tensor([i,1,2,3,4])) - self.assertEqual(out.item(), 11400+200*i) + out = fxn(inp:=Tensor([i,1,2,3,4])) + self.assertEqual(out.item(), 114+2*i) pre_free = GlobalCounters.mem_used fxn.captured.free_intermediates() savings_after_free = pre_free - GlobalCounters.mem_used - # Different allocator implementations have different savings. - expected_savings = 8196 if hasattr(Device[Device.DEFAULT].allocator, '_offset') else 2024 + expected_savings = (len(inp) * inp.dtype.itemsize * 2) + dtypes.float32.itemsize # (t1 and t2) + out self.assertEqual(savings_after_free, expected_savings) out = fxn(Tensor([11,1,2,3,4])) - self.assertEqual(out.item(), 13600) + self.assertEqual(out.item(), 136) # Try one more time... pre_free = GlobalCounters.mem_used @@ -633,7 +634,7 @@ class TestJitFree(unittest.TestCase): self.assertEqual(savings_after_free, expected_savings) out = fxn(Tensor([11,1,2,3,4])) - self.assertEqual(out.item(), 13600) + self.assertEqual(out.item(), 136) def test_updated_not_freed(self): x = Tensor([1]).realize() @@ -791,7 +792,7 @@ class TestJitGraphSplit(unittest.TestCase): hcqgraph=[self.ji_graph(5)]) def test_jit_multidev_xfer(self): - if Device.DEFAULT in {"CPU", "LLVM"}: raise unittest.SkipTest("CPU/LLVM is not a valid default device for this test (zero-copies)") + if Device.DEFAULT in {"CPU"}: raise unittest.SkipTest("CPU is not a valid default device for this test (zero-copies)") if Device.DEFAULT == "METAL" or REAL_DEV == "METAL": raise unittest.SkipTest("Metal is flaky, with multidevice (same as metal llama 4gpu?)") try: Device[f"{Device.DEFAULT}:1"] @@ -816,7 +817,7 @@ class TestJitGraphSplit(unittest.TestCase): @unittest.skipIf(getenv("MOCKGPU"), "MockGPU does not support parallel copies") def test_jit_multidev_copy(self): - if Device.DEFAULT in {"CPU", "LLVM"}: raise unittest.SkipTest("CPU/LLVM is not a valid default device for this test (zero-copies)") + if Device.DEFAULT in {"CPU"}: raise unittest.SkipTest("CPU/LLVM is not a valid default device for this test (zero-copies)") @TinyJit def f(inp): diff --git a/tinygrad_repo/test/test_kernel_cache.py b/tinygrad_repo/test/test_kernel_cache.py index 164b501a..a4f0f219 100644 --- a/tinygrad_repo/test/test_kernel_cache.py +++ b/tinygrad_repo/test/test_kernel_cache.py @@ -16,14 +16,14 @@ class TestKernelCache(unittest.TestCase): a1 = Tensor.rand(4,4).realize() b1 = Tensor.rand(4,4).realize() - orig_compile_func = Device['CPU'].compiler - Device['CPU'].compiler = None # making it not callable + orig_compile_func = Device['CPU'].compiler.compile_cached + Device['CPU'].compiler.compile_cached = None # making it not callable try: x1 = a1 + b1 + unique_const x1.realize() # Same kernel should be from cache. finally: - Device['CPU'].compiler = orig_compile_func + Device['CPU'].compiler.compile_cached = orig_compile_func if __name__ == "__main__": unittest.main() diff --git a/tinygrad_repo/test/test_linearizer.py b/tinygrad_repo/test/test_linearizer.py index bdcd2c6c..1726e9aa 100644 --- a/tinygrad_repo/test/test_linearizer.py +++ b/tinygrad_repo/test/test_linearizer.py @@ -2,69 +2,18 @@ import numpy as np import unittest from dataclasses import replace -from tinygrad.codegen.opt.kernel import Opt, OptOps, KernelOptError, Kernel, AxisType +from tinygrad.codegen.opt import Opt, OptOps from tinygrad.codegen.gpudims import get_grouped_dims -from tinygrad.uop.ops import UOp, Ops, GroupOp, KernelInfo +from tinygrad.uop.ops import UOp, Ops, GroupOp from tinygrad.device import Device, Buffer, is_dtype_supported from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View from tinygrad.tensor import Tensor, _to_np_dtype from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner, get_program -from tinygrad.codegen.opt.heuristic import hand_coded_optimizations -from tinygrad.helpers import prod, Context, getenv, CI, flatten, dedup, AMX, AMD_LLVM -from tinygrad.dtype import DType, dtypes, AddrSpace +from tinygrad.helpers import Context, flatten, dedup, TC_SELECT, TC_OPT +from tinygrad.dtype import DType, dtypes, PtrDType, AddrSpace from tinygrad.codegen import apply_rewrites, rewrites_for_views - -def push_views(ast): return apply_rewrites(ast, rewrites_for_views) - -def helper_realized_ast(r:Tensor|list[Tensor]) -> tuple[UOp, list[Buffer]]: - if isinstance(r, Tensor): r = [r] - s = Tensor.schedule(*r) - run_schedule(s[:-1]) # run all kernels except the last one - assert s[-1].ast.op is Ops.SINK, f"helper_realized_ast expects a SINK {s[-1]}" - # now all input buffers in s[-1] should be realized - # create fresh buffers for the outputs - bufs = [Buffer((x).device, x.size, x.dtype).allocate() if i < len(s[-1].ast.src) else x for i,x in enumerate(s[-1].bufs)] - return push_views(s[-1].ast), bufs - -def helper_tc_allclose(N:int, M:int, K:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_select:int=-1, tc_opt:int=0, use_tensor_cores:int=1): - a, b = Tensor.rand(M, K, dtype=dtype_in), Tensor.rand(K, N, dtype=dtype_in) - np_a, np_b = a.numpy(), b.numpy() - r = a.matmul(b, dtype=dtype_out) - if dtype_in == dtypes.bfloat16: r = r.float() - realized_ast, bufs = helper_realized_ast(r) - k = Kernel(realized_ast) - k.apply_tensor_cores(use_tensor_cores, axis=axis, tc_select=tc_select, tc_opt=tc_opt) - prg = CompiledRunner(replace(get_program(k.get_optimized_ast(), k.opts), device=Device.DEFAULT)) - if use_tensor_cores == 1: assert len([uop for uop in prg.p.uops if uop.op is Ops.WMMA]) > 0, "wmma not triggered" - assert len([x for x in k.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included" - prg.exec(bufs) - if dtype_in == dtypes.half: tc_atol, tc_rtol = 1e-2, 1e-3 - elif dtype_in == dtypes.bfloat16: tc_atol, tc_rtol = 1e-2, 1e-2 - else: tc_atol, tc_rtol = 5e-3, 1e-4 - c = bufs[0].numpy().reshape((M,N)) - np.testing.assert_allclose(c, np_a @ np_b, atol=tc_atol, rtol=tc_rtol) - -def helper_tc_ensure_uops_and_opts_count(N: int, M:int, K:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_select:int=-1, tc_opt:int=0, - ensure_triggered:bool=True): - a, b = Tensor.rand(M, K, dtype=dtype_in), Tensor.rand(K, N, dtype=dtype_in) - r = a.matmul(b, dtype=dtype_out) - sched = r.schedule() - realized_ast = sched[-1].ast - opts_to_apply = [Opt(OptOps.TC, axis, (tc_select, tc_opt, 1))] - realized_ast = realized_ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts_to_apply))) - - if ensure_triggered: - program = get_program(realized_ast, Device[Device.DEFAULT].renderer) - wmmas = len([uop for uop in program.uops if uop.op is Ops.WMMA]) - tcs = len([x for x in program.applied_opts if x.op is OptOps.TC]) - assert wmmas > 0, "tensor core not triggered" - assert tcs == 1, "tensor core opt not included" - else: - try: - program = get_program(realized_ast, Device[Device.DEFAULT].renderer) - assert False, "OptOps.TC triggered, expected KernelOptError" - except KernelOptError: pass +from tinygrad.renderer.ptx import PTXRenderer class TestLinearizer(unittest.TestCase): def test_arg_dedup(self): @@ -101,9 +50,8 @@ class TestLinearizer(unittest.TestCase): a_t = Tensor.full(st.shape, 2).contiguous().realize() b_t = Tensor.full(st.shape, 3).contiguous().realize() - lin = helper_linearizer_ast(sink, [a_t, b_t], wanna_output=[a_t.numpy()+b_t.numpy(), a_t.numpy()*b_t.numpy()])[0] - uops = get_program(lin.get_optimized_ast(), lin.opts).uops - + helper_linearizer_ast(sink, [a_t, b_t], wanna_output=[a_t.numpy()+b_t.numpy(), a_t.numpy()*b_t.numpy()]) + uops = get_program(sink, opts=[]).uops stores = [u for u in uops if u.op is Ops.STORE] mutable_bufs = dedup(flatten([[x for x in u.src[0].toposort() if x.op is Ops.DEFINE_GLOBAL] for u in stores])) assert len(mutable_bufs) == len(stores) == 2 @@ -117,45 +65,29 @@ class TestLinearizer(unittest.TestCase): if skip and i in skip: continue assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}" - @unittest.skip("broken. should not depends on push_views and implementation details of getitem") - @unittest.skipIf(CI and Device.DEFAULT in {"PTX", "AMD", "NV"}, "very slow") - def test_indexing_multireduce(self): - dataset = Tensor.rand(16384, 256).realize() - idxs = Tensor([0,3,5,6]).realize() - with Context(FUSE_ARANGE=1): - sink = dataset[idxs].contiguous().kernelize().uop.base.src[1].arg.ast - real_index = dataset.numpy()[idxs.numpy()].reshape(4, 256, 1, 1) - helper_linearizer_ast(push_views(sink), [dataset, idxs], wanna_output=[real_index]) - def test_two_nested_range(self): a = Tensor.randn(2, ).realize() out = a.reshape(2, 1).expand(2, 3).sum() - lin = helper_linearizer_opt(out, wanna_output=[np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)).sum()])[0] - uops = get_program(lin.get_optimized_ast(), lin.opts).uops + ast = helper_linearizer_opt(out, wanna_output=[np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)).sum()]) + uops = get_program(ast, opts=[]).uops ranges = [i for i,u in enumerate(uops) if u.op is Ops.RANGE] assert len(ranges) == 1 # NOTE: it collapses now - # RANGE -> LOAD -> RANGE -> ASSIGN - #assert any(x.op is Ops.LOAD for x in uops[ranges[0]:ranges[1]]) def test_three_nested_range(self): a = Tensor.randn(2, ).realize() out = a.reshape(2, 1).expand(2, 3).expand(2, 2, 3).sum() - lin = helper_linearizer_opt(out, wanna_output=[np.broadcast_to(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)), (2, 2, 3)).sum()])[0] - uops = get_program(lin.get_optimized_ast(), lin.opts).uops + ast = helper_linearizer_opt(out, wanna_output=[np.broadcast_to(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)), (2, 2, 3)).sum()]) + uops = get_program(ast, opts=[]).uops ranges = [i for i,u in enumerate(uops) if u.op is Ops.RANGE] assert len(ranges) == 1 # NOTE: it collapses now - # RANGE -> RANGE -> LOAD -> RANGE -> ASSIGN - # NOTE: nothing should toposort between the first two ranges - #assert ranges[0]+1 == ranges[1] - #assert any(x.op is Ops.LOAD for x in uops[ranges[1]:ranges[2]]) def test_two_nested_range_alt_indexing(self): a = Tensor([2, 2]).realize() out = a.reshape(2, 1).pad(((1, 1), (1, 1)), value=2).sum() - lin = helper_linearizer_opt(out, wanna_output=[24])[0] - uops = get_program(lin.get_optimized_ast(), lin.opts).uops + ast = helper_linearizer_opt(out, wanna_output=[24]) + uops = get_program(ast, opts=[]).uops ranges = [i for i,u in enumerate(uops) if u.op is Ops.RANGE] - # RANGE -> ALU -> RANGE -> ALU + LOAD -> ASSIGN + # RANGE -> ALU -> RANGE -> ALU + LOAD -> STORE assert any(x.op in GroupOp.ALU for x in uops[ranges[0]:ranges[1]]) assert not any(x.op is Ops.LOAD for x in uops[ranges[0]:ranges[1]]) assert any(x.op in {*GroupOp.ALU, Ops.LOAD} for x in uops[ranges[1]:]) @@ -164,52 +96,20 @@ class TestLinearizer(unittest.TestCase): a = Tensor.randn(4, 1).realize() b = Tensor.randn(1, 1).realize() out = (a + b[0]).sum() + b[0] - lin = helper_linearizer_opt(out, wanna_output=[(a.numpy()+b.numpy()[0]).sum()+b.numpy()])[0] - uops = get_program(lin.get_optimized_ast(), lin.opts).uops + ast = helper_linearizer_opt(out, wanna_output=[(a.numpy()+b.numpy()[0]).sum()+b.numpy()]) + uops = get_program(ast, opts=[]).uops ranges = [i for i,u in enumerate(uops) if u.op is Ops.RANGE] - # LOAD -> RANGE -> LOAD -> ASSIGN + # LOAD -> RANGE -> LOAD -> STORE assert len([x for x in uops[:ranges[0]] if x.op is Ops.LOAD]) == 1 def test_range_outer_op_before_phi_nested_range(self): a = Tensor.randn(2, ).realize() b = Tensor.randn(1, 1).realize() out = (a.reshape(2, 1).expand(2, 3) + b[0]).sum() + b[0] - lin = helper_linearizer_opt(out, wanna_output=[(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)) + b.numpy()[0]).sum() + b.numpy()])[0] - uops = get_program(lin.get_optimized_ast(), lin.opts).uops + ast = helper_linearizer_opt(out, wanna_output=[(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)) + b.numpy()[0]).sum() + b.numpy()]) + uops = get_program(ast, opts=[]).uops ranges = [i for i,u in enumerate(uops) if u.op is Ops.RANGE] assert len(ranges) == 1 # NOTE: it collapses now - #if getenv("PTX"): - # LOAD -> RANGE -> CAST -> ALU -> ALU -> LOAD -> ALU -> RANGE -> ALU -> ASSIGN - # assert uops[ranges[0]-2].op is Ops.LOAD - # assert ranges[1] == ranges[0]+6 - # assert [x.op for x in uops[ranges[1]-2:ranges[1]]] == [Ops.LOAD, Ops.ALU] - # LOAD -> RANGE -> LOAD -> ALU -> RANGE -> ASSIGN - #else: - # assert uops[ranges[0]-2].op is Ops.LOAD - # assert ranges[1] == ranges[0]+3 - # assert [x.op for x in uops[ranges[1]-2:ranges[1]]] == [Ops.LOAD, Ops.ALU] - - @unittest.skip("fragile crap") - def test_range_outer_op_after_phi(self): - a = Tensor.randn(4, 1).realize() - out = a.sum() * a.sum() - lin = helper_linearizer_opt(out, wanna_output=[a.numpy().sum()*a.numpy().sum()])[0] - uops = get_program(lin.get_optimized_ast(), lin.opts).uops - # RANGE -> LOAD -> ASSIGN -> ALU - end = max(i for i,u in enumerate(uops) if u.op is Ops.ENDRANGE) - # the INDEX can be first - assert uops[end+1].op in GroupOp.ALU or uops[end+2].op in GroupOp.ALU - - @unittest.skip("fragile crap") - def test_range_outer_op_after_phi_nested_range(self): - a = Tensor.randn(2, ).realize() - out = a.reshape(2, 1).expand(2, 3).sum() + a.reshape(2, 1).expand(2, 3).sum() - lin = helper_linearizer_opt(out, wanna_output=[(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3))).sum()*2])[0] - uops = get_program(lin.get_optimized_ast(), lin.opts).uops - # RANGE -> LOAD -> ASSIGN -> ALU - end = max(i for i,u in enumerate(uops) if u.op is Ops.ENDRANGE) - # the INDEX can be first - assert uops[end+1].op in GroupOp.ALU or uops[end+2].op in GroupOp.ALU def test_load_dedup(self): # for different leaves in the AST, the same loads may occur. @@ -256,7 +156,7 @@ class TestLinearizer(unittest.TestCase): @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") - @unittest.skipIf(getenv("PTX"), "broken on ptx for some reason") + @unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "broken on ptx for some reason") def test_upcast_with_locals(self): x, y = Tensor.rand(1,128), Tensor.rand(128, 128) r = (x@y).relu() @@ -313,166 +213,34 @@ class TestLinearizer(unittest.TestCase): d, w = Tensor.rand(4, 8, 8, 8, dtype=tensor_dtype), Tensor.rand(8, 8, 2, 2, dtype=tensor_dtype) helper_arg_acc_dtype(d.conv2d(w, dtype=acc_dtype), expected_dtype) - # TODO: don't skip bf16 for real device (METAL, AMD) - @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") - def test_tensor_cores(self): - for tc in Device[Device.DEFAULT].renderer.tensor_cores: - if not is_dtype_supported(tc.dtype_in) or not is_dtype_supported(tc.dtype_out): continue - # for AMX, tc.dims[2] == 1 so reduceop is None thus tensor_cores are not triggered - helper_tc_allclose(tc.dims[0], tc.dims[1], 2 if AMX else tc.dims[2], tc.dtype_in, tc.dtype_out, axis=0, tc_opt=0) - - @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") - def test_tensor_cores_codegen(self): - for tc in Device[Device.DEFAULT].renderer.tensor_cores: - if not is_dtype_supported(tc.dtype_in) or not is_dtype_supported(tc.dtype_out): continue - n, m, k = tc.dims[0], tc.dims[1], 2 if AMX else tc.dims[2] - a, b = Tensor.rand(m, k, dtype=tc.dtype_in), Tensor.rand(k, n, dtype=tc.dtype_in) - r = a.matmul(b, dtype=tc.dtype_out) - sched = r.schedule() - realized_ast = push_views(sched[-1].ast) - kernel = Kernel(realized_ast) - kernel.apply_tensor_cores(1, axis=0, tc_select=-1, tc_opt=2) - prg = get_program(kernel.get_optimized_ast(), kernel.opts) - if Device.DEFAULT == "LLVM": - assert "0x201000" in prg.src - elif Device.DEFAULT == "AMD" and AMD_LLVM: - assert "@llvm.amdgcn.wmma" in prg.src - elif Device[Device.DEFAULT].renderer.suffix == "PTX": - assert "mma.sync.aligned" in prg.src - else: - assert "__WMMA_" in prg.src - - @unittest.skipIf((Device.DEFAULT == "AMD") or (Device.DEFAULT == "PYTHON" and getenv("EMULATE_AMD")), "broken for AMD") - @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") - def test_tensor_cores_padded(self): - for tc in Device[Device.DEFAULT].renderer.tensor_cores: - if not is_dtype_supported(tc.dtype_in) or not is_dtype_supported(tc.dtype_out): continue - helper_tc_allclose(tc.dims[0]+(pad:=1), tc.dims[1]+pad, tc.dims[2]+pad, tc.dtype_in, tc.dtype_out, tc_opt=2) - - # AMD compiler bug: AMD miscompiles non-zero padded tc kernels with -O3, producing wrong results, nans or hang (see #9606) - # Internal bug: zero-stride dimensions combined with a mask may produce wrong index/valid for pad == 1 on AMD - @unittest.skipUnless((Device.DEFAULT == "AMD") or (Device.DEFAULT == "PYTHON" and getenv("EMULATE_AMD")), "test for AMD's tc") - @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") - @unittest.expectedFailure - def test_tensor_cores_padded_amd(self): - for tc in Device[Device.DEFAULT].renderer.tensor_cores: - if not is_dtype_supported(tc.dtype_in) or not is_dtype_supported(tc.dtype_out): continue - helper_tc_allclose(tc.dims[0]+(pad:=1), tc.dims[1]+pad, tc.dims[2]+pad, tc.dtype_in, tc.dtype_out, tc_opt=2) - - @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") - def test_tensor_cores_padded_uops(self): - for tc in Device[Device.DEFAULT].renderer.tensor_cores: - pad = 1 - - # check that TC is triggered for TC_OPT=2 - helper_tc_ensure_uops_and_opts_count(tc.dims[0]+pad, tc.dims[1]+pad, tc.dims[2]+pad, - tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=True) - - # check that TC is not triggered for TC_OPT<2 - helper_tc_ensure_uops_and_opts_count(tc.dims[0]+pad, tc.dims[1]+pad, tc.dims[2]+pad, - tc.dtype_in, tc.dtype_out, tc_opt=1, ensure_triggered=False) - helper_tc_ensure_uops_and_opts_count(tc.dims[0]+pad, tc.dims[1]+pad, tc.dims[2]+pad, - tc.dtype_in, tc.dtype_out, tc_opt=0, ensure_triggered=False) - - # check excessive padding doesn't trigger padded TC in TC_OPT=2 - helper_tc_ensure_uops_and_opts_count(tc.dims[0]//4, tc.dims[1], tc.dims[2], tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=False) - helper_tc_ensure_uops_and_opts_count(tc.dims[0], tc.dims[1]//4, tc.dims[2], tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=False) - if not AMX: # AMX tc.dims[2] == 1 - helper_tc_ensure_uops_and_opts_count(tc.dims[0], tc.dims[1], tc.dims[2]//4, tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=False) - - @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI is really slow here") - @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") - def test_tensor_cores_multi_reduce(self): - for tc in Device[Device.DEFAULT].renderer.tensor_cores: - if not is_dtype_supported(tc.dtype_in) or not is_dtype_supported(tc.dtype_out): continue - # this will be a M=G16, N=G32, M=G16, M=G16, K=R16, K=R16, K=R16 with 9 choices of TC MNK axes - golden_result = None - for axis in range(9): - a = Tensor.rand(16, 16, 29, 29, dtype=tc.dtype_in).realize() - b = Tensor.rand(32, 16, 16, 16, dtype=tc.dtype_in).realize() - c = a.conv2d(b, padding=1, dtype=tc.dtype_out) - realized_ast, real_bufs = helper_realized_ast(c) - - opts_to_apply = [Opt(OptOps.TC, axis, (-1, 2, 1))] - realized_ast = realized_ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts_to_apply))) - program = get_program(realized_ast, Device[Device.DEFAULT].renderer) - assert len([uop for uop in program.uops if uop.op is Ops.WMMA]) > 0, "tensor core not triggered" - assert len([x for x in program.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included" - - prg = CompiledRunner(program) - # TODO: support this even if numpy doesn't - if _to_np_dtype(real_bufs[0].dtype) is None: continue - real_bufs[0].copyin(np.zeros((real_bufs[0].size, ), dtype=_to_np_dtype(real_bufs[0].dtype)).data) # Zero to check that all values are filled - prg.exec(real_bufs) - result = np.frombuffer(real_bufs[0].as_buffer(), _to_np_dtype(real_bufs[0].dtype)) - - # ensure the results for each choice of axis matches - if golden_result is None: golden_result = np.frombuffer(real_bufs[0].as_buffer(), _to_np_dtype(real_bufs[0].dtype)) - np.testing.assert_allclose(result, golden_result, atol=0.1, rtol=0.2) - - # check that get_kernel_actions produces all 9 options - from tinygrad.codegen.opt.search import get_kernel_actions - tc_actions = [k for i, k in get_kernel_actions(Kernel(realized_ast), False).items() if k.applied_opts[0].op == OptOps.TC] - - available_tc = len([x for x in Device[Device.DEFAULT].renderer.tensor_cores if x.dtype_in == tc.dtype_in and x.dtype_out == tc.dtype_out]) - assert len(tc_actions) == 9 * available_tc, f"should contain 9 possible TC actions for every available TC, got {len(tc_actions)}" - - @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") - def test_tensor_cores_unroll_phi(self): - tc = Device[Device.DEFAULT].renderer.tensor_cores[0] - x, y = Tensor.rand(128, 128, dtype=tc.dtype_in), Tensor.rand(128, 128, dtype=tc.dtype_in) - r = x.matmul(y, dtype=tc.dtype_out) - k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4)]], apply_tc=True, atol=3e-2, rtol=1e-3)[-1] - for u in get_program(k.get_optimized_ast(), k.opts).uops: - if u.op is Ops.WMMA: - assert u.src[-1].src[0].op != Ops.ASSIGN - - @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") - @unittest.skipIf(Device.DEFAULT in {"CPU", "LLVM"}, "CPU does not support using a different type for accumulation") - def test_tensor_cores_unroll_casted_phi(self): - tc = [tc for tc in Device[Device.DEFAULT].renderer.tensor_cores if tc.dtype_in != tc.dtype_out][0] - x, y = Tensor.rand(128, 128, dtype=tc.dtype_in), Tensor.rand(128, 128, dtype=tc.dtype_in) - r = x.matmul(y, dtype=tc.dtype_out) - k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4)]], apply_tc=True, atol=3e-2, rtol=1e-3)[-1] - for u in get_program(k.get_optimized_ast(), k.opts).uops: - if u.op is Ops.WMMA: - #assert u.src[-1].dtype == dtypes.float.vec(prod(tc.thread_local_sizes[2])) - assert u.src[-1].src[0].op != Ops.ASSIGN - - @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") - @unittest.skipIf(Device.DEFAULT in {"CPU", "LLVM"}, "CPU does not support using a different type for accumulation") - def test_tensor_cores_unroll_casted_phi_with_children(self): - # all ASSIGN children are outside the loop - tc = [tc for tc in Device[Device.DEFAULT].renderer.tensor_cores if tc.dtype_in != tc.dtype_out][0] - x, y = Tensor.rand(128, 128, dtype=tc.dtype_in), Tensor.rand(128, 128, dtype=tc.dtype_in) - r = x.matmul(y, dtype=tc.dtype_out).relu() - k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4)]], apply_tc=True, atol=3e-2, rtol=1e-3)[-1] - for u in get_program(k.get_optimized_ast(), k.opts).uops: - if u.op is Ops.WMMA: - #assert u.src[-1].dtype == dtypes.float.vec(prod(tc.thread_local_sizes[2])) - assert u.src[-1].src[0].op != Ops.ASSIGN - @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") def test_simple_unroll_no_between_phi_dependencies(self): x, y = Tensor.rand(128, 128), Tensor.rand(128, 128) r = (x@y).relu() - k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4)]])[-1] - # the uops graph is RANGE -> DEFINE_ACC -> 4x ALU -> 4x ASSIGN -> ENDRANGE - uops = get_program(k.get_optimized_ast(), k.opts).uops + opt = [Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4)] + ast = helper_linearizer_opt(r, [opt]) + # the uops graph is DEFINE_REG -> 4x STORE 0.0 -> RANGE -> 4x ALU -> 4x STORE -> ENDRANGE + uops = get_program(ast, opts=opt).uops + begin_range = [i for i, x in enumerate(uops) if x.op is Ops.RANGE][-1] + end_range = [i for i, x in enumerate(uops) if x.op is Ops.ENDRANGE][0] + for i,u in enumerate(uops): print(i, u.op, [uops.index(s) for s in u.src], u.arg, u.dtype) for u in uops: - if u.op is Ops.ASSIGN: - assert u.src[1].op in GroupOp.ALU - # children of ASSIGN are placed after ENDRANGE - if any(x.op is Ops.ASSIGN for x in u.src): - end_range = [i for i, x in enumerate(uops) if x.op is Ops.ENDRANGE][0] + if u.op is Ops.STORE and isinstance(dt:=u.src[0].dtype, PtrDType) and dt.addrspace is AddrSpace.REG: + if uops.index(u) < begin_range: + assert u.src[1].op is Ops.CONST + else: + assert u.src[1].op in GroupOp.ALU + assert begin_range < uops.index(u) < end_range + # children of STORE are placed after ENDRANGE + if any(x.op is Ops.STORE and x.src[1].op in GroupOp.ALU for x in u.src): assert end_range < uops.index(u) def test_grouped_dims(self): def _assert_grouped_dims(prefix, dims, max_sizes, reverse_dims, expected_sizes, assert_same_length = True): idxs = get_grouped_dims(prefix, dims, max_sizes, reverse_dims) loop_idxs = dedup(flatten([[y for y in x.toposort() if y.op is Ops.SPECIAL] for x in idxs])) - loop_idxs = sorted(loop_idxs, key=lambda uop: uop.arg[0]) - sizes = [x.arg[1] for x in loop_idxs] + loop_idxs = sorted(loop_idxs, key=lambda uop: uop.arg) + sizes = [x.src[0].arg for x in loop_idxs] assert len(idxs) == len(dims), f"expected idxs to have same length as dims {len(dims)}, got {len(idxs)}" if assert_same_length: assert len(loop_idxs) == min(len(sizes), len(dims)), f"expected idxs to have length {min(len(sizes), len(dims))}, got {len(loop_idxs)}" @@ -542,13 +310,13 @@ class TestLinearizer(unittest.TestCase): def test_default_global_reversed(self): # shrink so that the dims do not collapse t = Tensor.ones(5, 6, 7).contiguous().realize().shrink(((0, 4), (0, 5), (0, 6))) - k = helper_linearizer_opt(t+1)[0] - uops = get_program(k.get_optimized_ast(), k.opts).uops + ast = helper_linearizer_opt(t+1) + uops = get_program(ast, opts=[]).uops idxs = dedup([uop for uop in uops if uop.op is Ops.SPECIAL]) - idxs = sorted(idxs, key=lambda uop: uop.arg[0]) - assert idxs[0].arg == ('gidx0', 6), idxs[0].arg - assert idxs[1].arg == ('gidx1', 5), idxs[1].arg - assert idxs[2].arg == ('gidx2', 4), idxs[2].arg + idxs = sorted(idxs, key=lambda uop: uop.arg) + assert (idxs[0].arg, idxs[0].src[0].arg) == ('gidx0', 6), idxs[0] + assert (idxs[1].arg, idxs[1].src[0].arg) == ('gidx1', 5), idxs[1].arg + assert (idxs[2].arg, idxs[2].src[0].arg) == ('gidx2', 4), idxs[2].arg def test_sum_collapse(self): t = Tensor([2]).reshape(1, 1).expand(256, 256).sum() @@ -575,20 +343,19 @@ class TestLinearizer(unittest.TestCase): sched_copy = sched[:] run_schedule(sched) np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.]) - realized_ast = sched_copy[-1].ast - realized_ast = realized_ast.replace(arg=KernelInfo(opts_to_apply=tuple())) - program = get_program(realized_ast, Device[Device.DEFAULT].renderer) + program = get_program(sched_copy[-1].ast, opts=()) assert not any(u.op == Ops.WHERE for u in program.uops), "found where where where should be folded" def test_phi_simplification(self): def helper(t, max_ops=0): - k = helper_linearizer_opt(t)[-1] - uops = get_program(k.get_optimized_ast(), k.opts).uops + ast = helper_linearizer_opt(t) + uops = get_program(ast).uops # ignore kernel optimized IF statements for now if if_op:=next((u for u in uops if u.op is Ops.IF), None): uops = uops[:uops.index(if_op)] assert len(set([u.op for u in uops if u.op in {Ops.RANGE, Ops.SPECIAL}])) == 1, "has either specials or ranges, not both" - assert len([u for u in uops if u.op is Ops.ASSIGN]) == 0, "ASSIGN should have been simplified" + reg_stores = [u for u in uops if u.op is Ops.STORE and isinstance(dt:=u.src[0].dtype, PtrDType) and dt.addrspace == AddrSpace.REG] + assert len(reg_stores) == 0, "STORE to reg should have been simplified" # TODO: once uops track min/max this will be fixed #assert len([u for u in uops if u.op is Ops.MAX]) <= max_ops, "no unnecessary MAX ops" @@ -600,7 +367,7 @@ class TestLinearizer(unittest.TestCase): helper(Tensor.arange(255), max_ops=2) @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") - @unittest.skipIf(getenv("PTX"), "broken on ptx for some reason") + @unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "broken on ptx for some reason") def test_grouped_store_phis(self): """ float4 acc0 = float4(0.0,0.0,0.0,0.0); @@ -613,33 +380,20 @@ class TestLinearizer(unittest.TestCase): """ x, y = Tensor.randn(64,64), Tensor.randn(64,64) out = x.matmul(y) - k = helper_linearizer_opt(out)[-1] - uops = get_program(k.get_optimized_ast(), k.opts).uops + with Context(TC=0): + ast = helper_linearizer_opt(out) + uops = get_program(ast).uops # check that the float4 cast collapses store_vals = [u.src[1] for u in uops if u.op is Ops.STORE and u.src[0].dtype.addrspace != AddrSpace.REG] for val in store_vals: assert val.dtype == dtypes.float.vec(4) # and val.op is not Ops.VECTORIZE - @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") - @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") - @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") - def test_arange_opts(self): - a = Tensor.arange(128) - helper_linearizer_opt(a, [ - [Opt(OptOps.GROUP, 0, 32)], - [Opt(OptOps.GROUPTOP, 0, 32)], - [Opt(op=OptOps.LOCAL, axis=0, arg=8)], - [Opt(op=OptOps.LOCAL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=0, arg=0)], - [Opt(op=OptOps.LOCAL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.GROUP, axis=0, arg=8)], - [Opt(op=OptOps.LOCAL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.GROUP, axis=0, arg=8), Opt(op=OptOps.UNROLL, axis=1, arg=4)], # noqa: E501 - ]) - @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") def test_grouped_store_values(self): x = Tensor.randn((4,3,6,6)).realize() out = x.flip((0,1)).contiguous() - k = helper_linearizer_opt(out)[-1] - store_val = [u.src[1] for u in get_program(k.get_optimized_ast(), k.opts).uops if u.op is Ops.STORE][0] + ast = helper_linearizer_opt(out) + store_val = [u.src[1] for u in get_program(ast).uops if u.op is Ops.STORE][0] assert store_val.dtype == dtypes.float.vec(4) and store_val.op is not Ops.VECTORIZE @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @@ -650,9 +404,9 @@ class TestLinearizer(unittest.TestCase): out = x@y opt = [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 2)] # upcast accs in both reduces - k = helper_linearizer_opt(out, opts=[opt])[-1] + ast = helper_linearizer_opt(out, opts=[opt]) def get_recursive(uop): return set.union(set(uop.src), [uop], *[get_recursive(v) for v in uop.src]) - uops = get_program(k.get_optimized_ast(), k.opts).uops + uops = get_program(ast, opts=opt).uops local_stores = [u for u in uops if u.op is Ops.STORE and any(x.op is Ops.DEFINE_LOCAL for x in get_recursive(u.src[0]))] global_stores = [u for u in uops if u.op is Ops.STORE and any(x.op is Ops.DEFINE_GLOBAL for x in get_recursive(u.src[0]))] barrier = [u for u in uops if u.op is Ops.BARRIER][0] @@ -667,12 +421,12 @@ class TestLinearizer(unittest.TestCase): @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") - @unittest.skipIf(getenv("PTX"), "broken on ptx for some reason") + @unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "broken on ptx for some reason") def test_grouped_store_local_only(self): x, y = Tensor.rand(1,128), Tensor.rand(128, 128) r = (x@y).relu() - k = helper_linearizer_opt(r)[-1] - uops = get_program(k.get_optimized_ast(), k.opts).uops + ast = helper_linearizer_opt(r) + uops = get_program(ast).uops stores = [u for u in uops if u.op is Ops.STORE and u.src[0].dtype.addrspace != AddrSpace.REG] # the float4 value stores directly in lds and we skip upcast @@ -686,739 +440,96 @@ class TestLinearizer(unittest.TestCase): @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") def test_skip_unmatching_upcasts(self): Tensor.manual_seed(0) - ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( - UOp(Ops.STORE, dtypes.void, arg=None, src=( - UOp(Ops.VIEW, dtypes.float.ptr(9600), arg=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(40, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=( # noqa: E501 - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9600), arg=0, src=()),)), - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.VIEW, dtypes.float.ptr(9600), arg=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(1, 240, 0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501 - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9600), arg=1, src=()),)),)),)),)) - opt = [ - Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=16), - Opt(op=OptOps.LOCAL, axis=1, arg=2), Opt(op=OptOps.UPCAST, axis=3, arg=2) - ] - k = helper_linearizer_ast(ast, [Tensor.randn(240*40).realize()], opts=[opt])[-1] - out = [u for u in get_program(k.get_optimized_ast(), k.opts).uops if u.op is Ops.STORE][0] + c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9600), arg=0, src=()) + c1 = c0.view(ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(40, 1, 0, 0), offset=0, mask=None, contiguous=True),))) + c2 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9600), arg=1, src=()) + c3 = c2.view(ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(1, 240, 0, 0), offset=0, mask=None, contiguous=False),))) + c4 = c3.load() + c5 = c1.store(c4) + ast = c5.sink() + opt = [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=16), + Opt(op=OptOps.LOCAL, axis=1, arg=2), Opt(op=OptOps.UPCAST, axis=3, arg=2)] + helper_linearizer_ast(ast, [Tensor.randn(240*40).realize()], opts=[opt]) + out = [u for u in get_program(ast, opts=opt).uops if u.op is Ops.STORE][0] assert out.src[1].op is Ops.VECTORIZE and out.src[1].dtype == dtypes.float.vec(4) @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") def test_skip_unmatching_upcasts_with_gep(self): Tensor.manual_seed(0) - ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( - UOp(Ops.STORE, dtypes.void, arg=None, src=( - UOp(Ops.VIEW, dtypes.float.ptr(256), arg=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(32, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=( # noqa: E501 - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(256), arg=0, src=()),)), - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.VIEW, dtypes.float.ptr(256), arg=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(1, 8, 0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501 - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(256), arg=1, src=()),)),)),)),)) + c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(256), arg=0, src=()) + c1 = c0.view(ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(32, 1, 0, 0), offset=0, mask=None, contiguous=True),))) + c2 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(256), arg=1, src=()) + c3 = c2.view(ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(1, 8, 0, 0), offset=0, mask=None, contiguous=False),))) + c4 = c3.load() + c5 = c1.store(c4) + ast = c5.sink() opt = [Opt(op=OptOps.LOCAL, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=2, arg=2), Opt(op=OptOps.LOCAL, axis=1, arg=8), Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=0, arg=2)] - k = helper_linearizer_ast(ast, [Tensor.randn(8*32).realize()], opts=[opt])[-1] - out = [u for u in get_program(k.get_optimized_ast(), k.opts).uops if u.op is Ops.STORE][0] + helper_linearizer_ast(ast, [Tensor.randn(8*32).realize()], opts=[opt]) + out = [u for u in get_program(ast).uops if u.op is Ops.STORE][0] assert out.src[1].op is Ops.VECTORIZE and out.src[1].dtype.count != 1 -@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "need backends that support float4") -class TestFloat4(unittest.TestCase): - @staticmethod - def count_float4(uops: list[UOp], n=4): - return (len([uop for uop in uops if uop.op is Ops.LOAD and uop.dtype == dtypes.float.vec(n)]), - len([uop for uop in uops if uop.op is Ops.STORE and uop.src[1].dtype == dtypes.float.vec(n)])) - @staticmethod - def count_half4(uops: list[UOp]): - return (len([uop for uop in uops if uop.op is Ops.LOAD and uop.dtype == dtypes.half.vec(4)]), - len([uop for uop in uops if uop.op is Ops.STORE and uop.src[1].dtype == dtypes.half.vec(4)])) +# *** helpers *** - def test_float4_basic(self): - a = Tensor.empty(2, 8).realize() - b = Tensor.empty(2, 8).realize() - c = a + b +def push_views(ast): return apply_rewrites(ast, rewrites_for_views) - s = c.schedule()[0] - realized_ast = s.ast - opts_to_apply = [Opt(op=OptOps.UPCAST, axis=0, arg=4)] - realized_ast = realized_ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts_to_apply))) - program = get_program(realized_ast, Device[Device.DEFAULT].renderer) - - assert TestFloat4.count_float4(program.uops) == (2, 1) - - @unittest.skipIf(Device.DEFAULT in {"CPU", "LLVM"} and AMX, "CPU with AMX upcasts float up to size 16") - def test_float4_multidim(self): - a = Tensor.empty(2, 8).realize() - b = Tensor.empty(2, 8).realize() - c = a + b - - s = c.schedule()[0] - uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=2)]).uops - assert TestFloat4.count_float4(uops) == (4, 2) - - @unittest.skipUnless(Device.DEFAULT in {"CPU", "LLVM"} and AMX, "Only CPU with AMX upcasts float up to size 16") - def test_float4_multidim_amx(self): - def kernel_for_shape(size, shift): - a = Tensor.empty(2, size).realize() - b = Tensor.empty(2, size).realize() - c = a + b - - s = c.schedule()[0] - return get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=shift)]).uops - - sizes = [12, 8, 16] - shifts = [3, 2, 4] - expected_upcast_size = [4, 8, 16] - expected_output = [(6,3), (2,1), (2,1)] - - for i in range(len(sizes)): - assert TestFloat4.count_float4(kernel_for_shape(sizes[i], shifts[i]), expected_upcast_size[i]) == expected_output[i] - - def test_float4_unaligned_load(self): - a = Tensor.empty(9).realize().shrink(((1, 9),)) - b = Tensor.empty(9).realize().shrink(((1, 9),)) - c = a + b - - s = c.schedule()[0] - realized_ast = s.ast - opts_to_apply = [Opt(op=OptOps.UPCAST, axis=0, arg=4)] - realized_ast = realized_ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts_to_apply))) - program = get_program(realized_ast, Device[Device.DEFAULT].renderer) - - assert TestFloat4.count_float4(program.uops) == (0, 1) - - @unittest.skipIf(Device.DEFAULT in {"CPU", "LLVM"} and AMX, "CPU with AMX upcasts float up to size 16") - def test_float4_multidim_unaligned_load(self): - a = Tensor.empty(2, 9).realize().shrink(((0, 2), (1, 9),)) - b = Tensor.empty(2, 9).realize().shrink(((0, 2), (1, 9),)) - c = a + b - - s = c.schedule()[0] - uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=2)]).uops - - assert TestFloat4.count_float4(uops) == (0, 2) - - @unittest.skipUnless(Device.DEFAULT in {"CPU", "LLVM"} and AMX, "Only CPU with AMX upcasts float up to size 16") - def test_float4_multidim_unaligned_load_amx(self): - def kernel_for_shape(size, shift): - a = Tensor.empty(2, size).realize().shrink(((0, 2), (1, size),)) - b = Tensor.empty(2, size).realize().shrink(((0, 2), (1, size),)) - c = a + b - - s = c.schedule()[0] - return get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=shift)]).uops - - sizes = [13, 9, 17] - shifts = [3, 2, 4] - expected_upcast_size = [4, 8, 16] - expected_output = [(0,3), (0,1), (0,1)] - - for i in range(len(sizes)): - assert TestFloat4.count_float4(kernel_for_shape(sizes[i], shifts[i]), expected_upcast_size[i]) == expected_output[i] - - def test_float4_sometimes_unaligned(self): - a = Tensor.empty(1, 1, 8).realize() - b = Tensor.empty(1, 1, 5).realize().shrink(((0, 1), (0, 1), (1, 5))) - c = a.conv2d(b) - # only the first and last conv dot products are aligned in a, and b is never aligned, so no - # float4 should be emitted (the reduce axis of size 4 is the float4 axis here) - - s = c.schedule()[0] - uops = get_program(s.ast, opts=[Opt(op=OptOps.UNROLL, axis=0, arg=4)]).uops - - assert TestFloat4.count_float4(uops) == (0, 0) - - def test_float4_multidim_sometimes_unaligned(self): - a = Tensor.empty(1, 1, 7).realize() - b = Tensor.empty(1, 1, 5).realize().shrink(((0, 1), (0, 1), (1, 5))) - c = a.conv2d(b) - # the first conv dot product is aligned in a. If we upcast the output and reduce - # dimension, then we could do float4 for only that one set of loads, but we currently - # don't. - # UPDATE: now we do this fusion - - s = c.schedule()[0] - uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=0)]).uops - - assert TestFloat4.count_float4(uops) in {(0,1), (1,1)} - - def test_float4_expand(self): - a = Tensor.empty(9).realize().shrink(((1, 9),)) - b = Tensor.empty(2).realize().reshape((2, 1)).expand((2,4)).reshape((8,)) - c = a + b - - # we will upcast the top axis of sz 4. they should not be coalesced into float4, - # since the top axis is not contiguous. - - s = c.schedule()[0] - uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4)]).uops - - assert TestFloat4.count_float4(uops) == (0, 1) - - def test_float4_heterogeneous(self): - a = Tensor.empty(8).realize() - b = Tensor.empty(9).realize().shrink(((1, 9),)) - c = a + b - - # should float4 b but not a - - s = c.schedule()[0] - uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4)]).uops - - assert TestFloat4.count_float4(uops) == (1, 1) - - def test_half4_load_unrolled(self): - # from llama 7B shard 4 gpus - ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( - UOp(Ops.STORE, dtypes.void, arg=None, src=( - UOp(Ops.VIEW, dtypes.float.ptr(96000), arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1), strides=(0, 32000, 1, 0), offset=0, mask=None, contiguous=True),)), src=( # noqa: E501 - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(96000), arg=0, src=()),)), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=( - UOp(Ops.CAST, dtypes.float, arg=None, src=( - UOp(Ops.MUL, dtypes.half, arg=None, src=( - UOp(Ops.LOAD, dtypes.half, arg=None, src=( - UOp(Ops.VIEW, dtypes.half.ptr(9216), arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 4096, 0, 1), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501 - UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(9216), arg=1, src=()),)),)), - UOp(Ops.LOAD, dtypes.half, arg=None, src=( - UOp(Ops.VIEW, dtypes.half.ptr(32768000), arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 0, 1024, 1), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501 - UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(32768000), arg=2, src=()),)),)),)),)),)),)),)) - - # TODO: fix this, expected might change but should be positive - for expected, opts in [ - ((7, 0), [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=3), Opt(op=OptOps.UNROLL, axis=0, arg=4)]), - ((5, 0), [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UNROLL, axis=0, arg=4)]), - ((2, 0), [Opt(op=OptOps.UNROLL, axis=0, arg=4)]), - ]: - ast = ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts))) - program = get_program(ast, Device[Device.DEFAULT].renderer) - - count = TestFloat4.count_half4(program.uops) - assert count == expected, f"{count=}, {expected=}" - - @unittest.skip("this doesn't happen anymore") - def test_float4_acc(self): - # from float32 stable diffusion red tinybox - ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( - UOp(Ops.STORE, dtypes.void, arg=None, src=( - UOp(Ops.VIEW, dtypes.float.ptr(33554432), arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 262144, 512, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=( # noqa: E501 - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(33554432), arg=0, src=()),)), - UOp(Ops.ADD, dtypes.float, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=( - UOp(Ops.MUL, dtypes.float, arg=None, src=( - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.VIEW, dtypes.float.ptr(67108864), arg=ShapeTracker(views=(View(shape=(1, 1, 1, 256, 4, 514, 4, 514), strides=(0, 0, 0, 262144, 0, 512, 0, 1), offset=-513, mask=((0, 1), (0, 1), (0, 1), (0, 256), (0, 4), (1, 513), (0, 4), (1, 513)), contiguous=False), View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 0, 2056, 1, 4227136, 1058840, 515), offset=0, mask=None, contiguous=False))), src=( # noqa: E501 - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(67108864), arg=1, src=()),)),)), - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.VIEW, dtypes.float.ptr(294912), arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 2304, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501 - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(294912), arg=2, src=()),)),)),)),)), - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.VIEW, dtypes.float.ptr(128), arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501 - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(128), arg=3, src=()),)),)),)),)),)) - - for expected, opts in [ - (1, [Opt(op=OptOps.UPCAST, axis=2, arg=4)]), - (4, [Opt(op=OptOps.UPCAST, axis=2, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4)]), - ]: - ast = ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts))) - program = get_program(ast, Device[Device.DEFAULT].renderer) - count = len([uop for uop in program.uops if uop.op is Ops.DEFINE_REG and uop.dtype == dtypes.float.vec(4)]) - assert count == expected, f"{count=}, {expected=}" - - @unittest.skip("this doesn't happen anymore") - def test_float2_acc(self): - # from resnet - ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( - UOp(Ops.STORE, dtypes.void, arg=None, src=( - UOp(Ops.VIEW, dtypes.half.ptr(212926464), arg=ShapeTracker(views=(View(shape=(1, 256, 1, 64, 1, 114, 1, 114), strides=(0, 831744, 0, 12996, 0, 114, 0, 1), offset=0, mask=None, contiguous=True),)), src=( # noqa: E501 - UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(212926464), arg=0, src=()),)), - UOp(Ops.CAST, dtypes.half, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (4, 6)), src=( - UOp(Ops.CAST, dtypes.float, arg=None, src=( - UOp(Ops.LOAD, dtypes.half, arg=None, src=( - UOp(Ops.VIEW, dtypes.half.ptr(462422016), arg=ShapeTracker(views=(View(shape=(256, 64, 3, 56, 2, 3, 56, 2), strides=(1806336, 28224, 3, 504, 0, 1, 9, 0), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 56), (0, 1), (0, 3), (0, 56), (0, 1)), contiguous=False), View(shape=(256, 64, 3, 115, 3, 115), strides=(7225344, 112896, 37632, 336, 112, 1), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 112), (0, 3), (0, 112)), contiguous=False), View(shape=(256, 64, 456, 456), strides=(7617600, 119025, 345, 1), offset=0, mask=((0, 256), (0, 64), (0, 345), (0, 345)), contiguous=False), View(shape=(1, 256, 1, 64, 4, 114, 4, 114), strides=(0, 13307904, 0, 207936, 51984, 456, 114, 1), offset=0, mask=None, contiguous=True))), src=( # noqa: E501 - UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(462422016), arg=1, src=()),)),)),)),)),)),)),)) - for expected, opts in [ - (16, [Opt(op=OptOps.LOCAL, axis=1, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=2, arg=2), Opt(op=OptOps.LOCAL, axis=2, arg=3), Opt(op=OptOps.UPCAST, axis=3, arg=4)]), # noqa: E501 - (4, [Opt(op=OptOps.LOCAL, axis=1, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=2, arg=2)]), - ]: - ast = ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts))) - program = get_program(ast, Device[Device.DEFAULT].renderer) - count = len([uop for uop in program.uops if uop.op is Ops.DEFINE_REG and uop.dtype == dtypes.float.vec(2)]) - assert count == expected, f"{count=}, {expected=}" - -class TestHandCodedOpts(unittest.TestCase): - def test_masked_upcast(self): - layer_1 = Tensor.cat(*[Tensor.empty(5) for _ in range(4)]) - layer_2 = Tensor.cat(layer_1.unsqueeze(0), Tensor.empty(6, 20)) - - s = layer_2.schedule()[-1] - k = Kernel(push_views(s.ast)) - k.apply_opts(hand_coded_optimizations(k)) - assert len(k.bufs) == 6 # make sure all ops are done in one kernel - # masked upcast should upcast masked axis of size 7 - # masked upcast should not upcast large (20) last axis - # float4/other hcopt shouldn't upcast last axis, since we already have 7 upcast, and the last axis is not very contiguous - assert k.upcasted == 1 and k.full_shape[-1] == 7 - - @unittest.skipIf(Device.DEFAULT in {"METAL", "WEBGPU"}, "METAL/WEBGPU split this kernel since it has 37 buffers") - def test_masked_upcast_wino(self): - monster = Tensor.stack(*[Tensor.stack(*[Tensor.empty(16) for _ in range(6)]) for _ in range(6)]) - - s = monster.schedule()[-1] - k = Kernel(push_views(s.ast)) - k.apply_opts(hand_coded_optimizations(k)) - assert len(k.bufs) == 37 # make sure all ops are done in one kernel - # should upcast the two Tensor.stacks - assert k.upcasted >= 2 and k.full_shape[k.shape_len-k.upcasted:k.shape_len].count(6) == 2 - - def test_masked_upcast_wino_full(self): - with Context(WINO=1): - x,w = Tensor.rand(1,4,8,8, requires_grad=True).realize(), Tensor.rand(4,4,3,3, requires_grad=True).realize() - out = Tensor.conv2d(x,w, padding=1) - out.mean().backward() - - upcasts = [] - wino_schedule = out.schedule() - # collect upcasts of tile transform kernels - for i, si in enumerate(wino_schedule): - k = Kernel(push_views(si.ast)) - k.apply_opts(hand_coded_optimizations(k)) - if k.reduceop is not None: continue # not a tile transform kernel (there is a gemm reduce kernel) - if len(k.bufs) < 22: continue # not a tile transform kernel (there's a permute kernel at the end) - upcasts.append(tuple(k.full_shape[k.shape_len - k.upcasted:k.shape_len])) - assert len(upcasts) == 3 # 3 transformation matrices - assert len(wino_schedule) <= 4 # 4 kernels - # this test case's inputs are too small, so one of the 4-stacks became a local, which is fine i guess - assert upcasts.count((6, 6)) == 2 #and upcasts.count((4, 4)) == 1 - - backward_schedule = Tensor.schedule(x.grad, w.grad) - for si in backward_schedule: - k = Kernel(push_views(si.ast)) - k.apply_opts(hand_coded_optimizations(k)) - if len(k.bufs) < 20: continue # not a tile transform kernel - # heuristic number to make sure that at least some upcasts but not too many upcasts are being done - assert 6 <= prod(k.full_shape[k.shape_len - k.upcasted:k.shape_len]) <= 216 - assert len(backward_schedule) <= 13 # just the current number, but it could be better - - def test_masked_upcast_many(self): - layer_1 = Tensor.cat(Tensor.rand(3, 4), Tensor.rand(4, 4)) - layer_2 = Tensor.cat(layer_1.unsqueeze(0), Tensor.rand(6, 7, 4)) - layer_3 = Tensor.cat(layer_2.unsqueeze(0), Tensor.rand(6, 7, 7, 4)) - - k = helper_linearizer_opt(layer_3)[-1] - assert len(k.bufs) == 5 # make sure all ops are done in one kernel - # check that we don't do too many upcasts - assert prod(k.full_shape[k.shape_len-k.upcasted:k.shape_len]) <= 49 - - @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") - def test_matvec(self): - N = 128 - a = Tensor.rand(1, N).realize() - b = Tensor.rand(N, N).realize() - c = a @ b - - k = helper_linearizer_opt(c)[-1] - - assert k.group_for_reduces == 1 - assert k.axis_types.count(AxisType.LOCAL) == 1 - assert k.upcasted == 1 +def helper_realized_ast(r:Tensor|list[Tensor]) -> tuple[UOp, list[Buffer]]: + if isinstance(r, Tensor): r = [r] + s = Tensor.schedule(*r) + run_schedule(s[:-1]) # run all kernels except the last one + assert s[-1].ast.op is Ops.SINK, f"helper_realized_ast expects a SINK {s[-1]}" + # now all input buffers in s[-1] should be realized + # create fresh buffers for the outputs + bufs = [Buffer(x.device, x.size, x.dtype).allocate() if i < len(s[-1].ast.src) else x for i,x in enumerate(s[-1].bufs)] + return push_views(s[-1].ast), bufs def helper_linearizer_ast(ast:UOp, inputs:list[Tensor], *args, **kwargs): assert isinstance(ast, UOp), "ast must be UOp" inbufs = [x.uop.base.buffer for x in inputs] - outbufs = [Buffer(inbufs[-1].device if inbufs else Device.DEFAULT, out.st_arg.size, out.src[1].dtype).allocate() \ - for out in ast.src] - return _helper_linearizer_opt_ast(ast, outbufs+inbufs, *args, **kwargs) + outbufs = [Buffer(inbufs[-1].device if inbufs else Device.DEFAULT, out.st_arg.size, out.src[1].dtype).allocate() for out in ast.src] + _helper_linearizer_opt_ast(ast, outbufs+inbufs, *args, **kwargs) def helper_linearizer_opt(r:Tensor|list[Tensor], *args, **kwargs): realized_ast, real_bufs = helper_realized_ast(r) - return _helper_linearizer_opt_ast(realized_ast, real_bufs, *args, **kwargs) + _helper_linearizer_opt_ast(realized_ast, real_bufs, *args, **kwargs) + return realized_ast -def copyout_outputs(lin:Kernel, outbufs:list[Buffer]) -> list[np.ndarray]: - ret = [] - for i,x in enumerate(outbufs): - shape: tuple[int, ...] = lin.ast.src[i].st_arg.shape - ret.append(np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)).reshape(shape)) - return ret +def copyout_outputs(outbufs:list[Buffer]) -> list[np.ndarray]: + return [np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)) for x in outbufs] def reset_bufs(bufs:list[Buffer]): for buf in bufs: buf.copyin(np.zeros((buf.size, ), dtype=_to_np_dtype(buf.dtype)).data) # Zero to check that all values are filled def _helper_linearizer_opt_ast(realized_ast:UOp, real_bufs:list[Buffer], opts=[], - apply_tc=False, atol=1e-4, rtol=1e-4, color_sizes=[], wanna_output=[]) -> list[Kernel]: - lins: list[Kernel] = [] - outbufs = [real_bufs[x.src[0].base.arg] for x in realized_ast.src] + apply_tc=False, atol=1e-4, rtol=1e-4, color_sizes=[], wanna_output=[]): + outbufs = real_bufs[:len(realized_ast.src)] device = real_bufs[0].device + wanna_output = [np.array(x).flatten() for x in wanna_output] - def get_prg(k:Kernel): return CompiledRunner(replace(get_program(k.get_optimized_ast(), k.opts), device=device)) + def get_prg(opts): return CompiledRunner(replace(get_program(realized_ast, opts=opts), device=device)) - def check_opt(opts, create_k, expected_color_size): - k = create_k() - lins.append(k) - if apply_tc: - assert k.apply_tensor_cores(1, extra_opts=opts), "no tensor core triggered" - else: - k.apply_opts(opts) - if expected_color_size is not None: - cs = list(zip(k.colors(), k.full_shape)) - assert cs == expected_color_size, f"expected={expected_color_size} got={cs}" - prg = get_prg(k) + def check_opt(opts): + prg = get_prg(opts=opts) reset_bufs(outbufs) prg.exec(real_bufs) - - for x,want in zip(copyout_outputs(k, outbufs), wanna_output): np.testing.assert_allclose(x, want, atol=atol, rtol=rtol) + for x,want in zip(copyout_outputs(outbufs), wanna_output): np.testing.assert_allclose(x, want, atol=atol, rtol=rtol) # Get baseline if it is not provided, which is not optimized at all. - k = Kernel(realized_ast) - lins.append(k) - prg = get_prg(k) + prg = get_prg(opts=()) prg.exec(real_bufs) - if len(wanna_output) == 0: wanna_output = copyout_outputs(k, outbufs) + if len(wanna_output) == 0: wanna_output = copyout_outputs(outbufs) else: - for buf,want in zip(copyout_outputs(k, outbufs), wanna_output): np.testing.assert_allclose(buf, want, atol=atol, rtol=rtol) + for buf,want in zip(copyout_outputs(outbufs), wanna_output): np.testing.assert_allclose(buf, want, atol=atol, rtol=rtol) # Check correctness of handcoded optimiztions. - k = Kernel(realized_ast) - k.apply_opts(hand_coded_optimizations(k)) - lins.append(k) - prg = get_prg(k) + prg = get_prg(opts=None) reset_bufs(outbufs) prg.exec(real_bufs) - for buf,want in zip(copyout_outputs(k, outbufs), wanna_output): np.testing.assert_allclose(buf, want, atol=atol, rtol=rtol) - for i,x in enumerate(opts): # Check custom transformations if any. - check_opt(x, lambda: Kernel(realized_ast), color_sizes[i] if i < len(color_sizes) else None) - return lins - -class TestKernelOpts(unittest.TestCase): - @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") - @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") - def test_local_and_grouped_reduce(self): - N = 128 - Tensor.manual_seed(1882) - a = Tensor.rand(4, 4, N, N) - b = Tensor.rand(4, 4, N) - r = (b.sqrt() + ((a+1).sum(axis=3).exp())) - helper_linearizer_opt(r, [ - [Opt(OptOps.LOCAL, 0, 2)], - [Opt(OptOps.LOCAL, 0, 8)], - [Opt(OptOps.LOCAL, 0, 16)], # Checking how it works with locals - [Opt(OptOps.GROUPTOP, 0, 2)], - [Opt(OptOps.GROUPTOP, 0, 32)], - [Opt(OptOps.GROUPTOP, 0, 64)], # Checking how it works with grouped reduce - [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 2)], - [Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.GROUPTOP, 0, 16)], - [Opt(OptOps.LOCAL, 0, 32), Opt(OptOps.GROUPTOP, 0, 2)], - # Checking how it works with locals + grouped reduce - [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 64)], - # Checking how it works with locals + grouped reduce + upcasts - [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.UPCAST, 0, 8), Opt(OptOps.UNROLL, 1, 4)], - # many local + many group - [Opt(OptOps.GROUP, 0, 2)] * 4, - [Opt(OptOps.LOCAL, 0, 2)] * 4, - [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUP, 0, 2)] * 4, - ]) - - def test_upcasts(self): - N = 16 - Tensor.manual_seed(1772) - a = Tensor.rand(N, N) - b = Tensor.rand(N, N) - r = (a+b).sqrt() * ((a+1).exp()) - helper_linearizer_opt(r, [ - [Opt(OptOps.UPCAST, 0, 2)], - [Opt(OptOps.UPCAST, 0, 4)], - [Opt(OptOps.UPCAST, 0, 8)], # Checking how it works with upcasts - ]) - - def test_full_upcast(self): - Tensor.manual_seed(1772) - a = Tensor.rand(4) - b = Tensor.rand(4) - r = (a+b).sqrt() * ((a+1).exp()) - helper_linearizer_opt(r, [ - [Opt(OptOps.UPCAST, 0, 4)], # Checking how it works with upcasts - ]) - - @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") - @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") - def test_matmul(self): - N = 128 - Tensor.manual_seed(1552) - a = Tensor.rand(N, N) - b = Tensor.rand(N, N) - r = a@b - helper_linearizer_opt(r, [ - [Opt(OptOps.UPCAST, 0, 2)], - [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4)], # Checking how it works with upcasts - [Opt(OptOps.LOCAL, 0, 2)], - [Opt(OptOps.LOCAL, 1, 32)], - [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4)], - [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 32)], - [Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.LOCAL, 1, 8)], # Checking how it works with locals - [Opt(OptOps.GROUPTOP, 0, 2)], - [Opt(OptOps.GROUPTOP, 0, 32)], - [Opt(OptOps.GROUPTOP, 0, 32), Opt(OptOps.UNROLL, 0, 4)], # Checking how it works with grouped_reduce - [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 1, 2), Opt(OptOps.GROUPTOP, 0, 32)], - [Opt(OptOps.LOCAL, 0, 8), Opt(OptOps.GROUPTOP, 0, 32)], - [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 0, 8), Opt(OptOps.GROUPTOP, 0, 4)], # Checking how it works with local+grouped_reduce - # Checking all together - [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4), - Opt(OptOps.UPCAST, 1, 2)], - # Full global upcast + local - [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 8)], - ]) - - @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") - @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") - def test_double_reduce(self): - N = 128 - Tensor.manual_seed(1552) - a = Tensor.rand(8, N, 8, N) - r = a.sum(axis=(1,3)) - helper_linearizer_opt(r, [ - # openCL / GPU=1 is 256 max threads - [Opt(OptOps.GROUPTOP, 0, 2)], [Opt(OptOps.GROUPTOP, 0, 32)], - [Opt(OptOps.GROUPTOP, 1, 2)], [Opt(OptOps.GROUPTOP, 1, 32)], # Checking how it works with 1 grouped_reduce. - [Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2)], - [Opt(OptOps.GROUPTOP, 0, 16), Opt(OptOps.GROUPTOP, 1, 2)], - [Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 64)], # Checking how it works with 2 grouped_reduces. - [Opt(OptOps.GROUPTOP, 0, 16), Opt(OptOps.GROUPTOP, 1, 2), Opt(OptOps.UNROLL, 0, 4)], - [Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 32), Opt(OptOps.UNROLL, 2, 4)], # Checking how it works with 2 grouped_reduces + upcasts. - [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4), Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 4)], - # Checking how it works with 2 grouped_reduces + upcasts + locals. - [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4), Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 32), Opt(OptOps.UNROLL, 1, 4)], - [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 1, 2), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UPCAST, 0, 2)], - [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 1, 2), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UPCAST, 0, 2), - Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UNROLL, 1, 4)], # Checking how it works with 2 grouped_reduces + upcasts + locals. - [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4), Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UPCAST, 0, 2), - Opt(OptOps.UPCAST, 0, 2)], # No globals - ]) - - @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") - @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") - def test_invalid_tensor_core_extra_opts(self): - N = 128 - Tensor.manual_seed(1552) - a = Tensor.rand(N, N) - b = Tensor.rand(N, N) - realized_ast, _ = helper_realized_ast(a@b) - invalid_opts = [ - [Opt(OptOps.LOCAL, 2, 2)], - [Opt(OptOps.UPCAST, 2, 2)], - [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 2, 2)], - ] - for x in invalid_opts: - k = Kernel(realized_ast) - with self.assertRaises(AssertionError): - assert k.apply_tensor_cores(use_tensor_cores=1, extra_opts=x), "no valid tensor core" # for METAL in runners - - @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") - @unittest.skipUnless(any(tc.dtype_in == tc.dtype_out == dtypes.half for tc in Device[Device.DEFAULT].renderer.tensor_cores), - "test requires tensor cores with accumulation in half") # testing with half suffices. - def test_tensor_core_opts(self): - N = 128 - Tensor.manual_seed(1552) - a, b = Tensor.rand(N, N, dtype=dtypes.half), Tensor.rand(N, N, dtype=dtypes.half) - r = a.matmul(b, dtype=dtypes.half) - atol, rtol = 0.25, 0.01 - helper_linearizer_opt(r, [ - [], - [Opt(OptOps.UPCAST, 0, 4)], - [Opt(OptOps.UPCAST, 1, 4)], - [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4)], # check upcasts - [Opt(OptOps.UNROLL, 0, 2)], # check unroll - [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 2)], # check combo of unroll and local - [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 2)], - [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4)], - [Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UPCAST, 0, 4)], # check permutations - [Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 0, 4)], - [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 1, 4)], - [Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], - ], apply_tc=True, atol=atol, rtol=rtol) - - @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") - @unittest.skipUnless(any(tc.dtype_in == tc.dtype_out == dtypes.half for tc in Device[Device.DEFAULT].renderer.tensor_cores), - "test requires tensor cores with accumulation in half") # testing with half suffices. - @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") - def test_tensor_core_opts_locals(self): - N = 128 - Tensor.manual_seed(1552) - a, b = Tensor.rand(N, N, dtype=dtypes.half), Tensor.rand(N, N, dtype=dtypes.half) - r = a.matmul(b, dtype=dtypes.half) - atol, rtol = 0.25, 0.01 - helper_linearizer_opt(r, [ - [Opt(OptOps.UNROLL, 0, 0)], # check full unroll of reduce with locals - [Opt(OptOps.LOCAL, 0, 4)], # check local - [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.LOCAL, 0, 2)], - [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 0, 4)], - ], apply_tc=True, atol=atol, rtol=rtol) - - @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") - @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") - @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory") - @unittest.skipUnless(any(tc.dtype_in == tc.dtype_out == dtypes.half for tc in Device[Device.DEFAULT].renderer.tensor_cores), - "test requires tensor cores with accumulation in half") # testing with half suffices. - # NOTE: the METAL test is broken, likely due to a compiler bug. passes on CI with -O0 and with default opt level locally on M3 - @unittest.skipIf(Device.DEFAULT == "METAL", "broken for METAL") - @unittest.skip("feature was removed") - def test_tensor_core_opts_group(self): - N = 128 - Tensor.manual_seed(1552) - a, b = Tensor.rand(N, N, dtype=dtypes.half), Tensor.rand(N, N, dtype=dtypes.half) - r = a.matmul(b, dtype=dtypes.half) - atol, rtol = 0.25, 0.01 - helper_linearizer_opt(r, [ - [Opt(OptOps.GROUP, 0, 2)], - [Opt(OptOps.GROUPTOP, 0, 4)], - [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.GROUP, 0, 2)], - [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUP, 0, 2)], - [Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.GROUP, 0, 2)], - [Opt(OptOps.UPCAST, 0, 2), Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUP, 0, 2)], - [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 1, 2)], - ], apply_tc=True, atol=atol, rtol=rtol) - - def test_padto_matmul(self): - if (CI and Device.DEFAULT in ["AMD", "NV", "CUDA"]): - self.skipTest("super slow on CUDA and AMD because of the big grid dims") - N = 17 * 17 - Tensor.manual_seed(289) - a = Tensor.rand(N, N) - b = Tensor.rand(N, N) - helper_linearizer_opt(a@b, [ - [Opt(OptOps.PADTO, 0, 32)], - [Opt(OptOps.PADTO, 1, 32)], - [Opt(OptOps.PADTO, 2, 32)], - [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32)], - [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32), Opt(OptOps.PADTO, 2, 32)], - # can optimize further post PADTO - [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32), Opt(OptOps.UPCAST, 0, 2), Opt(OptOps.UPCAST, 1, 2),], - ]) - - def test_padto_upcasted_not_ok(self): - N = 4 - a = Tensor.rand(N, N) - b = Tensor.rand(N, N) - helper_linearizer_opt(a@b, [ - [Opt(OptOps.UPCAST, 0, 0)], - [Opt(OptOps.UPCAST, 1, 0)], - [Opt(OptOps.UNROLL, 0, 0)], - [Opt(OptOps.PADTO, 0, 8)], - [Opt(OptOps.PADTO, 1, 8)], - [Opt(OptOps.PADTO, 2, 8)], - ]) - with self.assertRaises(KernelOptError): - helper_linearizer_opt(a@b, [[Opt(OptOps.UPCAST, 0, 0), Opt(OptOps.PADTO, 1, 8)]]) - with self.assertRaises(KernelOptError): - helper_linearizer_opt(a@b, [[Opt(OptOps.UPCAST, 1, 0), Opt(OptOps.PADTO, 1, 8)]]) - with self.assertRaises(KernelOptError): - helper_linearizer_opt(a@b, [[Opt(OptOps.UNROLL, 0, 0), Opt(OptOps.PADTO, 2, 8)]]) - - def test_padto_sum_ok(self): - N = 18 * 18 - # NOTE: this setup prevents 17 * 17 contiguous merged into one dimension - a = Tensor.rand(N, N).realize().shrink(((0, 17), (0, 17))) * 100 - b = (Tensor.rand(N, N) < 0.5).realize().shrink(((0, 17), (0, 17))) - - helper_linearizer_opt(a.sum(0), [ - [Opt(OptOps.PADTO, 0, 32)], - [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),], - ]) - helper_linearizer_opt(a.sum(1), [ - [Opt(OptOps.PADTO, 0, 32)], - [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),], - ]) - - # can pad sum reduce axis if there's no unsafe ops prior to sum - for axis in (0, 1): - helper_linearizer_opt(a.sum(), [[Opt(OptOps.PADTO, axis, 32)],]) - helper_linearizer_opt(a.sum(0), [[Opt(OptOps.PADTO, axis, 32)],]) - helper_linearizer_opt(b.sum(), [[Opt(OptOps.PADTO, axis, 32)],]) - helper_linearizer_opt(b.sum(0), [[Opt(OptOps.PADTO, axis, 32)],]) - helper_linearizer_opt(b.sum(dtype=dtypes.bool), [[Opt(OptOps.PADTO, axis, 32)],]) - # TODO: why? - if Device.DEFAULT != "WEBGPU": - helper_linearizer_opt(b.sum(0, dtype=dtypes.bool), [[Opt(OptOps.PADTO, axis, 32)],]) - helper_linearizer_opt(b.sum(1, dtype=dtypes.bool), [[Opt(OptOps.PADTO, axis, 32)],]) - - # having unsafe ops after sum is fine - helper_linearizer_opt(a.sum().exp(), [[Opt(OptOps.PADTO, 0, 32)],]) - helper_linearizer_opt(a.sum(0).exp(), [[Opt(OptOps.PADTO, 1, 32)],]) - - def test_padto_sum_not_ok(self): - N = 18 * 18 - # NOTE: this setup prevents 17 * 17 contiguous merged into one dimension - a = Tensor.rand(N, N).shrink(((0, 17), (0, 17))).exp() - # exp is not safe to pad - with self.assertRaises(KernelOptError): - helper_linearizer_opt(a.exp().sum(), [[Opt(OptOps.PADTO, 0, 32)],]) - with self.assertRaises(KernelOptError): - helper_linearizer_opt(a.exp().sum(0), [[Opt(OptOps.PADTO, 1, 32)],]) - - b = a < 1 - # lt is not safe to pad - with self.assertRaises(KernelOptError): - helper_linearizer_opt(b.sum(), [[Opt(OptOps.PADTO, 0, 32)],]) - with self.assertRaises(KernelOptError): - helper_linearizer_opt(b.sum(0), [[Opt(OptOps.PADTO, 1, 32)],]) - - def test_padto_max(self): - N = 18 * 18 - # NOTE: this setup prevents 17 * 17 contiguous merged into one axis - a = -Tensor.rand(N, N).shrink(((0, 17), (0, 17))) * 100 - - helper_linearizer_opt(a.max(0), [ - [Opt(OptOps.PADTO, 0, 32)], - [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),], - ]) - helper_linearizer_opt(a.max(1), [ - [Opt(OptOps.PADTO, 0, 32)], - [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),], - ]) - - # cannot pad max kernel on reduce - with self.assertRaises(KernelOptError): - helper_linearizer_opt(a.max(), [[Opt(OptOps.PADTO, 0, 32)],]) - with self.assertRaises(KernelOptError): - helper_linearizer_opt(a.max(0), [[Opt(OptOps.PADTO, 1, 32)],]) - - def test_padto_where(self): - Tensor.manual_seed(0) - N = 17 * 17 - a = (Tensor.randn(N, N).realize().max(axis=0, keepdim=True) > 1).where(1, 0) - helper_linearizer_opt(a.max(0), [ - [Opt(OptOps.PADTO, 0, 32)], - [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),], - ]) - - def test_padto_where_multioutput(self): - Tensor.manual_seed(0) - N = 17 * 17 - r = Tensor.randn(N, N).realize().max(axis=0, keepdim=True) > 1 - a0 = r.where(1, 0) - a1 = r.where(2, 0) - helper_linearizer_opt([a0.max(0), a1.max(0)], [ - [Opt(OptOps.PADTO, 0, 32)], - [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),], - ]) - - @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") - @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") - def test_color_shapes_with_local(self): - N = 32 - Tensor.manual_seed(1552) - a = Tensor.rand(N, N) - b = Tensor.rand(N, N) - r = a@b - opts_shapes = [ - ([Opt(OptOps.LOCAL, 0, 2)], [("blue",16),("blue",32),("cyan",2),("red",32)]), - ([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.GROUP, 0, 2)], [("blue",16),("blue",32),("cyan",2),("green",2),("red",16)]), - # check to ensure local_dims are stable for full UNROLL of the first reduce - ([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]), - ([Opt(OptOps.UNROLL, 0, 0),Opt(OptOps.LOCAL, 0, 2)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]), - # check behavior for full UNROLL on an existing GROUP - ([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.GROUP, 0, 0),Opt(OptOps.UNROLL, 0, 2)], [("blue",16),("blue",32),("cyan",2),("green",16),("magenta",2)]), - ([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.GROUP, 0, 0),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]), - ([Opt(OptOps.GROUP, 0, 0),Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]), - ([Opt(OptOps.GROUP, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",32),("blue",32),("red",16),("magenta",2)]), - ] - helper_linearizer_opt(r, [x[0] for x in opts_shapes], color_sizes=[x[1] for x in opts_shapes]) + for buf,want in zip(copyout_outputs(outbufs), wanna_output): np.testing.assert_allclose(buf, want, atol=atol, rtol=rtol) + for x in opts: # Check custom transformations if any. + check_opt(([Opt(OptOps.TC, 0, (TC_SELECT.value, TC_OPT.value, 1))] if apply_tc else [])+x) if __name__ == '__main__': unittest.main() diff --git a/tinygrad_repo/test/test_linearizer_dumb.py b/tinygrad_repo/test/test_linearizer_dumb.py index c8ebbbb3..4adca10d 100644 --- a/tinygrad_repo/test/test_linearizer_dumb.py +++ b/tinygrad_repo/test/test_linearizer_dumb.py @@ -5,193 +5,110 @@ import unittest from tinygrad import Device, dtypes from tinygrad.device import is_dtype_supported -from tinygrad.uop.ops import UOp, Ops -from tinygrad.helpers import getenv +from tinygrad.uop.ops import UOp, Ops, AxisType, KernelInfo from tinygrad.shape.shapetracker import ShapeTracker, View from tinygrad.codegen.opt.search import Opt, OptOps -from tinygrad.codegen.opt.kernel import Kernel from tinygrad.engine.realize import get_program +from tinygrad.renderer.ptx import PTXRenderer + +class TestLinearizerFailure(unittest.TestCase): + @unittest.skipUnless(Device.DEFAULT == "METAL", "only tested on METAL") + def test_failure_beam_mnist(self): + c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(4014080), arg=0, src=()) + c1 = UOp.range(UOp.const(dtypes.int, 512), 0, AxisType.GLOBAL) + c2 = UOp.range(UOp.const(dtypes.int, 784), 1, AxisType.GLOBAL) + c3 = UOp.range(UOp.const(dtypes.int, 10), 3, AxisType.GLOBAL) + c4 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(512), arg=1, src=()) + c5 = c4.index(c1, UOp.const(dtypes.bool, True)).load() + c6 = UOp.range(UOp.const(dtypes.int, 6000), 1004, AxisType.REDUCE) + c7 = UOp.range(UOp.const(dtypes.int, 3750), 2006, AxisType.REDUCE) + c8 = UOp.range(UOp.const(dtypes.int, 16), 2007, AxisType.GROUP_REDUCE) + c9 = UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(47040000), arg=2, src=()) + c10 = c9.index((((c3*UOp.const(dtypes.int, 4704000))+c2)+(c6*UOp.const(dtypes.int, 784))), UOp.const(dtypes.bool, True)).load() + c11 = c5.alu(Ops.CMPNE, ((((c3*UOp.const(dtypes.int, 6000))+c6)+((c7*UOp.const(dtypes.int, 16))+c8)).alu(Ops.CMPLT, UOp.const(dtypes.int, 59999)).where(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)).reduce(c7, c8, arg=Ops.ADD)+UOp.const(dtypes.int, -1))).where(UOp.const(dtypes.uchar, 0), c10).reduce(c6, arg=Ops.ADD) + c12 = c0.index((((c1*UOp.const(dtypes.int, 7840))+(c2*UOp.const(dtypes.int, 10)))+c3), UOp.const(dtypes.bool, True)).store(c11, c1, c2, c3) + ast = c12.sink(arg=KernelInfo(name='test', axis_types=(), dont_use_locals=False, applied_opts=(Opt(op=OptOps.GROUP, axis=1, arg=16),), opts_to_apply=None)) + _ = get_program(ast, Device["METAL"].renderer) class TestLinearizerDumb(unittest.TestCase): - @unittest.skipUnless(Device.DEFAULT == "METAL", "only tested on METAL") - def test_unmerged_ifs(self): - ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( - UOp(Ops.STORE, dtypes.void, arg=None, src=( - UOp(Ops.VIEW, dtypes.half.ptr(1605632), arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(25088, 0, 49, 7, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(1605632), arg=0, src=()),)), - UOp(Ops.MAX, dtypes.half, arg=None, src=( - UOp(Ops.MUL, dtypes.half, arg=None, src=( - UOp(Ops.CAST, dtypes.half, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=( - UOp(Ops.CAST, dtypes.float, arg=None, src=( - UOp(Ops.MUL, dtypes.half, arg=None, src=( - UOp(Ops.LOAD, dtypes.half, arg=None, src=( - UOp(Ops.VIEW, dtypes.half.ptr(1605632), arg=ShapeTracker(views=(View(shape=(1, 64, 1, 512, 4, 9, 4, 9), strides=(0, 25088, 0, 49, 0, 7, 0, 1), offset=-8, mask=((0, 1), (0, 64), (0, 1), (0, 512), (0, 4), (1, 8), (0, 4), (1, 8)), contiguous=False), View(shape=(64, 1, 512, 7, 7, 512, 3, 3), strides=(663552, 0, 0, 36, 1, 1296, 360, 10), offset=0, mask=None, contiguous=False))), src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(1605632), arg=1, src=()),)),)), - UOp(Ops.LOAD, dtypes.half, arg=None, src=( - UOp(Ops.VIEW, dtypes.half.ptr(2359296), arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 512, 3, 3), strides=(0, 0, 4608, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(2359296), arg=2, src=()),)),)),)),)),)),)), - UOp(Ops.CONST, dtypes.half, arg=0.9999950000374996, src=( - x16:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), - UOp(Ops.CONST, dtypes.half, arg=0.0, src=( - x16,)),)),)),)) - opts = [Opt(op=OptOps.TC, axis=2, arg=(-1, 2, 1)), Opt(op=OptOps.UPCAST, axis=2, arg=0), Opt(op=OptOps.UNROLL, axis=1, arg=0)] - k = Kernel(ast, opts=Device["METAL"].renderer) - k.apply_opts(opts) - prg = get_program(k.get_optimized_ast(), k.opts) - print(prg.src) - Device[Device.DEFAULT].compiler.compile_cached(prg.src) - gate_count = len([x for x in prg.src.splitlines() if "if" in x]) - assert gate_count == 1, f"must have only one gate {gate_count} != 1" - assert len([u for u in prg.uops if u.op is Ops.IF]) == 1, "must have a single IF" - @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "need local") def test_max_simplify_and_cancel(self): - ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( - UOp(Ops.STORE, dtypes.void, arg=None, src=( - UOp(Ops.VIEW, dtypes.int.ptr(1000), arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1000), arg=0, src=()),)), - UOp(Ops.MUL, dtypes.int, arg=None, src=( - UOp(Ops.CAST, dtypes.int, arg=None, src=( - UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( - UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.VIEW, dtypes.float.ptr(1000), arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1000), arg=1, src=()),)),)), - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.VIEW, dtypes.float.ptr(1), arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1), arg=2, src=()),)),)),)), - UOp(Ops.CONST, dtypes.bool, arg=True, src=( - x14:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), - UOp(Ops.ADD, dtypes.int, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (1,)), src=( - UOp(Ops.WHERE, dtypes.int, arg=None, src=( - UOp(Ops.VALID, dtypes.bool, arg=None, src=( - UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1001, 1999), strides=(0, 0), offset=0, mask=((0, 1001), (999, 1999)), contiguous=False), View(shape=(1000, 1000), strides=(1, 2000), offset=0, mask=None, contiguous=False))), src=()),)), - UOp(Ops.CONST, dtypes.int, arg=-1, src=( - x21:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1000), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), - UOp(Ops.CONST, dtypes.int, arg=0, src=( - x21,)),)),)), - UOp(Ops.CONST, dtypes.int, arg=1000, src=( - x14,)),)),)),)),)) - opts = [Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=8)] - k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) - k.apply_opts(opts) - prg = get_program(k.get_optimized_ast(), k.opts) + c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1000), arg=0, src=()) + c1 = c0.view(ShapeTracker(views=(View(shape=(1000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),))) + c2 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1000), arg=1, src=()) + c3 = c2.view(ShapeTracker(views=(View(shape=(1000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),))) + c4 = c3.load() + c5 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1), arg=2, src=()) + c6 = c5.view(ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),))) + c7 = c6.load() + c8 = UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()) + c9 = UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1001, 1999), strides=(0, 0), offset=0, mask=((0, 1001), (999, 1999)), contiguous=False), View(shape=(1000, 1000), strides=(1, 2000), offset=0, mask=None, contiguous=False))), src=()) + c10 = UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1000), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()) + c11 = c1.store((c4.alu(Ops.CMPNE, c7).alu(Ops.CMPNE, UOp.const(dtypes.bool, True, src=c8)).cast(dtypes.int)*(c9.f(Ops.VALID, dtype=dtypes.bool).where(UOp.const(dtypes.int, -1, src=c10), UOp.const(dtypes.int, 0, src=c10)).f(Ops.REDUCE_AXIS, arg=(Ops.ADD, (1,)))+UOp.const(dtypes.int, 1000, src=c8)))) + ast = c11.sink() + #opts = [Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=8)] + opts = [Opt(op=OptOps.LOCAL, axis=0, arg=8)] + prg = get_program(ast, Device[Device.DEFAULT].renderer, opts) print(prg.src) assert prg.uops is not None and not any(uop.op is Ops.MAX for uop in prg.uops), "leftover MAX" - @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "need local") - @unittest.skip("not applicable") - def test_expander_new_srcs(self): - ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( - UOp(Ops.STORE, dtypes.void, arg=None, src=( - UOp(Ops.VIEW, dtypes.float.ptr(25), arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(25), arg=0, src=()),)), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=( - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.VIEW, dtypes.float.ptr(25), arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(25), arg=1, src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.GROUP, axis=0, arg=0), Opt(op=OptOps.PADTO, axis=0, arg=32), Opt(op=OptOps.LOCAL, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=0)] - k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) - k.apply_opts(opts) - prg = get_program(k.get_optimized_ast(), k.opts) - print(prg.src) - if_uops = [u for u in prg.uops if u.op is Ops.IF] - self.assertIn(len(if_uops), {1,2,3}) - conditions = if_uops[0].src[0].toposort() - self.assertLessEqual(len(conditions), 9) - # this was a bug in embedding, someday we should fold this anyway @unittest.skipUnless(is_dtype_supported(dtypes.half), f"half dtype not supported on {Device.DEFAULT}") def test_llama_embedding(self): - ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( - UOp(Ops.STORE, dtypes.void, arg=None, src=( - UOp(Ops.VIEW, dtypes.half.ptr(4096), arg=ShapeTracker(views=(View(shape=(4096, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(4096), arg=0, src=()),)), - UOp(Ops.CAST, dtypes.half, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=( - UOp(Ops.CAST, dtypes.float, arg=None, src=( - UOp(Ops.MUL, dtypes.half, arg=None, src=( - UOp(Ops.CAST, dtypes.half, arg=None, src=( - UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( - UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( - UOp(Ops.ADD, dtypes.int, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (2,)), src=( - UOp(Ops.WHERE, dtypes.int, arg=None, src=( - UOp(Ops.VALID, dtypes.bool, arg=None, src=( - UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32001, 63999), strides=(0, 0), offset=0, mask=((0, 32001), (31999, 63999)), contiguous=False), View(shape=(4096, 32000, 32000), strides=(0, 1, 64000), offset=0, mask=None, contiguous=False))), src=()),)), - UOp(Ops.CONST, dtypes.int, arg=1, src=( - x16:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 32000), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), - UOp(Ops.CONST, dtypes.int, arg=0, src=( - x16,)),)),)), - UOp(Ops.CONST, dtypes.int, arg=-1, src=( - x19:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), - UOp(Ops.LOAD, dtypes.int, arg=None, src=( - UOp(Ops.VIEW, dtypes.int.ptr(1), arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1), arg=1, src=()),)),)),)), - UOp(Ops.CONST, dtypes.bool, arg=True, src=( - x19,)),)),)), - UOp(Ops.LOAD, dtypes.half, arg=None, src=( - UOp(Ops.VIEW, dtypes.half.ptr(131072000), arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(1, 4096, 0), offset=0, mask=None, contiguous=False),)), src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(131072000), arg=2, src=()),)),)),)),)),)),)),)),)) - k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) - prg = get_program(k.get_optimized_ast(), k.opts) + c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(4096), arg=0, src=()) + c1 = c0.view(ShapeTracker(views=(View(shape=(4096, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),))) + c2 = UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32001, 63999), strides=(0, 0), offset=0, mask=((0, 32001), (31999, 63999)), contiguous=False), View(shape=(4096, 32000, 32000), strides=(0, 1, 64000), offset=0, mask=None, contiguous=False))), src=()) + c3 = UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 32000), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()) + c4 = UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()) + c5 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1), arg=1, src=()) + c6 = c5.view(ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))) + c7 = c6.load() + c8 = UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(131072000), arg=2, src=()) + c9 = c8.view(ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(1, 4096, 0), offset=0, mask=None, contiguous=False),))) + c10 = c9.load() + c11 = c1.store(((c2.f(Ops.VALID, dtype=dtypes.bool).where(UOp.const(dtypes.int, 1, src=c3), UOp.const(dtypes.int, 0, src=c3)).f(Ops.REDUCE_AXIS, arg=(Ops.ADD, (2,)))+UOp.const(dtypes.int, -1, src=c4)).alu(Ops.CMPNE, c7).alu(Ops.CMPNE, UOp.const(dtypes.bool, True, src=c4)).cast(dtypes.half)*c10).cast(dtypes.float).f(Ops.REDUCE_AXIS, arg=(Ops.ADD, (1,))).cast(dtypes.half)) + ast = c11.sink() + prg = get_program(ast, Device[Device.DEFAULT].renderer) print(prg.src) @unittest.expectedFailure @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "need float4") def test_unrolled_float4_align(self): - ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( - UOp(Ops.STORE, dtypes.void, arg=None, src=( - UOp(Ops.VIEW, dtypes.float.ptr(1), arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1), arg=0, src=()),)), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 1)), src=( - UOp(Ops.WHERE, dtypes.float, arg=None, src=( - UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( - UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( - UOp(Ops.LOAD, dtypes.long, arg=None, src=( - UOp(Ops.VIEW, dtypes.long.ptr(18), arg=ShapeTracker(views=(View(shape=(3, 6), strides=(6, 1), offset=0, mask=None, contiguous=True),)), src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(18), arg=1, src=()),)),)), - UOp(Ops.CONST, dtypes.long, arg=-1, src=( - x11:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), - UOp(Ops.CONST, dtypes.bool, arg=True, src=( - x11,)),)), - UOp(Ops.CONST, dtypes.float, arg=0.0, src=( - x11,)), - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.VIEW, dtypes.float.ptr(18), arg=ShapeTracker(views=(View(shape=(3, 6), strides=(6, 1), offset=0, mask=None, contiguous=True),)), src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(18), arg=2, src=()),)),)),)),)),)),)) + c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1), arg=0, src=()) + c1 = c0.view(ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),))) + c2 = UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(18), arg=1, src=()) + c3 = c2.view(ShapeTracker(views=(View(shape=(3, 6), strides=(6, 1), offset=0, mask=None, contiguous=True),))) + c4 = c3.load() + c5 = UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()) + c6 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(18), arg=2, src=()) + c7 = c6.view(ShapeTracker(views=(View(shape=(3, 6), strides=(6, 1), offset=0, mask=None, contiguous=True),))) + c8 = c7.load() + c9 = c1.store(c4.alu(Ops.CMPNE, UOp.const(dtypes.long, -1, src=c5)).alu(Ops.CMPNE, UOp.const(dtypes.bool, True, src=c5)).where(UOp.const(dtypes.float, 0.0, src=c5), c8).f(Ops.REDUCE_AXIS, arg=(Ops.ADD, (0, 1)))) + ast = c9.sink() opts = [Opt(op=OptOps.UNROLL, axis=0, arg=0)] - k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) - k.apply_opts(opts) - prg = get_program(k.get_optimized_ast(), k.opts) + prg = get_program(ast, Device[Device.DEFAULT].renderer, opts) print(prg.src) - load_idxs = [x.src[1] for x in k.uops if x.op is Ops.LOAD and x.src[0].arg == 2] + load_idxs = [x.src[1] for x in prg.uops if x.op is Ops.LOAD and x.src[0].arg == 2] assert load_idxs[0] < load_idxs[1], f"first loaded idx {load_idxs[0].arg} then {load_idxs[1].arg}!" @unittest.expectedFailure @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "need float4") - @unittest.skipIf(getenv("PTX"), "this is somehow correct in PTX") + @unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "this is somehow correct in PTX") def test_upcasted_stores_out_of_order(self): - ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( - UOp(Ops.STORE, dtypes.void, arg=None, src=( - UOp(Ops.VIEW, dtypes.float.ptr(9360), arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 1, 1, 4, 3, 3), strides=(2340, 468, 36, 0, 0, 0, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)), src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9360), arg=0, src=()),)), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (6,)), src=( - UOp(Ops.MUL, dtypes.float, arg=None, src=( - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.VIEW, dtypes.float.ptr(144), arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(0, 0, 0, 0, 0, 0, 1, 0, 4, 48, 16), offset=0, mask=None, contiguous=False),)), src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(144), arg=1, src=()),)),)), - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.VIEW, dtypes.float.ptr(1040), arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(260, 13, 1, 0, 0, 0, 65, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1040), arg=2, src=()),)),)),)),)),)),)) + c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9360), arg=0, src=()) + c1 = c0.view(ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 1, 1, 4, 3, 3), strides=(2340, 468, 36, 0, 0, 0, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),))) + c2 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(144), arg=1, src=()) + c3 = c2.view(ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(0, 0, 0, 0, 0, 0, 1, 0, 4, 48, 16), offset=0, mask=None, contiguous=False),))) + c4 = c3.load() + c5 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1040), arg=2, src=()) + c6 = c5.view(ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(260, 13, 1, 0, 0, 0, 65, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))) + c7 = c6.load() + c8 = c1.store((c4*c7).f(Ops.REDUCE_AXIS, arg=(Ops.ADD, (6,)))) + ast = c8.sink() opts = [Opt(op=OptOps.UPCAST, axis=3, arg=0), Opt(op=OptOps.UPCAST, axis=2, arg=0)] - k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) - k.apply_opts(opts) - prg = get_program(k.get_optimized_ast(), k.opts) + prg = get_program(ast, Device[Device.DEFAULT].renderer, opts) print(prg.src) - store_idxs = [x.src[1] for x in k.uops if x.op is Ops.STORE] + store_idxs = [x.src[1] for x in prg.uops if x.op is Ops.STORE] for i in range(len(store_idxs) - 1): first_bounds = store_idxs[i].vmin+store_idxs[i].vmax next_bounds = store_idxs[i+1].vmin+store_idxs[i+1].vmax diff --git a/tinygrad_repo/test/test_linearizer_overflows.py b/tinygrad_repo/test/test_linearizer_overflows.py deleted file mode 100644 index b45565b2..00000000 --- a/tinygrad_repo/test/test_linearizer_overflows.py +++ /dev/null @@ -1,165 +0,0 @@ -# ruff: noqa: E501 -import unittest -from tinygrad import dtypes -from tinygrad.codegen.opt.kernel import Kernel -from tinygrad.codegen.opt.search import Opt, OptOps, bufs_from_lin -from extra.optimization.helpers import time_linearizer - -# stuff needed to unpack a kernel -from tinygrad.uop.ops import UOp, Ops -from tinygrad.shape.shapetracker import ShapeTracker -from tinygrad.shape.view import View - -def _test_overflow(ast, opts): - lin = Kernel(ast) - lin.apply_opts(opts) - bufs = bufs_from_lin(lin) - print(bufs) - time_linearizer(lin, bufs) - -# NOTE: if you want these to trigger, set launch bounds on HIP kernels -@unittest.skip("unneeded without launch bounds") -class TestLinearizerOverflow(unittest.TestCase): - def test_overflow_1(self): - ast = UOp(Ops.SINK, None, arg=None, src=( - UOp(Ops.STORE, None, arg=None, src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(51380224), arg=0, src=()), - UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.MAX, dtypes.float, arg=None, src=( - UOp(Ops.ADD, dtypes.float, arg=None, src=( - UOp(Ops.MUL, dtypes.float, arg=None, src=( - UOp(Ops.MUL, dtypes.float, arg=None, src=( - UOp(Ops.ADD, dtypes.float, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=( - UOp(Ops.MUL, dtypes.float, arg=None, src=( - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9633792), arg=1, src=()), - UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 64, 1, 3, 8, 230, 8, 230), strides=(0, 150528, 0, 50176, 0, 224, 0, 1), offset=-675, mask=((0, 1), (0, 64), (0, 1), (0, 3), (0, 8), (3, 227), (0, 8), (3, 227)), contiguous=False), View(shape=(64, 1, 64, 112, 112, 3, 7, 7), strides=(10156800, 0, 0, 3680, 2, 3385600, 425040, 231), offset=0, mask=None, contiguous=False))), src=()),)), - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9408), arg=2, src=()), - UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), - x16:=UOp(Ops.CONST, dtypes.float, arg=0.0, src=( - x17:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=3, src=()), - x20:=UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), - UOp(Ops.SQRT, dtypes.float, arg=None, src=( - UOp(Ops.MUL, dtypes.float, arg=None, src=( - x23:=UOp(Ops.CONST, dtypes.float, arg=1.0, src=( - x17,)), - UOp(Ops.RECIP, dtypes.float, arg=None, src=( - UOp(Ops.ADD, dtypes.float, arg=None, src=( - x23, - UOp(Ops.CONST, dtypes.float, arg=1e-05, src=( - x17,)),)),)),)),)),)), - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=4, src=()), - x20,)),)), - x16,)),)),)) - opts = [Opt(op=OptOps.LOCAL, axis=3, arg=16), Opt(op=OptOps.LOCAL, axis=2, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=2, arg=0)] - _test_overflow(ast, opts) - - # From BEAM on hlb_cifar.py - def test_overflow_2(self): - ast = UOp(Ops.SINK, None, arg=None, src=( - UOp(Ops.STORE, None, arg=None, src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(33554432), arg=0, src=()), - UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(512, 1, 64, 32, 32, 1, 1, 1), strides=(65536, 0, 1024, 32, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=( - UOp(Ops.MUL, dtypes.float, arg=None, src=( - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(16777216), arg=1, src=()), - UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 512, 1, 32, 4, 34, 4, 34), strides=(0, 32768, 0, 1024, 0, 32, 0, 1), offset=-33, mask=((0, 1), (0, 512), (0, 1), (0, 32), (0, 4), (1, 33), (0, 4), (1, 33)), contiguous=False), View(shape=(512, 1, 64, 32, 32, 32, 3, 3), strides=(591872, 0, 0, 136, 1, 18496, 4760, 35), offset=0, mask=None, contiguous=False))), src=()),)), - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(18432), arg=2, src=()), - UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(512, 1, 64, 32, 32, 32, 3, 3), strides=(0, 0, 288, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.LOCAL, axis=3, arg=16), Opt(op=OptOps.LOCAL, axis=2, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=2, arg=0), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UNROLL, axis=0, arg=0)] - _test_overflow(ast, opts) - - # from BEAM on default simple_conv.py (which is quite large): - def test_overflow_3(self): - ast = UOp(Ops.SINK, None, arg=None, src=( - UOp(Ops.STORE, None, arg=None, src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(33554432), arg=0, src=()), - UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(16, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=( - UOp(Ops.MUL, dtypes.float, arg=None, src=( - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(33554432), arg=1, src=()), - UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 16, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 16), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(16, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)), - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(147456), arg=2, src=()), - UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(16, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.LOCAL, axis=3, arg=16), Opt(op=OptOps.LOCAL, axis=2, arg=8), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=3, arg=0), Opt(op=OptOps.UPCAST, axis=1, arg=2), Opt(op=OptOps.UPCAST, axis=2, arg=2)] - _test_overflow(ast, opts) - - # from BEAM on BS=4 simple_conv.py: - def test_overflow_4(self): - ast = UOp(Ops.SINK, None, arg=None, src=( - UOp(Ops.STORE, None, arg=None, src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(8388608), arg=0, src=()), - UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(4, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=( - UOp(Ops.MUL, dtypes.float, arg=None, src=( - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(8388608), arg=1, src=()), - UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 4, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 4), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(4, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)), - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(147456), arg=2, src=()), - UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(4, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.UPCAST, axis=3, arg=4), Opt(op=OptOps.LOCAL, axis=3, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=2, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=2), Opt(op=OptOps.UPCAST, axis=2, arg=4)] - _test_overflow(ast, opts) - - # from BEAM on BS=2 simple_conv.py: - def test_overflow_5(self): - ast = UOp(Ops.SINK, None, arg=None, src=( - UOp(Ops.STORE, None, arg=None, src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(4194304), arg=0, src=()), - UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(2, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=( - UOp(Ops.MUL, dtypes.float, arg=None, src=( - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(4194304), arg=1, src=()), - UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 2, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 2), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(2, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)), - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(147456), arg=2, src=()), - UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(2, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.LOCAL, axis=3, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=3, arg=0), Opt(op=OptOps.LOCAL, axis=2, arg=2), Opt(op=OptOps.UPCAST, axis=1, arg=2), Opt(op=OptOps.UPCAST, axis=2, arg=2)] - _test_overflow(ast, opts) - - # from BEAM on BS=3 simple_conv.py: - def test_overflow_6(self): - ast = UOp(Ops.SINK, None, arg=None, src=( - UOp(Ops.STORE, None, arg=None, src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6291456), arg=0, src=()), - UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=( - UOp(Ops.MUL, dtypes.float, arg=None, src=( - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6291456), arg=1, src=()), - UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 3, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 3), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)), - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(147456), arg=2, src=()), - UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.LOCAL, axis=3, arg=16), Opt(op=OptOps.UPCAST, axis=3, arg=0), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=2, arg=8), Opt(op=OptOps.UPCAST, axis=1, arg=2), Opt(op=OptOps.UPCAST, axis=3, arg=2)] - _test_overflow(ast, opts) - - # from BEAM on BS=3 simple_conv.py: (alt) - def test_overflow_7(self): - ast = UOp(Ops.SINK, None, arg=None, src=( - UOp(Ops.STORE, None, arg=None, src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6291456), arg=0, src=()), - UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=( - UOp(Ops.MUL, dtypes.float, arg=None, src=( - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(6291456), arg=1, src=()), - UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 3, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 3), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)), - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(147456), arg=2, src=()), - UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.UPCAST, axis=3, arg=4), Opt(op=OptOps.LOCAL, axis=3, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=2, arg=8), Opt(op=OptOps.UPCAST, axis=1, arg=2), Opt(op=OptOps.UPCAST, axis=2, arg=4)] - _test_overflow(ast, opts) - -if __name__ == '__main__': - unittest.main() diff --git a/tinygrad_repo/test/test_memory_planner.py b/tinygrad_repo/test/test_memory_planner.py index 0ba3f904..6aa03c6b 100644 --- a/tinygrad_repo/test/test_memory_planner.py +++ b/tinygrad_repo/test/test_memory_planner.py @@ -120,5 +120,19 @@ class TestMemoryPlanner(unittest.TestCase): ] check_assign(bs) + def test_very_small_buffers(self): + bs = [ + [b(0, pin=True), b(1, size=32)], + [b(3, size=4), b(4, size=6)], + ] + check_assign(bs) + + def test_very_big_buffers(self): + bs = [ + [b(0, pin=True), b(1, size=34359738368000)], + [b(3, size=1 << 128), b(4, size=1 << 64)], + ] + check_assign(bs) + if __name__ == "__main__": unittest.main() diff --git a/tinygrad_repo/test/test_method_cache.py b/tinygrad_repo/test/test_method_cache.py index 497b4069..ce413e77 100644 --- a/tinygrad_repo/test/test_method_cache.py +++ b/tinygrad_repo/test/test_method_cache.py @@ -5,9 +5,9 @@ from tinygrad.nn.state import get_state_dict class TestMethodCache(unittest.TestCase): def setUp(self): - self.backup_compiler = Device[Device.DEFAULT].compiler + self.backup_compiler = Device[Device.DEFAULT].compiler.compile_cached def tearDown(self): - Device[Device.DEFAULT].compiler = self.backup_compiler + Device[Device.DEFAULT].compiler.compile_cached = self.backup_compiler def test_simple_methodcache(self): a = Tensor([1]) @@ -15,19 +15,19 @@ class TestMethodCache(unittest.TestCase): c = Tensor([3]) d = Tensor([4]) (a+b).realize() - Device[Device.DEFAULT].compiler = None + Device[Device.DEFAULT].compiler.compile_cached = None (c+d).realize() def test_nested_methodcache(self): a,b,c,d = Tensor([1]), Tensor([2]), Tensor([3]), Tensor([4]) ((a+b)+(a+b)).realize() - Device[Device.DEFAULT].compiler = None + Device[Device.DEFAULT].compiler.compile_cached = None ((c+d)+(c+d)).realize() def test_nested_methodcache_swap(self): a,b,c,d = Tensor([1]), Tensor([2]), Tensor([3]), Tensor([4]) ((a+b)+(c+d)).realize() - Device[Device.DEFAULT].compiler = None + Device[Device.DEFAULT].compiler.compile_cached = None ((c+d)+(a+b)).realize() @unittest.skip("incorrect use of transformer") @@ -38,7 +38,7 @@ class TestMethodCache(unittest.TestCase): # NOTE: you have to do this twice due to the k-v cache for i in range(3): model(Tensor([[1,2,3,4]]), Variable("start_pos", 0, 10).bind(i)).realize() for i in range(3): model(Tensor([[1,2,3,4]]), Variable("start_pos", 0, 10).bind(i)).realize() - Device[Device.DEFAULT].compiler = None + Device[Device.DEFAULT].compiler.compile_cached = None for i in range(3): model(Tensor([[1,2,3,4]]), Variable("start_pos", 0, 10).bind(i)).realize() if __name__ == '__main__': diff --git a/tinygrad_repo/test/test_multitensor.py b/tinygrad_repo/test/test_multitensor.py index dabcf1ed..86159ba3 100644 --- a/tinygrad_repo/test/test_multitensor.py +++ b/tinygrad_repo/test/test_multitensor.py @@ -2,7 +2,7 @@ import unittest, functools, random from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes, Variable from tinygrad.device import is_dtype_supported from tinygrad.uop.ops import Ops, UOp -from tinygrad.helpers import CI, getenv, prod, Context +from tinygrad.helpers import CI, getenv, prod, Context, RANGEIFY from tinygrad.nn.state import get_parameters, get_state_dict from tinygrad.engine.realize import lower_schedule, BufferCopy, CompiledRunner, run_schedule import numpy as np @@ -178,16 +178,14 @@ class TestMultiTensor(unittest.TestCase): run_schedule(sched) np.testing.assert_equal(xt.numpy(), X_np[i*2:i*2+2]) - @given(strat.sampled_from((4, 5)), strat.sampled_from((devices_2, devices_3)), + @given(strat.sampled_from((devices_2, devices_3)), strat.sampled_from((Ops.ADD, Ops.MUL, Ops.MAX)), - strat.sampled_from((None, 0, 1)), strat.sampled_from((None, 0, 1)), strat.sampled_from((1, 0, -1))) - def test_simple_reduce(self, N, devices, rop, shard_axis, reduce_axis, sign): - N = N * len(devices) - X = Tensor.rand(N*N).reshape(N, N).mul(sign) + strat.sampled_from((None, 0, 1)), strat.sampled_from((None, 0, 1))) + def test_simple_reduce(self, devices, rop, shard_axis, reduce_axis): + N = 4 * len(devices) + X = (Tensor.rand(N*N)-1).reshape(N, N).shard_(devices, shard_axis) n = X.numpy() - X.shard_(devices, shard_axis) - f = {Ops.ADD: lambda x: x.sum(reduce_axis), Ops.MUL: lambda x: x.prod(reduce_axis), - Ops.MAX: lambda x: x.max(reduce_axis)}[rop] + f = {Ops.ADD: lambda x: x.sum(reduce_axis), Ops.MUL: lambda x: x.prod(reduce_axis), Ops.MAX: lambda x: x.max(reduce_axis)}[rop] fX = f(X) fn = f(n) np.testing.assert_allclose(fX.numpy(), fn, rtol=1e-6, atol=1e-6) @@ -373,7 +371,8 @@ class TestMultiTensor(unittest.TestCase): np.testing.assert_allclose(y.numpy(), y_shard.numpy(), atol=1e-6, rtol=1e-6) # NOTE: this is failing on LLVM CI, no idea why. Works locally. - @unittest.skipIf(CI and REAL_DEV in ("CUDA", "NV", "LLVM", "CPU", "AMD"), "slow, and flaky on LLVM/CPU") + @unittest.skipIf(CI and REAL_DEV in ("CUDA", "NV", "CPU", "AMD"), "slow, and flaky on CPU") + @unittest.skipIf(RANGEIFY, "TODO: pm_rangeify hangs") def test_data_parallel_resnet(self): from extra.models.resnet import ResNet18 @@ -409,7 +408,8 @@ class TestMultiTensor(unittest.TestCase): # sometimes there is zeros in these grads... why? np.testing.assert_allclose(grad, shard_grad, atol=1e-5, rtol=1e-5) - @unittest.skipIf(CI and REAL_DEV in ("CUDA", "NV", "LLVM", "CPU", "AMD"), "slow, and flaky on LLVM/CPU") + @unittest.skipIf(CI and REAL_DEV in ("CUDA", "NV", "CPU", "AMD"), "slow, and flaky on CPU") + @unittest.skipIf(RANGEIFY, "TODO: pm_rangeify hangs") def test_data_parallel_resnet_train_step(self): from extra.models.resnet import ResNet18 fake_image = Tensor.rand((2, 3, 224//16, 224//16)) @@ -417,6 +417,7 @@ class TestMultiTensor(unittest.TestCase): m = ResNet18() self._test_model_train_step(m, fake_image, labels) + @unittest.skipIf(RANGEIFY, "TODO: pm_rangeify hangs") def test_data_parallel_simple_train_step(self): class Model: def __init__(self): self.conv1 = nn.Linear(128,128) diff --git a/tinygrad_repo/test/test_nn.py b/tinygrad_repo/test/test_nn.py index e46b34a7..417dab3f 100755 --- a/tinygrad_repo/test/test_nn.py +++ b/tinygrad_repo/test/test_nn.py @@ -2,7 +2,7 @@ import unittest import numpy as np import torch -from tinygrad import Tensor, Device, TinyJit +from tinygrad import Tensor, Device, TinyJit, dtypes from tinygrad.uop.ops import Ops from tinygrad.helpers import GlobalCounters, CI, Context from tinygrad.nn import Conv1d, ConvTranspose1d, Conv2d, ConvTranspose2d, Linear, Embedding @@ -229,7 +229,8 @@ class TestNN(unittest.TestCase): torch_z = torch_layer(torch_x) torch_z.sum().backward() - np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6) + # TODO: why is torch numbers all 0? + np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=5e-6) def test_layernorm(self): N, C, H, W = 20, 5, 10, 10 @@ -332,7 +333,7 @@ class TestNN(unittest.TestCase): np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6) np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3) - np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=2e-3, rtol=1e-3) + np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=3e-3, rtol=1e-3) np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=1e-3, rtol=1e-3) def test_rmsnorm(self): @@ -465,7 +466,7 @@ class TestNN(unittest.TestCase): # used to fail bounds check with Context(FUSE_ARANGE=1): embedding = Embedding(100, 1024) - input_ids = Tensor.empty(16, 16) + input_ids = Tensor.empty(16, 16, dtype=dtypes.int) embedding(input_ids).realize() def test_load_state_dict(self): diff --git a/tinygrad_repo/test/test_ops.py b/tinygrad_repo/test/test_ops.py index 8878af0e..afb03335 100644 --- a/tinygrad_repo/test/test_ops.py +++ b/tinygrad_repo/test/test_ops.py @@ -2,7 +2,7 @@ import time, math, unittest, functools, platform, warnings import numpy as np from typing import List, Callable import torch -from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, Context, TRANSCENDENTAL, AMD_LLVM +from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, Context, TRANSCENDENTAL, CPU_LLVM, AMD_LLVM, RANGEIFY, OSX from tinygrad import Tensor, Device, dtypes from tinygrad.tensor import _to_np_dtype from tinygrad.device import is_dtype_supported @@ -17,6 +17,9 @@ if CI: FORWARD_ONLY = getenv("FORWARD_ONLY", 0) PRINT_TENSORS = getenv("PRINT_TENSORS", 0) +def slow_test(test_func): + return unittest.skipIf(getenv("SKIP_SLOW_TEST"), "Skipping slow test")(test_func) + def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, grad_atol=1e-4, grad_rtol=1e-3, forward_only=False, vals=None, low=-2, high=2): if tinygrad_fxn is None: tinygrad_fxn = torch_fxn @@ -231,7 +234,8 @@ class TestOps(unittest.TestCase): def test_unfold(self): helper_test_op([(8,)], lambda x: x.unfold(0, 2, 1)) helper_test_op([(8,)], lambda x: x.unfold(0, 2, 2)) - helper_test_op([(8,)], lambda x: x.unfold(0, 7, 3)) + # TODO: something is wrong with unfold + if not getenv("TINY_BACKEND"): helper_test_op([(8,)], lambda x: x.unfold(0, 7, 3)) helper_test_op([(3,3,3)], lambda x: x.unfold(2, 2, 8)) helper_test_op([(3,3,3)], lambda x: x.unfold(1, 0, 8)) helper_test_op([(3,3,3,3,3)], lambda x: x.unfold(-1, 2, 2)) @@ -308,6 +312,12 @@ class TestOps(unittest.TestCase): helper_test_op([], lambda: torch.nn.functional.pad(torch.ones(256,256), pad=(0,64,0,0)).sum(axis=1), lambda: Tensor.ones(256,256).pad(((0,0), (0,64))).sum(axis=1), forward_only=True) + @unittest.skipUnless(OSX or Device.DEFAULT=="CPU", "TODO fail on some devices") + def test_sum_twice(self): + helper_test_op([(4, 4, 4)], lambda x: x.sum((0, 1)).sum()) + helper_test_op([(4, 4, 4)], lambda x: x.sum((0, 2)).sum()) + helper_test_op([(4, 4, 4)], lambda x: x.sum((1, 2)).sum()) + # this is more complex and won't fold for a while def test_sum_cat_collapse(self): helper_test_op([], lambda: torch.cat([torch.ones(256,256), torch.zeros(256,64)], dim=1).sum(axis=1), @@ -1009,9 +1019,11 @@ class TestOps(unittest.TestCase): def test_small_cumsum(self): helper_test_op([(10)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0)) + @slow_test def test_simple_cumsum(self): helper_test_op([(512)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0)) helper_test_op([(1022)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0)) + @slow_test def test_cumsum(self): helper_test_op([()], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0)) self.helper_test_exception([()], lambda x: torch.cumsum(x, dim=1), lambda x: Tensor.cumsum(x, axis=1), expected=IndexError) @@ -1029,9 +1041,11 @@ class TestOps(unittest.TestCase): def test_small_cumprod(self): helper_test_op([(10)],lambda x: torch.cumprod(x, dim=0),lambda x: Tensor.cumprod(x, axis=0)) + @slow_test def test_simple_cumprod(self): helper_test_op([(512)],lambda x: torch.cumprod(x, dim=0),lambda x: Tensor.cumprod(x, axis=0)) helper_test_op([(1022)],lambda x: torch.cumprod(x, dim=0),lambda x: Tensor.cumprod(x, axis=0)) + @slow_test def test_cumprod(self): helper_test_op([()],lambda x: torch.cumprod(x, dim=0),lambda x: Tensor.cumprod(x, axis=0)) self.helper_test_exception([()],lambda x: torch.cumprod(x, dim=1),lambda x: Tensor.cumprod(x, axis=1),expected=IndexError) @@ -1049,9 +1063,11 @@ class TestOps(unittest.TestCase): def test_small_cummax(self): helper_test_op([(10)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)) + @slow_test def test_simple_cummax(self): helper_test_op([(512)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)) helper_test_op([(1022)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)) + @slow_test def test_cummax(self): helper_test_op([()], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)) # TODO: torch allows this? @@ -1128,12 +1144,12 @@ class TestOps(unittest.TestCase): lambda x: x.argsort(dim, descending), forward_only=True) def test_topk(self): - helper_test_op([(10)], lambda x: x.topk(3).values, lambda x: x.topk(3)[0], forward_only=True) - helper_test_op([(10)], lambda x: x.topk(3).indices.type(torch.int32), lambda x: x.topk(3)[1], forward_only=True) + helper_test_op([(8)], lambda x: x.topk(3).values, lambda x: x.topk(3)[0], forward_only=True) + helper_test_op([(8)], lambda x: x.topk(3).indices.type(torch.int32), lambda x: x.topk(3)[1], forward_only=True) for dim in [0, 1, -1]: for largest in [True, False]: for sorted_ in [True]: # TODO support False - helper_test_op([(6,5,4)], + helper_test_op([(5,5,4)], lambda x: x.topk(4, dim, largest, sorted_).values, lambda x: x.topk(4, dim, largest, sorted_)[0], forward_only=True) helper_test_op([(5,5,4)], @@ -1148,53 +1164,55 @@ class TestOps(unittest.TestCase): np.testing.assert_equal(indices.numpy(), [2, 4, 6]) self.helper_test_exception([(4)], lambda x: x.topk(5), expected=(RuntimeError, ValueError)) + @slow_test def test_einsum(self): # matrix transpose - helper_test_op([(150,150)], lambda a: torch.einsum('ij->ji', a), lambda a: Tensor.einsum('ij->ji', a)) - helper_test_op([(150,150)], lambda a: torch.einsum('ij -> ji', a), lambda a: Tensor.einsum('ij -> ji', a)) - helper_test_op([(150,150)], lambda a: torch.einsum('ji', a), lambda a: Tensor.einsum('ji', a)) - helper_test_op([(20,30,40)], lambda a: torch.einsum('jki', a), lambda a: Tensor.einsum('jki', a)) - helper_test_op([(20,30,40)], lambda a: torch.einsum('dog', a), lambda a: Tensor.einsum('dog', a)) + helper_test_op([(10,10)], lambda a: torch.einsum('ij->ji', a), lambda a: Tensor.einsum('ij->ji', a)) + helper_test_op([(10,10)], lambda a: torch.einsum('ij -> ji', a), lambda a: Tensor.einsum('ij -> ji', a)) + helper_test_op([(10,10)], lambda a: torch.einsum('ji', a), lambda a: Tensor.einsum('ji', a)) + helper_test_op([(4,6,8)], lambda a: torch.einsum('jki', a), lambda a: Tensor.einsum('jki', a)) + helper_test_op([(4,6,8)], lambda a: torch.einsum('dog', a), lambda a: Tensor.einsum('dog', a)) # no -> and empty rhs - helper_test_op([(20,30),(30,40)], lambda a, b: torch.einsum('ij,jk', a, b), lambda a, b: Tensor.einsum('ij,jk', a, b)) + helper_test_op([(4,6),(6,8)], lambda a, b: torch.einsum('ij,jk', a, b), lambda a, b: Tensor.einsum('ij,jk', a, b)) # sum all elements - helper_test_op([(20,30,40)], lambda a: torch.einsum('ijk->', a), lambda a: Tensor.einsum('ijk->', a)) + helper_test_op([(4,6,8)], lambda a: torch.einsum('ijk->', a), lambda a: Tensor.einsum('ijk->', a)) # column sum - helper_test_op([(50,50)], lambda a: torch.einsum('ij->j', a), lambda a: Tensor.einsum('ij->j', a)) + helper_test_op([(5,5)], lambda a: torch.einsum('ij->j', a), lambda a: Tensor.einsum('ij->j', a)) # row sum - helper_test_op([(15,15)], lambda a: torch.einsum('ij->i', a), lambda a: Tensor.einsum('ij->i', a)) + helper_test_op([(5,5)], lambda a: torch.einsum('ij->i', a), lambda a: Tensor.einsum('ij->i', a)) # matrix-vector multiplication - helper_test_op([(15,20), (20,)], lambda a,b: torch.einsum('ik,k->i', a,b), lambda a,b: Tensor.einsum('ik,k->i', a, b)) + helper_test_op([(3,4), (4,)], lambda a,b: torch.einsum('ik,k->i', a,b), lambda a,b: Tensor.einsum('ik,k->i', a, b)) # matrix-matrix multiplication - helper_test_op([(15,20), (20,30)], lambda a,b: torch.einsum('ik,kj->ij', a,b), lambda a,b: Tensor.einsum('ik,kj->ij', a, b)) + helper_test_op([(3,4), (4,5)], lambda a,b: torch.einsum('ik,kj->ij', a,b), lambda a,b: Tensor.einsum('ik,kj->ij', a, b)) # matrix-matrix multiplication, different letter order - helper_test_op([(15,20), (20,30)], lambda a,b: torch.einsum('jk,ki->ji', a,b), lambda a,b: Tensor.einsum('jk,ki->ji', a, b)) + helper_test_op([(3,4), (4,5)], lambda a,b: torch.einsum('jk,ki->ji', a,b), lambda a,b: Tensor.einsum('jk,ki->ji', a, b)) # dot product - helper_test_op([(30),(30)], lambda a,b: torch.einsum('i,i->i', [a,b]), lambda a,b: Tensor.einsum('i,i->i', [a,b])) + helper_test_op([(5),(5)], lambda a,b: torch.einsum('i,i->i', [a,b]), lambda a,b: Tensor.einsum('i,i->i', [a,b])) # hadamard product - helper_test_op([(30,40),(30,40)], lambda a,b: torch.einsum('ij,ij->ij', a,b), lambda a,b: Tensor.einsum('ij,ij->ij', a,b)) + helper_test_op([(5,6),(5,6)], lambda a,b: torch.einsum('ij,ij->ij', a,b), lambda a,b: Tensor.einsum('ij,ij->ij', a,b)) # outer product - helper_test_op([(15,), (15,)], lambda a,b: torch.einsum('i,j->ij', a,b), lambda a,b: Tensor.einsum('i,j->ij',a,b)) + helper_test_op([(5,), (5,)], lambda a,b: torch.einsum('i,j->ij', a,b), lambda a,b: Tensor.einsum('i,j->ij',a,b)) # batch matrix multiplication - helper_test_op([(10,20,30),(10,30,40)], lambda a,b: torch.einsum('ijk,ikl->ijl', [a, b]), lambda a,b: Tensor.einsum('ijk,ikl->ijl', [a, b])) + helper_test_op([(2,4,6),(2,6,8)], lambda a,b: torch.einsum('ijk,ikl->ijl', [a, b]), lambda a,b: Tensor.einsum('ijk,ikl->ijl', [a, b])) # batch matrix multiplication, result permuted - helper_test_op([(10,20,25),(10,25,32)], lambda a,b: torch.einsum('ijk,ikl->jil', [a, b]), lambda a,b: Tensor.einsum('ijk,ikl->jil', [a, b])) + helper_test_op([(2,4,5),(2,5,7)], lambda a,b: torch.einsum('ijk,ikl->jil', [a, b]), lambda a,b: Tensor.einsum('ijk,ikl->jil', [a, b])) # batch matrix multiplication, result & input permuted - helper_test_op([(20,10,25),(10,25,32)], lambda a,b: torch.einsum('jik,ikl->jil', [a, b]), lambda a,b: Tensor.einsum('jik,ikl->jil', [a, b])) + helper_test_op([(4,2,5),(2,5,7)], lambda a,b: torch.einsum('jik,ikl->jil', [a, b]), lambda a,b: Tensor.einsum('jik,ikl->jil', [a, b])) # batch matrix multiplication, result with different letters - helper_test_op([(10,20,30),(10,30,40)], lambda a,b: torch.einsum('ijk,ika->ija', [a, b]), lambda a,b: Tensor.einsum('ijk,ika->ija', [a, b])) + helper_test_op([(2,4,6),(2,6,8)], lambda a,b: torch.einsum('ijk,ika->ija', [a, b]), lambda a,b: Tensor.einsum('ijk,ika->ija', [a, b])) # tensor contraction - helper_test_op([(3,5,8,10),(11,13,5,16,8)], lambda a,b: torch.einsum('pqrs,tuqvr->pstuv', a,b), + helper_test_op([(3,5,8,10),(11,7,5,13,8)], lambda a,b: torch.einsum('pqrs,tuqvr->pstuv', a,b), lambda a,b: Tensor.einsum('pqrs,tuqvr->pstuv', a,b), atol=1e-5) # tensor contraction, input permuted - helper_test_op([(3,8,10,5),(11,5,13,16,8)], lambda a,b: torch.einsum('prsq,tquvr->pstuv', a,b), + helper_test_op([(3,8,10,5),(11,5,7,13,8)], lambda a,b: torch.einsum('prsq,tquvr->pstuv', a,b), lambda a,b: Tensor.einsum('prsq,tquvr->pstuv', a,b), atol=1e-5) # tensor contraction, result with different letters - helper_test_op([(3,5,8,10),(11,13,5,16,8)], lambda a,b: torch.einsum('zqrs,tuqvr->zstuv', a,b), + helper_test_op([(3,5,8,10),(11,7,5,13,8)], lambda a,b: torch.einsum('zqrs,tuqvr->zstuv', a,b), lambda a,b: Tensor.einsum('zqrs,tuqvr->zstuv', a,b), atol=1e-5) # bilinear transformation helper_test_op([(2,3),(5,3,7),(2,7)], lambda a,b,c: torch.einsum('ik,jkl,il->ij', [a,b,c]), lambda a,b,c: Tensor.einsum('ik,jkl,il->ij', [a,b,c])) + @slow_test def test_einsum_ellipsis(self): """The expected behavior for einsum is described in the PyTorch docs: https://pytorch.org/docs/stable/generated/torch.einsum.html""" # test ellipsis @@ -1209,32 +1227,24 @@ class TestOps(unittest.TestCase): # match torch ellipsis handling helper_test_op([(32, 7, 24, 24, 24), (32, 7, 24, 24, 24)], lambda a, b: torch.einsum('ij...,ij...->ij', [a, b]), lambda a, b: Tensor.einsum('ij...,ij...->ij', [a, b])) - # multiple ellipsis in one operand are not allowed. This test shall raise an exception. - with self.assertRaises(RuntimeError): - helper_test_op([(2, 3, 4), (2, 3, 4)], lambda a, b: torch.einsum('...ik..., ...jk ->', [a, b]), - lambda a, b: Tensor.einsum('...ik..., ...jk ->', [a, b])) - # multiple ellipsis must broadcast together. This test shall raise an exception. - with self.assertRaises(RuntimeError): - helper_test_op([(2, 3, 4, 5), (5, 2, 7)], lambda a, b: torch.einsum('i...j,ji...->...', [a, b]), - lambda a, b: Tensor.einsum('i...j,ji...->...', [a, b])) + # multiple ellipsis in one operand are not allowed + self.helper_test_exception([(2, 3, 4), (2, 3, 4)], lambda a, b: torch.einsum('...ik..., ...jk ->', [a, b]), + lambda a, b: Tensor.einsum('...ik..., ...jk ->', [a, b]), expected=(RuntimeError, IndexError)) + # multiple ellipsis must broadcast together + self.helper_test_exception([(2, 3, 4), (2, 3, 4)], lambda a, b: torch.einsum('i...j,ji...->...', [a, b]), + lambda a, b: Tensor.einsum('i...j,ji...->...', [a, b]), expected=RuntimeError) def test_einsum_shape_check(self): - a = Tensor.zeros(3,8,10,5) - b = Tensor.zeros(11,5,13,16,8) - with self.assertRaises(AssertionError): - Tensor.einsum('pqrs,tuqvr->pstuv',a,b) + self.helper_test_exception([(3,8,10,5), (11,5,13,16,8)], lambda a, b: torch.einsum('pqrs,tuqvr->pstuv', [a, b]), + lambda a, b: Tensor.einsum('pqrs,tuqvr->pstuv', [a, b]), expected=RuntimeError) def test_einsum_arity_check1(self): - a = Tensor.zeros(10,15) - b = Tensor.zeros(15,20) - c = Tensor.zeros(20,10) - with self.assertRaises(AssertionError): - Tensor.einsum('ij,jk->ij', a,b,c) + self.helper_test_exception([(10,15), (15,20), (20,10)], lambda a, b, c: torch.einsum('ij,jk->ij', [a, b, c]), + lambda a, b, c: Tensor.einsum('ij,jk->ij', [a, b, c]), expected=(ValueError, RuntimeError)) def test_einsum_arity_check2(self): - a = Tensor.zeros(10,10) - with self.assertRaises(AssertionError): - Tensor.einsum('ij,jk->ij', a) + self.helper_test_exception([(10,10)], lambda a: torch.einsum('ij,jk->ij', a), + lambda a: Tensor.einsum('ij,jk->ij', a), expected=(ValueError, RuntimeError)) @unittest.skipIf(IMAGE>0, "no 1d dot for images") def test_dot_1d(self): @@ -1246,6 +1256,7 @@ class TestOps(unittest.TestCase): self.helper_test_exception([(4), (1,2)], lambda x, y: x.matmul(y), Tensor.dot, expected=RuntimeError) self.helper_test_exception([(2,1), (4)], lambda x, y: x.matmul(y), Tensor.dot, expected=RuntimeError) self.helper_test_exception([(1), (4)], lambda x, y: x.matmul(y), Tensor.dot, expected=RuntimeError) + @slow_test def test_dot(self): helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-5) helper_test_op([(8,45,65), (8,65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-5) @@ -1300,12 +1311,13 @@ class TestOps(unittest.TestCase): np.arange(64,128,dtype=np.float32).reshape(8,8)]) def test_small_gemm_eye(self): helper_test_op(None, lambda x,y: x.matmul(y), lambda x,y: x@y, vals=[np.eye(8).astype(np.float32), np.eye(8).astype(np.float32)]) - @unittest.skipIf(CI and Device.DEFAULT in ["NV", "LLVM", "GPU", "CUDA"] or IMAGE + @unittest.skipIf(CI and Device.DEFAULT in ["NV", "CL", "CUDA"] or (Device.DEFAULT == "CPU" and CPU_LLVM) or IMAGE or (Device.DEFAULT == "WEBGPU" and platform.system() == "Windows"), "not supported on these in CI/IMAGE") def test_gemm_fp16(self): helper_test_op([(64,64), (64,64)], lambda x,y: x.half().matmul(y.half()), atol=5e-3, rtol=5e-3) def test_gemm(self): helper_test_op([(64,64), (64,64)], lambda x,y: x.matmul(y)) + @slow_test def test_big_gemm(self): helper_test_op([(256,256), (256,256)], lambda x,y: x.matmul(y), atol=1e-4) @unittest.skipIf(IMAGE>0, "no 0 in shape matmul on images") @@ -1317,12 +1329,14 @@ class TestOps(unittest.TestCase): helper_test_op([(0,0), (0,0)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-7) helper_test_op([(0), (0,8)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-7) helper_test_op([(0), (0)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-7) + @slow_test def test_broadcastdot(self): helper_test_op([(10,45,65), (65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4) with self.assertRaises(RuntimeError): a = Tensor(3.14) b = Tensor.ones(3,3) a @ b + @slow_test def test_multidot(self): helper_test_op([(10,45,65), (10,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4) helper_test_op([(3,3,45,65), (3,3,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4) @@ -1400,6 +1414,11 @@ class TestOps(unittest.TestCase): helper_test_op(None, lambda x: x.max(), forward_only=True, vals=[[False, True]]) helper_test_op(None, lambda x: x.max(), forward_only=True, vals=[[True, False]]) + def test_const_reduce(self): + helper_test_op([(3,3)], lambda x: torch.full_like(x, 2).sum(), lambda x: (x.full_like(2)).sum(), forward_only=True) + helper_test_op([(3,3)], lambda x: torch.full_like(x, 2).prod(), lambda x: (x.full_like(2)).prod(), forward_only=True) + helper_test_op([(3,3)], lambda x: torch.full_like(x, 2).max(), lambda x: (x.full_like(2)).max(), forward_only=True) + @unittest.skipIf(Device.DEFAULT == "QCOM", "OpenCL fails to compile this (both on GPU(qcom)/QCOM backends)") def test_any(self): helper_test_op([(3,4,5,6)], lambda x: x.any(), forward_only=True) @@ -1451,12 +1470,14 @@ class TestOps(unittest.TestCase): def test_mean_zero_axis(self): helper_test_op([(1,0,3,0,5)], lambda x: x.mean(axis=(1,3))) + @slow_test def test_var(self): helper_test_op([(15, 25, 35)], lambda x: x.var()) helper_test_op([(15, 25, 35)], lambda x: x.var(correction=0)) helper_test_op([(15, 25, 35)], lambda x: x.var(correction=5)) # TODO: fix this # helper_test_op([(10, 2)], lambda x: x.var(correction=50)) + @slow_test def test_var_axis(self): helper_test_op([(15, 25, 35)], lambda x: x.var(0)) helper_test_op([(15, 25, 35)], lambda x: x.var(2)) @@ -1485,10 +1506,12 @@ class TestOps(unittest.TestCase): helper_test_op([(15, 25, 35)], lambda x: x.var(keepdim=True)) helper_test_op([(15, 25, 35)], lambda x: x.var(0, keepdim=True, correction=0)) + @slow_test def test_std(self): helper_test_op([(15, 25, 35)], lambda x: x.std()) helper_test_op([(15, 25, 35)], lambda x: x.std(correction=0)) helper_test_op([(15, 25, 35)], lambda x: x.std(correction=5)) + @slow_test def test_std_axis(self): helper_test_op([(15, 25, 35)], lambda x: x.std(0)) helper_test_op([(15, 25, 35)], lambda x: x.std(2)) @@ -1516,6 +1539,7 @@ class TestOps(unittest.TestCase): def test_std_keepdim(self): helper_test_op([(15, 25, 35)], lambda x: x.std(keepdim=True)) helper_test_op([(15, 25, 35)], lambda x: x.std(0, keepdim=True, correction=0)) + @slow_test def test_std_mean(self): helper_test_op([(15,25,35)], lambda x: torch.stack(torch.std_mean(x)), lambda x: Tensor.stack(*x.std_mean())) @@ -1565,6 +1589,7 @@ class TestOps(unittest.TestCase): helper_test_op([()], lambda x: torch.logsumexp(x, dim=0), lambda x: x.logsumexp(0), atol=1e-7, grad_atol=1e-7) helper_test_op([()], lambda x: torch.logsumexp(x, dim=-1), lambda x: x.logsumexp(-1), atol=1e-7, grad_atol=1e-7) + @slow_test def test_logcumsumexp(self): helper_test_op([(45,65)], lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(0), atol=1e-7, grad_atol=1e-7) helper_test_op([(45,65)], lambda x: torch.logcumsumexp(x, dim=1), lambda x: x.logcumsumexp(1), atol=1e-7, grad_atol=1e-7) @@ -1634,6 +1659,7 @@ class TestOps(unittest.TestCase): helper_test_op([(45,65), (45,1)], lambda x,y: x/y) helper_test_op([(45,65), ()], lambda x,y: x/y) + @slow_test def test_broadcast_partial(self): for torch_op, tinygrad_op in [(torch.add, Tensor.add), (torch.sub, Tensor.sub), (torch.mul, Tensor.mul), (torch.div, Tensor.div), (torch.pow, Tensor.pow)]: @@ -2029,12 +2055,14 @@ class TestOps(unittest.TestCase): lambda x,w,b: torch.nn.functional.conv2d(x,w,b), lambda x,w,b: Tensor.conv2d(x,w,b), grad_rtol=1e-5) + @slow_test @unittest.skipIf(IMAGE>0, "no conv3d on images") def test_simple_conv3d(self): helper_test_op([(1,4,9,9,9), (4,4,3,3,3)], lambda x,w: torch.nn.functional.conv3d(x,w), lambda x,w: Tensor.conv2d(x,w), grad_rtol=1e-5) + @slow_test @unittest.skipIf(IMAGE>0, "no conv3d on images") def test_padded_conv3d(self): helper_test_op([(1,4,5,5,5), (4,4,3,3,3)], @@ -2056,6 +2084,7 @@ class TestOps(unittest.TestCase): lambda x,w: torch.nn.functional.conv2d(x,w), lambda x,w: Tensor.conv2d(x,w), grad_rtol=1e-5) + @slow_test def test_nested_conv2d(self): helper_test_op([(1,32,9,9), (32,32,3,3), (32,32,3,3)], lambda x,w1,w2: torch.nn.functional.conv2d(torch.nn.functional.conv2d(x,w1).relu(), w2), @@ -2090,6 +2119,7 @@ class TestOps(unittest.TestCase): lambda x,w: torch.nn.functional.conv_transpose2d(x,w,groups=2), lambda x,w: Tensor.conv_transpose2d(x,w,groups=2), grad_rtol=1e-5) + @slow_test def test_padded_conv_transpose2d(self): for padding in [(1,2), (2,1), 2, 1, 0]: helper_test_op([(2,4,9,9), (4,4,3,3)], @@ -2098,6 +2128,7 @@ class TestOps(unittest.TestCase): self.helper_test_exception([(2,16,2,2), (32,16,3,3)], lambda x,w: torch.nn.functional.conv_transpose2d(x,w,padding=(1,1,1)), lambda x,w: Tensor.conv_transpose2d(x,w,padding=(1,1,1)), expected=(RuntimeError, ValueError)) + @slow_test def test_dilated_conv_transpose2d(self): for dilation in [(1,2), (2,1), 2, 1]: helper_test_op([(2,4,9,9), (4,4,3,3)], @@ -2110,12 +2141,14 @@ class TestOps(unittest.TestCase): lambda x,w: torch.nn.functional.conv_transpose2d(x,w, stride=stride), lambda x,w: Tensor.conv_transpose2d(x,w,stride=stride), atol=1e-5, grad_rtol=1e-5) + @slow_test def test_output_padded_conv_transpose2d(self): for output_padding, stride in [((1,1), (2,3)), ((2,1), (3,2))]: helper_test_op([(2,4,6,5), (4,4,3,3),(4,)], lambda x,w,b: torch.nn.functional.conv_transpose2d(x,w,b,output_padding=output_padding,stride=stride), lambda x,w,b: Tensor.conv_transpose2d(x,w,b,output_padding=output_padding,stride=stride), grad_rtol=1e-5) + @slow_test @unittest.skipIf(IMAGE>0, "no conv3d on images") def test_simple_conv_transpose3d(self): helper_test_op([(2,4,9,9,9), (4,4,3,3,3)], @@ -2170,8 +2203,10 @@ class TestOps(unittest.TestCase): lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups), lambda x,w: Tensor.conv2d(x,w,groups=groups), grad_rtol=1e-5) def test_conv2d(self): self._test_conv2d(bs=1, cin=3) + @slow_test def test_conv2d_bs_4_cin_3(self): self._test_conv2d(bs=4, cin=3, cout=2) def test_conv2d_bs_1_cin_1(self): self._test_conv2d(bs=1, cin=1) + @slow_test def test_conv2d_bs_4_cin_1(self): self._test_conv2d(bs=4, cin=1) def test_conv2d_errors(self): @@ -2185,6 +2220,7 @@ class TestOps(unittest.TestCase): self.helper_test_exception([(2,16,2,2), (32,16,3,3)], lambda x,w:torch.nn.functional.conv2d(x,w,padding=(1,1,1)), lambda x,w: Tensor.conv2d(x,w,padding=(1,1,1)), expected=(RuntimeError, ValueError)) + @slow_test def test_large_input_conv2d(self): bs = 4 cin = 16 @@ -2242,16 +2278,18 @@ class TestOps(unittest.TestCase): lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups), lambda x,w: Tensor.conv2d(x,w,groups=groups), grad_rtol=1e-5) + @slow_test def test_strided_conv2d_simple(self): bs,H,W = 2,3,1 helper_test_op([(bs,1,5,1), (1,1,H,W)], lambda x,w: torch.nn.functional.conv2d(x,w,stride=2), lambda x,w: Tensor.conv2d(x,w,stride=2)) - @unittest.skipIf(Device.DEFAULT != "LLVM", "DEVECTORIZE=0 only for LLVM") + @unittest.skipUnless(Device.DEFAULT == "CPU" and CPU_LLVM, "DEVECTORIZE=0 only for LLVM") def test_strided_conv2d_simple_vec(self): with Context(DEVECTORIZE=0): self.test_strided_conv2d_simple() + @slow_test def test_strided_conv2d(self): bs = 4 cin = 3 @@ -2337,6 +2375,7 @@ class TestOps(unittest.TestCase): lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz), lambda x: Tensor.max_pool2d(x, kernel_size=ksz)) + @slow_test def test_max_pool2d(self): for ksz in [(2,2), (3,3), 2, 3, (3,2), (5,5), (5,1)]: with self.subTest(kernel_size=ksz): @@ -2344,41 +2383,45 @@ class TestOps(unittest.TestCase): lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz), lambda x: Tensor.max_pool2d(x, kernel_size=ksz)) + @slow_test def test_max_pool2d_padding(self): for ksz in [(2,2), (3,3), 2, 3, (3,2)]: for p in [1, (1,0), (0,1)]: with self.subTest(kernel_size=ksz, padding=p): - helper_test_op([(32,2,11,28)], + helper_test_op([(4,2,11,28)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz, padding=p), lambda x: Tensor.max_pool2d(x, kernel_size=ksz, padding=p)) - self.helper_test_exception([(32,2,110,28)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(2,2), padding=(1,1,1)), + self.helper_test_exception([(4,2,110,28)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(2,2), padding=(1,1,1)), lambda x: Tensor.max_pool2d(x, kernel_size=(2,2), padding=(1,1,1)), expected=(RuntimeError, ValueError)) + @slow_test def test_max_pool2d_asymmetric_padding(self): - shape = (32,2,111,28) for p in [(0,1,0,1), (2,1,2,1), (2,0,2,1)]: with self.subTest(padding=p): - helper_test_op([shape], + helper_test_op([(4,2,111,28)], lambda x: torch.nn.functional.max_pool2d(torch.nn.functional.pad(x, p, value=float("-inf")), kernel_size=(5,5)), lambda x: Tensor.max_pool2d(x, kernel_size=(5,5), padding=p)) + @slow_test def test_max_pool2d_padding_int(self): ksz = (2,2) - helper_test_op([(32,2,11,28)], + helper_test_op([(4,2,11,28)], lambda x: torch.nn.functional.max_pool2d(x.int(), kernel_size=ksz, padding=1), lambda x: Tensor.max_pool2d(x.int(), kernel_size=ksz, padding=1), forward_only=True) + @slow_test def test_max_pool2d_bigger_stride(self): for stride in [(2,3), (3,2), 2, 3]: with self.subTest(stride=stride): - helper_test_op([(32,2,11,28)], + helper_test_op([(4,2,11,28)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(2,2), stride=stride), lambda x: Tensor.max_pool2d(x, kernel_size=(2,2), stride=stride)) + @slow_test def test_max_pool2d_bigger_stride_dilation(self): for stride, dilation in zip([(2,3), (3,2), 2, 3, 4], [(3,2), (2,3), 2, 3, 6]): with self.subTest(stride=stride): - helper_test_op([(32,2,11,28)], + helper_test_op([(4,2,11,28)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(2,2), stride=stride, dilation=dilation), lambda x: Tensor.max_pool2d(x, kernel_size=(2,2), stride=stride, dilation=dilation)) @@ -2388,6 +2431,7 @@ class TestOps(unittest.TestCase): lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(5,5), stride=1), lambda x: Tensor.max_pool2d(x, kernel_size=(5,5), stride=1)) + @slow_test def test_max_pool2d_smaller_stride(self): for stride in [(2,3), (3,2), 2, 3]: with self.subTest(stride=stride): @@ -2395,6 +2439,7 @@ class TestOps(unittest.TestCase): lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(5,5), stride=stride), lambda x: Tensor.max_pool2d(x, kernel_size=(5,5), stride=stride)) + @slow_test def test_max_pool2d_dilation(self): for dilation in [(2, 3), (3, 2), 2, 3]: helper_test_op([(3, 2, 17, 14)], @@ -2449,6 +2494,7 @@ class TestOps(unittest.TestCase): lambda x: Tensor.max_pool2d(x, kernel_size=(2,2), stride=1, return_indices=True)[1], vals=[[[[[1,2]*3]*6]]], forward_only=True) # Tensor([1,2,1,2,1,2]).expand(1,1,6,6) + @slow_test def test_max_unpool2d(self): args = {"kernel_size":(5,5), "stride":(6,5)} helper_test_op([(8,3,50,50)], @@ -2479,6 +2525,7 @@ class TestOps(unittest.TestCase): ), forward_only=True) + @slow_test def test_avg_pool2d(self): shape = (32,2,111,28) for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]: @@ -2492,6 +2539,7 @@ class TestOps(unittest.TestCase): lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=(1,2), padding=(0,1), stride=(5,1)), lambda x: Tensor.avg_pool2d(x, kernel_size=(1,2), padding=(0,1), stride=(5,1)), rtol=1e-5) + @slow_test def test_avg_pool2d_padding(self): shape = (32,2,111,28) for ksz in [(2,2), (3,3), 2, 3, (3,2)]: @@ -2513,6 +2561,7 @@ class TestOps(unittest.TestCase): self.helper_test_exception([shape], lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=(2,2), padding=(1,1,1)), lambda x: Tensor.avg_pool2d(x, kernel_size=(2,2), padding=(1,1,1)), expected=(RuntimeError, ValueError)) + @slow_test def test_avg_pool2d_padding_not_counted(self): shape = (32,2,111,28) for ksz in [(2,2), (3,3), 2, 3, (3,2)]: @@ -2594,12 +2643,14 @@ class TestOps(unittest.TestCase): def test_interpolate_nearest_exact(self): self.test_interpolate_nearest("nearest-exact") + @slow_test def test_interpolate_bilinear(self): for in_sz, out_sz in [((12,20),(9,31)), ((12,9),(31,20)), ((9,31),(20,12))]: helper_test_op([(2,3)+in_sz], lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode="bilinear"), lambda x: Tensor.interpolate(x, size=out_sz, mode="linear"), atol=1e-4) + @slow_test def test_interpolate_bilinear_corners_aligned(self): for in_sz, out_sz in [((12,20),(9,31)), ((12,9),(31,20)), ((9,31),(20,12))]: helper_test_op([(2,3)+in_sz], @@ -2618,6 +2669,7 @@ class TestOps(unittest.TestCase): lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode="trilinear", align_corners=True), lambda x: Tensor.interpolate(x, size=out_sz, mode="linear", align_corners=True), atol=1e-4) + @slow_test def test_cat(self): for dim in range(-2, 3): helper_test_op([(45,65,9), (45,65,9), (45,65,9)], lambda x,y,z: torch.cat((x,y,z), dim), lambda x,y,z: x.cat(y, z, dim=dim)) @@ -2716,6 +2768,7 @@ class TestOps(unittest.TestCase): data = [math.inf, -math.inf, math.nan] helper_test_op((), lambda: torch.tensor(data)[torch.tensor([0, 1, 2])], lambda: Tensor(data)[Tensor([0, 1, 2])]) + @slow_test def test_slice_fancy_indexing_no_dim_collapse(self): a,b,c,d,e,i,j,k,o,p = self._get_index_randoms() # no dim collapse from int or dim injection from None @@ -2725,6 +2778,7 @@ class TestOps(unittest.TestCase): helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,...,e], lambda x: x[i,...,p]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[...,c,:,e], lambda x: x[...,k,:,p]) + @slow_test def test_slice_fancy_indexing_dim_collapse_int(self): a,b,c,d,e,i,j,k,o,p = self._get_index_randoms() # dim collapse from int @@ -2734,6 +2788,7 @@ class TestOps(unittest.TestCase): helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,2,2,2,e], lambda x: x[i,2,2,2,p]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[1,:,3:11:2,d,0:2], lambda x: x[1,:,3:11:2,o,0:2]) + @slow_test def test_slice_fancy_indexing_dim_inject_none(self): a,b,c,d,e,i,j,k,o,p = self._get_index_randoms() # dim injection from None @@ -2767,6 +2822,7 @@ class TestOps(unittest.TestCase): helper_test_op([(2,3)], lambda x: x[torch.tensor([[0,1,-1],[-1,-2,0]]), torch.tensor([2,1,-1])], lambda x: x[Tensor([[0,1,-1],[-1,-2,0]]), Tensor([2,1,-1])]) + @slow_test def test_slice_fancy_indexing_list_indices(self): a,b,c,d,e,i,j,k,o,p = self._get_index_randoms() helper_test_op([(2,5,6,5,3,4)], lambda x: x[[[0]]], lambda x: x[[[0]]]) @@ -2777,6 +2833,7 @@ class TestOps(unittest.TestCase): helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,b,c,[[1],[2],[3]],...], lambda x: x[i,j,k,[[1],[2],[3]],...]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,[2,1,0],c,[-2,1,0],e], lambda x: x[i,[2,1,0],k,[-2,1,0],p]) + @slow_test def test_slice_fancy_indexing_tuple_indices(self): a,b,c,d,e,i,j,k,o,p = self._get_index_randoms() helper_test_op([(2,5,6,5,3,4)], lambda x: x[(((0,),),)], lambda x: x[(((0,),),)]) @@ -2786,6 +2843,7 @@ class TestOps(unittest.TestCase): helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,((2,),(1,),(0,)),c,(2,1,0)], lambda x: x[i,((2,),(1,),(0,)),k,(2,1,0)]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[1,(2,1,0),None,c,(2,1,0),e], lambda x: x[1,(2,1,0),None,k,(2,1,0),p]) + @slow_test def test_slice_fancy_indexing_list_with_tensors(self): a,b,c,d,e,i,j,k,o,p = self._get_index_randoms() helper_test_op([(2,5,6,5,3,4)], lambda x: x[[a]], lambda x: x[[i]]) @@ -2892,15 +2950,16 @@ class TestOps(unittest.TestCase): with self.assertRaises(TypeError): Tensor.ones(4).scatter(dim=1, index=Tensor([0]), src=Tensor.ones(4), reduce="add") + @slow_test def test_scatter_reduce(self): b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False) - a = Tensor(b.detach().cpu().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False) + a = Tensor(b.detach().cpu().numpy().astype(np.int32), requires_grad=False) for reduce in ("sum", "prod", "mean", "amin", "amax"): for dim in (-1,1,-3): - helper_test_op([(4,5,6), (4,5,6)], + helper_test_op([(3,4,5), (3,4,5)], lambda x,src: x.scatter_reduce(dim=dim, index=b, src=src, reduce=reduce), lambda x,src: x.scatter_reduce(dim=dim, index=a, src=src, reduce=reduce), forward_only=True) - helper_test_op([(4,5,6), (4,5,6)], + helper_test_op([(3,4,5), (3,4,5)], lambda x,src: x.scatter_reduce(dim=dim, index=b, src=src, reduce=reduce, include_self=False), lambda x,src: x.scatter_reduce(dim=dim, index=a, src=src, reduce=reduce, include_self=False), forward_only=True) @@ -2927,15 +2986,18 @@ class TestOps(unittest.TestCase): lambda x,src: x.half().scatter_reduce(dim=0, index=a, src=src, reduce="sum"), RuntimeError) + @slow_test def test_scaled_dot_product_attention(self): helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64)], torch.nn.functional.scaled_dot_product_attention, Tensor.scaled_dot_product_attention) helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64), (32,8,16,16)], lambda x,y,z,m: torch.nn.functional.scaled_dot_product_attention(x,y,z,attn_mask=m), lambda x,y,z,m: Tensor.scaled_dot_product_attention(x,y,z,attn_mask=m)) + @slow_test def test_scaled_dot_product_attention_mismatch_ls(self): helper_test_op([(32,8,4,64), (32,8,16,64), (32,8,16,64)], torch.nn.functional.scaled_dot_product_attention, Tensor.scaled_dot_product_attention) + @slow_test def test_scaled_dot_product_attention_causal(self): helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64)], lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z,is_causal=True), @@ -2946,6 +3008,7 @@ class TestOps(unittest.TestCase): lambda x,y,z,m: Tensor.scaled_dot_product_attention(x,y,z,is_causal=True,attn_mask=m), expected=RuntimeError) + @slow_test def test_scaled_dot_product_attention_gqa(self): helper_test_op([(32,32,16,64), (32,8,16,64), (32,8,16,64)], lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z,enable_gqa=True), @@ -2977,6 +3040,8 @@ class TestOps(unittest.TestCase): helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy_with_logits(x,y.clip(0,1), pos_weight=torch.tensor(pos_weight)), lambda x,y: x.binary_crossentropy_logits(y.clip(0,1),pos_weight=Tensor(pos_weight))) + + @unittest.skipIf(RANGEIFY > 1, "broken on RANGEIFY > 1, TODO: fix") def test_cross_entropy_class_probabilities(self): helper_test_op([(32,), (32,)], lambda x,y: torch.nn.functional.cross_entropy(x, y), lambda x,y: x.cross_entropy(y)) helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.cross_entropy(x, y), lambda x,y: x.cross_entropy(y)) diff --git a/tinygrad_repo/test/test_opt_gemm.py b/tinygrad_repo/test/test_opt_gemm.py index 1c9072ec..27aa767e 100644 --- a/tinygrad_repo/test/test_opt_gemm.py +++ b/tinygrad_repo/test/test_opt_gemm.py @@ -2,7 +2,7 @@ import numpy as np import unittest from tinygrad import Tensor from tinygrad.helpers import get_single_element -from tinygrad.codegen.opt.kernel import Opt, OptOps +from tinygrad.codegen.opt import Opt, OptOps from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program class TestOptGemm(unittest.TestCase): diff --git a/tinygrad_repo/test/test_opts.py b/tinygrad_repo/test/test_opts.py new file mode 100644 index 00000000..4a6310ef --- /dev/null +++ b/tinygrad_repo/test/test_opts.py @@ -0,0 +1,22 @@ +import unittest +from tinygrad import Tensor, Device +from tinygrad.helpers import RANGEIFY, CPU_LLVM +from tinygrad.codegen.opt import Opt, OptOps +from tinygrad.engine.realize import get_program + +@unittest.skipIf(RANGEIFY>0, "arg is partial contig in rangeify") +class TestOpts(unittest.TestCase): + def test_opt_upcast(self): + opts = (Opt(OptOps.UPCAST, 0, 4),) + a = Tensor.empty(16) + b = Tensor.empty(16) + out = (a+b).contiguous(arg=opts) + s = out.schedule() + self.assertEqual(s[-1].ast.arg.opts_to_apply, opts) + if Device.DEFAULT in {"CPU", "CL", "METAL"} and not CPU_LLVM: + prg = get_program(s[-1].ast) + self.assertIn('float4', prg.src) + +if __name__ == '__main__': + unittest.main() + diff --git a/tinygrad_repo/test/test_outerworld_range.py b/tinygrad_repo/test/test_outerworld_range.py index 36fac3e2..cfc610cd 100644 --- a/tinygrad_repo/test/test_outerworld_range.py +++ b/tinygrad_repo/test/test_outerworld_range.py @@ -1,5 +1,5 @@ import unittest -from tinygrad import Tensor, nn, Variable, UOp, dtypes +from tinygrad import Tensor, nn, Variable, UOp # outerworld range should support three things # 1. full optimizer steps (test_model_bound_range) @@ -136,7 +136,7 @@ class TestOuterworldRange(unittest.TestCase): def test_model_bound_range(self): m, opt = get_model_and_opt() # TODO: should ranges be unique so you don't have to pass in the -1? - rng = UOp.range(dtypes.int, self.STEPS, -1) + rng = UOp.range(self.STEPS, -1) vib = Variable('i', 0, self.STEPS-1).bind(rng) loss = (m(self.X[vib]) - self.Y[vib]).square().mean() loss.backward() diff --git a/tinygrad_repo/test/test_profiler.py b/tinygrad_repo/test/test_profiler.py index e4a70f23..6143086c 100644 --- a/tinygrad_repo/test/test_profiler.py +++ b/tinygrad_repo/test/test_profiler.py @@ -17,7 +17,7 @@ def helper_collect_profile(*devs): cpu_events.clear() profile_list = [] - with Context(PROFILE=1): + with Context(VIZ=1): yield profile_list for dev in devs: dev.synchronize() for dev in devs: dev._at_profile_finalize() @@ -31,7 +31,7 @@ def helper_profile_filter_device(profile, device:str): return [x for x in profile if getattr(x, "device", None) == device], dev_events[0] # TODO: support in HCQCompiled -is_cpu_hcq = Device.DEFAULT in {"CPU", "LLVM"} +is_cpu_hcq = Device.DEFAULT in {"CPU"} @unittest.skipUnless((issubclass(type(Device[Device.DEFAULT]), HCQCompiled) and not is_cpu_hcq) or Device.DEFAULT in {"METAL"}, "Dev not supported") class TestProfiler(unittest.TestCase): diff --git a/tinygrad_repo/test/test_quantize_onnx.py b/tinygrad_repo/test/test_quantize_onnx.py index 1555b6be..005d9789 100644 --- a/tinygrad_repo/test/test_quantize_onnx.py +++ b/tinygrad_repo/test/test_quantize_onnx.py @@ -3,11 +3,9 @@ import numpy as np import unittest from dataclasses import replace from tinygrad import Tensor, Context, Device, dtypes -from tinygrad.uop.ops import Ops, UOp # noqa: F401 # pylint: disable=unused-import -from tinygrad.codegen.opt.kernel import Kernel, Opt, OptOps +from tinygrad.uop.ops import Ops +from tinygrad.codegen.opt import Opt, OptOps from tinygrad.engine.realize import CompiledRunner, ExecItem, lower_schedule_item, get_program -from tinygrad.codegen.opt.search import bufs_from_lin -from tinygrad.shape.shapetracker import ShapeTracker, View # noqa: F401 # pylint: disable=unused-import N = 512 @@ -236,129 +234,5 @@ class TestQuantizeOnnx(unittest.TestCase): opts = [Opt(op=OptOps.UPCAST, axis=0, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)] sexec(out, opts) -@unittest.skipIf(Device.DEFAULT != "DSP", "only tests for DSP") -class TestDSPCache(unittest.TestCase): - def test_cache_speed(self): - # string becuase this breaks Python language server for syntax highlight for some reason - ast = eval("""UOp(Ops.SINK, dtypes.void, arg=None, src=( - UOp(Ops.STORE, dtypes.void, arg=None, src=( - UOp(Ops.VIEW, dtypes.uchar.ptr(25088), arg=ShapeTracker(views=(View(shape=(1, 28, 28, 32, 1), strides=(0, 896, 32, 1, 0), offset=0, mask=None, contiguous=True),)), src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(25088), arg=0, src=()),)), - UOp(Ops.CAST, dtypes.uchar, arg=None, src=( - UOp(Ops.XOR, dtypes.int, arg=None, src=( - UOp(Ops.MAX, dtypes.int, arg=None, src=( - UOp(Ops.XOR, dtypes.int, arg=None, src=( - UOp(Ops.MAX, dtypes.int, arg=None, src=( - UOp(Ops.CAST, dtypes.int, arg=None, src=( - UOp(Ops.ADD, dtypes.float, arg=None, src=( - UOp(Ops.ADD, dtypes.float, arg=None, src=( - UOp(Ops.MUL, dtypes.float, arg=None, src=( - UOp(Ops.ADD, dtypes.float, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (4,)), src=( - UOp(Ops.MUL, dtypes.float, arg=None, src=( - UOp(Ops.MUL, dtypes.float, arg=None, src=( - UOp(Ops.CAST, dtypes.float, arg=None, src=( - UOp(Ops.CAST, dtypes.int, arg=None, src=( - UOp(Ops.LOAD, dtypes.uchar, arg=None, src=( - UOp(Ops.VIEW, dtypes.uchar.ptr(150528), arg=ShapeTracker(views=(View(shape=(1, 28, 28, 32, 192), strides=(0, 5376, 192, 0, 1), offset=0, mask=None, contiguous=False),)), src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(150528), arg=1, src=()),)),)),)),)), - UOp(Ops.CONST, dtypes.float, arg=0.012368360534310341, src=( - x22:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 28, 28, 32, 192), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), - UOp(Ops.MUL, dtypes.float, arg=None, src=( - UOp(Ops.CAST, dtypes.float, arg=None, src=( - UOp(Ops.CAST, dtypes.int, arg=None, src=( - UOp(Ops.LOAD, dtypes.char, arg=None, src=( - UOp(Ops.VIEW, dtypes.char.ptr(6144), arg=ShapeTracker(views=(View(shape=(32, 48, 4), strides=(4, 128, 1), offset=0, mask=None, contiguous=False), View(shape=(1, 28, 28, 32, 192), strides=(0, 0, 0, 192, 1), offset=0, mask=None, contiguous=False))), src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.char.ptr(6144), arg=2, src=()),)),)),)),)), - UOp(Ops.CONST, dtypes.float, arg=0.007441135589033365, src=( - x22,)),)),)),)), - UOp(Ops.MUL, dtypes.float, arg=None, src=( - UOp(Ops.CAST, dtypes.float, arg=None, src=( - UOp(Ops.LOAD, dtypes.int, arg=None, src=( - UOp(Ops.VIEW, dtypes.int.ptr(32), arg=ShapeTracker(views=(View(shape=(1, 28, 28, 32, 1), strides=(0, 0, 0, 1, 0), offset=0, mask=None, contiguous=False),)), src=( - UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(32), arg=3, src=()),)),)),)), - UOp(Ops.CONST, dtypes.float, arg=9.203465015161783e-05, src=( - x36:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 28, 28, 32, 1), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), - UOp(Ops.CONST, dtypes.float, arg=33.812857328652136, src=( - x36,)),)), - UOp(Ops.CONST, dtypes.float, arg=0.4999999, src=( - x36,)),)), - UOp(Ops.CONST, dtypes.float, arg=136.0, src=( - x36,)),)),)), - UOp(Ops.CONST, dtypes.int, arg=0, src=( - x36,)),)), - x41:=UOp(Ops.CONST, dtypes.int, arg=-1, src=( - x36,)),)), - UOp(Ops.CONST, dtypes.int, arg=-256, src=( - x36,)),)), - x41,)),)),)),))""") - opts = [Opt(op=OptOps.UNROLL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=1, arg=32), Opt(op=OptOps.UPCAST, axis=0, arg=4)] - with Context(DEVECTORIZE=0, QUANTIZE=1): - prg = get_program(ast, opts=opts) - - new_src = """ -typedef int int32 __attribute__((aligned(128),vector_size(128))); -typedef signed char signed_char128 __attribute__((aligned(128),vector_size(128))); -typedef unsigned char unsigned_char8 __attribute__((aligned(8),vector_size(8))); -typedef unsigned char unsigned_char4 __attribute__((aligned(4),vector_size(4))); -typedef unsigned char unsigned_char128 __attribute__((aligned(128),vector_size(128))); -__attribute__((noinline)) void r_196_32_4_24_8(unsigned char* restrict __attribute__((align_value(128))) data0, unsigned char* restrict __attribute__((align_value(128))) data1, signed char* restrict __attribute__((align_value( -128))) data2, int* restrict __attribute__((align_value(128))) data3) { - int32 cast0 = (int32){0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}; - int32 val0 = *((int32*)((data3+0))); - for (int ridx0 = 0; ridx0 < 196; ridx0++) { - int32 acc0 = cast0; - int32 acc1 = cast0; - int32 acc2 = cast0; - int32 acc3 = cast0; - __builtin_HEXAGON_Y2_dcfetch(data1+ridx0*768); - __builtin_HEXAGON_Y2_dcfetch(data1+ridx0*768+192); - __builtin_HEXAGON_Y2_dcfetch(data1+ridx0*768+384); - __builtin_HEXAGON_Y2_dcfetch(data1+ridx0*768+576); - for (int ridx1 = 0; ridx1 < 24; ridx1++) { - signed_char128 val1 = *((signed_char128*)((data2+(ridx1<<8)))); - signed_char128 val2 = *((signed_char128*)((data2+((1+(ridx1<<1))<<7)))); - - int alu0 = ((ridx0*768)+(ridx1<<3)); - - unsigned_char8 val3 = *((unsigned_char8*)((data1+alu0))); - __builtin_HEXAGON_Y2_dcfetch(((data1+alu0)+16)); - unsigned_char8 val4 = *((unsigned_char8*)((data1+(alu0+192)))); - __builtin_HEXAGON_Y2_dcfetch(((data1+(alu0+192))+16)); - unsigned_char8 val5 = *((unsigned_char8*)((data1+(alu0+384)))); - __builtin_HEXAGON_Y2_dcfetch(((data1+(alu0+384))+16)); - unsigned_char8 val6 = *((unsigned_char8*)((data1+(alu0+576)))); - __builtin_HEXAGON_Y2_dcfetch(((data1+(alu0+576))+16)); - - unsigned_char4 alu5 = __builtin_shufflevector(val3, val3, 0, 1, 2, 3); - unsigned_char4 alu6 = __builtin_shufflevector(val4, val4, 0, 1, 2, 3); - unsigned_char4 alu7 = __builtin_shufflevector(val5, val5, 0, 1, 2, 3); - unsigned_char4 alu8 = __builtin_shufflevector(val6, val6, 0, 1, 2, 3); - acc0 = __builtin_HEXAGON_V6_vrmpybus_acc_128B(acc0, val1, (*((unsigned int*)&alu5))); - acc1 = __builtin_HEXAGON_V6_vrmpybus_acc_128B(acc1, val1, (*((unsigned int*)&alu6))); - acc2 = __builtin_HEXAGON_V6_vrmpybus_acc_128B(acc2, val1, (*((unsigned int*)&alu7))); - acc3 = __builtin_HEXAGON_V6_vrmpybus_acc_128B(acc3, val1, (*((unsigned int*)&alu8))); - - unsigned_char4 alu9 = __builtin_shufflevector(val3, val3, 4, 5, 6, 7); - unsigned_char4 alu10 = __builtin_shufflevector(val4, val4, 4, 5, 6, 7); - unsigned_char4 alu11 = __builtin_shufflevector(val5, val5, 4, 5, 6, 7); - unsigned_char4 alu12 = __builtin_shufflevector(val6, val6, 4, 5, 6, 7); - acc0 = __builtin_HEXAGON_V6_vrmpybus_acc_128B(acc0, val2, (*((unsigned int*)&alu9))); - acc1 = __builtin_HEXAGON_V6_vrmpybus_acc_128B(acc1, val2, (*((unsigned int*)&alu10))); - acc2 = __builtin_HEXAGON_V6_vrmpybus_acc_128B(acc2, val2, (*((unsigned int*)&alu11))); - acc3 = __builtin_HEXAGON_V6_vrmpybus_acc_128B(acc3, val2, (*((unsigned int*)&alu12))); - } - unsigned_char128 alu18 = __builtin_HEXAGON_V6_vpackhub_sat_128B(__builtin_HEXAGON_V6_vpackwh_sat_128B((((((acc3+val0)*203)+32767)/65536)+136), (((((acc2+val0)*203)+32767)/65536)+136)), __builtin_HEXAGON_V6_vpackwh_sat_128B((((((acc1+val0)*203)+32767)/65536)+136), (((((acc0+val0)*203)+32767)/65536)+136))); - *((unsigned_char128*)((data0+(ridx0<<7)))) = alu18; - } -} -""" - prg = replace(prg, src=new_src+prg.src.split("/* DSP boilerplate */ ")[1]) - rt = CompiledRunner(prg) - #Device.default.compiler.disassemble(rt.lib) - ei = ExecItem(rt, bufs_from_lin(Kernel(ast))) - tm = ei.run(wait=True) - print(f"final time {tm*1e6:.2f} us") - if __name__ == "__main__": unittest.main() diff --git a/tinygrad_repo/test/test_randomness.py b/tinygrad_repo/test/test_randomness.py index 5c0c5c8c..aeeb2fb3 100644 --- a/tinygrad_repo/test/test_randomness.py +++ b/tinygrad_repo/test/test_randomness.py @@ -1,15 +1,17 @@ import unittest, math from functools import partial -import numpy as np -import torch -from tinygrad import nn, dtypes, Tensor, Device, TinyJit -from tinygrad.helpers import getenv, CI +from tinygrad import nn, dtypes, Tensor, Device, TinyJit, Variable +from tinygrad.helpers import getenv, CI, OSX from tinygrad.device import is_dtype_supported from tinygrad.engine.realize import lower_schedule, CompiledRunner -from hypothesis import given, settings, strategies as strat +from tinygrad.renderer.ptx import PTXRenderer from test.helpers import not_support_multi_device +import numpy as np +import torch +from hypothesis import given, settings, strategies as strat + settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False)) settings.load_profile("my_profile") @@ -98,7 +100,7 @@ class TestRandomness(unittest.TestCase): np.testing.assert_allclose(jr, r) - @unittest.skipIf(getenv("PTX"), "fails with PTX") + @unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "fails with PTX") def test_threefry_doesnt_use_long(self): for (_,ei) in lower_schedule(Tensor.rand(20).schedule()): if isinstance(ei.prg, CompiledRunner): @@ -323,9 +325,9 @@ class TestRandomness(unittest.TestCase): torch_res = torch_res.unsqueeze(0) for i in range(torch_res.shape[0]): self.assertTrue(equal_distribution(lambda *_: tiny_res[i], lambda _: torch_res[i])) - _check_with_torch(w=[0.231, 0., 1., 0.5], num_samples=2000, replacement=True) - _check_with_torch(w=[[0.2, 0.8]], num_samples=2000, replacement=True) # 2D but only 1 row - _check_with_torch(w=[[0.453, 0., 1., 0.81], [0.1, 0.8, 0., 0.1]], num_samples=2000, replacement=True) + _check_with_torch(w=[0.231, 0., 1., 0.5], num_samples=300, replacement=True) + _check_with_torch(w=[[0.2, 0.8]], num_samples=300, replacement=True) # 2D but only 1 row + _check_with_torch(w=[[0.453, 0., 1., 0.81], [0.1, 0.8, 0., 0.1]], num_samples=300, replacement=True) # no-replacement isn't supported, unless taking only one sample w = [0.1, 0.9] self.assertRaises(AssertionError, lambda: Tensor(w).multinomial(100, replacement=False)) @@ -359,5 +361,20 @@ class TestRandomness(unittest.TestCase): assert equal_distribution(lambda *_: nn.BatchNorm2d(*params).weight, lambda _: torch.nn.BatchNorm2d(*params).weight.detach()) assert equal_distribution(lambda *_: nn.BatchNorm2d(*params).bias, lambda _: torch.nn.BatchNorm2d(*params).bias.detach()) +# TODO: still fails with MAX_KERNEL_BUFFERS +@unittest.skipIf(Device.DEFAULT == "WEBGPU" and not OSX, "WEBGPU Vulkan can only run kernels with up to 10 buffers") +class TestSample(unittest.TestCase): + def test_sample(self): + X = Tensor.rand(10000, 50).realize() + BS = 16 + idxs = np.random.randint(0, X.shape[0], size=(BS)) + # this uncovered a bug with arg sort order + batch = [Variable(f'idx{i}', 0, X.shape[0]-1).bind(s) for i,s in enumerate(idxs.tolist())] + x = Tensor.cat(*[X.shrink(((batch[i], batch[i]+1), None)) for i in range(BS)]) + print(idxs) + ret = x.numpy() + base = X.numpy()[idxs] + np.testing.assert_equal(ret, base) + if __name__ == "__main__": unittest.main() diff --git a/tinygrad_repo/test/test_rangeify.py b/tinygrad_repo/test/test_rangeify.py index 88c38809..3c86f1e3 100644 --- a/tinygrad_repo/test/test_rangeify.py +++ b/tinygrad_repo/test/test_rangeify.py @@ -1,6 +1,20 @@ import unittest -from tinygrad import Tensor -from tinygrad.helpers import RANGEIFY +from tinygrad import Tensor, nn +from tinygrad.helpers import RANGEIFY, Context, GlobalCounters +from tinygrad.uop.ops import UOp + +@unittest.skipIf(RANGEIFY<1, "tests only for RANGEIFY") +class TestRangeifyAssign(unittest.TestCase): + def test_assign_permuted(self): + A = Tensor.empty(4, 4, dtype='int') + B = Tensor.arange(16).reshape(4,4) + ret = A.permute(1,0).assign(B) + lst = ret.tolist() + lst2 = A.tolist() + lst3 = B.tolist() + print(lst) + print(lst2) + print(lst3) N = 256 @@ -11,6 +25,26 @@ class TestRangeify(unittest.TestCase): ba = A.expand(N, N) ((ba+1).sum(axis=1) + (ba+2).sum(axis=0)).realize() + def test_partial_contig(self): + A = Tensor.empty(64, 64, 64) + ret = A.sum(axis=2).contiguous(arg=(1,)).sum(axis=1) + ret.realize() + + def test_double_gemm_real(self): + def go(): + with Context(DEBUG=0): + Tensor.manual_seed(1337) + A,B,C = [Tensor.randn(N, N) for _ in range(3)] + Tensor.realize(A, B, C) + GlobalCounters.reset() + return (A@B@C).realize() + rng = go() + with Context(RANGEIFY=0, DEBUG=2): + ref = go() + mse = ((rng-ref)**2).sum().item() + print(f"mse: {mse}") + self.assertLessEqual(mse, 1e-2) + def test_double_gemm(self): A = Tensor.empty(N, N) B = Tensor.empty(N, N) @@ -72,6 +106,16 @@ class TestRangeify(unittest.TestCase): w2 = Tensor.empty(12, 8, 3, 3) x.conv2d(w1).conv2d(w2).realize() + def test_conv_maxpool_contig(self): self.test_conv_maxpool(True) + def test_conv_maxpool(self, contig=False): + GlobalCounters.reset() + x = Tensor.empty(32, 16, 64, 64) + l1 = nn.Conv2d(16, 16, 3) + for p in nn.state.get_parameters(l1): p.replace(Tensor.empty(p.shape)) + x = l1(x) + if contig: x = x.contiguous() + x.max_pool2d().realize() + def test_double_conv2d_half_contig(self): x = Tensor.empty(1, 4, 32, 32) w1 = Tensor.empty(8, 4, 3, 3) @@ -96,27 +140,41 @@ class TestRangeify(unittest.TestCase): out.realize() def test_flash_attention(self): - BS = 4 - HEADS = 2 - MATDIM = 16 - EMB = 8 - q = Tensor.empty(BS, HEADS, MATDIM, EMB) - k = Tensor.empty(BS, HEADS, MATDIM, EMB) - v = Tensor.empty(BS, HEADS, MATDIM, EMB) - q.scaled_dot_product_attention(k, v).realize() + BS, HEADS, SEQLEN, EMB = 4, 2, 16, 8 -from tinygrad import dtypes -from tinygrad.uop.ops import UOp + # bigger + #BS, HEADS, SEQLEN, EMB = 4, 32, 1024, 64 + + # llama 8B + #BS, HEADS, SEQLEN, EMB = 4, 32, 2048, 128 + + def fa(): + Tensor.manual_seed(1337) + with Context(DEBUG=0): q,k,v = [Tensor.rand(BS, HEADS, SEQLEN, EMB).contiguous().realize() for _ in range(3)] + return q.scaled_dot_product_attention(k, v).realize() + + with Context(DEBUG=4): + GlobalCounters.reset() + ret = fa() + with Context(RANGEIFY=0): + with Context(DEBUG=2): + GlobalCounters.reset() + cmp = fa() + with Context(DEBUG=0): + mse = ((cmp-ret)**2).sum().item() + print(f"mse: {mse}") + self.assertLessEqual(mse, 1e-6) # contiguous + reduce can support ranges? +@unittest.skip("okay to disable this for now") @unittest.skipIf(RANGEIFY<1, "tests only for RANGEIFY") class TestOuterworld(unittest.TestCase): def test_passthrough_range(self): t = Tensor.rand(10, 10).realize() # passthrough ranges - a = UOp.range(dtypes.int, 10, -1) + a = UOp.range(10, -1) sel = t[a] cpy = sel.contiguous(a).realize() @@ -126,7 +184,7 @@ class TestOuterworld(unittest.TestCase): t = Tensor.rand(10, 10).realize() # passthrough ranges - a = UOp.range(dtypes.int, 10, -1) + a = UOp.range(10, -1) sel = t[9-a] cpy = sel.contiguous(a).realize() @@ -138,7 +196,7 @@ class TestOuterworld(unittest.TestCase): x = Tensor.ones(3, 10, 2).contiguous() # vmap across axis 0 - a = UOp.range(dtypes.int, 3, -1) + a = UOp.range(3, -1) out = f(x[a]) out = out.contiguous(a) @@ -146,17 +204,39 @@ class TestOuterworld(unittest.TestCase): out.realize() print(out.numpy()) + @unittest.skip("opts don't work") def test_triple_gemm(self): x = Tensor.rand(1, 16).realize() W = Tensor.rand(3, 16, 16).realize() manual = (x @ W[0] @ W[1] @ W[2]).contiguous().realize() - a = UOp.range(dtypes.int, 3, -1) + a = UOp.range(3, -1) x = x.assign(x @ W[a]) out = x.contiguous(a)[-1].contiguous().realize() self.assertTrue((manual==out).all().item()) + def test_setitem_pyrange(self): + with Context(DEBUG=0): + t = Tensor.rand(10).realize() + o = Tensor.empty(10) + GlobalCounters.reset() + for i in range(10): + o[i] = t[i] + o.realize() + self.assertTrue((t==o).all().item()) + + @unittest.skip("TODO: fix this") + def test_setitem(self): + with Context(DEBUG=0): + t = Tensor.rand(10).realize() + o = Tensor.empty(10) + GlobalCounters.reset() + i = UOp.range(10, -1) + o[i] = t[i] + o.contiguous(i).realize() + self.assertTrue((t==o).all().item()) + if __name__ == '__main__': unittest.main() diff --git a/tinygrad_repo/test/test_renderer_failures.py b/tinygrad_repo/test/test_renderer_failures.py index 3dd795ea..8092914b 100644 --- a/tinygrad_repo/test/test_renderer_failures.py +++ b/tinygrad_repo/test/test_renderer_failures.py @@ -25,7 +25,8 @@ def _test_uop_result(inputs:List[Tensor], stores:List[UOp], local_size=None): initial_value=np.zeros(sz, dtype=_to_np_dtype(dtype)).data) for u in uops if u.op is Ops.STORE] inbufs = [cast(UOp,x.uop).base.buffer for x in inputs] src = Device[Device.DEFAULT].renderer.render(uops) - ei = CompiledRunner(ProgramSpec("test", src, Device.DEFAULT, uops[-1], uops=uops, local_size=local_size)) + ei = CompiledRunner(ProgramSpec(uops[-1].arg.name if uops[-1].arg is not None else "test", + src, Device.DEFAULT, uops[-1], uops=uops, local_size=local_size)) ei.exec(outbufs+inbufs) return [np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)) for x in outbufs] @@ -45,7 +46,7 @@ class TestRendererFailures(unittest.TestCase): @unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, PythonRenderer)), "test is for ptx or python renderer") def test_gated_store_with_alu(self): a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) - gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0) + gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0) gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0, gate_alu), UOp.const(dtypes.int, 1))) sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,)) uops = full_rewrite(sink, Device[Device.DEFAULT].renderer) @@ -55,8 +56,8 @@ class TestRendererFailures(unittest.TestCase): @unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, PythonRenderer)), "test is for ptx or python renderer") def test_gated_store_with_alu_2d(self): a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) - gate_alu_0 = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0) - gate_alu_1 = (lidx1:=UOp(Ops.SPECIAL, dtypes.int, (), ('lidx1', 2))).ne(0) + gate_alu_0 = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0) + gate_alu_1 = (lidx1:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 2),), 'lidx1')).ne(0) gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0+lidx1*4, gate_alu_0&gate_alu_1), UOp.const(dtypes.int, 1))) sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,)) uops = full_rewrite(sink, Device[Device.DEFAULT].renderer) @@ -100,7 +101,7 @@ class TestPTXFailures(unittest.TestCase): @unittest.skip("INDEX can only have a gate ALU parent, not an IF") def test_gated_store_with_if(self): a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) - gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0) + gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0) val = UOp.const(dtypes.int, 1) if_uop = UOp(Ops.IF, dtypes.void, (gate_alu,)) gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0, if_uop), val)) diff --git a/tinygrad_repo/test/test_sample.py b/tinygrad_repo/test/test_sample.py deleted file mode 100644 index d5347463..00000000 --- a/tinygrad_repo/test/test_sample.py +++ /dev/null @@ -1,22 +0,0 @@ -import unittest -import numpy as np -from tinygrad import Tensor, Variable, Device -from tinygrad.helpers import OSX - -# TODO: still fails with MAX_KERNEL_BUFFERS -@unittest.skipIf(Device.DEFAULT == "WEBGPU" and not OSX, "WEBGPU Vulkan can only run kernels with up to 10 buffers") -class TestSample(unittest.TestCase): - def test_sample(self): - X = Tensor.rand(10000, 50).realize() - BS = 16 - idxs = np.random.randint(0, X.shape[0], size=(BS)) - # this uncovered a bug with arg sort order - batch = [Variable(f'idx{i}', 0, X.shape[0]-1).bind(s) for i,s in enumerate(idxs.tolist())] - x = Tensor.cat(*[X.shrink(((batch[i], batch[i]+1), None)) for i in range(BS)]) - print(idxs) - ret = x.numpy() - base = X.numpy()[idxs] - np.testing.assert_equal(ret, base) - -if __name__ == '__main__': - unittest.main() \ No newline at end of file diff --git a/tinygrad_repo/test/test_schedule.py b/tinygrad_repo/test/test_schedule.py index 94d338b5..2c0d8bd1 100644 --- a/tinygrad_repo/test/test_schedule.py +++ b/tinygrad_repo/test/test_schedule.py @@ -12,9 +12,9 @@ from tinygrad import nn, dtypes, Device, Tensor from tinygrad.device import is_dtype_supported from tinygrad.dtype import DType, ImageDType from tinygrad.shape.shapetracker import ShapeTracker -from tinygrad.uop.ops import PatternMatcher, UOp, Ops, GroupOp, UPat, graph_rewrite, track_rewrites +from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat, graph_rewrite, track_rewrites from tinygrad.uop.symbolic import symbolic_simple -from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp +from tinygrad.helpers import CI, DEBUG, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp, RANGEIFY from tinygrad.schedule.kernelize import merge_views, get_kernelize_map, Kernel from tinygrad.engine.schedule import create_schedule_with_vars from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule @@ -33,6 +33,7 @@ def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Te # test lowering all the ScheduleItems to ExecItems kernel_cnt = len([si for si,ei in lower_schedule(sched.copy()) if isinstance(ei.prg, CompiledRunner) or not filter_sink]) if kernel_cnt != allowed: + if RANGEIFY: return sched # allow different kernel count, TODO: fix the asserts print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}") if DEBUG >= 3: for i,s in enumerate(sched): @@ -41,6 +42,8 @@ def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Te raise KernelCountException(f"{kernel_cnt} != {allowed}") return sched +def expect_rangeify_fails(fxn): return (unittest.expectedFailure if RANGEIFY else (lambda f:f))(fxn) + def _realize_weights(m): for p in nn.state.get_parameters(m): p.realize() @@ -111,6 +114,7 @@ class TestSchedule(unittest.TestCase): self.assertListEqual(a.tolist(), [[15]]) @unittest.skipIf(Device.DEFAULT == "CPU", "devices must mismatch") + @expect_rangeify_fails def test_error_on_device_mismatch(self): a = Tensor.empty(10) b = Tensor.empty(10, device="CPU") @@ -118,11 +122,12 @@ class TestSchedule(unittest.TestCase): with self.assertRaisesRegex(RuntimeError, "all buffers must be on the same device"): check_schedule(c, 1) @unittest.skipIf(Device.DEFAULT == "CPU", "devices must mismatch") + @expect_rangeify_fails def test_error_on_device_mismatch_alt(self): a = Tensor.empty(10) b = Tensor.empty((1,), device="CPU").expand(10).contiguous() c = a+b - with self.assertRaisesRegex(RuntimeError, "all buffers must be on the same device"): check_schedule(c, 1) + with self.assertRaisesRegex(RuntimeError, "all buffers must be on the same device"): check_schedule(c, 2 if RANGEIFY else 1) @unittest.skipUnless(is_dtype_supported(dtypes.half) and getenv("CAST_AFTER_EXPAND"), "need half and CAST_AFTER_EXPAND=1") @unittest.skip("CAST_AFTER_EXPAND is not supported") @@ -140,6 +145,7 @@ class TestSchedule(unittest.TestCase): np.testing.assert_equal(xt.numpy(), X.numpy()[1][0]) @unittest.skipIf(CI and Device.DEFAULT == "NV", "crashes on NV CI") + @unittest.skipIf(RANGEIFY, "rangeify doesn't implement input buffer limiting") def test_add_chain_buffers(self): N = 31 with Context(TRACK_MATCH_STATS=0, DEBUG=0): @@ -198,9 +204,10 @@ class TestSchedule(unittest.TestCase): def test_simplify_padded_const(self): a = Tensor.empty(1022).cummax(axis=0) - sched = check_schedule(a, 5) - ast = sched[0].ast - self.assertLessEqual(len([u for u in ast.toposort() if u.op is Ops.WHERE]), 6) + check_schedule(a, 5) + # TODO: what is this testing? + #ast = sched[0].ast + #self.assertLessEqual(len([u for u in ast.toposort() if u.op is Ops.WHERE]), 6) def test_basic_binop_fusion(self): a = Tensor.empty(10) @@ -278,7 +285,7 @@ class TestSchedule(unittest.TestCase): a = Tensor.empty(10,10,10) b = Tensor.empty(10,10,1) c = a.sum(axis=0, keepdim=True).permute(2,1,0) + b - with self.assertRaises(KernelCountException): check_schedule(c, 1) + check_schedule(c, 2) def test_allow_push_permutes(self): a = Tensor.randn(10,10,10).realize() @@ -316,7 +323,7 @@ class TestSchedule(unittest.TestCase): b = Tensor.empty(10) c = a+b d = a.reshape(10,1)+b.reshape(10,1) - with self.assertRaises(KernelCountException): check_schedule(d, 0, [c]) + check_schedule(d, 1, [c]) # failing in new lazy def test_cache_binaryop_transpose(self): @@ -324,7 +331,7 @@ class TestSchedule(unittest.TestCase): b = Tensor.empty(10,10) c = (a.T*b.T).T #.contiguous() d = a*b - with self.assertRaises(KernelCountException): check_schedule(d, 0, [c]) + check_schedule(d, 1, [c]) def test_cache_two_reduceops(self): a = Tensor.empty(10) @@ -339,7 +346,7 @@ class TestSchedule(unittest.TestCase): r1 = (x - r0).sum(axis=0).div(2) out = r0 + r1 schedule = check_schedule(out, 2) - reduceops = [x for si in schedule for x in si.ast.toposort() if x.op is Ops.REDUCE_AXIS] + reduceops = [x for si in schedule for x in si.ast.toposort() if x.op in {Ops.REDUCE_AXIS, Ops.REDUCE}] assert len(reduceops) == 2 def test_cache_reduce_multiple_children(self): @@ -349,9 +356,9 @@ class TestSchedule(unittest.TestCase): r1 = (x - r0).sum(axis=0).div(2) out0 = r0 + y out1 = r1 + y - schedule = check_schedule([out0, out1], 4) - reduceops = [x for si in schedule for x in si.ast.toposort() if x.op is Ops.REDUCE_AXIS] - assert len(reduceops) == 2 + schedule = check_schedule([out0, out1], 2 if RANGEIFY else 4) + reduceops = [x for si in schedule for x in si.ast.toposort() if x.op in {Ops.REDUCE_AXIS, Ops.REDUCE}] + assert len(reduceops) == (3 if RANGEIFY else 2) def test_div_collapse_buffer(self): a = Tensor.full((4,), 4.0).contiguous().realize() @@ -394,6 +401,7 @@ class TestSchedule(unittest.TestCase): # a and b share the same underlying device memory self.assertIs(a.uop.realized, b.uop.realized) + @expect_rangeify_fails def test_clone_doesnt_dedup(self): src = Tensor.ones(4).contiguous().realize() a = src.clone() @@ -417,6 +425,11 @@ class TestSchedule(unittest.TestCase): b = Tensor.full((4, 4), 1.).contiguous().realize() check_schedule([a+b, a+b], 1) + def test_const_realize(self): + t = Tensor.ones(2) + check_schedule(t[0], 0) + check_schedule(t[1], 0) + def test_fold_double_unary(self): y = Tensor.empty(2) out = y.sum(keepdim=True).sqrt().neg() @@ -558,7 +571,7 @@ class TestSchedule(unittest.TestCase): c = a+b d = a.reshape(10,1)+b.reshape(10,1) out = c.sum() + d.sum() - with self.assertRaises(KernelCountException): check_schedule(out, 1) + check_schedule(out, 2) def test_children_dont_push(self): a = Tensor.empty(10, 10, 1) @@ -569,6 +582,7 @@ class TestSchedule(unittest.TestCase): check_schedule(f, 2) # failing in new lazy + @unittest.skip("always fusing elementwise") def test_dont_fuse_binops_with_children(self): a = Tensor.empty(10) b = Tensor.empty(10) @@ -576,8 +590,8 @@ class TestSchedule(unittest.TestCase): keep_me = a+b e = keep_me.sum() # noqa: F841 give keep_me a child (NOTE: BinaryOps won't be a child since it will instant fuse) d = keep_me+c - with self.assertRaises(KernelCountException): check_schedule(d, 2) - with self.assertRaises(KernelCountException): check_schedule(keep_me, 0, [d]) + check_schedule(d, 2) + check_schedule(keep_me, 0, [d]) #@unittest.skip("failing in old lazy") def test_permute_breaks_fusion(self): @@ -627,7 +641,8 @@ class TestSchedule(unittest.TestCase): x = x.image_conv2d(w3, b3) # NOOP, 3 convs, contiguous - with self.assertRaises(KernelCountException): check_schedule(x, 5) + #check_schedule(x, 5) + check_schedule(x, 8) def test_image_conv_fusion_minimal(self): b1 = Tensor.empty(16) @@ -682,6 +697,7 @@ class TestSchedule(unittest.TestCase): c = (a.sum(2).contiguous() + b).contiguous() check_schedule(c, 2) + @expect_rangeify_fails def test_kernelize(self): a = Tensor.empty(10) b = Tensor.empty(10) @@ -689,12 +705,14 @@ class TestSchedule(unittest.TestCase): d = c+2 check_schedule(d, 2) + @expect_rangeify_fails def test_kernelize_view(self): a = Tensor.empty(4,1) b = a*2 c = b.kernelize()+Tensor.empty(4,4) check_schedule(c, 2) + @expect_rangeify_fails def test_kernelize_diamond(self): a = Tensor([0]).realize() prev_a = (a+1).contiguous() @@ -703,6 +721,7 @@ class TestSchedule(unittest.TestCase): assert prev_a.uop in a.uop.src, "contiguous usage must run before assign" self.assertEqual((prev_a+a*3).item(), 1+2*3) + @expect_rangeify_fails def test_multioutput_ast(self): a = Tensor.zeros(1, dtype=dtypes.int).contiguous().realize().uop b = Tensor.zeros(1, dtype=dtypes.int).contiguous().realize().uop @@ -714,6 +733,7 @@ class TestSchedule(unittest.TestCase): self.assertEqual(b.buffer.numpy(), [12]) # unlike schedule, kernelize can be called multiple times on a Tensor + @expect_rangeify_fails def test_double_kerenlize(self): a = Tensor.empty(10) b = Tensor.empty(10) @@ -722,6 +742,7 @@ class TestSchedule(unittest.TestCase): e = c.kernelize()+d.kernelize() check_schedule(e, 3) + @expect_rangeify_fails def test_kernelize_bw(self): a = Tensor.full((3,), 2.0, requires_grad=True).contiguous() b = Tensor.full((3,), 3.0, requires_grad=True).contiguous() @@ -732,6 +753,7 @@ class TestSchedule(unittest.TestCase): self.assertEqual(z.item(), 18.0) self.assertEqual(z.grad.item(), 1.0) + @expect_rangeify_fails def test_kernelize_bw_view(self): a = Tensor.full((3,1), 2.0, requires_grad=True).contiguous() b = Tensor.full((3,1), 3.0, requires_grad=True).contiguous() @@ -784,6 +806,13 @@ class TestSchedule(unittest.TestCase): out = x + 1 check_schedule(out, 0, filter_sink=False) + def test_zero_size_assign(self): + f = Tensor.full((2,), 0.).contiguous().realize() + a = f.shrink_to((0,)) + a.assign(Tensor.ones_like(a)) + check_schedule(a, 0) + self.assertEqual(a.tolist(), []) + def test_reduce_permute_nofuse(self): x = Tensor.empty(32, 32, 32) y = Tensor.empty(32, 32) @@ -888,26 +917,24 @@ class TestSchedule(unittest.TestCase): out = x.contiguous() + y.contiguous() check_schedule(out, 2, filter_sink=False) - @unittest.expectedFailure def test_reduce_same_size(self): Tensor.manual_seed(0) a = Tensor.randn(4, 4).realize() out0 = a.sum() + 2 out1 = a.sum() + 4 out2 = out0 * out1 - run_schedule(check_schedule([out0, out1, out2], 1)) + run_schedule(check_schedule([out0, out1, out2], 1 if RANGEIFY else 4)) np.testing.assert_allclose(out0.numpy(), out0_np:=a.numpy().sum()+2, atol=1e-4, rtol=1e-6) np.testing.assert_allclose(out1.numpy(), out1_np:=a.numpy().sum()+4, atol=1e-4, rtol=1e-6) np.testing.assert_allclose(out2.numpy(), out0_np*out1_np, atol=1e-4, rtol=1e-6) - @unittest.expectedFailure def test_reduce_multiple_paths(self): Tensor.manual_seed(0) a = Tensor.randn(4, 4).realize() out0 = a.sum().exp2() # out1 has two paths to a.sum() out1 = a.sum() + out0 - run_schedule(check_schedule([out0, out1], 1)) + run_schedule(check_schedule([out0, out1], 1 if RANGEIFY else 3)) np.testing.assert_allclose(out0.numpy(), out0_np:=np.exp2(a.numpy().sum()), atol=1e-4, rtol=1e-4) np.testing.assert_allclose(out1.numpy(), a.numpy().sum()+out0_np, atol=1e-4, rtol=1e-6) @@ -983,7 +1010,6 @@ class TestSchedule(unittest.TestCase): np.testing.assert_allclose(e.numpy(), e_np:=b.numpy() + out0_np, atol=1e-4, rtol=1e-4) np.testing.assert_allclose(out1.numpy(), r_np + e_np[0][0][0], atol=1e-4, rtol=1e-4) - # changed by multireduce def test_reduce_expand_child(self): Tensor.manual_seed(0) a = Tensor.randn((32, 32, 32)).realize() @@ -995,13 +1021,12 @@ class TestSchedule(unittest.TestCase): np.testing.assert_allclose(out0.numpy(), a.numpy().sum()+2, atol=1e-4, rtol=1e-4) np.testing.assert_allclose(out1.numpy(), a.numpy().sum()+b.numpy(), atol=1e-4, rtol=1e-4) - @unittest.expectedFailure def test_reduce_shrink_child(self): a = Tensor.empty(100, 100) b = Tensor.empty(10,) c = a.sum() + b[0] d = a.sum() + 2 - check_schedule([c, d], 1) + check_schedule([c, d], 1 if RANGEIFY else 3) def test_reduce_multiple_paths_midshrink(self): a = Tensor.empty(4, 4) @@ -1024,20 +1049,6 @@ class TestSchedule(unittest.TestCase): run_schedule(check_schedule(out, 2)) np.testing.assert_allclose(out.numpy(), x.numpy().std(axis=-1, ddof=1), atol=1e-4, rtol=1e-4) - def test_argmin_multireduce_fusion(self): - Tensor.manual_seed(0) - x = Tensor.randn(4, 32).realize() - out = x.argmin(-1) - run_schedule(check_schedule(out, 2)) - np.testing.assert_equal(out.numpy(), x.numpy().argmin(axis=-1)) - - def test_argmax_multireduce_fusion(self): - Tensor.manual_seed(0) - x = Tensor.randn(4, 32).realize() - out = x.argmax(-1) - run_schedule(check_schedule(out, 2)) - np.testing.assert_equal(out.numpy(), x.numpy().argmax(axis=-1)) - def test_scaled_dot_product_attention_multireduce_fusion(self): Tensor.manual_seed(0) q = Tensor.randn(32,8,16,8).realize() @@ -1050,6 +1061,14 @@ class TestSchedule(unittest.TestCase): compare = torch.nn.functional.scaled_dot_product_attention(torch.tensor(q.numpy()),torch.tensor(k.numpy()),torch.tensor(v.numpy())) np.testing.assert_allclose(out.numpy(), compare.numpy(), atol=1e-6, rtol=1e-3) + with Context(FUSE_ATTENTION=1): + out = Tensor.scaled_dot_product_attention(q,k,v) + run_schedule(check_schedule(out, 1)) + if getenv("CHECK", 1): + import torch + compare = torch.nn.functional.scaled_dot_product_attention(torch.tensor(q.numpy()),torch.tensor(k.numpy()),torch.tensor(v.numpy())) + np.testing.assert_allclose(out.numpy(), compare.numpy(), atol=1e-6, rtol=1e-3) + def test_ugly_reduceop_pairing(self): Tensor.manual_seed(0) a = Tensor.randn(4, 32).realize() @@ -1171,13 +1190,14 @@ class TestSchedule(unittest.TestCase): np.testing.assert_allclose(out.numpy(), expected, atol=1e-4, rtol=1e-4) @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") + @expect_rangeify_fails def test_softmax_upcast(self): # input half, softmax in float Tensor.manual_seed(0) x = Tensor.randn(4, 12, 64, 64, dtype=dtypes.half).realize() out = x.softmax(dtype=dtypes.float) sched = out.schedule() - self.assertEqual(len(sched), 3) + self.assertEqual(len(sched), 2 if RANGEIFY else 3) self.assertEqual(sched[0].bufs[0].dtype, dtypes.half) # input float, softmax in float @@ -1194,7 +1214,6 @@ class TestSchedule(unittest.TestCase): x.softmax().sum().backward() run_schedule(check_schedule(x.grad, 4)) - # changed by: multireduce spec def test_layernorm_onelayer_fusion(self): Tensor.manual_seed(0) layer = nn.LayerNorm([10, 10]) @@ -1308,6 +1327,7 @@ class TestSchedule(unittest.TestCase): with Context(FUSE_CONV_BW=1): check_schedule(opt.schedule_step(), 14) @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") + @expect_rangeify_fails def test_prefer_half_buffer(self): x = Tensor.ones(4).contiguous().realize() # y = Tensor.ones(4).contiguous().realize() @@ -1425,7 +1445,6 @@ class TestSchedule(unittest.TestCase): run_schedule(schedule) np.testing.assert_allclose(b.numpy(), a.numpy().sum(0)+a.numpy().max(0) + a.numpy().max(1)+a.numpy().sum(1)+2, atol=1e-4, rtol=1e-4) - # changed by: multireduce spec # pattern in test_transformer def test_partial_fuse1(self): Tensor.manual_seed(0) @@ -1438,7 +1457,6 @@ class TestSchedule(unittest.TestCase): np.testing.assert_allclose(c.numpy(), a.numpy().sum()+2, atol=1e-4, rtol=1e-4) np.testing.assert_allclose(d.numpy(), (a.numpy().sum() - b.numpy().sum()) * 4, atol=1e-4, rtol=1e-4) - # changed by: multireduce spec # pattern in conv def test_partial_fuse2(self): Tensor.manual_seed(0) @@ -1451,9 +1469,7 @@ class TestSchedule(unittest.TestCase): np.testing.assert_allclose(c.numpy(), a.numpy().sum()+2, atol=1e-4, rtol=1e-4) np.testing.assert_allclose(d.numpy(), b.numpy().sum()-(a.numpy().sum()+2), atol=1e-4, rtol=1e-4) - # changed by: multireduce spec # pattern in adam - @unittest.expectedFailure def test_partial_fuse3(self): Tensor.manual_seed(0) a = Tensor.randn(16, 16).realize() @@ -1463,14 +1479,12 @@ class TestSchedule(unittest.TestCase): e = c * d f = b.sum() - e # run_schedule(check_schedule([c, d, e, f], 1)) - run_schedule(check_schedule([c, d, e, f], 2)) + run_schedule(check_schedule([c, d, e, f], 2 if RANGEIFY else 5)) np.testing.assert_allclose(c.numpy(), c_np:=a.numpy().sum()+2, atol=1e-4, rtol=1e-4) np.testing.assert_allclose(d.numpy(), d_np:=a.numpy().sum()*2, atol=1e-4, rtol=1e-4) np.testing.assert_allclose(e.numpy(), e_np:=c_np*d_np, atol=1e-4, rtol=1e-4) np.testing.assert_allclose(f.numpy(), b.numpy().sum() - e_np, atol=1e-4, rtol=1e-4) - # changed by: multireduce spec - @unittest.expectedFailure def test_partial_fuse4(self): Tensor.manual_seed(0) a = Tensor.randn(16, 16).realize() @@ -1480,7 +1494,7 @@ class TestSchedule(unittest.TestCase): e = c * d f = (b - d).sum() - e # run_schedule(check_schedule([c, d, e, f], 1)) - run_schedule(check_schedule([c, d, e, f], 3)) + run_schedule(check_schedule([c, d, e, f], 5)) np.testing.assert_allclose(c.numpy(), c_np:=a.numpy().sum()+2, atol=1e-4, rtol=1e-4) np.testing.assert_allclose(d.numpy(), d_np:=a.numpy().sum()*2, atol=1e-4, rtol=1e-4) np.testing.assert_allclose(e.numpy(), e_np:=c_np*d_np, atol=1e-4, rtol=1e-4) @@ -1615,11 +1629,11 @@ class TestSchedule(unittest.TestCase): out = x.argmax(1) run_schedule(check_schedule(out, 2)) - def test_conv2d(self): _test_conv2d(7) - def test_conv2d_fused(self): _test_conv2d(5, FUSE_CONV_BW=1) + def test_conv2d(self): _test_conv2d(4 if RANGEIFY else 7) + def test_conv2d_fused(self): _test_conv2d(4 if RANGEIFY else 5, FUSE_CONV_BW=1) @unittest.skipUnless(is_dtype_supported(dtypes.half) and is_dtype_supported(dtypes.ulong), "need half and ulong") - def test_conv2d_half(self): _test_conv2d(7, dtype=dtypes.half) + def test_conv2d_half(self): _test_conv2d(4 if RANGEIFY else 7, dtype=dtypes.half) @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") @unittest.skipIf(Device.DEFAULT == "WEBGPU", "Causes other tests to fail") @unittest.expectedFailure @@ -1646,7 +1660,8 @@ class TestSchedule(unittest.TestCase): constv = Tensor.empty(2, 2).uop.const_like(10).contiguous() check_schedule(constv, 1) - @unittest.skipIf(Device.DEFAULT != "GPU", "image only supported on GPU") + @unittest.skipIf(Device.DEFAULT != "CL", "image only supported on CL") + @expect_rangeify_fails def test_image_matmul(self): with Context(IMAGE=2): x = Tensor.randn((9, 9)).realize() @@ -1682,6 +1697,7 @@ class TestSchedule(unittest.TestCase): def test_late_fusion_post_expand(self): self._test_fusion([(32, 32)], lambda a:a-a.sum(1), 2) + @expect_rangeify_fails def test_cast_padded_view(self): a = Tensor.arange(4).reshape(1, 4) casted_view = a.pad(((0, 1), (0, 0))).cast(dtypes.float) @@ -1711,6 +1727,7 @@ class TestSchedule(unittest.TestCase): self.assertListEqual(realized_const_view.tolist(), [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]) @given(strat.sampled_from(dtypes.all), strat.sampled_from(dtypes.all)) + @expect_rangeify_fails def test_cast_padded_const(self, dt1, dt2): assume(is_dtype_supported(dt1) and is_dtype_supported(dt2)) a = Tensor(1, dtype=dt1).reshape(1, 1).pad(((1, 1), None)) @@ -1720,53 +1737,41 @@ class TestSchedule(unittest.TestCase): run_schedule(check_schedule(realized_const_view, 1)) np.testing.assert_equal(realized_const_view.numpy(), [[0], [1], [0]]) -class TestIndexing(unittest.TestCase): - def check_schedule(self, xt:Tensor|list[Tensor], cnt:int): - with Context(FUSE_ARANGE=getenv("FUSE_ARANGE", 1)): - lst = [xt] if isinstance(xt, Tensor) else xt - s = Tensor.schedule(*lst) - lowered = [x[1] for x in lower_schedule(s.copy())] - kernels = [ei for ei in list(lowered) if isinstance(ei.prg, CompiledRunner)] - if FUSE_ARANGE and len(kernels) != cnt: - raise KernelCountException(f"{len(kernels)} != {cnt}") - for ei in lowered: ei.run(do_update_stats=True) - return s - def test_simple_indexing(self): X = Tensor.randn(10, 10).realize() idxs = Tensor([0, 2]).realize() xt = X[idxs] - self.check_schedule(xt, 2) + run_schedule(check_schedule(xt, 2)) np.testing.assert_equal(xt.numpy(), X.numpy()[idxs.numpy()]) def test_simple_indexing_alt(self): X = Tensor.arange(16).reshape(4, 4) xt = X[[1, 2], [-1, 2]] - self.check_schedule(xt, 1) + run_schedule(check_schedule(xt, 1)) np.testing.assert_equal(xt.numpy(), (np.arange(16).reshape(4, 4))[[1, 2], [-1, 2]]) def test_advanced_indexing(self): X = Tensor.arange(10)+1 xt = X[[0, -1]] - self.check_schedule(xt, 1) + run_schedule(check_schedule(xt, 1)) np.testing.assert_equal(xt.numpy(), (np.arange(10)+1)[[0, -1]]) def test_advanced_indexing_alt(self): X = Tensor.arange(6).reshape(3, 2)+1 xt = X[[Tensor([2]), Tensor([1])]] - self.check_schedule(xt, 3) + run_schedule(check_schedule(xt, 3)) np.testing.assert_equal(xt.numpy(), 6) def test_advanced_simple_indexing_combined(self): X = Tensor.arange(16).reshape(4, 4) xt = X[1:2, [-1, 2]] - self.check_schedule(xt, 1) + run_schedule(check_schedule(xt, 1)) def test_push_through_reshape(self): Tensor.manual_seed(0) x = Tensor.randn(10, 20).realize() out = x.argmax(1) - self.check_schedule(out, 2) + run_schedule(check_schedule(out, 2)) np.testing.assert_allclose(out.numpy(), np.argmax(x.numpy(), 1)) def test_arange_push_through_expand(self): @@ -1774,35 +1779,35 @@ class TestIndexing(unittest.TestCase): a = Tensor.arange(4,) b = Tensor.randn(4, 4).realize() out = (a+b).sum() - self.check_schedule(out, 1) + run_schedule(check_schedule(out, 1)) np.testing.assert_allclose(out.numpy(), (np.arange(4)+b.numpy()).sum(), atol=1e-5) def test_argmin(self): Tensor.manual_seed(0) x = Tensor.randn(4, 32).realize() out = x.argmin(-1) - self.check_schedule(out, 2) + run_schedule(check_schedule(out, 2)) np.testing.assert_equal(out.numpy(), x.numpy().argmin(axis=-1)) def test_argmax(self): Tensor.manual_seed(0) x = Tensor.randn(4, 32).realize() out = x.argmax(-1) - self.check_schedule(out, 2) + run_schedule(check_schedule(out, 2)) np.testing.assert_equal(out.numpy(), x.numpy().argmax(axis=-1)) def test_arange_transposed(self): Tensor.manual_seed(0) x = Tensor.randint(4, 1).realize() a = ((Tensor.arange(4,)*x).T).sum() - self.check_schedule(a, 1) + run_schedule(check_schedule(a, 1)) np.testing.assert_equal(a.numpy(), (np.arange(4)*x.numpy()).T.sum()) def test_div_padded_arange(self): x = Tensor.full((2,2), 16) y = x.idiv(Tensor.linspace(2, 8, steps=4, dtype=dtypes.int).reshape(2,2)).pad(((1,1), (1,1))) out = y.sum(axis=1) - with Context(FUSE_ARANGE=1): run_schedule(check_schedule(out, 2)) + run_schedule(check_schedule(out, 2)) self.assertListEqual(out.tolist(), [0, 12, 4, 0]) def test_arange_transposed_descendants(self): @@ -1811,7 +1816,7 @@ class TestIndexing(unittest.TestCase): a = (Tensor.arange(4,)*x).T b = Tensor.randint(4, 4).realize() out = (a+b).sum() - self.check_schedule(out, 1) + run_schedule(check_schedule(out, 1)) np.testing.assert_equal(out.numpy(), ((np.arange(4)*x.numpy()).T+b.numpy()).sum()) def test_arange_index(self): @@ -1819,7 +1824,7 @@ class TestIndexing(unittest.TestCase): x = Tensor.randn(5, 2).realize() a = Tensor.arange(10) out = (x + a[2]).sum() - self.check_schedule(out, 1) + run_schedule(check_schedule(out, 1)) np.testing.assert_allclose(out.numpy(), (x.numpy()+np.arange(10)[2]).sum(), atol=1e-5, rtol=1e-6) def test_arange_index_shrink(self): @@ -1828,14 +1833,14 @@ class TestIndexing(unittest.TestCase): x = Tensor.randn(11).realize() a = Tensor.arange(22) out = (x + a[:11]).sum() - self.check_schedule(out, 1) + check_schedule(out, 1) def test_arange_index_contiguous(self): Tensor.manual_seed(0) x = Tensor.randn(5, 2).realize() a = Tensor.arange(10).contiguous() out = (x + a[2]).sum() - self.check_schedule(out, 3) + run_schedule(check_schedule(out, 3)) np.testing.assert_allclose(out.numpy(), (x.numpy()+np.arange(10)[2]).sum(), atol=1e-5, rtol=1e-6) def test_arange_index_child(self): @@ -1843,62 +1848,24 @@ class TestIndexing(unittest.TestCase): x = Tensor.randn(5, 2).realize() a = Tensor.arange(10)+1 out = (x + a[2]).sum() - self.check_schedule(out, 1) + run_schedule(check_schedule(out, 1)) np.testing.assert_allclose(out.numpy(), (x.numpy()+(np.arange(10)+1)[2]).sum(), atol=1e-5, rtol=1e-6) - def test_arange_index_contiguous_child(self): + def test_user_contiguous(self): Tensor.manual_seed(0) x = Tensor.randn(5, 2).realize() a = (Tensor.arange(10)+1).contiguous() out = (x + a[2]).sum() - self.check_schedule(out, 3) + run_schedule(check_schedule(out, 3)) np.testing.assert_allclose(out.numpy(), (x.numpy()+(np.arange(10)+1)[2]).sum(), atol=1e-5, rtol=1e-6) - def test_arange_childless_base(self): - a = Tensor.arange(4) - self.check_schedule(a, 1) - np.testing.assert_equal(a.numpy(), np.arange(4)) - - def test_arange_childless_view(self): - a = Tensor.arange(4).reshape(2, 2) - a[0] = 4 - np.testing.assert_equal(a.numpy(), [[4, 4], [2, 3]]) - - def test_arange_group_childless_base(self): - Tensor.manual_seed(0) - x = Tensor.randint(4).realize() - a = Tensor.arange(4)+x - self.check_schedule(a, 1) - np.testing.assert_equal(a.numpy(), np.arange(4)+x.numpy()) - - def test_arange_group_childless_view(self): - Tensor.manual_seed(0) - x = Tensor.ones(4).contiguous().realize() - a = Tensor.arange(4)+x - a[0] = 6 - np.testing.assert_equal(a.numpy(), [6., 2., 3., 4.]) - @unittest.skip("BUFFER_VIEW no longer supported on non-disk devices") def test_arange_view_op(self): a = Tensor.arange(12).reshape(4, 3).shrink(((1, 2), (1, 3))).contiguous() - sched = self.check_schedule(a, 1) + sched = run_schedule(check_schedule(a, 1)) self.assertIs(sched[1].ast.op, Ops.BUFFER_VIEW) np.testing.assert_equal(a.numpy(), [[4, 5]]) - @unittest.skipIf(Device.DEFAULT == "CPU", "tests copy from ext device") - def test_arange_shrink_copy(self): - a = Tensor.arange(12).reshape(4, 3).shrink(((1, 2), (1, 3))).to("CPU") - sched = self.check_schedule(a, 2) # NOTE: there is a contiguous between REDUCE_AXIS and COPY - self.assertIs(sched[-1].ast.op, Ops.COPY) - np.testing.assert_equal(a.numpy(), [[4, 5]]) - - @unittest.skipIf(Device.DEFAULT == "CPU", "tests copy from ext device") - def test_arange_expand_copy(self): - a = Tensor.arange(4).reshape(2, 2, 1).expand(2, 2, 2).contiguous().to("CPU") - sched = self.check_schedule(a, 2) # NOTE: there is a contiguous between REDUCE_AXIS and COPY - self.assertIs(sched[2].ast.op, Ops.COPY) - np.testing.assert_equal(a.numpy(), [[[0, 0], [1, 1]], [[2, 2], [3, 3]]]) - @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") def test_precompute_freqs_cis(self): from extra.models.llama import precompute_freqs_cis @@ -1914,23 +1881,33 @@ class TestIndexing(unittest.TestCase): def test_fuse_assign_contiguous(self): x = Tensor.zeros(4, 4, dtype=dtypes.int).contiguous().realize() a = Tensor.arange(8).reshape(4, 2) - self.check_schedule(x.shrink((None, (0, 2))).assign(a.contiguous()), 2) + run_schedule(check_schedule(x.shrink((None, (0, 2))).assign(a.contiguous()), 2)) np.testing.assert_equal(x.numpy(), [[0, 1, 0, 0], [2, 3, 0, 0], [4, 5, 0, 0], [6, 7, 0, 0]]) - def test_assign_non_contiguous(self): - x = Tensor.zeros(4, 4, dtype=dtypes.int).contiguous().realize() - y = Tensor.randint(4, 2) - a = Tensor.arange(8).reshape(4, 2)+y - x.shrink((None, (0, 2))).assign(a).realize() - xref = np.zeros((4, 4), dtype=int) - xref[:, :2] = np.arange(8).reshape(4, 2)+y.numpy() + def test_assign_non_contiguous_alt(self): self.test_assign_non_contiguous(alt=True) + def test_assign_non_contiguous(self, alt=False): + x = (Tensor.arange(16)-100).reshape(4,4).contiguous().realize() + xref = x.numpy() + if alt: + y = Tensor.randint(2, 4).contiguous().realize() + a = Tensor.arange(8).reshape(2, 4)+y + tst = x.shrink(((0, 2), None)).assign(a).realize() + xref[:2, :] = np.arange(8).reshape(2, 4)+y.numpy() + else: + y = Tensor.randint(4, 2).contiguous().realize() + a = Tensor.arange(8).reshape(4, 2)+y + tst = x.shrink((None, (0, 2))).assign(a).realize() + xref[:, :2] = np.arange(8).reshape(4, 2)+y.numpy() np.testing.assert_equal(x.numpy(), xref) + if RANGEIFY > 0: + # NOTE: this is a bug on non rangeify + np.testing.assert_equal(tst.numpy(), a.numpy()) def test_sparse_categorical_crossentropy_simple(self): X = Tensor([[0, 2, 3], [1, 2, 3]]).realize() Y = Tensor([1, 2]).realize() loss = X.sparse_categorical_crossentropy(Y) - self.check_schedule(loss, 4) + run_schedule(check_schedule(loss, 4)) np.testing.assert_allclose(loss.item(), 0.878309, atol=1e-5, rtol=1e-6) @unittest.skipIf(Device.DEFAULT == "WEBGPU", "Validation error on WebGPU") @@ -1942,28 +1919,21 @@ class TestIndexing(unittest.TestCase): yt = Tensor.randn(BS, 10).realize() with Context(SPLIT_REDUCEOP=0): loss = yt.sparse_categorical_crossentropy(Y_train[samples]) - self.check_schedule(loss, 6) + run_schedule(check_schedule(loss, 6)) loss_fused = loss.numpy() loss_ref = torch.nn.CrossEntropyLoss()(torch.tensor(yt.numpy()), torch.tensor(Y_train.numpy())[torch.tensor(samples.numpy())]) np.testing.assert_allclose(loss_fused, loss_ref.numpy(), atol=1e-6, rtol=1e-6) - @unittest.expectedFailure def test_arange_fuse_grouped_children(self): X = Tensor.randn(4, 4).realize() r = (X+Tensor.arange(16).reshape(4, 4)).sum() out0 = r+2 out1 = r+3 - self.check_schedule([out0, out1], 1) + run_schedule(check_schedule([out0, out1], 1 if RANGEIFY else 3)) r_ref = (X.numpy()+np.arange(16).reshape(4, 4)).sum() np.testing.assert_allclose(out0.numpy(), r_ref+2, rtol=2e-7) np.testing.assert_allclose(out1.numpy(), r_ref+3, rtol=2e-7) - def test_dont_fold_arange_contiguous_view(self): - X = Tensor.randn(4, 4).realize() - r = (X+Tensor.arange(16).reshape(4, 4).contiguous()).sum(1, keepdim=True) - self.check_schedule([r], 2) - np.testing.assert_allclose(r.numpy(), (X.numpy()+np.arange(16).reshape(4, 4)).sum(1, keepdims=True), atol=1e-5, rtol=1e-6) - @unittest.skip("multi output isn't supported") def test_multiview_arange_children(self): X = Tensor.randn(2,3,4,4).numpy() @@ -2093,6 +2063,7 @@ class TestView(unittest.TestCase): run_schedule(sched) np.testing.assert_equal(b.numpy(), 0) + @expect_rangeify_fails def test_mask_dim_1(self): # mask out dim = 1 works too a = Tensor.rand(10, 10).realize() @@ -2119,6 +2090,7 @@ class TestView(unittest.TestCase): # a*VIEW(x), where VIEW(x) = 0 # x collapses along with its children + @unittest.skipIf(RANGEIFY, "this only fails if you run all of TestSchedule, some global tensor map bug?") def test_parent_view_collapses(self): a = Tensor([1, 2]) b = Tensor.arange(3).contiguous() @@ -2136,6 +2108,7 @@ class TestView(unittest.TestCase): # a*VIEW(x), where VIEW(x) = 0 # x+2 # as long as one child realizes, x does not collapse + @expect_rangeify_fails def test_parent_multiple_children_no_collapse(self): a = Tensor([1, 2]) b = Tensor.arange(3).contiguous() @@ -2200,84 +2173,6 @@ class TestSimplifier(unittest.TestCase): assert UPat(Ops.CONST, arg=False).match(sink, {}), f"expected {sink} to collapse to a const False" assert sink.shape == a.shape -tensor_const_pm = PatternMatcher([ - (UPat(Ops.CONST, src=(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),)),)), lambda: True), - (UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR, src=(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),)))), UPat(Ops.CONST))), lambda: True), -]) -class TestConst(unittest.TestCase): - # ** part 1: basic functionality of a tensor directly created from CONST - - def test_tensor_const(self): - a = Tensor(1) - print(a.uop) - self.assertTrue(tensor_const_pm.rewrite(a.uop)) - - def test_tensor_variable(self): - vv = UOp.variable("a", 0, 10).bind(1) - a = Tensor(vv) - print(a.uop) - self.assertTrue(tensor_const_pm.rewrite(a.uop)) - - def test_const_schedule(self): - a = Tensor.ones((4, 4)) - sched = a.schedule() - self.assertEqual(len(sched), 0) - - def test_const_contiguous_schedule(self): - # this ends up in the big graph - a = Tensor.ones((4,)).contiguous() - sched = a.schedule() - self.assertEqual(len(sched), 1) - - # ** part 2: scheduler behavior when const folding happens later - - def test_const_folding_no_realize(self): - a = Tensor([1, 2, 3, 4])*0 - sched = a.schedule() - self.assertEqual(len(sched), 0) - - def test_src_const_folding(self): - with Context(TRACK_MATCH_STATS=0): - a = Tensor.full((4,), 1).contiguous().realize() - b = Tensor.full((4,), 2).contiguous().realize() - mul0 = a*0 - add = b+mul0 - sched = add.schedule() - self.assertEqual(len(sched), 0) - # b+0 and b share the same underlying device memory - self.assertIs(add.uop.buffer, b.uop.buffer) - self.assertListEqual(add.tolist(), [2, 2, 2, 2]) - - def test_src_masked_const_folding(self): - with Context(TRACK_MATCH_STATS=0): - a = Tensor.full((4,), 1).contiguous().realize() - b = Tensor.full((6,), 2).contiguous().realize() - mul0 = a*0 - add = b+mul0.pad((1, 1), value=2) - sched = add.schedule() - self.assertEqual(len(sched), 1) - run_schedule(sched) - # add gets assigned to a new buffer - self.assertIsNot(add.uop.base.realized, b.uop.base.realized) - self.assertListEqual(add.tolist(), [4, 2, 2, 2, 2, 4]) - - # ** part 3: Tensor variable bindings - - #@unittest.expectedFailure # TODO: should schedule assert if you try to realize a Variable? - def test_var_schedule(self): - vv = UOp.variable("a", 0, 10).bind(1) - a = Tensor(vv) - sched = a.schedule() - self.assertEqual(len(sched), 0) - - def test_add_tvar(self): - vv = UOp.variable("a", 0, 10).bind(1) - a = Tensor(vv)+2 - sched, var_vals = a.schedule_with_vars() - self.assertEqual(len(sched), 1) - run_schedule(sched, var_vals) - self.assertEqual(a.tolist(), 3) - @unittest.skipIf(Device.DEFAULT == "CPU", "tests copy from another device to cpu") class TestCopyFolding(unittest.TestCase): def test_const_copy_is_free(self): @@ -2291,6 +2186,7 @@ class TestCopyFolding(unittest.TestCase): b = (a*zeros).to("CPU") run_schedule(check_schedule(b, 0, filter_sink=False)) self.assertListEqual(b.tolist(), [0, 0, 0]) + self.assertEqual(b.device, "CPU") def test_alu_after_copy(self): a = Tensor.ones((4,)).to("CPU") @@ -2299,6 +2195,12 @@ class TestCopyFolding(unittest.TestCase): add.kernelize() assert all_same([x.device for x in add.uop.src]), f"ALU has different devices! {[x.device for x in add.src]}" + def test_alu_before_copy(self): + buf = Tensor.ones(1).contiguous().realize() + a = buf+1 + b = a.to("CPU") + self.assertListEqual(b.tolist(), [2.]) + def test_copy_to_same_device(self): a = Tensor.empty(4).uop b = a.copy_to_device(a.device) @@ -2345,6 +2247,7 @@ class TestCopyFolding(unittest.TestCase): b.realize() self.assertListEqual(b.tolist(), [[0, 2], [1, 3]]) + @expect_rangeify_fails def test_permute_on_disk(self): with open(temp('dt_arange_4_permute'), "wb") as f: f.write(Tensor.arange(4).realize().uop.base.buffer.as_buffer()) a = Tensor.empty(4, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_4_permute')}") @@ -2491,6 +2394,7 @@ class TestUOpBecome(unittest.TestCase): self.assertEqual(add.uop.shape, (8, 2)) assert add.uop is not add.uop.base + @expect_rangeify_fails def test_new_flat_buffer(self): a = Tensor.empty(4,) b = Tensor.empty(4,) @@ -2516,6 +2420,7 @@ class TestUOpBecome(unittest.TestCase): z = (img*x) / y check_schedule(z, 1) + @expect_rangeify_fails def test_become_existing_buffer(self): a = Tensor.empty(4, 4) b = a*1 @@ -2543,6 +2448,7 @@ class TestUOpBecome(unittest.TestCase): check_schedule(b, 0) assert UPat(Ops.CONST, arg=0).match(b.uop.base, {}) # scheduling replaces the tensor uop with a VIEW(BUFFER) + @expect_rangeify_fails def test_become_const_in_view(self): # if we shrink the base down to a size 0, only the VIEW becomes CONST, base is unchanged. add = Tensor.empty(2, 2)+Tensor.empty(2, 2) @@ -2560,6 +2466,7 @@ class TestUOpBecome(unittest.TestCase): assert UPat(Ops.CONST, arg=3).match(const_add.uop.base, {}) # tensors can become another realized tensor source + @expect_rangeify_fails def test_become_existing_buf_simple(self): a = Tensor.empty(4, 4) b = a+0 @@ -2568,12 +2475,14 @@ class TestUOpBecome(unittest.TestCase): self.assertIs(a.uop, b.uop) # they can also chain other movement ops on top of the tensor source + @expect_rangeify_fails def test_become_existing_buf_view(self): a = Tensor.empty(4, 4) b = a.permute((1, 0))+0 check_schedule(b, 0) self.assertEqual(b.uop.st, a.uop.permute((1, 0)).st) + @expect_rangeify_fails def test_become_existing_buf_view_alt(self): a = Tensor.empty(4, 4) b = a.permute((1, 0)).reshape((8, 2))+0 @@ -2581,6 +2490,7 @@ class TestUOpBecome(unittest.TestCase): self.assertEqual(b.uop.st, a.uop.permute((1, 0)).reshape((8, 2)).st) # they can also have other base parents that simplified, in that case we just backtrack to the chained mops + @expect_rangeify_fails def test_become_existing_buf_complex(self): a = Tensor.empty(4, 4) b = (a.permute((1, 0))+0).reshape((8, 2))+0 @@ -2588,6 +2498,7 @@ class TestUOpBecome(unittest.TestCase): self.assertEqual(b.uop.st, a.uop.permute((1, 0)).reshape((8, 2)).st) assert b.uop.base.op is Ops.BUFFER + @expect_rangeify_fails def test_become_multiple_choices(self): a = Tensor.empty(16) b = (a.reshape(1, 1, 4, 1, 4)+0).reshape(1, 1, 4, 4).shrink(((0, 1), (0, 1), (0, 3), (0, 3)))+0 @@ -2599,6 +2510,7 @@ class TestUOpBecome(unittest.TestCase): assert b.uop is c.uop assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)).match(c.uop, {}) + @expect_rangeify_fails def test_setitem_becomes_subbuffer(self): a = Tensor.full((4,), 2.).contiguous().realize() b = a.shrink(((0, 2),)).assign(Tensor.full((2,), 1.0)) diff --git a/tinygrad_repo/test/test_search.py b/tinygrad_repo/test/test_search.py deleted file mode 100644 index 011a4948..00000000 --- a/tinygrad_repo/test/test_search.py +++ /dev/null @@ -1,146 +0,0 @@ -import unittest - -from tinygrad.codegen.opt.kernel import Opt, OptOps, Kernel -from tinygrad.uop.ops import UOp, Ops -from tinygrad.codegen.opt.search import bufs_from_lin, actions, beam_search -from tinygrad.device import Device -from tinygrad.tensor import Tensor -from tinygrad.dtype import dtypes -from tinygrad.helpers import Context, GlobalCounters -from tinygrad.engine.realize import capturing -from tinygrad.shape.shapetracker import ShapeTracker -from tinygrad.shape.view import View -from extra.optimization.helpers import time_linearizer - -class TestBEAM(unittest.TestCase): - def test_dynamic_beam(self): - # TODO: make this infra globally usable - class Capture: - def __init__(self): self.captured = [] - def add(self, x): self.captured.append(x) - - capturing.append(Capture()) - kernel_count = GlobalCounters.kernel_count - with Context(BEAM=1): Tensor.zeros(16).contiguous().realize() - assert GlobalCounters.kernel_count == kernel_count + 1 - k_beam_1 = capturing[0].captured - capturing.clear() - - capturing.append(Capture()) - kernel_count = GlobalCounters.kernel_count - with Context(BEAM=0): Tensor.zeros(16).contiguous().realize() - assert GlobalCounters.kernel_count == kernel_count + 1 - k_beam_0 = capturing[0].captured - capturing.clear() - self.assertNotEqual(k_beam_0[-1].prg.p.src, k_beam_1[-1].prg.p.src) - - def test_get_kernel_actions_dedup(self): - from test.test_linearizer import helper_realized_ast - from tinygrad.codegen.opt.search import get_kernel_actions - a = Tensor.empty(4, 3) - b = Tensor.empty(3) - realized_ast, _ = helper_realized_ast(a @ b) - candidates = [ - Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.UPCAST, axis=0, arg=4), - Opt(op=OptOps.LOCAL, axis=0, arg=0), Opt(op=OptOps.LOCAL, axis=0, arg=4), - Opt(op=OptOps.UNROLL, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=3), - Opt(op=OptOps.GROUP, axis=0, arg=0), Opt(op=OptOps.GROUP, axis=0, arg=3), - Opt(op=OptOps.GROUPTOP, axis=0, arg=0), Opt(op=OptOps.GROUPTOP, axis=0, arg=3), - ] - lins = get_kernel_actions(Kernel(realized_ast), include_0=False, candidates=candidates).values() - - # ensure amt=0 are not duplicated - assert all(len(x.applied_opts) == 1 for x in lins) - kernel_actions = [x.applied_opts[0] for x in lins] - assert Opt(OptOps.UPCAST, axis=0, arg=4) not in kernel_actions, "did not de-dup UPCAST" - assert Opt(OptOps.LOCAL, axis=0, arg=4) not in kernel_actions, "did not de-dup LOCAL" - assert Opt(OptOps.UNROLL, axis=0, arg=3) not in kernel_actions, "did not de-dup UNROLL" - assert Opt(OptOps.GROUP, axis=0, arg=3) not in kernel_actions, "did not de-dup GROUP" - assert Opt(OptOps.GROUPTOP, axis=0, arg=3) not in kernel_actions, "did not de-dup GROUPTOP" - - @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") - def test_search_over_shape(self): - from test.test_linearizer import helper_realized_ast - from tinygrad.codegen.opt.search import get_kernel_actions - - dtype_pairs = [(tc.dtype_in, tc.dtype_out) for tc in Device[Device.DEFAULT].renderer.tensor_cores] - multi_shape_dtype_pairs = [dts for dts in dtype_pairs if dtype_pairs.count(dts) > 1] - - if len(multi_shape_dtype_pairs) == 0: raise unittest.SkipTest("only one tc available per dtype pair to search over") - - for (dtype_in, dtype_out) in multi_shape_dtype_pairs: - a = Tensor.rand(16, 16, dtype=dtype_in) - b = Tensor.rand(16, 16, dtype=dtype_in) - realized_ast, _ = helper_realized_ast(a.matmul(b, dtype=dtype_out)) - - lins = get_kernel_actions(Kernel(realized_ast)).values() - assert len(set(lin.tensor_core.dims for lin in lins if lin.tensor_core is not None)) > 1 - - def test_get_kernel_actions_preserves_actions_state(self): - from test.test_linearizer import helper_realized_ast - from tinygrad.codegen.opt.search import get_kernel_actions - a = Tensor.rand(16, 16) - b = Tensor.rand(16, 16) - realized_ast, _ = helper_realized_ast(a @ b) - actions_before = actions.copy() - get_kernel_actions(Kernel(realized_ast)) - actions_after = actions.copy() - assert actions_after == actions_before, "actions state was not preserved" - - @unittest.skip("invalid reduce now") - def test_filter_global_buffer(self): - # taken from https://github.com/tinygrad/tinygrad/issues/4612 - ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( - UOp(Ops.STORE, dtypes.void, arg=None, src=( - UOp(Ops.VIEW, dtypes.float.ptr(256), arg=ShapeTracker(views=(View(shape=(1, 1, 256), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=( # noqa: E501 - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(256), arg=0, src=()),)), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.MAX, (1,)), src=( - UOp(Ops.MUL, dtypes.float, arg=None, src=( - UOp(Ops.ADD, dtypes.float, arg=None, src=( - UOp(Ops.ADD, dtypes.float, arg=None, src=( - UOp(Ops.ADD, dtypes.float, arg=None, src=( - UOp(Ops.ADD, dtypes.float, arg=None, src=( - UOp(Ops.ADD, dtypes.float, arg=None, src=( - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.VIEW, dtypes.float.ptr(64128), arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=0, mask=((0, 64128),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=( # noqa: E501 - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64128), arg=1, src=()),)),)), - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.VIEW, dtypes.float.ptr(64128), arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-64128, mask=((64128, 128256),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=( # noqa: E501 - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64128), arg=2, src=()),)),)),)), - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.VIEW, dtypes.float.ptr(64128), arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-128256, mask=((128256, 192384),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=( # noqa: E501 - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64128), arg=3, src=()),)),)),)), - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.VIEW, dtypes.float.ptr(64128), arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-192384, mask=((192384, 256512),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=( # noqa: E501 - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64128), arg=4, src=()),)),)),)), - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.VIEW, dtypes.float.ptr(64128), arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-256512, mask=((256512, 320640),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=( # noqa: E501 - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64128), arg=5, src=()),)),)),)), - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.VIEW, dtypes.float.ptr(64128), arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-320640, mask=((320640, 384768),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=( # noqa: E501 - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64128), arg=6, src=()),)),)),)), - UOp(Ops.CONST, dtypes.float, arg=1.4285714285714286, src=( - UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 501, 256), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) # noqa: E501 - lin = Kernel(ast) - - bufs = bufs_from_lin(lin) - best_lin = beam_search(lin, bufs, 2) - assert best_lin - # need disable_cache to trigger. - tm = time_linearizer(best_lin, bufs, allow_test_size=False, cnt=2, disable_cache=True) - assert tm - - def test_beam_unnamed_kernels(self): - from test.test_linearizer import push_views - a = Tensor.rand(100) - b = Tensor.rand(100) - si = (a+b).schedule()[-1] - lin = Kernel(push_views(si.ast)) - bufs = bufs_from_lin(lin) - # TODO: beam should have better instrumentation so we don't have to check this indirect thing - kcount = len(Kernel.kernel_cnt) - beam_search(lin, bufs, 3, disable_cache=True) - self.assertEqual(kcount, len(Kernel.kernel_cnt)) - -if __name__ == '__main__': - unittest.main() diff --git a/tinygrad_repo/test/test_setitem.py b/tinygrad_repo/test/test_setitem.py index 967acc29..2005b7c8 100644 --- a/tinygrad_repo/test/test_setitem.py +++ b/tinygrad_repo/test/test_setitem.py @@ -1,4 +1,6 @@ import unittest +import random +from os import getenv from tinygrad import Tensor, TinyJit, Variable, dtypes from tinygrad.helpers import Context import numpy as np @@ -165,6 +167,41 @@ class TestSetitem(unittest.TestCase): t[idx] = val self.assertEqual(t.tolist(), [val]*idx_size+[idx_size]) + def test_setitem_advanced_indexing(self): + # Example from https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing + t = Tensor.zeros(10,20,30,40,50).contiguous() + ind_1 = Tensor([5,3,7,8]) + ind_2 = Tensor([[[0],[1],[2]],[[3],[4],[5]]]) + v = Tensor.arange(2*3*4*10*30*50).reshape(2,3,4,10,30,50) + t[:, ind_1, :, ind_2, :] = v + n = np.zeros((10,20,30,40,50)) + n[:, ind_1.numpy(), :, ind_2.numpy(), :] = v.numpy() + np.testing.assert_allclose(t.numpy(), n) + + def test_setitem_2d_tensor_indexing(self): + t = Tensor.zeros(2).contiguous() + index = Tensor([[0, 1], [1,0]]) + v = Tensor.arange(2*2).reshape(2, 2).contiguous() + t[index] = v + n = np.zeros((2,)) + n[index.numpy()] = v.numpy() + np.testing.assert_allclose(t.numpy(), n) + + @unittest.skip("slow") + def test_setitem_tensor_indexing_fuzz(self): + random.seed(getenv("SEED", 42)) + for _ in range(getenv("ITERS", 100)): + size = random.randint(5, 10) + d0, d1, d2 = random.randint(1,5), random.randint(1,5), random.randint(1,5) + t = Tensor.zeros(size).contiguous() + n = np.zeros((size,)) + index = Tensor.randint((d0, d1, d2), low=0, high=size) + v = Tensor.arange(d0*d1*d2).reshape(d0, d1, d2) + t[index] = v + n[index.numpy()] = v.numpy() + np.testing.assert_allclose(t.numpy(), n, err_msg=f"failed with index={index.numpy().tolist()} and v={v.numpy().tolist()}") + + class TestWithGrad(unittest.TestCase): def test_no_requires_grad_works(self): z = Tensor.rand(8, 8) diff --git a/tinygrad_repo/test/test_symbolic_jit.py b/tinygrad_repo/test/test_symbolic_jit.py index 881ce334..f28d274d 100644 --- a/tinygrad_repo/test/test_symbolic_jit.py +++ b/tinygrad_repo/test/test_symbolic_jit.py @@ -2,50 +2,55 @@ import unittest from test.helpers import assert_jit_cache_len from tinygrad import Variable, Tensor, TinyJit -from tinygrad.helpers import Context +from tinygrad.helpers import RANGEIFY import numpy as np class TestSymbolicJit(unittest.TestCase): - def setUp(self): - # A lot of these test are out of bounds, so we ignore the bounds check - self.context = Context(IGNORE_OOB=1) - self.context.__enter__() - - def tearDown(self): - self.context.__exit__(None, None, None) - def test_plus1(self): def f(a): return (a+1).realize() jf = TinyJit(f) + a = Tensor.rand(3, 10) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(3, i) - symbolic = jf(a.reshape(3, vi)).reshape(3, i).numpy() - expected = f(a).numpy() + symbolic = jf(a[:, :vi])[:3, :i].numpy() + expected = f(a[:, :i]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) assert_jit_cache_len(jf, 1) + def test_plus1_pad(self): + # TODO: without contiguous, the pad is not captured in jit + def f(a): return (a+1).pad((None, (0, 10-a.shape[1]))).contiguous().realize() + jf = TinyJit(f) + a = Tensor.rand(3, 10) + for i in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + symbolic = jf(a[:, :vi]).numpy() + expected = f(a[:, :i]).numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + assert_jit_cache_len(jf, 1 if RANGEIFY else 2) # one add and one pad, can be one kernel? + def test_add(self): def f(a, b): return (a+b).realize() jf = TinyJit(f) + a = Tensor.rand(3, 10) + b = Tensor.rand(3, 10) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(3, i) - b = Tensor.rand(3, i) - symbolic = jf(a.reshape(3, vi), b.reshape(3, vi)).reshape(3, i).numpy() - expected = f(a, b).numpy() + symbolic = jf(a[:, :vi], b[:, :vi]) + symbolic = symbolic[:3, :i].numpy() + expected = f(a[:, :i], b[:, :i]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) assert_jit_cache_len(jf, 1) def test_matmul(self): def f(a, b): return (a@b).realize() jf = TinyJit(f) + a = Tensor.rand(3, 10) + b = Tensor.rand(10, 5) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(3, i) - b = Tensor.rand(i, 5) - symbolic = jf(a.reshape(3, vi), b.reshape(vi, 5)).numpy() - expected = f(a, b).numpy() + symbolic = jf(a[:, :vi], b[:vi, :]).numpy() + expected = f(a[:, :i], b[:i, :]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) assert_jit_cache_len(jf, 1) @@ -55,119 +60,119 @@ class TestSymbolicJit(unittest.TestCase): s = (s+s).realize() # this one does not have symbols in input return s jf = TinyJit(f) + a = Tensor.rand(3, 10) + b = Tensor.rand(10, 5) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(3, i) - b = Tensor.rand(i, 5) - symbolic = jf(a.reshape(3, vi), b.reshape(vi, 5)).numpy() - expected = f(a, b).numpy() + symbolic = jf(a[:, :vi], b[:vi, :]).numpy() + expected = f(a[:, :i], b[:i, :]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) assert_jit_cache_len(jf, 2) def test_attention(self): def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).realize() jf = TinyJit(f) + q = Tensor.rand(2, 1, 4, 8) + k = Tensor.rand(2, 10, 4, 8) + v = Tensor.rand(2, 10, 4, 8) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - q = Tensor.rand(2, 1, 4, 8) - k = Tensor.rand(2, i, 4, 8) - v = Tensor.rand(2, i, 4, 8) - symbolic = jf(q, k.reshape(2, vi, 4, 8), v.reshape(2, vi, 4, 8)).reshape(2, 4, 1, 8).numpy() - expected = f(q, k, v).numpy() + symbolic = jf(q, k[:, :vi], v[:, :vi])[:2, :4, :1, :8].numpy() + expected = f(q, k[:, :i], v[:, :i]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - assert_jit_cache_len(jf, 5) + assert_jit_cache_len(jf, 4 if RANGEIFY else 5) def test_cat_dim0(self): def f(a, b): return a.cat(b, dim=0).realize() jf = TinyJit(f) + a = Tensor.rand(10, 3) + b = Tensor.rand(2, 3) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(i, 3) - b = Tensor.rand(2, 3) - symbolic = jf(a.reshape(vi, 3), b).reshape(i+2, 3).numpy() - expected = f(a, b).numpy() + symbolic = jf(a[:vi], b)[:i+2, :3].numpy() + expected = f(a[:i], b).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) assert_jit_cache_len(jf, 1) def test_cat_dim1(self): def f(a, b): return a.cat(b, dim=1).realize() jf = TinyJit(f) + a = Tensor.rand(3, 10) + b = Tensor.rand(3, 2) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(3, i) - b = Tensor.rand(3, 2) - symbolic = jf(a.reshape(3, vi), b).reshape(3, i+2).numpy() - expected = f(a, b).numpy() + symbolic = jf(a[:, :vi], b)[:3, :i+2].numpy() + expected = f(a[:, :i], b).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) assert_jit_cache_len(jf, 1) def test_cat_dim0_two_vars(self): def f(a, b): return a.cat(b, dim=0).realize() jf = TinyJit(f) - for i in range(1, 5): - for j in range(1, 5): + a = Tensor.rand(10, 3) + b = Tensor.rand(10, 3) + for i in range(2, 5): + for j in range(2, 5): vi = Variable("i", 1, 10).bind(i) vj = Variable("j", 1, 10).bind(j) - a = Tensor.rand(i, 3) - b = Tensor.rand(j, 3) - symbolic = jf(a.reshape(vi, 3), b.reshape(vj, 3)).reshape(i+j, 3).numpy() - expected = f(a, b).numpy() + symbolic = jf(a[:vi], b[:vj])[:i+j, :3].numpy() + expected = f(a[:i], b[:j]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) assert_jit_cache_len(jf, 1) def test_cat_dim1_two_vars(self): def f(a, b): return a.cat(b, dim=1).realize() jf = TinyJit(f) - for i in range(1, 5): - for j in range(1, 5): + a = Tensor.rand(3, 10) + b = Tensor.rand(3, 10) + for i in range(2, 5): + for j in range(2, 5): vi = Variable("i", 1, 10).bind(i) vj = Variable("j", 1, 10).bind(j) - a = Tensor.rand(3, i) - b = Tensor.rand(3, j) - symbolic = jf(a.reshape(3, vi), b.reshape(3, vj)).reshape(3, i+j).numpy() - expected = f(a, b).numpy() + symbolic = jf(a[:, :vi], b[:, :vj])[:3, :i+j].numpy() + expected = f(a[:, :i], b[:, :j]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) assert_jit_cache_len(jf, 1) def test_two_vars_plus1_ij(self): def f(a, b): return (a@b+1).realize() jf = TinyJit(f) - for i in range(1, 5): - for j in range(1, 5): + a = Tensor.rand(10, 3) + b = Tensor.rand(3, 10) + for i in range(2, 5): + for j in range(2, 5): vi = Variable("i", 1, 10).bind(i) vj = Variable("j", 1, 10).bind(j) - a = Tensor.rand(i, 3) - b = Tensor.rand(3, j) - symbolic = jf(a.reshape(vi, 3), b.reshape(3, vj)).reshape(i, j).numpy() - expected = f(a, b).numpy() + symbolic = jf(a[:vi, :], b[:, :vj])[:i, :j].numpy() + expected = f(a[:i, :], b[:, :j]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) assert_jit_cache_len(jf, 1) def test_two_vars_plus1_ji(self): def f(a, b): return (a@b+1).realize() jf = TinyJit(f) - for i in range(1, 5): - for j in range(1, 5): + a = Tensor.rand(10, 3) + b = Tensor.rand(3, 10) + for i in range(2, 5): + for j in range(2, 5): vi = Variable("i", 1, 10).bind(i) vj = Variable("j", 1, 10).bind(j) - a = Tensor.rand(j, 3) - b = Tensor.rand(3, i) - symbolic = jf(a.reshape(vj, 3), b.reshape(3, vi)).reshape(j, i).numpy() - expected = f(a, b).numpy() + symbolic = jf(a[:vj, :], b[:, :vi])[:j, :i].numpy() + expected = f(a[:j, :], b[:, :i]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) assert_jit_cache_len(jf, 1) def test_jit_symbolic_shape_mismatch(self): @TinyJit def add(a, b): return (a+b).realize() + a = Tensor.rand(3, 10) + b = Tensor.rand(3, 10) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(3, i).reshape(3, vi) - b = Tensor.rand(3, i).reshape(3, vi) - add(a, b) + add(a[:, :vi], b[:, :vi]) vi2 = Variable("i", 1, 10).bind(7) - a = Tensor.rand(3, 7).reshape(3, vi2) - bad = Tensor.rand(4, 7).reshape(4, vi2) + a = Tensor.rand(3, 7)[:, :vi2] + bad = Tensor.rand(4, 7)[:, :vi2] with self.assertRaises(AssertionError): add(a, bad) @@ -175,9 +180,9 @@ class TestSymbolicJit(unittest.TestCase): # shrink is a movement, so we pair it with a simple function to test the JIT interaction def f(a): return (a+1).realize() jf = TinyJit(f) + a = Tensor.rand(7, 11) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(7, 11) symbolic = a.shrink(((3,5),(vi,vi+2))) symbolic = jf(symbolic).numpy() expected = f(a.shrink(((3,5),(i,i+2)))).numpy() @@ -188,9 +193,9 @@ class TestSymbolicJit(unittest.TestCase): # slice is a movement, so we pair it with a simple function to test the JIT interaction def f(a): return (a+1).realize() jf = TinyJit(f) + a = Tensor.rand(7, 11) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(7, 11) symbolic = a[3:5, vi:vi+2] symbolic = jf(symbolic).numpy() expected = f(a[3:5, i:i+2]).numpy() @@ -204,19 +209,19 @@ class TestSymbolicJit(unittest.TestCase): vi = Variable("i", 1, 10).bind(i) a = Tensor.ones(vi, 11).contiguous() symbolic = a[:, 1:2] - symbolic = jf(symbolic).reshape(i, 1).numpy() - expected = f(a.reshape(i, 11)[:, 1:2]).numpy() + symbolic = jf(symbolic)[:i, :1].numpy() + expected = f(a[:i, :][:, 1:2]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) assert_jit_cache_len(jf, 1) def test_ones_sum(self): def f(a): return a.sum().realize() jf = TinyJit(f) + t = Tensor.ones(10) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - t = Tensor.ones(i) - symbolic = jf(t.reshape(vi)).item() - expected = f(t).item() + symbolic = jf(t[:vi]).item() + expected = f(t[:i]).item() np.testing.assert_equal(symbolic, expected) def test_mean(self): @@ -226,22 +231,22 @@ class TestSymbolicJit(unittest.TestCase): jf = TinyJit(f) jf0 = TinyJit(f0) jf1 = TinyJit(f1) + a = Tensor.rand(10, 3) + b = Tensor.rand(10, 3) + c = Tensor.rand(10, 3) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - # aixs = None - a = Tensor.rand(i, 3) - symbolic = jf(a.reshape(vi, 3)).numpy() - expected = a.mean().numpy() + # axis = None + symbolic = jf(a[:vi]).numpy() + expected = a[:i].mean().numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - # aixs = 0 - a = Tensor.rand(i, 3) - symbolic = jf0(a.reshape(vi, 3)).numpy() - expected = a.mean(0).numpy() + # axis = 0 + symbolic = jf0(b[:vi]).numpy() + expected = b[:i].mean(0).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - # aixs = 1 - a = Tensor.rand(i, 3) - symbolic = jf1(a.reshape(vi, 3)).reshape(i).numpy() - expected = a.mean(1).numpy() + # axis = 1 + symbolic = jf1(c[:vi])[:i].numpy() + expected = c[:i].mean(1).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_mean_2d(self): @@ -251,24 +256,24 @@ class TestSymbolicJit(unittest.TestCase): jf = TinyJit(f) jf0 = TinyJit(f0) jf1 = TinyJit(f1) - for i in range(1, 5): - for j in range(1, 5): + a = Tensor.rand(10, 10) + b = Tensor.rand(10, 10) + c = Tensor.rand(10, 10) + for i in range(2, 5): + for j in range(2, 5): vi = Variable("i", 1, 10).bind(i) vj = Variable("j", 1, 10).bind(j) - # aixs = None - a = Tensor.rand(i, j) - symbolic = jf(a.reshape(vi, vj)).numpy() - expected = a.mean().numpy() + # axis = None + symbolic = jf(a[:vi, :vj]).numpy() + expected = a[:i, :j].mean().numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - # aixs = 0 - a = Tensor.rand(i, j) - symbolic = jf0(a.reshape(vi, vj)).reshape(j).numpy() - expected = a.mean(0).numpy() + # axis = 0 + symbolic = jf0(b[:vi, :vj])[:j].numpy() + expected = b[:i, :j].mean(0).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - # aixs = 1 - a = Tensor.rand(i, j) - symbolic = jf1(a.reshape(vi, vj)).reshape(i).numpy() - expected = a.mean(1).numpy() + # axis = 1 + symbolic = jf1(c[:vi, :vj])[:i].numpy() + expected = c[:i, :j].mean(1).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_var(self): @@ -278,22 +283,22 @@ class TestSymbolicJit(unittest.TestCase): jf = TinyJit(f) jf0 = TinyJit(f0) jf1 = TinyJit(f1) + a = Tensor.rand(10, 3) + b = Tensor.rand(10, 3) + c = Tensor.rand(10, 3) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - # aixs = None - a = Tensor.rand(i, 3) - symbolic = jf(a.reshape(vi, 3)).numpy() - expected = a.var().numpy() + # axis = None + symbolic = jf(a[:vi]).numpy() + expected = a[:i].var().numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - # aixs = 0 - a = Tensor.rand(i, 3) - symbolic = jf0(a.reshape(vi, 3)).numpy() - expected = a.var(0).numpy() + # axis = 0 + symbolic = jf0(b[:vi]).numpy() + expected = b[:i].var(0).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - # aixs = 1 - a = Tensor.rand(i, 3) - symbolic = jf1(a.reshape(vi, 3)).reshape(i).numpy() - expected = a.var(1).numpy() + # axis = 1 + symbolic = jf1(c[:vi])[:i].numpy() + expected = c[:i].var(1).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_var_2d(self): @@ -303,24 +308,24 @@ class TestSymbolicJit(unittest.TestCase): jf = TinyJit(f) jf0 = TinyJit(f0) jf1 = TinyJit(f1) - for i in range(1, 5): - for j in range(1, 5): + a = Tensor.rand(10, 10) + b = Tensor.rand(10, 10) + c = Tensor.rand(10, 10) + for i in range(2, 5): + for j in range(2, 5): vi = Variable("i", 1, 10).bind(i) vj = Variable("j", 1, 10).bind(j) - # aixs = None - a = Tensor.rand(i, j) - symbolic = jf(a.reshape(vi, vj)).numpy() - expected = a.var().numpy() + # axis = None + symbolic = jf(a[:vi, :vj]).numpy() + expected = a[:i, :j].var().numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - # aixs = 0 - a = Tensor.rand(i, j) - symbolic = jf0(a.reshape(vi, vj)).reshape(j).numpy() - expected = a.var(0).numpy() + # axis = 0 + symbolic = jf0(b[:vi, :vj])[:j].numpy() + expected = b[:i, :j].var(0).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - # aixs = 1 - a = Tensor.rand(i, j) - symbolic = jf1(a.reshape(vi, vj)).reshape(i).numpy() - expected = a.var(1).numpy() + # axis = 1 + symbolic = jf1(c[:vi, :vj])[:i].numpy() + expected = c[:i, :j].var(1).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) if __name__ == '__main__': diff --git a/tinygrad_repo/test/test_symbolic_ops.py b/tinygrad_repo/test/test_symbolic_ops.py index f6264463..991a9dcc 100644 --- a/tinygrad_repo/test/test_symbolic_ops.py +++ b/tinygrad_repo/test/test_symbolic_ops.py @@ -1,62 +1,62 @@ import unittest -from tinygrad import Tensor, Variable +from tinygrad import Tensor, Variable, GlobalCounters from tinygrad.shape.shapetracker import View -from tinygrad.helpers import Context, GlobalCounters from tinygrad.uop.ops import sym_infer from tinygrad.dtype import dtypes -from tinygrad.device import Device +from tinygrad.device import is_dtype_supported from examples.gpt2 import Attention import numpy as np class TestSymbolicOps(unittest.TestCase): - def setUp(self): - # A lot of these test are out of bounds, so we ignore the bounds check - self.context = Context(IGNORE_OOB=1) - self.context.__enter__() - - def tearDown(self): - self.context.__exit__(None, None, None) - def test_plus1(self): def f(a): return (a+1).realize() + a = Tensor.rand(3, 10) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(3, i) - symbolic = f(a.reshape(3, vi)).reshape(3, i).numpy() - expected = f(a).numpy() + symbolic = f(a[:, :vi])[:3, :i].numpy() + expected = f(a[:, :i]).numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + + def test_plus1_pad(self): + def f(a): return (a+1).pad((None, (0, 10-a.shape[1]))).realize() + a = Tensor.rand(3, 10) + for i in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + symbolic = f(a[:, :vi]).numpy() + expected = f(a[:, :i]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_add(self): def f(a, b): return (a+b).realize() + a = Tensor.rand(3, 10) + b = Tensor.rand(3, 10) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(3, i) - b = Tensor.rand(3, i) - symbolic = f(a.reshape(3, vi), b.reshape(3, vi)).reshape(3, i).numpy() - expected = f(a, b).numpy() + symbolic = f(a[:, :vi], b[:, :vi])[:, :i].numpy() + expected = f(a[:, :i], b[:, :i]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_matmul(self): def f(a, b): return (a@b).realize() + a = Tensor.rand(3, 10) + b = Tensor.rand(10, 5) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(3, i) - b = Tensor.rand(i, 5) - symbolic = f(a.reshape(3, vi), b.reshape(vi, 5)).numpy() - expected = f(a, b).numpy() + symbolic = f(a[:, :vi], b[:vi, :]).numpy() + expected = f(a[:, :i], b[:i, :]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_attention(self, dropout_p=0.0, imin=1, imax=5, use_symbolic=True): def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=dropout_p).realize() + q = Tensor.rand(2, 1, 4, 8) + k = Tensor.rand(2, 10, 4, 8) + v = Tensor.rand(2, 10, 4, 8) for i in range(imin, imax): vi = Variable("i", 1, 10).bind(i) if use_symbolic else i - q = Tensor.rand(2, 1, 4, 8) - k = Tensor.rand(2, i, 4, 8) - v = Tensor.rand(2, i, 4, 8) Tensor.realize(q, k, v) GlobalCounters.reset() - symbolic = f(q, k.reshape(2, vi, 4, 8), v.reshape(2, vi, 4, 8)).reshape(2, 4, 1, 8).numpy() - expected = f(q, k, v).numpy() + symbolic = f(q, k[:, :vi, :, :], v[:, :vi, :, :])[:2, :4, :1, :8].numpy() + expected = f(q, k[:, :i, :, :], v[:, :i, :, :]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_attention_cmp_symbolic(self): @@ -90,73 +90,80 @@ class TestSymbolicOps(unittest.TestCase): def test_cat_dim0(self): def f(a, b): return a.cat(b, dim=0).realize() + a = Tensor.rand(10, 3) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(i, 3) b = Tensor.rand(2, 3) - symbolic = f(a.reshape(vi, 3), b).reshape(i+2, 3).numpy() - expected = f(a, b).numpy() + symbolic = f(a[:vi, :], b)[:i+2, :3].numpy() + expected = f(a[:i, :], b).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_cat_dim1(self): def f(a, b): return a.cat(b, dim=1).realize() + a = Tensor.rand(3, 10) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(3, i) b = Tensor.rand(3, 2) - symbolic = f(a.reshape(3, vi), b).reshape(3, i+2).numpy() - expected = f(a, b).numpy() + symbolic = f(a[:, :vi], b)[:3, :i+2].numpy() + expected = f(a[:, :i], b).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_cat_dim0_two_vars(self): def f(a, b): return a.cat(b, dim=0).realize() - for i in range(1, 5): - for j in range(1, 5): + a = Tensor.rand(10, 3) + b = Tensor.rand(10, 3) + for i in range(2, 5): + for j in range(2, 5): vi = Variable("i", 1, 10).bind(i) vj = Variable("j", 1, 10).bind(j) - a = Tensor.rand(i, 3) - b = Tensor.rand(j, 3) - symbolic = f(a.reshape(vi, 3), b.reshape(vj, 3)).reshape(i+j, 3).numpy() - expected = f(a, b).numpy() + symbolic = f(a[:vi, :], b[:vj, :])[:i+j, :3].numpy() + expected = f(a[:i, :], b[:j, :]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_cat_dim1_two_vars(self): def f(a, b): return a.cat(b, dim=1).realize() - for i in range(1, 5): - for j in range(1, 5): + a = Tensor.rand(3, 10) + b = Tensor.rand(3, 10) + for i in range(2, 5): + for j in range(2, 5): vi = Variable("i", 1, 10).bind(i) vj = Variable("j", 1, 10).bind(j) - a = Tensor.rand(3, i) - b = Tensor.rand(3, j) - symbolic = f(a.reshape(3, vi), b.reshape(3, vj)).reshape(3, i+j).numpy() - expected = f(a, b).numpy() + symbolic = f(a[:, :vi], b[:, :vj])[:3, :i+j].numpy() + expected = f(a[:, :i], b[:, :j]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_two_vars_plus1_ij(self): def f(a, b): return (a@b+1).realize() - for i in range(1, 5): - for j in range(1, 5): + a = Tensor.rand(10, 3).realize() + b = Tensor.rand(3, 10).realize() + for i in range(2, 5): + for j in range(2, 5): vi = Variable("i", 1, 10).bind(i) vj = Variable("j", 1, 10).bind(j) - a = Tensor.rand(i, 3) - b = Tensor.rand(3, j) - symbolic = f(a.reshape(vi, 3), b.reshape(3, vj)).reshape(i, j).numpy() - expected = f(a, b).numpy() + symbolic = f(a[:vi, :], b[:, :vj])[:i, :j].numpy() + expected = f(a[:i, :], b[:, :j]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_two_vars_plus1_ji(self): # reverse the order of variables def f(a, b): return (a@b+1).realize() - for i in range(1, 5): - for j in range(1, 5): + a = Tensor.rand(10, 3).realize() + b = Tensor.rand(3, 10).realize() + for i in range(2, 5): + for j in range(2, 5): vi = Variable("i", 1, 10).bind(i) vj = Variable("j", 1, 10).bind(j) - a = Tensor.rand(j, 3) - b = Tensor.rand(3, i) - symbolic = f(a.reshape(vj, 3), b.reshape(3, vi)).reshape(j, i).numpy() - expected = f(a, b).numpy() + symbolic = f(a[:vj, :], b[:, :vi])[:j, :i].numpy() + expected = f(a[:j, :], b[:, :i]).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + def test_invalid_symbolic_reshape(self): + a = Tensor.rand(30) + for i in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + # Cannot reshape into symbolic from non-symbolic + with self.assertRaises(ValueError): a.reshape((3, vi)) + def test_shrink(self): for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) @@ -171,16 +178,16 @@ class TestSymbolicOps(unittest.TestCase): vi = Variable("i", 1, 10).bind(i) a = Tensor.rand(7, 11) symbolic = a[3:5, vi:vi+2] + print(symbolic.shape) symbolic = symbolic.numpy() expected = a[3:5, i:i+2].numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_slice_no_start(self): + a = Tensor.rand(7, 11) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(7, 11) - symbolic = a[3:5, :vi:1].reshape(2,i) - symbolic = symbolic.numpy() + symbolic = a[3:5, :vi:1][:2, :i].numpy() expected = a[3:5, :i:1].numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) @@ -188,7 +195,7 @@ class TestSymbolicOps(unittest.TestCase): for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) a = Tensor(1).unsqueeze(0).pad((0, 1)).unsqueeze(0) - symbolic = a.expand(vi, 2).reshape(i, 2).numpy() + symbolic = a.expand(vi, 2)[:i, :2].numpy() expected = a.expand(i, 2).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) @@ -196,36 +203,44 @@ class TestSymbolicOps(unittest.TestCase): for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) a = Tensor.ones(vi, 11).contiguous() - symbolic = a[:, 1:2].reshape(i, 1).numpy() - expected = a.reshape(i, 11)[:, 1:2].numpy() + symbolic = a[:, 1:2][:i, :1].numpy() + expected = Tensor.ones(i, 11)[:, 1:2].numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_ones_sum(self): + t = Tensor.ones(10) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - t = Tensor.ones(i) - symbolic = t.reshape(vi).sum().item() - expected = t.sum().item() + symbolic = t[:vi].sum().item() + expected = t[:i].sum().item() np.testing.assert_equal(symbolic, expected) def test_mean(self): + a = Tensor.rand(10, 3) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) for axis in [None, 0, 1]: - a = Tensor.rand(i, 3) - expected = a.mean(axis).numpy() - symbolic = a.reshape(vi, 3).mean(axis).reshape(expected.shape).numpy() + expected = a[:i].mean(axis).numpy() + symbolic = a[:vi].mean(axis) + if axis is None: + symbolic = symbolic.numpy() + else: + symbolic = symbolic[:expected.shape[0]].numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_mean_2d(self): - for i in range(1, 5): - for j in range(1, 5): + a = Tensor.rand(10, 10) + for i in range(2, 5): + for j in range(2, 5): vi = Variable("i", 1, 10).bind(i) vj = Variable("j", 1, 10).bind(j) for axis in [None, 0, 1]: - a = Tensor.rand(i, j) - expected = a.mean(axis).numpy() - symbolic = a.reshape(vi, vj).mean(axis).reshape(expected.shape).numpy() + expected = a[:i, :j].mean(axis).numpy() + symbolic = a[:vi, :vj].mean(axis) + if axis is None: + symbolic = symbolic.numpy() + else: + symbolic = symbolic[:expected.shape[0]].numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_var(self): @@ -233,43 +248,59 @@ class TestSymbolicOps(unittest.TestCase): for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) for axis in [None, 0, 1]: - expected = a[:i, :].var(axis).numpy() - symbolic = a[:vi, :].var(axis).reshape(expected.shape).numpy() + expected = a[:i].var(axis).numpy() + symbolic = a[:vi].var(axis) + if axis is None: + symbolic = symbolic.numpy() + else: + symbolic = symbolic[:expected.shape[0]].numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_var_2d(self): - for i in range(1, 5): - for j in range(1, 5): + a = Tensor.rand(10, 10) + for i in range(2, 5): + for j in range(2, 5): vi = Variable("i", 1, 10).bind(i) vj = Variable("j", 1, 10).bind(j) for axis in [None, 0, 1]: - a = Tensor.rand(i, j) - expected = a.var(axis).numpy() - symbolic = a.reshape(vi, vj).var(axis).reshape(expected.shape).numpy() + expected = a[:i, :j].var(axis).numpy() + symbolic_result = a[:vi, :vj].var(axis) + if axis is None: + symbolic = symbolic_result.numpy() + else: + symbolic = symbolic_result[:expected.shape[0]].numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_bitcast_down(self): + a = Tensor.rand(10, 3) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(i, 3) - expected = a.bitcast(dtypes.uint8).numpy() - symbolic = a.reshape(vi, 3).bitcast(dtypes.uint8).reshape(expected.shape).numpy() + expected = a[:i].bitcast(dtypes.uint8).numpy() + symbolic_result = a[:vi].bitcast(dtypes.uint8) + if len(expected.shape) == 2: + symbolic = symbolic_result[:expected.shape[0], :expected.shape[1]].numpy() + else: + symbolic = symbolic_result[:].numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=0) - @unittest.skipIf(Device.DEFAULT == "WEBGPU", "no uint64") + @unittest.skipUnless(is_dtype_supported(dtypes.uint64), "no uint64") def test_bitcast_up(self): + a = Tensor.rand(10, 4) for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(i, 4) - expected = a.bitcast(dtypes.uint64).numpy() - symbolic = a.reshape(vi, 4).bitcast(dtypes.uint64).reshape(expected.shape).numpy() + expected = a[:i].bitcast(dtypes.uint64).numpy() + symbolic_result = a[:vi].bitcast(dtypes.uint64) + if len(expected.shape) == 2: + symbolic = symbolic_result[:expected.shape[0], :expected.shape[1]].numpy() + else: + symbolic = symbolic_result[:].numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=0) @unittest.expectedFailure def test_conv2d_ceildiv_edge_case(self): v = Variable('v', 11, 50_000) val = 39601 - x = Tensor.randn(1, 22, 39601).reshape(1, 22, v.bind(val)) + x = Tensor.randn(1, 22, 50_000)[:, :, :v.bind(val)] weight = Tensor.randn(256, 22, 12) result = x.conv2d(weight=weight, groups=1, stride=6, dilation=1, padding=(3, 3)) diff --git a/tinygrad_repo/test/test_tensor.py b/tinygrad_repo/test/test_tensor.py index 21a86883..defeaac5 100644 --- a/tinygrad_repo/test/test_tensor.py +++ b/tinygrad_repo/test/test_tensor.py @@ -4,12 +4,12 @@ import torch import unittest, copy, mmap, random, math, array from tinygrad import Tensor, Device, dtypes from tinygrad.tensor import _METADATA -from tinygrad.helpers import getenv, temp, mv_address +from tinygrad.helpers import getenv, temp, mv_address, RANGEIFY from extra.gradcheck import numerical_jacobian, jacobian, gradcheck from hypothesis import given, settings, strategies as strat from tinygrad.device import is_dtype_supported from tinygrad.uop.ops import Ops, UOp -from tinygrad.runtime.support.compiler_cuda import PTX +from tinygrad.renderer.ptx import PTXRenderer from tinygrad.codegen import full_rewrite from tinygrad.dtype import DType @@ -415,6 +415,21 @@ class TestTinygrad(unittest.TestCase): data = _generate_data(depth) np.testing.assert_allclose(Tensor(data).numpy(), np.array(data)) + def test_tensor_list_implicit_cast(self): + data = [True, False] + np.testing.assert_equal(Tensor(data, dtype=dtypes.int).numpy(), torch.tensor(data, dtype=torch.int).numpy()) + np.testing.assert_equal(Tensor(data, dtype=dtypes.uint8).numpy(), torch.tensor(data, dtype=torch.uint8).numpy()) + np.testing.assert_equal(Tensor(data, dtype=dtypes.float).numpy(), torch.tensor(data, dtype=torch.float).numpy()) + data = [-1, 0, 1, 2, 3] + np.testing.assert_equal(Tensor(data, dtype=dtypes.int).numpy(), torch.tensor(data, dtype=torch.int).numpy()) + np.testing.assert_equal(Tensor(data, dtype=dtypes.uint8).numpy(), torch.tensor(data, dtype=torch.uint8).numpy()) + np.testing.assert_equal(Tensor(data, dtype=dtypes.float).numpy(), torch.tensor(data, dtype=torch.float).numpy()) + data = [-3.5, -2.5, -1.5, 0, 1.5, 2.5, 3.5] + np.testing.assert_equal(Tensor(data, dtype=dtypes.int).numpy(), torch.tensor(data, dtype=torch.int).numpy()) + # NOTE: torch and jax raise OverflowError: Python integer -3 out of bounds for uint8 + # np.testing.assert_equal(Tensor(data, dtype=dtypes.uint8).numpy(), torch.tensor(data, dtype=torch.uint8).numpy()) + np.testing.assert_equal(Tensor(data, dtype=dtypes.float).numpy(), torch.tensor(data, dtype=torch.float).numpy()) + def test_tensor_list_special_values(self): if is_dtype_supported(dtypes.float16): data = [math.nan, -math.inf, 65504, 65519, 65519.999, 65520, 65520.1] @@ -501,10 +516,6 @@ class TestTinygrad(unittest.TestCase): print(c) def test_env_overwrite_default_device(self): - subprocess.run(['DISK=1 python3 -c "from tinygrad import Device; assert Device.DEFAULT != \\"DISK\\""'], - shell=True, check=True) - subprocess.run(['NPY=1 python3 -c "from tinygrad import Device; assert Device.DEFAULT != \\"NPY\\""'], - shell=True, check=True) subprocess.run([f'{Device.DEFAULT}=1 python3 -c "from tinygrad import Device; assert Device.DEFAULT == \\"{Device.DEFAULT}\\""'], shell=True, check=True) subprocess.run([f'DISK=1 {Device.DEFAULT}=1 python3 -c "from tinygrad import Device; assert Device.DEFAULT == \\"{Device.DEFAULT}\\""'], @@ -539,6 +550,11 @@ class TestTinygrad(unittest.TestCase): def test_shrink(self): t = Tensor.arange(32).contiguous().realize() self.assertListEqual(t[16:20].tolist(), [16,17,18,19]) + self.assertListEqual(t.shrink_to(16).tolist(), list(range(16))) + t = t.reshape(4, 8).contiguous().realize() + self.assertListEqual(t.shrink_to(2, 2).tolist(), [[0, 1], [8, 9]]) + with self.assertRaises(ValueError): t.shrink_to(2) + with self.assertRaises(ValueError): t.shrink_to(2, 2, 2) @unittest.skip("this test is just flaky, sync issue") class TestMoveTensor(unittest.TestCase): @@ -633,17 +649,22 @@ class TestZeroShapeTensor(unittest.TestCase): def test_pad(self): t = Tensor.rand(3, 2, 0).pad((None, None, (1, 1)), value=1) - assert t.shape == (3, 2, 2) + self.assertEqual(t.shape, (3, 2, 2)) np.testing.assert_equal(t.numpy(), np.ones((3, 2, 2))) t = Tensor.rand(3, 2, 0).pad((None, (1, 1), None), value=1) - assert t.shape == (3, 4, 0) + self.assertEqual(t.shape, (3, 4, 0)) np.testing.assert_equal(t.numpy(), np.ones((3, 4, 0))) t = Tensor.rand(3, 2, 0).pad(((1, 1), None, None), value=1) - assert t.shape == (5, 2, 0) + self.assertEqual(t.shape, (5, 2, 0)) np.testing.assert_equal(t.numpy(), np.ones((5, 2, 0))) + np.testing.assert_equal(Tensor([1, 2]).pad_to(4).numpy(), [1, 2, 0, 0]) + np.testing.assert_equal(Tensor([[1, 2]]).pad_to(2, 3).numpy(), [[1, 2, 0], [0, 0, 0]]) + with self.assertRaises(TypeError): Tensor([1, 2]).pad_to(2, 3) + with self.assertRaises(TypeError): Tensor([[1, 2]]).pad_to(3) + def test_shrink_into_zero(self): t = Tensor.rand(3, 4).realize() assert t.shrink((None, (2, 2))).realize().shape == (3, 0) @@ -850,11 +871,18 @@ class TestTensorMetadata(unittest.TestCase): self.assertEqual(y.grad.uop.metadata[0].name, "sigmoid") self.assertTrue(y.grad.uop.metadata[0].backward) si = Tensor.schedule(out, x.grad, y.grad)[-1] - self.assertEqual(len(si.metadata), 4, f"failed with {si.metadata}") - self.assertSetEqual(set(m.name for m in si.metadata), {"sigmoid", "__mul__", "relu"}) - bw = [m for m in si.metadata if m.backward] - self.assertEqual(len(bw), 2) - self.assertEqual(bw[0].name, "sigmoid") + if not RANGEIFY: + self.assertEqual(len(si.metadata), 4, f"failed with {si.metadata}") + self.assertSetEqual(set(m.name for m in si.metadata), {"sigmoid", "__mul__", "relu"}) + bw = [m for m in si.metadata if m.backward] + self.assertEqual(len(bw), 2) + self.assertEqual(bw[0].name, "sigmoid") + else: + self.assertEqual(len(si.metadata), 3, f"failed with {si.metadata}") + self.assertSetEqual(set(m.name for m in si.metadata), {"sigmoid", "relu"}) + bw = [m for m in si.metadata if m.backward] + self.assertEqual(len(bw), 1) + self.assertEqual(bw[0].name, "sigmoid") class TestIdxUpcast(unittest.TestCase): def _find_op(self, ast: UOp, op: Ops): @@ -900,19 +928,24 @@ class TestIdxUpcast(unittest.TestCase): def test_regular_sym(self): self.do_op_then_assert(dtypes.int, 2048, 2048, UOp.variable("dim3", 1, 64).bind(32)) - @unittest.skipIf(PTX, "PTX always convert Ops.INDEX to int64") + @unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX always convert Ops.INDEX to int64") def test_symfold(self): # This would cause an overflow, but after sym fold it's within int32 a = Tensor.arange(65535) uops = self._schedule_render(a) assert all(uop.dtype is not dtypes.long for uop in uops) + def test_arange_raise_overflow(self): + with self.assertRaises(ValueError): + self._schedule_render(Tensor.arange(2**33, dtype=dtypes.int)) + @unittest.skipIf(is_dtype_supported(dtypes.long), "int64 is supported") def test_int64_unsupported_overflow_sym(self): with self.assertRaises(KeyError): self.do_op_then_assert(dtypes.long, 2048, 2048, UOp.variable("dim3", 1, 2048).bind(32)) @unittest.skipIf(is_dtype_supported(dtypes.long), "int64 is supported") + @unittest.expectedFailure # bug in gpu dims limiting def test_int64_unsupported_overflow(self): with self.assertRaises(KeyError): self.do_op_then_assert(dtypes.long, 2048, 2048, 2048) diff --git a/tinygrad_repo/test/test_tensor_variable.py b/tinygrad_repo/test/test_tensor_variable.py index a680cf9a..a046555d 100644 --- a/tinygrad_repo/test/test_tensor_variable.py +++ b/tinygrad_repo/test/test_tensor_variable.py @@ -1,7 +1,6 @@ import unittest import numpy as np from tinygrad import Tensor, Variable -from tinygrad.helpers import Context class TestTensorVariable(unittest.TestCase): def test_add_tvar(self): @@ -23,43 +22,38 @@ class TestTensorVariable(unittest.TestCase): assert (Tensor(3) * (vv * 4)).item() == 24 def test_symbolic_mean(self): - with Context(IGNORE_OOB=1): - vv = Variable("a", 1, 10).bind(2) - t = Tensor.ones(2, 2).contiguous().reshape(2, vv) - ret = t.mean().item() - assert ret == 1 + vv = Variable("a", 1, 10).bind(2) + t = Tensor.ones(2, 10).contiguous()[:, :vv] + ret = t.mean().item() + assert ret == 1 def test_symbolic_mean_2d(self): - with Context(IGNORE_OOB=1): - vv = Variable("a", 1, 10).bind(2) - vv2 = Variable("b", 1, 10).bind(2) - t = Tensor.ones(2, 2).contiguous().reshape(vv2, vv) - ret = t.mean().item() - assert ret == 1 + vv = Variable("a", 1, 10).bind(2) + vv2 = Variable("b", 1, 10).bind(2) + t = Tensor.ones(10, 10).contiguous()[:vv2, :vv] + ret = t.mean().item() + assert ret == 1 def test_symbolic_mean_2d_axis_1(self): - with Context(IGNORE_OOB=1): - vv = Variable("a", 1, 10).bind(2) - vv2 = Variable("b", 1, 10).bind(2) - t = Tensor.ones(2, 2).contiguous().reshape(vv2, vv) - ret = t.mean(axis=1).reshape(2, 1).numpy() - assert np.all(ret == 1) + vv = Variable("a", 1, 10).bind(2) + vv2 = Variable("b", 1, 10).bind(2) + t = Tensor.ones(10, 10).contiguous()[:vv2, :vv] + ret = t.mean(axis=1)[:2].reshape(2, 1).numpy() + assert np.all(ret == 1) def test_symbolic_mean_2d_add(self): - with Context(IGNORE_OOB=1): - add_term = Variable("c", 0, 10).bind(1) - vv = Variable("a", 1, 10).bind(1) - vv2 = Variable("b", 1, 10).bind(1) - t = Tensor.ones(2, 2).contiguous().reshape(vv2+add_term, vv+add_term) - ret = t.mean().item() - assert ret == 1 + add_term = Variable("c", 0, 10).bind(1) + vv = Variable("a", 1, 10).bind(1) + vv2 = Variable("b", 1, 10).bind(1) + t = Tensor.ones(20, 20).contiguous()[:vv2+add_term, :vv+add_term] + ret = t.mean().item() + assert ret == 1 def test_symbolic_var(self): - with Context(IGNORE_OOB=1): - vv = Variable("a", 1, 10).bind(2) - t = Tensor.ones(2, 2).contiguous().reshape(2, vv) - ret = t.var().item() - assert ret == 0 + vv = Variable("a", 1, 10).bind(2) + t = Tensor.ones(2, 10).contiguous()[:, :vv] + ret = t.var().item() + assert ret == 0 def test_symbolic_pad(self): vv = Variable("a", 1, 10).bind(2) @@ -72,25 +66,35 @@ class TestTensorVariable(unittest.TestCase): def test_symbolic_arange(self): vv = Variable("a", 1, 10) ret = Tensor.arange(0, vv.bind(4)) - self.assertListEqual(ret.reshape(4).tolist(), [0,1,2,3]) + self.assertListEqual(ret[:4].tolist(), [0,1,2,3]) def test_symbolic_arange_sym_start(self): vv = Variable("a", 1, 6) ret = Tensor.arange(vv.bind(4), 7) - self.assertListEqual(ret.reshape(3).tolist(), [4,5,6]) + self.assertListEqual(ret[:3].tolist(), [4,5,6]) # TODO: add vmin/vmax pattern for symbolic denominator @unittest.expectedFailure def test_symbolic_arange_sym_step(self): vv = Variable("step", 1, 3) ret = Tensor.arange(0, 10, vv.bind(2)) - self.assertListEqual(ret.reshape(5).tolist(), [0,2,4,6,8]) + self.assertListEqual(ret[:5].tolist(), [0,2,4,6,8]) def test_symbolic_arange_two_vars(self): begin = Variable("b", 1, 5) end = Variable("e", 6, 10) ret = Tensor.arange(begin.bind(4), end.bind(7)) - self.assertListEqual(ret.reshape(3).tolist(), [4,5,6]) + self.assertListEqual(ret[:3].tolist(), [4,5,6]) + + def test_variable_empty(self): + v = Variable("i", 1, 10) + # TODO: Tensor creation from unbound variable should assert + # with self.assertRaises(AssertionError): t = Tensor.empty(3, v) + vb = v.bind(3) + t = Tensor.empty(3, vb) + assert t.uop.base.buffer.size == 30 + assert t.uop.st.shape == (3, vb) + if __name__ == '__main__': unittest.main() diff --git a/tinygrad_repo/test/test_tiny.py b/tinygrad_repo/test/test_tiny.py index e86a7269..31bb84f5 100644 --- a/tinygrad_repo/test/test_tiny.py +++ b/tinygrad_repo/test/test_tiny.py @@ -1,12 +1,20 @@ # basic self-contained tests of the external functionality of tinygrad import unittest, random from tinygrad import Tensor, Context, Variable, TinyJit, dtypes, Device, nn -from tinygrad.helpers import IMAGE, CI +from tinygrad.helpers import IMAGE, CI, getenv class TestTiny(unittest.TestCase): # *** basic functionality *** + def test_const(self): + const = Tensor(2.0) + self.assertEqual(const.item(), 2.0) + + def test_copy(self): + out = Tensor([1.,2,3]) + self.assertListEqual(out.tolist(), [1.0, 2.0, 3.0]) + def test_plus(self): out = Tensor([1.,2,3]) + Tensor([4.,5,6]) self.assertListEqual(out.tolist(), [5.0, 7.0, 9.0]) @@ -27,10 +35,21 @@ class TestTiny(unittest.TestCase): out = Tensor.ones(256).contiguous().sum() self.assertEqual(out.item(), 256) - def test_gemm(self, N=64, out_dtype=dtypes.float): + def test_gemm(self, N=getenv("GEMM_N", 64), out_dtype=dtypes.float): a = Tensor.ones(N,N).contiguous() b = Tensor.eye(N).contiguous() - self.assertListEqual((out:=a@b).flatten().tolist(), [1.0]*(N*N)) + lst = (out:=a@b).tolist() + for y in range(N): + for x in range(N): + self.assertEqual(lst[y][x], 1.0, msg=f"mismatch at ({y},{x})") + if IMAGE < 2: self.assertEqual(out.dtype, out_dtype) + + def test_gemv(self, N=getenv("GEMV_N", 64), out_dtype=dtypes.float): + a = Tensor.ones(1,N).contiguous() + b = Tensor.eye(N).contiguous() + lst = (out:=a@b).tolist() + for x in range(N): + self.assertEqual(lst[0][x], 1.0, msg=f"mismatch at {x}") if IMAGE < 2: self.assertEqual(out.dtype, out_dtype) # *** randomness *** @@ -76,7 +95,7 @@ class TestTiny(unittest.TestCase): ones = Tensor.ones(10).contiguous() for s in [2,5]: ret = ones[:i.bind(s)] + 1 - self.assertListEqual(ret.contiguous().reshape(s).tolist(), [2.0]*s) + self.assertListEqual(ret.contiguous()[:s].tolist(), [2.0]*s) def test_symbolic_reduce(self): i = Variable('i', 1, 10) @@ -88,7 +107,7 @@ class TestTiny(unittest.TestCase): # *** a model *** # TODO: this is failing because of how swizzling rewrites the ShapeTracker of the final STORE - @unittest.skipIf(IMAGE>0 or (CI and Device.DEFAULT == "DSP"), "failing because of make things that can't be images not images") + @unittest.skipIf(CI and Device.DEFAULT == "DSP", "failing because of make things that can't be images not images") def test_mnist(self): layers = [ nn.Conv2d(1, 32, 5), Tensor.relu, @@ -107,7 +126,7 @@ class TestTiny(unittest.TestCase): self.assertEqual(len(probs[0]), 10) # TODO: this is failing because of how swizzling rewrites the ShapeTracker of the final STORE - @unittest.skipIf(IMAGE>0 or (CI and Device.DEFAULT == "DSP"), "failing because of make things that can't be images not images") + @unittest.skipIf(CI and Device.DEFAULT == "DSP", "failing because of make things that can't be images not images") def test_mnist_backward(self): # NOTE: we don't have the whole model here for speed layers = [ @@ -126,7 +145,7 @@ class TestTiny(unittest.TestCase): # *** image *** - @unittest.skipIf(Device.DEFAULT != "GPU", "image only supported on GPU") + @unittest.skipIf(Device.DEFAULT != "CL", "image only supported on CL") def test_image(self): with Context(IMAGE=2): self.test_gemm(N=4, out_dtype=dtypes.imagef((4, 1, 4))) diff --git a/tinygrad_repo/test/test_uop_graph.py b/tinygrad_repo/test/test_uop_graph.py index 945c2b6a..b1a95c03 100644 --- a/tinygrad_repo/test/test_uop_graph.py +++ b/tinygrad_repo/test/test_uop_graph.py @@ -3,7 +3,7 @@ import unittest, pytest from tinygrad import dtypes, Variable from tinygrad.dtype import AddrSpace from tinygrad.helpers import DEBUG, Context -from tinygrad.uop.ops import Ops, UOp, UPat, PatternMatcher, track_rewrites, graph_rewrite, GroupOp +from tinygrad.uop.ops import Ops, UOp, UPat, PatternMatcher, track_rewrites, graph_rewrite, GroupOp, KernelInfo from tinygrad.uop.symbolic import sym from tinygrad.codegen import full_rewrite, full_rewrite_to_sink from tinygrad.codegen.late.expander import expander @@ -17,7 +17,7 @@ simple_pm = PatternMatcher([ def to_uops_list(u:List[UOp]) -> List[UOp]: # we strip the SINK here for legacy reasons - ret = full_rewrite(UOp.sink(*u)) + ret = full_rewrite(UOp.sink(*u, arg=KernelInfo(opts_to_apply=()))) assert ret[-1].op is Ops.SINK return ret[:-1] @@ -212,7 +212,7 @@ class TestUOpGraph(unittest.TestCase): def test_where_same_fold(self): v = UOp.variable('tmp', 0, 1) - c0 = UOp(Ops.CONST, dtypes.int, arg=0) + c0 = UOp(Ops.CONST, dtypes.index, arg=0) vc = UOp(Ops.CMPNE, dtypes.bool, (v, c0)) c1 = UOp(Ops.CONST, dtypes.float, arg=1.0) out = UOp(Ops.WHERE, dtypes.float, (vc, c1, c1)) @@ -398,7 +398,7 @@ class TestUOpGraph(unittest.TestCase): self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 1) def test_depth_2_const_fold(self): - v = UOp.variable("tmp", 0, 1) + v = UOp.variable("tmp", 0, 1, dtypes.int) c2 = UOp(Ops.CONST, dtypes.int, arg=2) c4 = UOp(Ops.CONST, dtypes.int, arg=4) vc = UOp(Ops.ADD, dtypes.int, (v, c2)) @@ -417,6 +417,49 @@ class TestUOpGraph(unittest.TestCase): uops = to_uops_list([v.bitcast(dt)]) self.assertEqual(len([x for x in uops if x.op is Ops.BITCAST]), 0, f"dtype = {dt}") + def test_where_on_gated_load_fold(self): + ridx0 = UOp.range(100, 0) + d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), (), 0) + ld = d0.index(ridx0, ridx0<50).load() + w = (ridx0<50).where(ld, 5) + uops = to_uops_list([w]) + for u in uops: + assert u.op is not Ops.WHERE + if u.op is Ops.LOAD: assert u.src[1].arg==5 + + def test_where_on_gated_load_folds_swapped_branches(self): + ridx0 = UOp.range(100, 0) + d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), (), 0) + ld = d0.index(ridx0, (ridx0<50).logical_not()).load() + w = (ridx0<50).where(5, ld) + uops = to_uops_list([w]) + for u in uops: + assert u.op is not Ops.WHERE + if u.op is Ops.LOAD: assert u.src[1].arg==5 + + def test_where_in_store_becomes_gate(self): + ridx0 = UOp.range(100, 0) + d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), (), 0) + idx = d0.index(ridx0) + ld = idx.load() + val = (ridx0<50).where(5, ld) + st = idx.store(val, ridx0) + uops = to_uops_list([st]) + for u in uops: + assert u.op is not Ops.WHERE + if u.op is Ops.STORE: assert u.src[1].arg==5 + + def test_load_idx_becomes_int(self): + d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), (), 0) + d1 = UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), (), 1) + l0 = UOp(Ops.LOAD, dtypes.long, (d0.index(UOp.const(dtypes.int, 0)),)) + idx = l0 * 600 + valid = (l0<-1).ne(True)&(l0<3000) + l1 = UOp(Ops.LOAD, dtypes.long, (d1.index(idx, valid),)) + uops = to_uops_list([l1]) + for u in uops: + if u.op is Ops.INDEX: self.assertEqual(u.src[1].dtype, dtypes.int) + def test_in_out_of_bounds_access(self): with Context(IGNORE_OOB=0): glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) @@ -445,7 +488,7 @@ class TestUOpGraph(unittest.TestCase): with Context(IGNORE_OOB=0): glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), src=(), arg=0) v = Variable("v", 0, 20) - st0 = UOp(Ops.STORE, dtypes.void, src=(glbl0.index(v), UOp.const(dtypes.int, 0), UOp(Ops.IF, src=(v<16,)))) + st0 = UOp(Ops.STORE, dtypes.void, src=(glbl0.index(v, v<16), UOp.const(dtypes.int, 0))) to_uops_list([st0]) st1 = UOp(Ops.STORE, dtypes.void, (glbl0.index(v), v, v<20)) @@ -458,12 +501,12 @@ class TestUOpGraph(unittest.TestCase): sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.uint.ptr(8, addrspace=AddrSpace.LOCAL), (), "temp0") # Define indices, valids and barrier - gidx = UOp(Ops.SPECIAL, dtypes.int, (), ("gidx0", 416)) - lidx = UOp(Ops.SPECIAL, dtypes.int, (), ("lidx0", 10)) + gidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 416),), "gidx0") + lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 10),), "lidx0") gate = (gidx<400) & (lidx<8) - local_store = UOp(Ops.STORE, dtypes.void, (sbuf.index(lidx), UOp.const(dtypes.uint, 1), UOp(Ops.IF, src=(lidx<8,)))) + local_store = UOp(Ops.STORE, dtypes.void, (sbuf.index(lidx, lidx<8), UOp.const(dtypes.uint, 1))) barrier = UOp(Ops.BARRIER, dtypes.void, (local_store,)) if_barrier = UOp(Ops.IF, dtypes.void, (gate, barrier)) @@ -477,7 +520,7 @@ class TestUOpGraph(unittest.TestCase): def test_load_with_float_in_index(self): with Context(IGNORE_OOB=0): - ridx = UOp.range(dtypes.int, 20, 0) + ridx = UOp.range(20, 0) glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) i = (ridx.cast(dtypes.float)*0.68).trunc().cast(dtypes.int) ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, ((0<=i)&(i<16))),)) @@ -490,7 +533,7 @@ class TestUOpGraph(unittest.TestCase): def test_load_cast_to_bool(self): with Context(IGNORE_OOB=0): glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1), (), 0) - ridx = UOp.range(dtypes.int, 20, 0) + ridx = UOp.range(20, 0) ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(ridx, ridx.cast(dtypes.bool).logical_not()),)) to_uops_list([ld0]) @@ -499,7 +542,7 @@ class TestUOpGraph(unittest.TestCase): with Context(IGNORE_OOB=0): glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) mask = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(16), (), 0) - ridx = UOp.range(dtypes.int, 20, 0) + ridx = UOp.range(20, 0) ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(ridx, ridx<16&mask),))) to_uops_list([ld0]) @@ -512,7 +555,7 @@ class TestUOpGraph(unittest.TestCase): def test_in_out_bounds_access_with_mask(self): with Context(IGNORE_OOB=0): glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) - gidx0 = UOp(Ops.SPECIAL, dtype=dtypes.int, arg=("gidx0", 42)) + gidx0 = UOp(Ops.SPECIAL, dtypes.index, (UOp.const(dtypes.index, 42),), "gidx0") ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, (5=0)&(ld0<32)),)) + gidx0 = UOp(Ops.SPECIAL, dtypes.index, (UOp.const(dtypes.index, 42),), "gidx0") + ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, gidx0<8),)).cast(dtypes.index) + ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(ld0*2, (ld0>=0)&(ld0<32)),)).cast(dtypes.index) to_uops_list([ld1]) ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(ld0*2, (ld0>=0)&(ld0<64)),)) @@ -549,7 +592,7 @@ class TestUOpGraph(unittest.TestCase): glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1) glbl2 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 2) idx = UOp.const(dtypes.int, 0) - ld0 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(idx, UOp.const(dtypes.bool, False)),)) + ld0 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(UOp.invalid()),)) ld1 = UOp(Ops.LOAD, dtypes.int, (glbl2.index(idx, UOp.const(dtypes.bool, True)),)) uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(idx), ld1+ld0))]) ld0 = uops[-1].src[-1] @@ -559,10 +602,10 @@ class TestUOpGraph(unittest.TestCase): def test_fold_gated_load_local(self): glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) smem = UOp(Ops.DEFINE_LOCAL, dtypes.int.ptr(size=18, addrspace=AddrSpace.LOCAL), (), "temp") - lidx = UOp(Ops.SPECIAL, dtypes.int, (), ("lidx0", 16)) + lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 16),), "lidx0") st = UOp(Ops.STORE, dtypes.void, (smem.index(lidx), UOp.load(glbl0.index(lidx), dtype=dtypes.int))) barrier = UOp(Ops.BARRIER, dtypes.void, (st, )) - ld0 = UOp(Ops.LOAD, dtypes.int, (smem.index(lidx+1, UOp.const(dtypes.bool, False)), barrier)) + ld0 = UOp(Ops.LOAD, dtypes.int, (smem.index(UOp.invalid()), barrier)) ld1 = UOp(Ops.LOAD, dtypes.int, (smem.index(lidx+2, UOp.const(dtypes.bool, True)), barrier)) uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(lidx), ld1+ld0))]) @@ -575,7 +618,7 @@ class TestUOpGraph(unittest.TestCase): idx0 = UOp.const(dtypes.int, 0) idx1 = UOp.const(dtypes.int, 0) val = UOp.const(dtypes.int, 42) - st0 = glbl.index(idx0, UOp.const(dtypes.bool, False)).store(val) + st0 = glbl.index(UOp.invalid()).store(val) st1 = glbl.index(idx0, UOp.const(dtypes.bool, True)).store(val) uops = to_uops_list([st0, st1]) # only the second store happens @@ -592,8 +635,8 @@ class TestUOpGraph(unittest.TestCase): def test_switched_range_order(self): glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) cf = UOp.const(dtypes.float, 0.0) - r1 = UOp.range(dtypes.int, 2, 0) - r2 = UOp.range(dtypes.int, 2, 1) + r1 = UOp.range(2, 0) + r2 = UOp.range(2, 1) alu = UOp(Ops.MUL, dtypes.int, (r2, r1)) store = UOp(Ops.STORE, dtypes.void, (glbl.index(alu), cf)) uops = to_uops_list([store]) @@ -756,8 +799,8 @@ class TestIFUOps(unittest.TestCase): def test_create_ifs(self): gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(size=4, addrspace=AddrSpace.LOCAL), (), "smem") - valid = UOp(Ops.SPECIAL, dtypes.int, (), ("gidx0", 10))<5 - lidx = UOp(Ops.SPECIAL, dtypes.int, (), ("lidx0", 4)) + valid = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 10),), "gidx0")<5 + lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), "lidx0") gate = valid&(lidx.ne(2)) idx = UOp.const(dtypes.int, 0) st = UOp(Ops.STORE, dtypes.void, (sbuf.index(idx), UOp.const(dtypes.float, 42))) @@ -775,8 +818,8 @@ class TestIFUOps(unittest.TestCase): def test_expand_ifs_one_gate(self): gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(size=16, addrspace=AddrSpace.LOCAL), (), "smem") - valid = UOp(Ops.SPECIAL, dtypes.int, (), ("gidx0", 4))<1 - lidx = UOp(Ops.SPECIAL, dtypes.int, (), ("lidx0", 16)) + valid = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), "gidx0")<1 + lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 16),), "lidx0") gate = valid&(lidx.ne(2)) st = UOp(Ops.STORE, dtypes.void, (sbuf, lidx, UOp.const(dtypes.float, 42))) barrier = UOp(Ops.BARRIER, dtypes.void, (st,)) @@ -794,8 +837,8 @@ class TestIFUOps(unittest.TestCase): @unittest.expectedFailure def test_expand_ifs_dumb(self): buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) - valid = UOp(Ops.SPECIAL, dtypes.int, (), ("gidx0", 10))<5 - lidx = UOp(Ops.SPECIAL, dtypes.int, (), ("lidx0", 4)) + valid = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 10),), "gidx0")<5 + lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), "lidx0") gate = valid&(lidx.ne(2)) stores = [UOp(Ops.STORE, dtypes.void, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)] sink = UOp(Ops.SINK, dtypes.void, tuple(stores)) diff --git a/tinygrad_repo/test/test_uops.py b/tinygrad_repo/test/test_uops.py index dfd6a7f9..4a7ed308 100644 --- a/tinygrad_repo/test/test_uops.py +++ b/tinygrad_repo/test/test_uops.py @@ -14,7 +14,8 @@ from tinygrad.engine.realize import CompiledRunner, get_program from tinygrad.codegen import full_rewrite from tinygrad.uop.symbolic import sym from tinygrad.device import is_dtype_supported -from tinygrad.codegen.opt.kernel import Opt, OptOps +from tinygrad.codegen.opt import Opt, OptOps +from tinygrad.renderer.ptx import PTXRenderer def to_uops_list(u:list[UOp], opts=None, skip_check=False) -> list[UOp]: return full_rewrite(UOp.sink(*u), opts) @@ -22,7 +23,7 @@ def _uops_to_prg(uops_list): uops = full_rewrite(ast:=UOp.sink(*uops_list), opts=Device[Device.DEFAULT].renderer) src = Device[Device.DEFAULT].renderer.render(uops) has_local = Device[Device.DEFAULT].renderer.has_local - return CompiledRunner(ProgramSpec("test", src, Device.DEFAULT, ast, uops=uops, + return CompiledRunner(ProgramSpec(uops[-1].arg.name if uops[-1].arg is not None else "test", src, Device.DEFAULT, ast, uops=uops, global_size=[1,1,1] if has_local else None, local_size=[1,1,1] if has_local else None)) def uop(uops:list[UOp], uop:Ops, dtype:Optional[DType], src:tuple[UOp, ...], arg:Any=None) -> UOp: @@ -130,9 +131,9 @@ class TestFloatUOps(TestUOps): class TestNonFloatUOps(TestUOps): def test_add_int32(self): self._test_bop_fxn(Ops.ADD, lambda a,b: int(a)+int(b), (dtypes.int32, dtypes.int32)) def test_mul_int32(self): self._test_bop_fxn(Ops.MUL, lambda a,b: int(a)*int(b), (dtypes.int32, dtypes.int32)) - @unittest.skipUnless(getenv("PTX"), "only ptx uses bitshifts") + @unittest.skipUnless(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "only ptx uses bitshifts") def test_shr_int32(self): self._test_bop_fxn(Ops.SHR, lambda a,b: int(a)>>int(b), (dtypes.int32, dtypes.int32), no_b_neg=True) - @unittest.skipUnless(getenv("PTX"), "only ptx uses bitshifts") + @unittest.skipUnless(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "only ptx uses bitshifts") def test_shl_int32(self): self._test_bop_fxn(Ops.SHL, lambda a,b: int(a)< bytes: return src.encode() @@ -56,7 +101,7 @@ class TestCompiler(unittest.TestCase): class TestRunAsModule(unittest.TestCase): def test_module_runs(self): p = subprocess.run([sys.executable, "-m", "tinygrad.device"],stdout=subprocess.PIPE, stderr=subprocess.PIPE, - env={**os.environ, "DEBUG": "1"}, timeout=10,) + env={**os.environ, "DEBUG": "1"}, timeout=30,) out = (p.stdout + p.stderr).decode() self.assertEqual(p.returncode, 0, msg=out) self.assertIn("CPU", out) # for sanity check diff --git a/tinygrad_repo/test/unit/test_disk_tensor.py b/tinygrad_repo/test/unit/test_disk_tensor.py index 94fa484d..57584df2 100644 --- a/tinygrad_repo/test/unit/test_disk_tensor.py +++ b/tinygrad_repo/test/unit/test_disk_tensor.py @@ -307,7 +307,7 @@ class TestDiskTensor(unittest.TestCase): ret = t.bitcast(dtypes.uint16).to("CPU") + 1 assert ret.tolist() == [2827, 3341, 3855, 4369] - @unittest.skipIf(OSX, "new LLVM has an issue on OSX") + @unittest.skipIf(OSX or Device.DEFAULT == "CL", "new LLVM has an issue on OSX, CL=1 gives the wrong output") def test_bf16_disk_write_read(self): t = Tensor([10000, -1, -1000, -10000, 20], dtype=dtypes.float32) t.to(f"disk:{temp('dt_bf16_disk_write_read_f32')}").realize() @@ -318,7 +318,7 @@ class TestDiskTensor(unittest.TestCase): with open(temp('dt_bf16_disk_write_read_bf16'), "wb") as f: f.write(adat) t = Tensor.empty(5, dtype=dtypes.bfloat16, device=f"disk:{temp('dt_bf16_disk_write_read_bf16')}") - ct = t.llvm_bf16_cast(dtypes.float) + ct = t.to(Device.DEFAULT).cast(dtypes.float) assert ct.numpy().tolist() == [9984., -1, -1000, -9984, 20] def test_copy_from_disk(self): diff --git a/tinygrad_repo/test/unit/test_dtype.py b/tinygrad_repo/test/unit/test_dtype.py index d8fc7626..78be06aa 100644 --- a/tinygrad_repo/test/unit/test_dtype.py +++ b/tinygrad_repo/test/unit/test_dtype.py @@ -21,6 +21,10 @@ class TestEqStrDType(unittest.TestCase): def test_ptr_eq(self): assert dtypes.float32.ptr() == dtypes.float32.ptr() assert not (dtypes.float32.ptr() != dtypes.float32.ptr()) + def test_ptr_nbytes(self): + assert dtypes.float16.ptr(32).nbytes() == 32 * dtypes.float16.itemsize + def test_ptr_nbytes_unlimited(self): + self.assertRaises(RuntimeError, lambda: dtypes.float32.ptr().nbytes()) def test_strs(self): if PtrDType is None: raise unittest.SkipTest("no PtrDType support") self.assertEqual(str(dtypes.imagef((1,2,4))), "dtypes.imagef((1, 2, 4))") @@ -56,6 +60,7 @@ class TestCastConvenienceMethod(unittest.TestCase): class TestDtypeTolist(unittest.TestCase): def test_bfloat16(self): self.assertEqual(Tensor([-60000, 1.5, 3.1, 60000], device="PYTHON", dtype=dtypes.bfloat16).tolist(), [-59904.0, 1.5, 3.09375, 59904.0]) + def test_fp8(self): # 448 self.assertEqual(Tensor([-30000, 1.5, 3.1, 30000], device="PYTHON", dtype=dtypes.fp8e4m3).tolist(), [-448.0, 1.5, 3.0, 448.0]) # 57344 diff --git a/tinygrad_repo/test/unit/test_dtype_spec.py b/tinygrad_repo/test/unit/test_dtype_spec.py index 0e78af19..3c1ffc0e 100644 --- a/tinygrad_repo/test/unit/test_dtype_spec.py +++ b/tinygrad_repo/test/unit/test_dtype_spec.py @@ -1,12 +1,11 @@ -import unittest, math, operator, subprocess +import unittest, math, operator, subprocess, struct from tinygrad.tensor import Tensor, dtypes, Device -from tinygrad.dtype import DType, DTYPES_DICT, truncate, truncate_fp16, truncate_bf16, _to_np_dtype, least_upper_dtype, least_upper_float +from tinygrad.dtype import DType, DTYPES_DICT, truncate, float_to_fp16, float_to_bf16, _to_np_dtype, least_upper_dtype, least_upper_float from tinygrad.device import is_dtype_supported from tinygrad.helpers import getenv, CI, DEBUG from hypothesis import given, settings, strategies as strat import numpy as np import torch -import ml_dtypes settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False)) settings.load_profile("my_profile") @@ -22,10 +21,15 @@ def _assert_eq(tensor:Tensor, target_dtype:DType, target, tol_target_dtype:float if DEBUG >= 2: print(tensor.numpy()) try: assert tensor.dtype == target_dtype - np.testing.assert_allclose(tensor.numpy(), target, rtol={dtypes.float16:1e-3, dtypes.bfloat16:1e-2}.get(target_dtype, tol_target_dtype)) + np.testing.assert_allclose(tensor.numpy(), target, rtol={dtypes.float16:1e-3, dtypes.bfloat16:1e-2, + dtypes.fp8e4m3:1e-1, dtypes.fp8e5m2:5e-1}.get(target_dtype, tol_target_dtype)) + except AssertionError as e: raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e +def u32_to_f32(u): return struct.unpack('f', struct.pack('I', u))[0] +def f32_to_u32(f): return struct.unpack('I', struct.pack('f', f))[0] + class TestHelpers(unittest.TestCase): signed_ints = (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64) uints = (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64) @@ -97,35 +101,105 @@ class TestHelpers(unittest.TestCase): np.testing.assert_equal(dt.min, False) np.testing.assert_equal(dt.max, True) - def test_truncate_fp16(self): - self.assertEqual(truncate_fp16(1), 1) - self.assertEqual(truncate_fp16(65504), 65504) - self.assertEqual(truncate_fp16(65519.999), 65504) - self.assertEqual(truncate_fp16(65520), math.inf) + def test_dtype_range_vec(self): + for dt in core_dtypes: + self.assertEqual(dt.min, dt.vec(4).min) + self.assertEqual(dt.max, dt.vec(4).max) - def test_truncate_bf16(self): - self.assertEqual(truncate_bf16(1), 1) - self.assertAlmostEqual(truncate_bf16(1.1), 1.09375, places=7) - for a in [1234, 23456, -777.777]: - self.assertEqual(truncate_bf16(a), torch.tensor([a], dtype=torch.bfloat16).item()) - # TODO: torch bfloat 1.1 gives 1.1015625 instead of 1.09375 + def test_float_to_fp16(self): + self.assertEqual(float_to_fp16(1), 1) + self.assertEqual(float_to_fp16(65504), 65504) + self.assertEqual(float_to_fp16(65519.999), 65504) + self.assertEqual(float_to_fp16(65520), math.inf) + self.assertEqual(float_to_fp16(1e-8), 0.0) + self.assertEqual(float_to_fp16(-65504), -65504) + self.assertEqual(float_to_fp16(-65519.999), -65504) + self.assertEqual(float_to_fp16(-65520), -math.inf) + self.assertTrue(math.isnan(float_to_fp16(math.nan))) + + def test_float_to_bf16(self): + # TODO: fuzz this better max_bf16 = torch.finfo(torch.bfloat16).max - self.assertEqual(truncate_bf16(max_bf16), max_bf16) - self.assertEqual(truncate_bf16(min_bf16:=-max_bf16), min_bf16) - self.assertEqual(truncate_bf16(max_bf16 * 1.00001), math.inf) - self.assertEqual(truncate_bf16(min_bf16 * 1.00001), -math.inf) + for a in [1, 1.1, 1234, 23456, -777.777, max_bf16, max_bf16 * 1.00001, -max_bf16, -max_bf16 * 1.00001, math.inf, -math.inf]: + self.assertEqual(float_to_bf16(a), torch.tensor([a], dtype=torch.bfloat16).item()) + self.assertTrue(math.isnan(float_to_bf16(math.nan))) + + def test_float_to_bf16_nan(self): + # In f32, NaN = exp 0xFF and mantissa ≠ 0. Quiet-vs-signaling is bit 22 of the mantissa: 1 = qNaN, 0 = sNaN. + # qNaN(+/-), sNaN(+/-) overflow(+/-) + patterns = [0x7FC00001, 0xFFC00001, 0x7F800001, 0xFF800001, 0x7FFFFFFF, 0xFFFFFFFF] + for u in patterns: + x = u32_to_f32(u) + y = float_to_bf16(x) + t = torch.tensor([x], dtype=torch.bfloat16).item() + self.assertTrue(math.isnan(y)) + self.assertTrue(math.isnan(t)) + + def test_float_to_bf16_round(self): + # round_to_nearest_even + uppers = [0x3f800000, 0x41230000, 0xC1460000] # 1.0, 10.1875, -12.375 + for upper in uppers: + base = upper & 0xFFFF0000 + base_f32 = u32_to_f32(base) + base_f32_round_up = u32_to_f32(base + 0x00010000) + + # low < 0x8000(0.5ULP) -> round down + x = u32_to_f32(base | 0x00007000) + self.assertEqual(float_to_bf16(x), base_f32) + self.assertEqual(torch.tensor([x], dtype=torch.bfloat16).item(), base_f32) + + # low > 0x8000(0.5ULP) -> round up + x = u32_to_f32(base | 0x0000C000) + self.assertEqual(float_to_bf16(x), base_f32_round_up) + self.assertEqual(torch.tensor([x], dtype=torch.bfloat16).item(), base_f32_round_up) + + # low == 0x8000(0.5ULP) and LSB even -> round down + if ((upper >> 16) & 1) == 0: + x = u32_to_f32(base | 0x00008000) + self.assertEqual(float_to_bf16(x), base_f32) + self.assertEqual(torch.tensor([x], dtype=torch.bfloat16).item(), base_f32) + # low == 0x8000(0.5ULP) and LSB odd -> round up + else: + x = u32_to_f32(base | 0x00008000) + self.assertEqual(float_to_bf16(x), base_f32_round_up) + self.assertEqual(torch.tensor([x], dtype=torch.bfloat16).item(), base_f32_round_up) + + def test_float_to_bf16_boundary(self): + # bf16 max finite: exp=0xFE, faction=0x7F => 0x7F7F0000(f32) + # bf16 inf(+/-): exp=0xFF + base = 0x7F7F0000 + inf_u32 = 0x7F800000 + + # low < 0.5ULP + x = u32_to_f32(base | 0x00007FFF) + self.assertEqual(f32_to_u32(float_to_bf16(x)), base) + self.assertEqual(f32_to_u32(torch.tensor([x], dtype=torch.bfloat16).item()), base) + + # low > 0.5ULP -> overflows to +inf + x = u32_to_f32(base | 0x0000C000) + self.assertEqual(f32_to_u32(float_to_bf16(x)), inf_u32) + self.assertEqual(f32_to_u32(torch.tensor([x], dtype=torch.bfloat16).item()), inf_u32) + + # low == 0.5ULP and LSB odd -> overflows to +inf + x = u32_to_f32(base | 0x00008000) + self.assertEqual(f32_to_u32(float_to_bf16(x)), inf_u32) + self.assertEqual(f32_to_u32(torch.tensor([x], dtype=torch.bfloat16).item()), inf_u32) @given(strat.floats(width=32, allow_subnormal=True, allow_nan=True, allow_infinity=True)) def test_truncate_fp8e4m3(self, x): - if x > FP8E4M3_MAX: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), FP8E4M3_MAX) + if math.isnan(x): np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), x) + elif math.isinf(x): np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), math.copysign(math.nan, x)) + elif x > FP8E4M3_MAX: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), FP8E4M3_MAX) elif x < -FP8E4M3_MAX: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), -FP8E4M3_MAX) - else: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), ml_dtypes.float8_e4m3fn(x)) + else: np.testing.assert_equal(truncate[dtypes.fp8e4m3](x), torch.tensor(x, dtype=torch.float8_e4m3fn).float().item()) @given(strat.floats(width=32, allow_subnormal=True, allow_nan=True, allow_infinity=True)) def test_truncate_fp8e5m2(self, x): - if x > FP8E5M2_MAX: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), FP8E5M2_MAX) + if math.isnan(x): np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), x) + elif math.isinf(x): np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), x) + elif x > FP8E5M2_MAX: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), FP8E5M2_MAX) elif x < -FP8E5M2_MAX: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), -FP8E5M2_MAX) - else: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), ml_dtypes.float8_e5m2(x)) + else: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), torch.tensor(x, dtype=torch.float8_e5m2).float().item()) class TestTypeSpec(unittest.TestCase): def setUp(self): @@ -305,7 +379,7 @@ class TestTypePromotion(unittest.TestCase): assert least_upper_dtype(dtypes.int32, dtypes.uint32) == dtypes.int64 assert least_upper_dtype(dtypes.uint32, dtypes.int64) == dtypes.int64 # similar to jax but we don't use weak type - assert least_upper_dtype(dtypes.int64, dtypes.uint64) == dtypes.float16 + assert least_upper_dtype(dtypes.int64, dtypes.uint64) == dtypes.fp8e4m3 assert least_upper_dtype(dtypes.float16, dtypes.float32) == dtypes.float32 assert least_upper_dtype(dtypes.float32, dtypes.float64) == dtypes.float64 @@ -314,6 +388,14 @@ class TestTypePromotion(unittest.TestCase): assert least_upper_dtype(dtypes.float16, dtypes.int64) == dtypes.float16 assert least_upper_dtype(dtypes.float16, dtypes.uint64) == dtypes.float16 assert least_upper_dtype(dtypes.fp8e4m3, dtypes.fp8e5m2) == dtypes.half + assert least_upper_dtype(dtypes.fp8e4m3, dtypes.bfloat16) == dtypes.bfloat16 + assert least_upper_dtype(dtypes.fp8e5m2, dtypes.bfloat16) == dtypes.bfloat16 + assert least_upper_dtype(dtypes.fp8e4m3, dtypes.float16) == dtypes.float16 + assert least_upper_dtype(dtypes.fp8e5m2, dtypes.float16) == dtypes.float16 + assert least_upper_dtype(dtypes.fp8e4m3, dtypes.int64) == dtypes.fp8e4m3 + assert least_upper_dtype(dtypes.fp8e4m3, dtypes.uint64) == dtypes.fp8e4m3 + assert least_upper_dtype(dtypes.fp8e5m2, dtypes.int64) == dtypes.fp8e5m2 + assert least_upper_dtype(dtypes.fp8e5m2, dtypes.uint64) == dtypes.fp8e5m2 class TestAutoCastType(unittest.TestCase): def setUp(self): @@ -370,8 +452,10 @@ class TestAutoCastType(unittest.TestCase): assert (Tensor([0, 1], dtype=dtypes.uint16)).sum().dtype == dtypes.uint32 assert (Tensor([0, 1], dtype=dtypes.uint32)).sum().dtype == dtypes.uint32 assert (Tensor([0, 1], dtype=dtypes.uint64)).sum().dtype == dtypes.uint64 + assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).sum().dtype == dtypes.fp8e4m3 + assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).sum().dtype == dtypes.fp8e5m2 assert (Tensor([0, 1], dtype=dtypes.float16)).sum().dtype == dtypes.float16 - #assert (Tensor([0, 1], dtype=dtypes.bfloat16)).sum().dtype == dtypes.bfloat16 + assert (Tensor([0, 1], dtype=dtypes.bfloat16)).sum().dtype == dtypes.bfloat16 assert (Tensor([0, 1], dtype=dtypes.float32)).sum().dtype == dtypes.float32 assert (Tensor([0, 1], dtype=dtypes.float64)).sum().dtype == dtypes.float64 @@ -402,8 +486,10 @@ class TestAutoCastType(unittest.TestCase): assert (Tensor([0, 1], dtype=dtypes.uint16)).mean().dtype == dtypes.float32 assert (Tensor([0, 1], dtype=dtypes.uint32)).mean().dtype == dtypes.float32 assert (Tensor([0, 1], dtype=dtypes.uint64)).mean().dtype == dtypes.float32 + assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).mean().dtype == dtypes.fp8e4m3 + assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).mean().dtype == dtypes.fp8e5m2 assert (Tensor([0, 1], dtype=dtypes.float16)).mean().dtype == dtypes.float16 - #assert (Tensor([0, 1], dtype=dtypes.bfloat16)).mean().dtype == dtypes.bfloat16 + assert (Tensor([0, 1], dtype=dtypes.bfloat16)).mean().dtype == dtypes.bfloat16 assert (Tensor([0, 1], dtype=dtypes.float32)).mean().dtype == dtypes.float32 assert (Tensor([0, 1], dtype=dtypes.float64)).mean().dtype == dtypes.float64 @@ -417,8 +503,10 @@ class TestAutoCastType(unittest.TestCase): assert (Tensor([0, 1], dtype=dtypes.uint16)).cumsum(0).dtype == dtypes.uint32 assert (Tensor([0, 1], dtype=dtypes.uint32)).cumsum(0).dtype == dtypes.uint32 assert (Tensor([0, 1], dtype=dtypes.uint64)).cumsum(0).dtype == dtypes.uint64 + assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).cumsum(0).dtype == dtypes.fp8e4m3 + assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).cumsum(0).dtype == dtypes.fp8e5m2 assert (Tensor([0, 1], dtype=dtypes.float16)).cumsum(0).dtype == dtypes.float16 - #assert (Tensor([0, 1], dtype=dtypes.bfloat16)).cumsum(0).dtype == dtypes.bfloat16 + assert (Tensor([0, 1], dtype=dtypes.bfloat16)).cumsum(0).dtype == dtypes.bfloat16 assert (Tensor([0, 1], dtype=dtypes.float32)).cumsum(0).dtype == dtypes.float32 assert (Tensor([0, 1], dtype=dtypes.float64)).cumsum(0).dtype == dtypes.float64 @@ -490,10 +578,10 @@ class TestAutoCastType(unittest.TestCase): def test_gradient_dtype(self): old_default_float = dtypes.default_float - for default_dtype in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]: + for default_dtype in dtypes.floats: if not is_dtype_supported(default_dtype): continue dtypes.default_float = default_dtype - for dtype in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]: + for dtype in dtypes.floats: if not is_dtype_supported(dtype): continue if DEBUG >= 2: print(f"testing {default_dtype=}, {dtype=}") @@ -549,4 +637,4 @@ class TestAutoCastType(unittest.TestCase): np.testing.assert_allclose(out.numpy(), tt.log_softmax(0).numpy(), rtol=1e-3) out = t.log_softmax(0, dtype=dtypes.float) self.assertEqual(out.dtype, dtypes.float) - np.testing.assert_allclose(out.numpy(), tt.log_softmax(0, dtype=torch.float).numpy(), rtol=1e-3) \ No newline at end of file + np.testing.assert_allclose(out.numpy(), tt.log_softmax(0, dtype=torch.float).numpy(), rtol=1e-3) diff --git a/tinygrad_repo/test/unit/test_gguf.py b/tinygrad_repo/test/unit/test_gguf.py index b36ad14f..1c9cd5bc 100644 --- a/tinygrad_repo/test/unit/test_gguf.py +++ b/tinygrad_repo/test/unit/test_gguf.py @@ -53,11 +53,38 @@ class TestGGUF(unittest.TestCase): def test_load_tinyllama_q4_0(self): self._test_gguf_load("https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q4_0.gguf?download=true") def test_load_gpt2_q4_1(self): self._test_gguf_load("https://huggingface.co/PrunaAI/gpt2-GGUF-smashed/resolve/main/gpt2.Q4_1.gguf?download=true") def test_load_sample_q6_k(self): self._test_gguf_load("https://huggingface.co/Isotr0py/test-gguf-sample/resolve/main/Quant_Q6_K_1024.gguf?download=true") + def test_load_sample_mxfp4(self): self._test_gguf_load("https://huggingface.co/ngxson/boring-testing-tiny/resolve/main/stories260K-mxfp4.gguf?download=true") def test_dequantization_q4_0(self): self._test_dequantization(ggml.GGML_TYPE_Q4_0) def test_dequantization_q4_1(self): self._test_dequantization(ggml.GGML_TYPE_Q4_1) def test_dequantization_q8_0(self): self._test_dequantization(ggml.GGML_TYPE_Q8_0) def test_dequantization_q6_k(self): self._test_dequantization(ggml.GGML_TYPE_Q6_K) + def test_dequantization_mxfp4(self): + MXFP4 = 39 + + def encode(nibbles, E): + packed = [(low & 0xF) | ((high & 0xF) << 4) for low, high in zip(nibbles[:16], nibbles[16:])] + return np.array([E] + packed, dtype=np.uint8) + + def decode(code, E): + sign = -1.0 if (code & 0b1000) else 1.0 + exp = (code >> 1) & 0b11 + mant = code & 0b1 + val = (1.0 + 0.5 * mant) * np.exp2(exp - 1) if exp else 0.5 * mant + scale = np.exp2(E - 128) if E >= 2 else np.exp2(-127 if E == 1 else -128) + return sign * val * scale + + blocks, expected = [], [] + rng = np.random.default_rng(42) + for _ in range(4): + E = rng.integers(0, 256) + codes = rng.integers(0, 16, size=32, dtype=np.uint8) + blocks.append(encode(codes, E)) + expected.extend(decode(c, E) for c in codes) + tensor = Tensor(np.concatenate(blocks)) + out = ggml_data_to_tensor(tensor, len(expected), MXFP4) + # TODO: should this be exact equal? somehow failed on CI + np.testing.assert_allclose(out.numpy(), expected, atol=0.0, rtol=1e-6) def test_expected_failure_unknown_type(self): with self.assertRaises(ValueError): diff --git a/tinygrad_repo/test/unit/test_graph_rewrite.py b/tinygrad_repo/test/unit/test_graph_rewrite.py index 30d74542..8655019a 100644 --- a/tinygrad_repo/test/unit/test_graph_rewrite.py +++ b/tinygrad_repo/test/unit/test_graph_rewrite.py @@ -65,21 +65,21 @@ class TestFoldingAndReduction(unittest.TestCase): def test_full_graph_rewrite_reduction_with_unused_range(self): const1 = UOp.const(dtypes.int32, 15) const2 = UOp.const(dtypes.int32, 25) - rng = UOp.range(dtypes.int32, 10, idx=0) + rng = UOp.range(10, idx=0) optimized_sink = apply_rewrite((const1 + const2).reduce(Ops.ADD, rng)) expected_sum = 10 * (15 + 25) self.assertEqual(optimized_sink.arg, expected_sum) @unittest.skip("currently failing") def test_full_graph_rewrite_range_reduction(self): - simple_range = UOp.range(dtypes.int32, 5, idx=0) + simple_range = UOp.range(5, idx=0) optimized_sink = apply_rewrite(simple_range.reduce(Ops.ADD, simple_range)) expected_sum = sum(range(5)) self.assertEqual(optimized_sink.arg, expected_sum) @unittest.skip("currently failing") def test_full_graph_rewrite_simple_reduction_folding(self): - simple_range = UOp.range(dtypes.int32, 4, idx=0) + simple_range = UOp.range(4, idx=0) add_uop = simple_range + UOp.const(dtypes.int32, 1) optimized_sink = apply_rewrite(add_uop.reduce(Ops.ADD, simple_range)) expected_sum = sum(i + 1 for i in range(4)) @@ -87,8 +87,8 @@ class TestFoldingAndReduction(unittest.TestCase): @unittest.skip("currently failing") def test_full_graph_rewrite_nested_loop_collapse(self): - outer_range = UOp.range(dtypes.int32, 8, 0) - inner_range = UOp.range(dtypes.int32, 4, 1) + outer_range = UOp.range(8, 0) + inner_range = UOp.range(4, 1) expr = (outer_range * 10) + inner_range optimized_reduce_uop = apply_rewrite(expr.reduce(Ops.ADD, outer_range, inner_range)) self.assertEqual(optimized_reduce_uop.op, Ops.CONST) @@ -97,26 +97,29 @@ class TestFoldingAndReduction(unittest.TestCase): class TestModuloAndDivisionFolding(unittest.TestCase): def test_full_graph_rewrite_modulo_folding_with_define_var(self): - x_var_uop = UOp.variable('x', 0, 100) + # index dtype because div-mod rules only work on index + x_var_uop = UOp.variable('x', 0, 100).cast(dtypes.index) optimized_mod_uop = apply_rewrite(((x_var_uop * 4) + 2) % 4) self.assertEqual(optimized_mod_uop.op, Ops.CONST) self.assertEqual(optimized_mod_uop.arg, 2) def test_full_graph_rewrite_division_folding_with_define_var(self): - n_var_uop = UOp.variable('n', 1, 1000) + # index dtype because div-mod rules only work on index + n_var_uop = UOp.variable('n', 1, 1000).cast(dtypes.index) optimized_div_uop = apply_rewrite((n_var_uop * 6) // 3) self.assertEqual(optimized_div_uop.op, Ops.MUL) self.assertEqual(optimized_div_uop.src[1].arg, 2) def test_full_graph_rewrite_complex_mod_div_folding(self): - k_var_uop = UOp.variable('k', 0, 50) + # index dtype because div-mod rules only work on index + k_var_uop = UOp.variable('k', 0, 50).cast(dtypes.index) optimized_div_uop = apply_rewrite(((k_var_uop * 12 + 8) % 6) // 2) self.assertEqual(optimized_div_uop.op, Ops.CONST) self.assertEqual(optimized_div_uop.arg, 1) def test_graph_rewrite_div_folding_bug(self): lhs = UOp(Ops.ADD, dtypes.int.vec(4), src=( - UOp(Ops.VECTORIZE, dtypes.int.vec(4), arg=None, src=(UOp(Ops.SPECIAL, dtypes.int, arg=('lidx0', 32), src=()),)*4), + UOp(Ops.VECTORIZE, dtypes.int.vec(4), arg=None, src=(UOp(Ops.SPECIAL, dtypes.int, arg='lidx0', src=(UOp.const(dtypes.int, 32),)),)*4), UOp(Ops.VCONST, dtypes.int.vec(4), arg=(0, 256, 512, 768), src=()))) rhs = UOp.const(dtypes.int.vec(4), 2) unopt = lhs 5 using symbolic. """ # In some versions of TinyGrad, you might do: (-(-five_node)) - five_node = UOp.const(dtypes.int, 5) + five_node = UOp.const(dtypes.index, 5) # If your code allows UOp(...), do that; else you might do something like: # double_neg_five = -(-five_node) # But let's be explicit: @@ -85,8 +85,8 @@ class TestRewriteMap(unittest.TestCase): """ Combine both rewrites: add(0, neg(neg(5))) => add(0, 5) => 5 """ - zero_node = UOp.const(dtypes.int, 0) - five_node = UOp.const(dtypes.int, 5) + zero_node = UOp.const(dtypes.index, 0) + five_node = UOp.const(dtypes.index, 5) neg_five = -five_node double_neg_five = -neg_five root_add = zero_node + double_neg_five @@ -103,7 +103,7 @@ class TestRewriteMap(unittest.TestCase): def test_multi_var_rewrites(self): x_var = UOp.variable('x', 0, 10) y_var = UOp.variable('y', -5, 5) - zero_node = UOp.const(dtypes.int, 0) + zero_node = UOp.const(dtypes.index, 0) sum_with_zero = y_var + zero_node # (y + 0) combined = x_var + sum_with_zero # x + (y + 0) @@ -155,8 +155,8 @@ class TestRewriteMap(unittest.TestCase): x_var = UOp.variable('x', 1, 10) y_var = UOp.variable('y', -5, 5) z_var = UOp.variable('z', 0, 5) - zero_node = UOp.const(dtypes.int, 0) - one_node = UOp.const(dtypes.int, 1) + zero_node = UOp.const(dtypes.index, 0) + one_node = UOp.const(dtypes.index, 1) # Build sub-expressions yz_sum = y_var + z_var # (y + z) diff --git a/tinygrad_repo/test/unit/test_search.py b/tinygrad_repo/test/unit/test_search.py deleted file mode 100644 index 33b07e3a..00000000 --- a/tinygrad_repo/test/unit/test_search.py +++ /dev/null @@ -1,56 +0,0 @@ -import unittest -from tinygrad import Tensor, Device -from tinygrad.codegen.opt.kernel import Kernel -from tinygrad.device import Buffer -from tinygrad.codegen.opt.search import get_test_global_size, bufs_from_lin -from tinygrad.helpers import GlobalCounters -from extra.optimization.helpers import time_linearizer -from test.test_linearizer import push_views - -class TestSearchUtil(unittest.TestCase): - def test_get_test_global_size(self): - self.assertEqual(get_test_global_size([256, 256, 256], 65536, {}), ([256, 16, 16], 256.0)) - self.assertEqual(get_test_global_size([65536, 1, 1], 256, {}), ([256, 1, 1], 256.0)) - self.assertEqual(get_test_global_size([77, 1, 1], 16, {}), ([9, 1, 1], 77/9)) - - def test_bufs_from_lin(self): - a = Tensor([1,2,3,4]).realize() - si = (a+1).schedule()[0] - rawbufs = bufs_from_lin(Kernel(si.ast)) - assert len(rawbufs) == 2 - assert all(r is not None for r in rawbufs) - assert all(isinstance(r, Buffer) for r in rawbufs) - assert all(r.size > 0 for r in rawbufs) - - def test_bufs_from_lin_alt(self): - a = Tensor.randn(4, 4).realize() - b = a+a[0] - si = b.schedule()[0] - rawbufs = bufs_from_lin(Kernel(push_views(si.ast))) - assert len(rawbufs) == 2 - assert all(r is not None for r in rawbufs) - assert all(isinstance(r, Buffer) for r in rawbufs) - assert all(r.size > 0 for r in rawbufs) - -class TestTimeLinearizer(unittest.TestCase): - @unittest.skipIf(Device.DEFAULT == "WEBGPU", "WebGPU timestamps are low precision, tm is 0") - def test_reasonable_time(self): - a = Tensor([1,2,3,4]).realize() - si = (a+1).schedule()[0] - # create fresh empty buffers - rawbufs = [Buffer(b.device, b.size, b.dtype).allocate() for b in si.bufs] - tm = time_linearizer(Kernel(push_views(si.ast)), rawbufs, allow_test_size=False, cnt=10, disable_cache=True) - assert tm > 0 and tm != float('inf') - - # Ensure that the kernel count is not incremented by time_linearizer when clearing l2 - def test_kernel_count(self): - ast = Tensor.zeros(16).contiguous().kernelize().uop.src[1].arg.ast - lin = Kernel(push_views(ast)) - bufs = bufs_from_lin(lin) - - kernel_count = GlobalCounters.kernel_count - time_linearizer(lin, bufs, allow_test_size=False, cnt=2, disable_cache=True, clear_l2=True) - assert GlobalCounters.kernel_count == kernel_count, "kernel count was incremented by time_linearizer" - -if __name__ == "__main__": - unittest.main() \ No newline at end of file diff --git a/tinygrad_repo/test/unit/test_shapetracker.py b/tinygrad_repo/test/unit/test_shapetracker.py index bf2bf3d3..849aa19d 100644 --- a/tinygrad_repo/test/unit/test_shapetracker.py +++ b/tinygrad_repo/test/unit/test_shapetracker.py @@ -1,7 +1,7 @@ #!/usr/bin/env python import unittest import numpy as np -from tinygrad.dtype import dtypes +from tinygrad.dtype import dtypes, Invalid from tinygrad.helpers import prod from tinygrad.shape.shapetracker import ShapeTracker, View from tinygrad import Variable @@ -10,7 +10,8 @@ from tinygrad.codegen.late.devectorizer import sym from itertools import product def shapetracker_getitem(st:ShapeTracker, val:int): - idx, valid = st.reshape((st.size,)).to_indexed_uops([UOp.const(dtypes.int, val)]) + valid_idx = st.reshape((st.size,)).to_valid_uop([UOp.const(dtypes.int, val)]) + idx, valid = valid_idx.get_idx(), valid_idx.get_valid() idx, valid = graph_rewrite(idx, sym), graph_rewrite(valid, sym) assert idx.op is Ops.CONST and valid.op is Ops.CONST return idx.arg, valid.arg @@ -68,7 +69,7 @@ class CheckingShapeTracker: def contiguous(self): return self.st.contiguous def assert_same(self): - x = [(v[0] if (v:=shapetracker_getitem(self.st, i))[1] else -1) for i in range(prod(self.st.shape))] + x = [(v[0] if (v:=shapetracker_getitem(self.st, i))[1] and v[0] is not Invalid else -1) for i in range(prod(self.st.shape))] y = [self[i] for i in range(prod(self.shape))] assert self.st.shape == self.shape assert x == y, f"mismatch shapetracker:{x} real:{y}" @@ -154,7 +155,7 @@ class TestRealStrides(unittest.TestCase): View.create((1, 3, 22, 21), (0, 192, 16, 1), 0, ((0, 1), (0, 3), (0, 12), (0, 16))), View.create((3, 11, 7, 2, 3), (462, 21, 1, 231, 7), 0, None), )) - self.assertEqual(st.real_strides(), (132, None, None, None, None)) + self.assertEqual(st.real_strides(), (132, 12, None, None, None)) class TestRealSimplifies(unittest.TestCase): def tearDown(self): @@ -619,20 +620,6 @@ class TestMaskedShapeTracker(unittest.TestCase): st3.reshape((4, 3, 6, 5)) st3.assert_same() - def test_axis_is_masked(self): - st = ShapeTracker.from_shape((100, 100, 100, 100)).pad(((0,1),(0,0),(2,0), (0,0))) - assert st.axis_is_masked(0) - assert not st.axis_is_masked(1) - assert st.axis_is_masked(2) - assert not st.axis_is_masked(3) - - def test_axis_is_masked_rw1(self): - st = ShapeTracker(views=(View(shape=(1, 2, 1, 4, 4, 13, 4, 13), strides=(0, 324, 0, 81, 0, 9, 0, 1), offset=-20, - mask=((0, 1), (0, 2), (0, 1), (0, 4), (0, 4), (2, 11), (0, 4), (2, 11)), contiguous=False), - View(shape=(2, 4, 11, 11, 4, 3, 3), strides=(10816, 0, 52, 1, 2704, 728, 14), offset=0, - mask=None, contiguous=False))) - assert not st.axis_is_masked(0) - class TestShapeTracker(unittest.TestCase): def setUp(self): self.st = CheckingShapeTracker((7,4)) @@ -830,34 +817,33 @@ class TestShapeTrackerSize(unittest.TestCase): class TestRender(unittest.TestCase): def test_render(self): st = ShapeTracker.from_shape((2, 3)) - idx, valid = st.to_indexed_uops() + valid_idx = st.to_valid_uop() + idx, valid = valid_idx.get_idx(), valid_idx.get_valid() self.assertEqual(idx.render(), "((ridx0*3)+ridx1)") self.assertEqual(valid.render(), "True") st = st.pad(((0, 1), (0, 0))) - idx, valid = st.to_indexed_uops() + valid_idx = st.to_valid_uop() + idx, valid = valid_idx.get_idx(), valid_idx.get_valid() self.assertEqual(idx.render(), "((ridx0*3)+ridx1)") self.assertEqual(valid.render(), "(ridx0<2)") -class TestVariableReshape(unittest.TestCase): - def test_reshape(self): - st = ShapeTracker.from_shape((3,)) - st = st.reshape((Variable("i", 1, 10),)) +class TestVariableShrink(unittest.TestCase): + def test_shrink(self): + st = ShapeTracker.from_shape((10,)) + st = st.shrink(((0, Variable("i", 1, 10)),)) assert len(st.views) == 1 - def test_reshape_stride_0(self): - st = ShapeTracker.from_shape((3,), (0,)) - st = st.reshape((Variable("i", 1, 10).bind(3),)) - assert len(st.views) == 1, f"multiview {st}" - - def test_reshape_bound(self): - st = ShapeTracker.from_shape((3,)) - st = st.reshape((Variable("i", 1, 10).bind(3),)) + def test_shrink_bound(self): + st = ShapeTracker.from_shape((10,)) + st = st.shrink(((0, Variable("i", 1, 10).bind(3)),)) assert len(st.views) == 1 - def test_add(self): - st1 = ShapeTracker.from_shape((3,)) - st2 = ShapeTracker.from_shape((Variable("i", 1, 10),)) +class TestVariableMerge(unittest.TestCase): + def test_add_reshape(self): + vi = Variable("i", 1, 10) + st1 = ShapeTracker.from_shape((vi,)) + st2 = ShapeTracker.from_shape((1, vi,)) st = st1+st2 assert len(st.views) == 1 @@ -867,15 +853,17 @@ class TestVariableReshape(unittest.TestCase): st = st1+st2 assert len(st.views) == 1, f"multiview {st}" - def test_add_bound(self): - st1 = ShapeTracker.from_shape((3,)) - st2 = ShapeTracker.from_shape((Variable("i", 1, 10).bind(3),)) + def test_add_reshape_bound(self): + vi = Variable("i", 1, 10).bind(3) + st1 = ShapeTracker.from_shape((vi,)) + st2 = ShapeTracker.from_shape((1, vi,)) st = st1+st2 assert len(st.views) == 1 def test_simplify(self): - st1 = ShapeTracker.from_shape((3,)) - st2 = ShapeTracker.from_shape((Variable("i", 1, 10).bind(3),)) + vi = Variable("i", 1, 10).bind(3) + st1 = ShapeTracker.from_shape((vi,)) + st2 = ShapeTracker.from_shape((1, vi,)) st = ShapeTracker((st1.views[0], st2.views[0])) st = st.simplify() assert len(st.views) == 1 diff --git a/tinygrad_repo/test/unit/test_shapetracker_math.py b/tinygrad_repo/test/unit/test_shapetracker_math.py index efd017f5..3a74ae30 100644 --- a/tinygrad_repo/test/unit/test_shapetracker_math.py +++ b/tinygrad_repo/test/unit/test_shapetracker_math.py @@ -87,20 +87,6 @@ class TestShapeTrackerAdd(unittest.TestCase): assert not (st_equal(st1, st2)) class TestShapeTrackerAddVariable(unittest.TestCase): - def test_self_add(self): - j = Variable("j", 0, 20).bind(10) - a = ShapeTracker.from_shape((10,10)) - x = a.reshape((10, j)) - out = x + x - assert out == x - - def test_self_add_reshape(self): - j = Variable("j", 0, 20).bind(10) - a = ShapeTracker.from_shape((10,10)) - x = a.reshape((10, j)) - out = x.reshape((5, 2, j)) + x - assert out == x - def test_merge_symbolic_views(self): var_i = Variable('i', 1, 10) var_j = Variable('i', 1, 10) diff --git a/tinygrad_repo/test/unit/test_simplify_valid_idx.py b/tinygrad_repo/test/unit/test_simplify_valid_idx.py index 47b8ce4f..b9690dae 100644 --- a/tinygrad_repo/test/unit/test_simplify_valid_idx.py +++ b/tinygrad_repo/test/unit/test_simplify_valid_idx.py @@ -4,6 +4,7 @@ from tinygrad.codegen import full_rewrite_to_sink from tinygrad.dtype import dtypes from tinygrad.uop.ops import UOp, Ops from tinygrad.uop.symbolic import simplify_valid +from tinygrad.helpers import Context def get_gated_load_uop(valid:UOp, idx:UOp): return UOp(Ops.LOAD, dtypes.float, ( @@ -17,9 +18,9 @@ def get_load_image_uop(image_shape:tuple[int, ...], valid:UOp, idx:tuple[UOp, UO UOp(Ops.VECTORIZE, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0.0),) * 4) )) -def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.int, (), (expr, nmax)) +def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.index, (UOp.const(dtypes.index, nmax),), expr) def Variable(expr, nmin, nmax): return UOp.variable(expr, nmin, nmax) -def Range(n, nmax): return UOp.range(dtypes.int, nmax, n) +def Range(n, nmax): return UOp.range(nmax, n) class TestHelpers(unittest.TestCase): def test_is_increasing(self): @@ -45,7 +46,8 @@ class TestHelpers(unittest.TestCase): class TestValidIdxSimplification(unittest.TestCase): def check(self, load, sidx, svalid): - load = full_rewrite_to_sink(load.sink()).src[0] + with Context(NOOPT=1): + load = full_rewrite_to_sink(load.sink()).src[0] idx, valid = load.src[0].src[1], load.src[0].src[2] self.assertEqual(idx.render(simplify=False), sidx) self.assertEqual(valid.render(simplify=False), svalid) @@ -195,9 +197,21 @@ class TestValidIdxSimplification(unittest.TestCase): "1", "((((ridx0+ridx1)<1)!=True)&(((ridx2+ridx3)<1)!=True))") + def test_valid_with_non_const_rhs(self): + ridx0 = Range(0, 2**16) + ridx1 = Range(1, 4) + ridx2 = Range(2, 4) + valid = (ridx0<(ridx1*4 + ridx2))&(ridx0<-1).ne(True) + idx = ridx0%1024 + load = get_gated_load_uop(valid, idx) + self.check(load, + "ridx0", + "(ridx0<((ridx1*4)+ridx2))") + class TestImageSimplification(unittest.TestCase): def check(self, load, svalid, sidx0, sidx1): - load = full_rewrite_to_sink(load.sink()).src[0] + with Context(NOOPT=1): + load = full_rewrite_to_sink(load.sink()).src[0] idx = load.src[0].src[1] self.assertEqual(idx.op, Ops.VECTORIZE) self.assertEqual(len(idx.src), 2) @@ -255,6 +269,7 @@ class TestImageSimplification(unittest.TestCase): load = get_load_image_uop(shape, (gidx1<5), (gidx0, gidx1+5)) self.check(load, None, "gidx0", "(gidx1+5)") + @unittest.skip("this should be constructed with an invalid gate") def test_valid_empty_set(self): gidx0 = Special("gidx0", 32) gidx1 = Special("gidx1", 32) @@ -345,7 +360,7 @@ class TestImageSimplification(unittest.TestCase): self.check(load, None, "((gidx*3)+-1438)", "0") def test_simplify2(self): - # from GPU=1 DEBUG=4 FORWARD_ONLY=1 IMAGE=2 python3 test/test_ops.py TestOps.test_simple_padding_conv2d + # from CL=1 DEBUG=4 FORWARD_ONLY=1 IMAGE=2 python3 test/test_ops.py TestOps.test_simple_padding_conv2d lidx = Special("lidx", 4) valid = (lidx<3) & (lidx<1).ne(True) idx = ((lidx+1)%2, (lidx+1)//2-1) diff --git a/tinygrad_repo/test/unit/test_symbolic_shapetracker.py b/tinygrad_repo/test/unit/test_symbolic_shapetracker.py index ed065d05..c89419a2 100644 --- a/tinygrad_repo/test/unit/test_symbolic_shapetracker.py +++ b/tinygrad_repo/test/unit/test_symbolic_shapetracker.py @@ -13,7 +13,6 @@ class TestSymbolic(unittest.TestCase): assert st.shape == (x, 3) assert st.real_strides() == (3, 1) - @unittest.expectedFailure def test_real_strides_0(self): st = ShapeTracker(views=(View(shape=(2, (Variable('start_pos', 1, 8)+1), 1, 1), strides=(8, 1, 0, 0), offset=0, mask=((0, 2), (0, Variable('start_pos', 1, 8)), (0, 1), (0, 1)), contiguous=False), View(shape=(2, (Variable('start_pos', 1, 8)+1)), strides=((Variable('start_pos', 1, 8)+1), 1), offset=0, mask=None, contiguous=True))) # noqa: E501 self.assertEqual(st.real_strides(), (8, None)) @@ -48,11 +47,11 @@ class TestSymbolic(unittest.TestCase): i = Variable("i", 1, 5).bind(3) j = Variable("j", 1, 5).bind(3) k = Variable("k", 1, 5).bind(3) - t = Tensor.rand(3, 4).reshape(i, 4).cat(Tensor.rand(3, 4).reshape(j, 4), dim=0).cat(Tensor.rand(3, 4).reshape(k, 4), dim=0) + t = Tensor.rand(5, 4)[:i].cat(Tensor.rand(5, 4)[:j], dim=0).cat(Tensor.rand(5, 4)[:k], dim=0) st = t.uop.st self.assert_tuple_equal(st.shape, (i+j+k, 4)) assert st.real_strides() == (4, 1) - t = Tensor.rand(3, 3).reshape(i, 3).cat(Tensor.rand(3, 3).reshape(i, 3), dim=0).cat(Tensor.rand(3, 3), dim=0) + t = Tensor.rand(5, 3)[:i].cat(Tensor.rand(5, 3)[:i], dim=0).cat(Tensor.rand(3, 3), dim=0) st = t.uop.st self.assert_tuple_equal(st.shape, (2*i+3, 3)) assert st.real_strides() == (3, 1) @@ -61,7 +60,7 @@ class TestSymbolic(unittest.TestCase): i = Variable("i", 1, 5).bind(4) j = Variable("j", 1, 5).bind(4) k = Variable("k", 1, 5).bind(4) - t = Tensor.rand(3, 4).reshape(3, i).cat(Tensor.rand(3, 4).reshape(3, j), dim=1).cat(Tensor.rand(3, 4).reshape(3, k), dim=1) + t = Tensor.rand(3, 5)[:, :i].cat(Tensor.rand(3, 5)[:, :j], dim=1).cat(Tensor.rand(3, 5)[:, :k], dim=1) st = t.uop.st self.assert_tuple_equal(st.shape, (3, i+j+k)) self.assert_tuple_equal(st.real_strides(), (i+j+k, 1)) @@ -73,19 +72,19 @@ class TestSymbolicVarVals(unittest.TestCase): def test_var_vals_shape(self): x = Variable("x", 1, 100).bind(3) - assert ShapeTracker.from_shape((x, 3)).var_vals == {Variable("x", 1, 100): 3} + assert ShapeTracker.from_shape((x, 3)).var_vals == {"x": 3} def test_var_vals_offset(self): x = Variable("x", 1, 100).bind(3) st = ShapeTracker.from_shape((4, 3)).shrink(((x, x+1), (0, 3))) self.assert_equal(st.views[-1].offset, x * 3) - assert st.var_vals == {Variable("x", 1, 100): 3} + assert st.var_vals == {"x": 3} def test_var_vals_mask(self): x = Variable("x", 1, 100).bind(3) view = View.create(shape=(3,4), strides=(4,1), offset=0, mask=((0, x), (0, 4))) st = ShapeTracker(views=(view,)) - assert st.var_vals == {Variable("x", 1, 100): 3} + assert st.var_vals == {"x": 3} def test_var_vals_complex(self): x = Variable("x", 1, 100).bind(3) @@ -93,13 +92,13 @@ class TestSymbolicVarVals(unittest.TestCase): z = Variable("z", 1, 100).bind(5) st = ShapeTracker.from_shape((x, 5, y)).shrink(((0, x), (z, z+1), (0, 3))) self.assert_equal(st.views[-1].offset, y * z) - assert st.var_vals == {Variable("x", 1, 100): 3, Variable("y", 1, 100):4, Variable("z", 1, 100): 5} + assert st.var_vals == {"x": 3, "y": 4, "z": 5} def test_shrink_reshape(self): x = Variable("x", 1, 100).bind(3) st = ShapeTracker.from_shape((10, 10, 10)).shrink(((x, x+3), (3, 7), (2, 5))) st = st.reshape((3*4*3,)) - assert st.var_vals == {Variable("x", 1, 100): 3} + assert st.var_vals == {"x": 3} class TestShapeTrackerUnbind(unittest.TestCase): def test_view_unbind(self): @@ -109,60 +108,44 @@ class TestShapeTrackerUnbind(unittest.TestCase): assert unbound_view == View.create(shape=(v, 4)) assert var_val == {v: 3} - def test_reshape_unbind(self): - v = Variable("v", 1, 100) - bv = Variable("v", 1, 100).bind(3) - t = Tensor.rand(3, 4).reshape(bv, 4) - unbound_st, var_val = t.uop.st.unbind() - assert unbound_st == ShapeTracker((View.create(shape=(v, 4)),)) - assert var_val == {v: 3} - def test_shrink_unbind(self): v = Variable("v", 1, 100) bv = Variable("v", 1, 100).bind(2) + t = Tensor.rand(3, 4).shrink(((0,bv),(0,4))) + unbound_st, var_val = t.uop.st.unbind() + assert unbound_st == ShapeTracker((View.create(shape=(v, 4)),)) + assert var_val == {v: 2} t = Tensor.rand(3, 4).shrink(((bv, bv+1), (0, 4))) unbound_st, var_val = t.uop.st.unbind() assert unbound_st == ShapeTracker((View.create(shape=(1, 4), offset=4*v),)) assert var_val == {v: 2} -class TestSymbolicReshapeFromContiguous(unittest.TestCase): - def test_reshape_into_symbols_simple(self): +class TestSymbolicReshape(unittest.TestCase): + def test_reshape(self): + a = Tensor.rand(5, 4) + b = Tensor.rand(5, 6) for i in range(1, 6): vi = Variable("i", 1, 5).bind(i) - t = Tensor.rand(i, 4).reshape(vi, 4) - assert t.shape == (vi, 4) - t = Tensor.rand(i, 6).reshape(vi, 2, 3) - assert t.shape == (vi, 2, 3) - - def test_reshape_symbols_reshape_ints(self): - for i in range(1, 6): - vi = Variable("i", 1, 5).bind(i) - t = Tensor.rand(i, 4).reshape(vi, 4) - assert t.shape == (vi, 4) - t = t.reshape(i, 4) - assert t.shape == (i, 4) - - @unittest.skip("works now") - def test_reshape_into_symbols_bad_shape(self): - vi = Variable("i", 1, 10).bind(4) - # TODO: this never actually worked, it relied on lazy - #with self.assertRaises(ValueError): - # Tensor.rand(4, 6).reshape(vi, 6).reshape(1, 77) # reshape to a different size new shape through symbolic shape - with self.assertRaises(AssertionError): - Tensor.rand(3, 4).reshape(3, (vi+1)) # reshape into non-Variable Node + ret = a[:vi] + ret = ret.reshape((vi, 4)) + assert ret.shape == (vi, 4) + ret = b[:vi] + ret = ret.reshape((vi, 2, 3)) + assert ret.shape == (vi, 2, 3) def test_two_symbol_reshape(self): + t = Tensor.rand(5, 5) for i in range(1, 6): for j in range(1, 6): vi = Variable("i", 1, 5).bind(i) vj = Variable("j", 1, 5).bind(j) - t = Tensor.rand(i, j).reshape(vi, vj) - assert t.shape == (vi, vj) - # NOTE: this is currently not allowed - # t = t.reshape(1, vi*vj) - # assert t.shape == (1, vi*vj) - t = t.reshape(vj, vi) - assert t.shape == (vj, vi) + ret = t[:vi, :vj] + ret = ret.reshape(vj, vi) + assert ret.shape == (vj, vi) + ret = ret.reshape(vi, vj) + assert ret.shape == (vi, vj) + ret = ret.reshape(1, vi*vj) + assert ret.shape == (1, vi*vj) def test_symbolic_mask(self): # taken from gpt2 single kvcache @@ -175,41 +158,6 @@ class TestSymbolicReshapeFromContiguous(unittest.TestCase): new_shape = (2, (Variable('start_pos', 1, 128)+1), 16, 64) assert view.reshape(new_shape) is None -class TestSymbolicReshapeFromNonContiguous(unittest.TestCase): - def test_reshape_from_const(self): - vi = Variable("i", 1, 5).bind(4) - t = Tensor.ones(3, 4).reshape(3, vi) - assert t.shape == (3, vi) - assert not t.uop.st.contiguous - assert len(t.uop.st.views) == 1 - - def test_reshape_not_allowed(self): - vi = Variable("i", 1, 5).bind(4) - with self.assertRaises(ValueError): - # different shape length # TODO: cases where contractions matched might be fine - Tensor.ones(3, 4, 1).reshape(3, vi) - with self.assertRaises(ValueError): - # size matched, but dimensions do not match - Tensor.ones(4, 3).reshape(3, vi) - - def test_reshape_from_padded(self): - vi = Variable("i", 1, 5).bind(4) - t = Tensor.ones(3, 4).contiguous().expand(2, 3, 4).pad(((1, 1), None, None)).shrink((None, None, (1, 3))) - st = t.uop.st - assert len(st.views) == 1 - view = st.views[0] - assert view.shape == (4, 3, 2) - t = t.reshape(vi, 3, 2) - st2 = t.uop.st - assert len(st2.views) == 1 - view2 = st2.views[0] - # check only shape changed. strides, offset, mask, contiguous remained the same - assert view2.shape == (vi, 3, 2) - assert view.strides == view2.strides == (0, 4, 1) - assert view.offset == view2.offset == 1 - assert view.mask == view2.mask == ((1, 3), (0, 3), (0, 2)) - assert not view.contiguous and not view2.contiguous - class TestSymbolicExpand(unittest.TestCase): def test_expand_into_symbols(self): vi = Variable("i", 1, 5).bind(3) @@ -220,11 +168,12 @@ class TestSymbolicExpand(unittest.TestCase): assert a.shape == (3, vi, vj) def test_plus_expands_constant(self): + a = Tensor.rand(3, 5) for i in range(1, 6): vi = Variable("i", 1, 5).bind(i) - a = Tensor.rand(3, i).reshape(3, vi) - a = a + 1 - self.assertTupleEqual(a.shape, (3, vi)) + ret = a[:, :vi] + ret = ret + 1 + self.assertTupleEqual(ret.shape, (3, vi)) def test_pad_then_expand_into_symbols(self): vi = Variable("i", 1, 10).bind(3) @@ -234,6 +183,11 @@ class TestSymbolicExpand(unittest.TestCase): self.assertEqual(a.reshape(vi*25).shape, (vi*25,)) class TestSymbolicShrink(unittest.TestCase): + def test_shrink_symbols_simple(self): + vi = Variable("i", 1, 5) + t = Tensor.rand(5, 5).shrink(((0, 5),(0,vi))) + assert t.shape == (5, vi) + def test_shrink_symbols(self): vi = Variable("i", 1, 5) t = Tensor.rand(3, 5).shrink(((0, 2), (vi, vi+1))) @@ -242,10 +196,10 @@ class TestSymbolicShrink(unittest.TestCase): class TestSymbolicPad(unittest.TestCase): def test_pad(self): v = Variable("v", 1, 100).bind(5) - t = Tensor.ones(5).reshape(v).pad(((4, 0),)).reshape(9) - assert t.shape == (9,) - st = t.uop.st - print(st) + t = Tensor.ones(100)[:v].pad(((4, 0),)) + t = t[:9] + assert t.tolist() == [0,0,0,0,1,1,1,1,1] + if __name__ == '__main__': unittest.main() diff --git a/tinygrad_repo/test/test_tensor_data.py b/tinygrad_repo/test/unit/test_tensor_data.py similarity index 100% rename from tinygrad_repo/test/test_tensor_data.py rename to tinygrad_repo/test/unit/test_tensor_data.py diff --git a/tinygrad_repo/test/unit/test_tensor_uop_representation.py b/tinygrad_repo/test/unit/test_tensor_uop_representation.py index 59d28a43..fe8d47bb 100644 --- a/tinygrad_repo/test/unit/test_tensor_uop_representation.py +++ b/tinygrad_repo/test/unit/test_tensor_uop_representation.py @@ -34,7 +34,7 @@ class TestTensorMutates(unittest.TestCase): is_pattern_uop(c.uop.base, realized_pattern) # NOTE: we keep movement ops on top of the buffer view is_pattern_uop(c.uop, UPat(Ops.BUFFER)) - is_pattern_uop(d.uop, UPat(Ops.VIEW, src=(realized_pattern,))) + assert d.uop is not d.uop.base def test_reshape_is_same_child(self): a = Tensor([1,2,3]) @@ -58,46 +58,12 @@ class TestTensorUopRepresentation(unittest.TestCase): print(c.uop) is_pattern(c, UPat(Ops.ADD, src=(realized_pattern, realized_pattern))) - def test_const_pattern(self): - a = Tensor(1) - print(a.uop) - is_pattern(a, const_pattern) # const in tensor has a DEVICE and VIEW src - is_pattern(a, UPat.cvar("x")) # even cvar works! - - def test_consts_do_not_realize(self): - a = Tensor(1) - print(a.uop) - pre_realize = a.uop - a.realize() - assert a.uop is pre_realize - - def test_viewed_consts_do_not_realize(self): - a = Tensor.ones(10, 10) - print(a.uop) - a.realize() - is_pattern(a, const_pattern) - self.assertEqual(a.uop.shape, (10, 10)) - - # CONST is EXPAND -> RESHAPE -> CONST -> DEVICE - def test_consts_dont_have_buffers(self): - a = Tensor.ones(10, 10) - buffers_in_parents = [x.op for x in a.uop.toposort() if x.op is Ops.BUFFER] - self.assertEqual(len(buffers_in_parents), 0) - is_pattern(a, UPat(Ops.EXPAND, src=(UPat(Ops.RESHAPE, src=(const_pattern,)),))) - - # COPY has a copyin source and a device. - def test_copyin(self): - a = Tensor([1.,2,3]).realize() - c = a.to("TEST") # NOTE: this isn't checked - print(c.uop) - is_pattern(c, UPat(Ops.COPY, src=(realized_pattern, UPat(Ops.DEVICE)), arg=None)) - def test_empty_buf(self): a = Tensor.empty(3, 3) is_pattern(a, UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),))) vi = UOp.variable("i", 1, 3).bind(1) a = Tensor.empty(3, vi) - is_pattern(a, UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),))) + is_pattern(a, UPat(Ops.RESHAPE, src=(UPat(Ops.SHRINK, src=(UPat(Ops.BUFFER),))),)) self.assertEqual(a.uop.base.buffer.size, 9) if __name__ == '__main__': diff --git a/tinygrad_repo/test/unit/test_uop_spec.py b/tinygrad_repo/test/unit/test_uop_spec.py index 0244ba35..65559897 100644 --- a/tinygrad_repo/test/unit/test_uop_spec.py +++ b/tinygrad_repo/test/unit/test_uop_spec.py @@ -81,5 +81,16 @@ class TestUOpSpec(unittest.TestCase): with self.assertRaisesRegex(RuntimeError, "UOp verification failed"): type_verify([a], tensor_uop_spec) +class TestUOpSink(unittest.TestCase): + def test_0(self): + s = UOp.sink() + self.assertEqual(len(s.src), 0) + + def test_1(self): + a = UOp.const(dtypes.int, 0) + s1 = UOp.sink(a) + s2 = a.sink() + self.assertIs(s1, s2) + if __name__ == '__main__': unittest.main() diff --git a/tinygrad_repo/test/unit/test_uop_symbolic.py b/tinygrad_repo/test/unit/test_uop_symbolic.py index 5d116c1a..c7ee4488 100644 --- a/tinygrad_repo/test/unit/test_uop_symbolic.py +++ b/tinygrad_repo/test/unit/test_uop_symbolic.py @@ -2,22 +2,20 @@ import unittest, pickle, functools, math import z3 -from tinygrad.dtype import dtypes, ConstType +from tinygrad.dtype import dtypes, ConstType, DType, Invalid from tinygrad.codegen import full_rewrite -from tinygrad.codegen.late.devectorizer import sym from tinygrad.helpers import Context -from tinygrad.uop.ops import UOp, Ops, graph_rewrite, sym_infer -from tinygrad import Variable -from tinygrad.uop.spec import z3_renderer +from tinygrad.uop.ops import UOp, Ops, graph_rewrite, sym_infer, track_rewrites +from tinygrad.uop.symbolic import sym +from tinygrad.uop.spec import uops_to_z3 -def render(self) -> tuple[str, ConstType, ConstType]: - # NOTE: we need STORE so the ALU op has children - glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0) - uops = full_rewrite(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0)), self)).sink()) - rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[1] - return rewritten_uop.render(simplify=False), rewritten_uop.vmin, rewritten_uop.vmax +@track_rewrites(name="simplify symbolic uop") +def render(v) -> UOp: + v_simplified = graph_rewrite(v, sym) + return v_simplified -def uconst(val): return UOp.const(dtypes.int, val) +def Variable(name: str, min_val: ConstType, max_val: ConstType, dtype: DType=dtypes.index): return UOp.variable(name,min_val,max_val,dtype) +def uconst(val): return UOp.const(dtypes.index, val) def usum(ops): return functools.reduce(lambda x,y: x+y, ops) def uand(ops): return functools.reduce(lambda x,y: x*y, ops) @@ -30,12 +28,12 @@ class TestSymbolicPickle(unittest.TestCase): class TestSymbolic(unittest.TestCase): def helper_test_variable(self, v, n, m, s, test_z3:bool=True): + v_simplified = render(v) if test_z3: solver = z3.Solver() - z3_sink = graph_rewrite(v.sink(v.simplify()), z3_renderer, ctx=(solver, {})) - expr, epxr_simplified = z3_sink.src[0].arg, z3_sink.src[1].arg - self.assertEqual(solver.check(expr != epxr_simplified), z3.unsat, "simplified expression not equal to original") - rendered, nmin, nmax = render(v) + expr, expr_simplified = uops_to_z3(solver, v, v_simplified) + self.assertEqual(solver.check(expr != expr_simplified), z3.unsat, "simplified expression not equal to original") + rendered, nmin, nmax = v_simplified.render(simplify=False), v_simplified.vmin, v_simplified.vmax if isinstance(s, tuple): self.assertIn(rendered, s) else: self.assertEqual(rendered, s) self.assertEqual(nmin, n) @@ -95,6 +93,37 @@ class TestSymbolic(unittest.TestCase): assert idx1+idx2 is not idx2 assert idx1*idx2 is not idx2*idx1 + def test_uop_gcd_method(self): + a = Variable("a", 0, 8) + b = Variable("b", 0, 8) + self.assertEqual(UOp.gcd(a, a*b, a*3).simplify(), a) + self.assertEqual(UOp.gcd(a*a*a, a*b*a, a*3*a).simplify(), a*a) + self.assertEqual(UOp.gcd(a*a*10, b*a*5, a*a*5).simplify(), a*5) + self.assertEqual(UOp.gcd(a*10, b*5, a*5).simplify(), a.const_like(5)) + self.assertEqual(UOp.gcd(a, b*5, a*5).simplify(), a.const_like(1)) + + def test_divides_exact(self): + a = Variable("a", 1, 8) + b = Variable("b", 1, 8) + self.assertEqual((a*a*3).divide_exact(a).simplify(), a*3) + self.assertEqual((a*a*3).divide_exact(a*a*3).simplify(), a.const_like(1)) + self.assertEqual((a*b*3).divide_exact(a.const_like(3)).simplify(), a*b) + self.assertEqual((a*a*3).divide_exact(a*a.const_like(-3)).simplify(), a*-1) + self.assertEqual((a*a*b*3).divide_exact(a*b).simplify(), a*3) + self.assertEqual((a*3+a*b).divide_exact(a).simplify(), b+3) + self.assertEqual((a*b*3+a*b*b).divide_exact(a*b).simplify(), b+3) + self.assertEqual((((a*-2)+14)*b).divide_exact(((a*-2)+14)).simplify(), b) + + def test_divide_exact_not(self): + a = Variable("a", 1, 8) + b = Variable("b", 1, 8) + x = Variable("x", -20, 0) + self.assertEqual((a).divide_exact(b), None) + self.assertEqual((a+2).divide_exact(a), None) + self.assertEqual((x*-1).divide_exact(a), None) + self.assertEqual((a*5).divide_exact(a*10), None) + self.assertEqual((a*10-1).divide_exact(a*10), None) + def test_factorize(self): a = Variable("a", 0, 8) b = Variable("b", 0, 8) @@ -112,7 +141,7 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable(-Variable("a", 0, 8), -8, 0, "(a*-1)") def test_xor_0(self): - self.helper_test_variable(Variable("a", 0, 8) ^ 0, 0, 8, "a") + self.helper_test_variable(Variable("a", 0, 8, dtypes.int) ^ 0, 0, 8, "a", test_z3=False) def test_add_1(self): self.helper_test_variable(Variable("a", 0, 8)+1, 1, 9, "(a+1)") @@ -209,6 +238,16 @@ class TestSymbolic(unittest.TestCase): self.assertEqual((Variable("x", -10, 0)%Variable("y", -10, -1))._min_max, (-9, 0)) self.assertEqual((Variable("x", -10, 0)%Variable("y", 1, 10))._min_max, (-9, 0)) + def test_range_div_its_symbolic_bound(self): + a = Variable("a", 1, 10, dtypes.index) + ridx0 = UOp.range(a+2, 0) + self.helper_test_variable(ridx0//(a+2), 0, 0, "0") + + def test_range_mod_its_symbolic_bound(self): + a = Variable("a", 1, 10, dtypes.index) + ridx = UOp.range(a+2, 0) + self.helper_test_variable(ridx%(a+2), 0, 11, "ridx0") + def test_div_min_max(self): self.helper_test_variable(Variable("a", 2, 7) // 2, 1, 3, "(a//2)") self.helper_test_variable(Variable("a", 0, 6) // 2, 0, 3, "(a//2)") @@ -366,6 +405,16 @@ class TestSymbolic(unittest.TestCase): def test_and_remove(self): self.helper_test_variable(uand([uconst(1), Variable("a", 0, 1)]), 0, 1, "a") + def test_bool_or_not_tautology(self): + a = Variable("a", 0, 10) + c = a<10 + self.helper_test_variable(c | c.logical_not(), True, True, "True") + + def test_bool_and_not_contradiction(self): + a = Variable("a", 0, 10) + c = a<10 + self.helper_test_variable(c & c.logical_not(), False, False, "False") + def test_mod_factor_negative(self): self.helper_test_variable(usum([uconst(-29), Variable("a", 0, 10), Variable("b", 0, 10)*28]) % 28, -27, 27, "(((a+(b*28))+-29)%28)") self.helper_test_variable(usum([uconst(-29), Variable("a", 0, 100), Variable("b", 0, 10)*28]) % 28, -27, 27, "(((a+(b*28))+-29)%28)") @@ -432,6 +481,33 @@ class TestSymbolic(unittest.TestCase): def test_mul_div_factor_div_neg(self): self.helper_test_variable((Variable("a", 0, 10)*-4+4)//8, -4, 0, "(((a*-1)+1)//2)") + def test_div_symbolic_const_gcd(self): + a = Variable("a", -10, 10) + b = Variable("b", -10, 10) + d = Variable("d", 1, 10) + self.helper_test_variable((3*a+9*b)//(3*d), -40, 40, "((a+(b*3))//d)") + + def test_symbolic_gcd_div(self): + a = Variable("a", -10, 10) + b = Variable("b", -10, 10) + c = Variable("c", -10, 10) + d1 = Variable("d1", 1, 10) + d2 = Variable("d2", -10, -1) + self.helper_test_variable((d1*a*b*d1)//(d1), -1000, 1000, "(a*(b*d1))") + self.helper_test_variable((d1*a*d2*b*d1)//(d1*d2), -1000, 1000, "(a*(b*d1))") + self.helper_test_variable((d1*a + b*d1)//(d1), -20, 20, "(a+b)") + self.helper_test_variable((d1*a + b*d1 + c*d1)//(d1), -30, 30, "(c+(a+b))") + self.helper_test_variable((3*a*d1 + 9*b*d1)//(3*d1*d2), -40, 40, "(((a+(b*3))//(d2*-1))*-1)") + self.helper_test_variable((3*a*d1 + 9*b*d1+3)//(3*d1*d2), -401, 399, "(((((a*d1)+((b*d1)*3))+1)//((d1*d2)*-1))*-1)") + + def test_symbolic_factor_remainder_div(self): + a = Variable("a", 0, 10) + b = Variable("b", 0, 10) + d = Variable("d", 1, 10) + self.helper_test_variable((d*a+b)//d, 0, 20, "(a+(b//d))") + self.helper_test_variable((d*a*20+b)//(5*d), 0, 42, "((a*4)+(b//(d*5)))") + self.helper_test_variable((d*a*20+b*d*5+10)//(5*d), 0, 52, "((b+(a*4))+(2//d))") + def test_mod_gcd_factor_neg(self): self.helper_test_variable((Variable("a", 0, 10)*-4+4)%8, -4, 4, "((((a*-1)+1)%2)*4)") @@ -454,8 +530,8 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable((Variable("idx", 0, 9)*-10)//11, -8, 0, "(((idx*10)//11)*-1)") def test_nest_div_negative_factor(self): - ridx0=UOp.variable("ridx0", 0, 9) - ridx1=UOp.variable("ridx1", 0, 6) + ridx0=Variable("ridx0", 0, 9) + ridx1=Variable("ridx1", 0, 6) self.helper_test_variable(((((ridx0*-7)+ridx1)+63)//35), 0, 1, "(((ridx0//5)*-1)+1)") def test_div_into_mod(self): @@ -524,8 +600,8 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable(x//y, 2, 2, "2") self.helper_test_variable(x%y, 0, 7, "(x+(y*-2))") # ensure all 4 corners are checked - x = Variable("x", -10, 10) - y = Variable("y", -8, 9) + x = Variable("x", -10, 10, dtypes.int) + y = Variable("y", -8, 9, dtypes.int) self.helper_test_variable(x//y, -2147483648, 2147483647, "(x//y)") self.helper_test_variable(x%y, -2147483648, 2147483647, "(x%y)") @@ -551,8 +627,13 @@ class TestSymbolic(unittest.TestCase): def test_div_mod_recombine(self): gidx = Variable("gidx", 0, 124) + lidx = Variable("lidx", 0, 124) self.helper_test_variable(gidx%4+(gidx//4)*4, 0, 124, "gidx") self.helper_test_variable((gidx//4)*4+gidx%4, 0, 124, "gidx") + self.helper_test_variable(lidx+gidx%4+(gidx//4)*4, 0, 248, "(gidx+lidx)") + self.helper_test_variable(lidx+(gidx//4)*4+gidx%4, 0, 248, "(gidx+lidx)") + self.helper_test_variable(lidx+(gidx//4)*8+2*(gidx%4), 0, 372, "(lidx+(gidx*2))") + self.helper_test_variable(lidx+2*(gidx%4)+(gidx//4)*8, 0, 372, "(lidx+(gidx*2))") def test_div_mod_recombine_folded_mod(self): a = Variable("a", 0, 2) @@ -568,39 +649,6 @@ class TestSymbolic(unittest.TestCase): with self.assertRaises(AssertionError): self.helper_test_variable((30 * b + 1) % 18 + ((30 * b + 1) // 18) * 18, 1, 3001, "((b*30)+1)") - def test_arange_unrolled4(self): - gidx = Variable("gidx", 0, 2559) - unrolled_div = (gidx+2561)//4+(gidx+2562)//4+(gidx+2560)//4+(gidx+2559)//4 - self.helper_test_variable(unrolled_div, 2559, 5118, "(gidx+2559)") - - def test_arange_unrolled4_mul(self): - gidx = Variable("gidx", 0, 2559) - unrolled_div = 2*((gidx+2561)//4)+2*((gidx+2562)//4)+2*((gidx+2560)//4)+2*((gidx+2559)//4) - self.helper_test_variable(unrolled_div, 5118, 10236, "((gidx*2)+5118)") - - def test_arange_unrolled4_small(self): - gidx = Variable("gidx", 0, 3) - unrolled_div = (gidx)//4+(gidx+2)//4+(gidx+3)//4+(gidx+1)//4 - self.helper_test_variable(unrolled_div, 0, 3, "gidx") - - gidx = Variable("gidx", 0, 2) - unrolled_div = (gidx)//4+(gidx+2)//4+(gidx+3)//4+(gidx+1)//4 - self.helper_test_variable(unrolled_div, 0, 2, "gidx") - - gidx = Variable("gidx", 0, 1) - unrolled_div = (gidx)//4+(gidx+2)//4+(gidx+3)//4+(gidx+1)//4 - self.helper_test_variable(unrolled_div, 0, 1, "gidx") - - def test_arange_unrolled2(self): - gidx = Variable("gidx", 0, 2559) - unrolled_div = (gidx+2559)//2+(gidx+2560)//2+3 - self.helper_test_variable(unrolled_div, 2562, 5121, "(gidx+2562)") - - def test_arange_unrolled2_neg(self): - ridx = Variable("ridx", 0, 255) - unrolled_div = -((255-ridx)//2) - ((256-ridx)//2) - self.helper_test_variable(unrolled_div, -255, 0, "(ridx+-255)") - def test_gated_load(self): idx = Variable("idx", 0, 24) self.helper_test_variable(idx//4, 0, 6, "(idx//4)") @@ -640,15 +688,16 @@ class TestSymbolic(unittest.TestCase): cond = Variable("x", 0, 3) < 2 a = Variable("a", 0, 3) b = Variable("b", 0, 3) + c = Variable("c", 0, 3) aa = cond.where(a, a.ufix(0)) bb = cond.where(b, b.ufix(1)) self.helper_test_variable(aa, 0, 3, "(a if (x<2) else 0)") self.helper_test_variable(bb, 0, 3, "(b if (x<2) else 1)") self.helper_test_variable(aa+bb, 0, 6, "((a+b) if (x<2) else 1)") self.helper_test_variable(aa.maximum(bb), 0, 3, "(max(a, b) if (x<2) else 1)") + self.helper_test_variable((c+aa)+bb, 0, 9, "(c+((a+b) if (x<2) else 1))") # not combining because it increased total ALU - c = Variable("c", 0, 3) cc = cond.where(c, c+1) self.helper_test_variable(bb+cc, 0, 7, "((b if (x<2) else 1)+(c if (x<2) else (c+1)))") @@ -673,10 +722,10 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable(-a<-b, False, True, "(b Generator[dict, None, None]: + lst = get_viz_list() + assert len(lst) > rewrite_idx, "only loaded {len(lst)} traces, expecting at least {idx}" + return get_details(tracked_ctxs[rewrite_idx][step]) class BaseTestViz(unittest.TestCase): def setUp(self): @@ -66,7 +70,6 @@ class TestViz(BaseTestViz): def test_rewrite_location(self): def inner(sink): return graph_rewrite(sink, PatternMatcher([])) - @track_rewrites(name=True) def outer(sink): return inner(sink) outer(UOp.variable("a", 1, 10)) lst = get_viz_list() @@ -130,7 +133,7 @@ class TestViz(BaseTestViz): (UPat(Ops.CONST, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)), ]) with self.assertRaises(RuntimeError): exec_rewrite(a, [pm]) - graphs = flatten(x["graph"].values() for x in get_details(tracked_ctxs[0][0])) + graphs = flatten(x["graph"].values() for x in get_viz_details(0, 0)) self.assertEqual(graphs[0], uop_to_json(a)[id(a)]) self.assertEqual(graphs[1], uop_to_json(b)[id(b)]) # fallback to NOOP with the error message @@ -139,10 +142,12 @@ class TestViz(BaseTestViz): def test_const_node_visibility(self): a = UOp.variable("a", 0, 10) - z = UOp.const(dtypes.int, 0) + z = UOp.const(dtypes.index, 0) alu = a*z exec_rewrite(alu, [sym]) - graphs = [x["graph"] for x in get_details(tracked_ctxs[0][0])] + lst = get_viz_list() + self.assertEqual(len(lst), 1) + graphs = [x["graph"] for x in get_viz_details(0, 0)] # embed const in the parent node when possible self.assertEqual(list(graphs[0]), [id(a), id(alu)]) self.assertEqual(list(graphs[1]), [id(z)]) @@ -175,7 +180,6 @@ class TestVizTree(BaseTestViz): c = UOp.variable("c",0,10) d = UOp.variable("d",0,10) sink = UOp.sink(a+b, c+d) - @track_rewrites() def tree_rewrite(): return graph_rewrite(sink, root, name="root") tree_rewrite() lst = get_viz_list() @@ -246,9 +250,46 @@ class TestVizIntegration(BaseTestViz): b = Tensor.empty(1) metadata = (alu:=a+b).uop.metadata alu.kernelize() - graph = next(get_details(tracked_ctxs[0][0]))["graph"] + graph = next(get_viz_details(0, 0))["graph"] self.assertEqual(len([n for n in graph.values() if repr(metadata) in n["label"]]), 1) + # tracing also works without a track_rewrites context + # all graph_rewrites get put into the default group + def test_default_tracing(self): + def test(root): + return graph_rewrite(root, sym) + test(c:=UOp.const(dtypes.int, 1)) + test(c+1) + ls = get_viz_list() + self.assertEqual(len(ls), 1) + self.assertEqual(ls[0]["name"], "default graph_rewrite") + + # using @track_rewrites organizes function calls into groups + # and nicely counts function calls. + def test_group_traces(self): + @track_rewrites() + def test(root): + return graph_rewrite(root, sym) + test(c:=UOp.const(dtypes.int, 1)) + test(c+1) + ls = get_viz_list() + self.assertEqual(len(ls), 2) + for i in range(2): self.assertEqual(ls[i]["name"], f"test n{i+1}") + + # @track_rewrites always starts a new group. + def test_group_combined(self): + def default_test(root): return graph_rewrite(root, sym) + tracked_test = track_rewrites()(default_test) + c = UOp.const(dtypes.int, 1) + default_test(c+1) # goes to the default group + tracked_test(c) # all rewrites after this go inside the second group. + default_test(c+2) + ls = get_viz_list() + self.assertEqual(len(ls), 2) + self.assertEqual(list(next(get_viz_details(0, 0))["graph"]), [id(c+1)]) + self.assertEqual(list(next(get_viz_details(1, 0))["graph"]), [id(c)]) + self.assertEqual(list(next(get_viz_details(1, 1))["graph"]), [id(c+2)]) + from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry from tinygrad.viz.serve import get_profile @@ -265,27 +306,27 @@ def option(i:int) -> int|None: return None if i == 0 else i-1 def load_profile(lst:list[ProfileEvent]) -> dict: ret = get_profile(lst) u = TinyUnpacker(ret) - dur, global_peak, index_len, layout_len = u(" NV:1') self.assertEqual(nv1_events[0]['st'], 954) #self.assertEqual(j['devEvents'][7]['pid'], j['devEvents'][3]['pid']) - graph_events = j['layout']['NV Graph']['shapes'] + graph_events = j['layout']['NV Graph']['events'] self.assertEqual(graph_events[0]['st'], nv_events[0]['st']) self.assertEqual(graph_events[0]['st']+graph_events[0]['dur'], nv1_events[0]['st']+nv1_events[0]['dur']) @@ -364,6 +407,21 @@ class TestVizProfiler(unittest.TestCase): with self.assertRaises(struct.error): get_profile(prof) + def test_python_marker(self): + with Context(VIZ=1): + a = Tensor.empty(1, device="NULL") + b = Tensor.empty(1, device="NULL") + (a+b).realize() + profile_marker("test 1") + (a*b).realize() + profile_marker("test 2") + profile_ret = load_profile(cpu_events) + markers = profile_ret["markers"] + kernels = profile_ret["layout"]["NULL"]["events"] + self.assertEqual(len(markers), 2) + assert kernels[0]["st"] <= markers[0]["ts"] <= kernels[1]["st"] + assert markers[1]["ts"] >= kernels[1]["st"]+kernels[1]["dur"] + def _alloc(b:int): a = Tensor.empty(b, device="NULL", dtype=dtypes.char) a.uop.buffer.allocate() @@ -376,7 +434,7 @@ class TestVizMemoryLayout(BaseTestViz): profile_ret = load_profile(Buffer.profile_events) ret = profile_ret["layout"][f"{a.device} Memory"] self.assertEqual(ret["peak"], 2) - self.assertEqual(len(ret["shapes"]), 2) + self.assertEqual(len(ret["events"]), 2) def test_del_once(self): a = _alloc(1) @@ -385,7 +443,7 @@ class TestVizMemoryLayout(BaseTestViz): profile_ret = load_profile(Buffer.profile_events) ret = profile_ret["layout"][f"{b.device} Memory"] self.assertEqual(ret["peak"], 1) - self.assertEqual(len(ret["shapes"]), 3) + self.assertEqual(len(ret["events"]), 3) def test_alloc_free(self): a = _alloc(1) @@ -395,7 +453,7 @@ class TestVizMemoryLayout(BaseTestViz): profile_ret = load_profile(Buffer.profile_events) ret = profile_ret["layout"][f"{c.device} Memory"] self.assertEqual(ret["peak"], 2) - self.assertEqual(len(ret["shapes"]), 4) + self.assertEqual(len(ret["events"]), 4) if __name__ == "__main__": unittest.main() diff --git a/tinygrad_repo/test/test_winograd.py b/tinygrad_repo/test/unit/test_winograd.py similarity index 76% rename from tinygrad_repo/test/test_winograd.py rename to tinygrad_repo/test/unit/test_winograd.py index 453e910f..d1ddfc8b 100644 --- a/tinygrad_repo/test/test_winograd.py +++ b/tinygrad_repo/test/unit/test_winograd.py @@ -1,8 +1,9 @@ -import unittest +import unittest, sys import numpy as np from tinygrad import Tensor, GlobalCounters, dtypes, Context, nn -from tinygrad.helpers import CI, Profiling, WINO, getenv +from tinygrad.helpers import CI, Profiling, WINO +@unittest.skipIf(sys.platform.startswith("win"), "flaky on Windows") class TestWinogradClose(unittest.TestCase): def test_close(self): inp = Tensor.rand(1, 16, 16, 16) @@ -18,6 +19,7 @@ class TestWinogradClose(unittest.TestCase): test = conv(inp).realize() np.testing.assert_allclose(cmp.numpy(), test.numpy(), atol=1e-5) +@unittest.skipIf(sys.platform.startswith("win"), "flaky on Windows") class TestWinograd(unittest.TestCase): def setUp(self): self.old = WINO.value @@ -28,17 +30,20 @@ class TestWinograd(unittest.TestCase): def test_profile(self): x,w = Tensor.rand(1,4,9,9).realize(), Tensor.rand(4,4,3,3).realize() with Profiling(enabled=not CI, sort='time'): - out = Tensor.conv2d(x,w).realize() - out.numpy() + Tensor.conv2d(x,w).realize() - def test_four_kernels(self): + def test_forward_kernels(self): x,w = Tensor.rand(1,4,9,9).realize(), Tensor.rand(4,4,3,3).realize() - GlobalCounters.reset() - out = Tensor.conv2d(x,w).realize() - assert GlobalCounters.kernel_count == 4 - out.numpy() + out = Tensor.conv2d(x,w) + self.assertEqual(len(out.schedule()), 4) + + def test_backward_kernels(self): + x,w = Tensor.empty(1,4,9,9,requires_grad=True).realize(), Tensor.empty(4,4,3,3,requires_grad=True).realize() + out = Tensor.conv2d(x,w, padding=1) + out.mean().backward() + backward_schedule = Tensor.schedule(x.grad, w.grad) + self.assertEqual(len(backward_schedule), 9) - @unittest.skipIf(getenv("PTX"), "winograd uses too much in PTX") def test_counters(self): IC, OC, X, Y = 4,4,9,9 #OC, IC, X, Y = 512, 256, 8, 8 diff --git a/tinygrad_repo/tinygrad/apps/llm.py b/tinygrad_repo/tinygrad/apps/llm.py index 4f80d4bf..5d79e69b 100644 --- a/tinygrad_repo/tinygrad/apps/llm.py +++ b/tinygrad_repo/tinygrad/apps/llm.py @@ -53,17 +53,15 @@ class SimpleTokenizer: try: return [ self._normal_tokens[p] for p in parts ] except KeyError: raise RuntimeError("token not found") -def apply_rope(x:Tensor, start_pos:int|UOp, base:int=10000): +def apply_rope(x:Tensor, start_pos:int|UOp, base:float = 10000.0) -> Tensor: B, H, T, Hd = x.shape - # NOTE: this is usually in a RoPE cache, but tinygrad JIT should prune it outside the kernel - # TODO: make it do that - freq = base ** (-Tensor.arange(0, 1, 2/Hd, dtype='float32')) - angles = Tensor.arange(start_pos, start_pos+T, dtype='float32')[None, None, :, None] * freq - cos, sin = angles.cos(), angles.sin() - x = x.reshape(B, H, T, Hd // 2, 2) # split into pairs - y1 = x[..., 0] * cos - x[..., 1] * sin - y2 = x[..., 0] * sin + x[..., 1] * cos - return Tensor.stack(y1, y2, dim=-1).reshape(B, H, T, Hd) + assert (Hd & 1) == 0, "RoPE requires an even head dimension" + half = Hd // 2 + angles = (Tensor.arange(T, dtype="float32") + start_pos)[:, None] * (base ** (-(Tensor.arange(half, dtype="float32") / half)))[None, :] + cos, sin = angles.cos().reshape(1, 1, T, half).cast(x.dtype), angles.sin().reshape(1, 1, T, half).cast(x.dtype) + x_pairs = x.reshape(B, H, T, half, 2) + return Tensor.stack(x_pairs[..., 0] * cos - x_pairs[..., 1] * sin, + x_pairs[..., 0] * sin + x_pairs[..., 1] * cos, dim=-1).reshape(B, H, T, Hd) class TransformerBlock: def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int=0): diff --git a/tinygrad_repo/tinygrad/codegen/__init__.py b/tinygrad_repo/tinygrad/codegen/__init__.py index 6a38d1dd..01b51555 100644 --- a/tinygrad_repo/tinygrad/codegen/__init__.py +++ b/tinygrad_repo/tinygrad/codegen/__init__.py @@ -2,7 +2,7 @@ from typing import Any, Callable import functools from dataclasses import dataclass from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL -from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp +from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype from tinygrad.uop.spec import type_verify from tinygrad.renderer import Renderer @@ -12,12 +12,14 @@ from tinygrad.codegen.quantize import pm_quant from tinygrad.codegen.gpudims import pm_add_gpudims from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing from tinygrad.uop.decompositions import get_late_rewrite_patterns -from tinygrad.codegen.late.expander import migrate_indexing, expander +from tinygrad.codegen.late.expander import migrate_indexing, expander, pm_pre_expander from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \ ReduceContext, correct_load_store, pm_render from tinygrad.codegen.late.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext -from tinygrad.codegen.opt import pm_get_optimization, pm_do_optimize from tinygrad.codegen.opt.swizzler import view_left, view_right, fix_kernel_ops +from tinygrad.codegen.opt.postrange import pm_postrange_opt +from tinygrad.codegen.simplify import pm_simplify_ranges, pm_reduce_simplify, pm_flatten_range +from tinygrad.schedule.rangeify import pm_add_buffers, rangeify_codegen @dataclass class RewriteStep: @@ -42,38 +44,47 @@ rewrites_for_linearizer = [ RewriteStep(block_merge, name="Linearizer: Merge Blocks"), RewriteStep(pm_finalize, name="Linearizer: Finalize")] -def get_rewrites_for_renderer(opts:Renderer, linearizer:bool=True) -> list[RewriteStep]: +def get_rewrites_for_renderer(opts:Renderer, optimize:bool=True, linearizer:bool=True) -> list[RewriteStep]: # cache with the values of the context vars - return _get_rewrites_for_renderer(opts, linearizer, QUANTIZE.value, DEVECTORIZE.value, TRANSCENDENTAL.value) + return _get_rewrites_for_renderer(opts, optimize, linearizer, QUANTIZE.value, DEVECTORIZE.value, TRANSCENDENTAL.value) @functools.cache -def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVECTORIZE, _TRANSCENDENTAL) -> list[RewriteStep]: +def _get_rewrites_for_renderer(opts:Renderer, optimize:bool, linearizer:bool, _QUANTIZE, _DEVECTORIZE, _TRANSCENDENTAL) -> list[RewriteStep]: # ** lowerer (rewrite_shapetracker_with_index) ** ret: list[RewriteStep] = [] - # view pushing - ret.extend(rewrites_for_views) + if optimize: + # view pushing + ret.extend(rewrites_for_views) - # this is kernel.py - ret.append(RewriteStep(pm_get_optimization, ctx=lambda _: opts, name="get optimization")) - ret.append(RewriteStep(pm_do_optimize, ctx=lambda _: opts, name="optimize ast")) + # lowerer first + if _QUANTIZE and opts.device in {"CPU", "DSP"}: ret.append(RewriteStep(pm_quant, name="quantize")) + ret.append(RewriteStep(pm_lowerer, get_index, name="lowerer", bottom_up=True)) - if _QUANTIZE and opts.device in {"CPU", "DSP"}: ret.append(RewriteStep(pm_quant, name="quantize")) - ret.append(RewriteStep(pm_lowerer, get_index, name="lowerer", bottom_up=True)) + # symbolic (NOTE: this is a requirement for pm_simplify_ranges to be correct) + ret.append(RewriteStep(sym+pm_flatten_range, name="initial symbolic")) + + # optimize (schedule) the AST + ret.append(RewriteStep(pm_simplify_ranges, name="simplify ranges")) + ret.append(RewriteStep(pm_reduce_simplify, name="simplify reduces")) + ret.append(RewriteStep(pm_postrange_opt, ctx=lambda _: opts, name="post optimize ast")) # ** expander (expand_rewrite) ** - ret.append(RewriteStep(sym+migrate_indexing, name="initial symbolic")) - - # add gpu dims (late). this also handles UNROLL range - ret.append(RewriteStep(pm_add_gpudims, lambda _: opts, name="add gpudims")) + ret.append(RewriteStep(sym+migrate_indexing, name="postopt symbolic")) # expand - ret.append(RewriteStep(sym+expander, name="expander")) + ret.append(RewriteStep(sym+pm_pre_expander+expander, name="expander")) + + # add locals + ret.append(RewriteStep(pm_add_buffers+rangeify_codegen, name="add local buffers")) # ** devectorizer (full_graph_rewrite) ** # remove reduce ret.append(RewriteStep(pm_reduce+gep_pushing, lambda _: ReduceContext(), name="remove_reduce")) + # add gpu dims (late). this works after devectorize, but it's faster here + ret.append(RewriteStep(pm_add_gpudims, lambda _: opts, name="add gpudims")) + # devectorize (TODO: does this need opts?) if _DEVECTORIZE >= 2: pm_devectorize = sym+load_store_folding+load_store_indexing elif _DEVECTORIZE: pm_devectorize = sym+devectorize+load_store_folding+correct_load_store+load_store_indexing @@ -83,6 +94,9 @@ def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVEC supported_ops = tuple(opts.code_for_op.keys()) extra_matcher = opts.extra_matcher if opts.extra_matcher is not None else PatternMatcher([]) + # lower the index dtype to a concrete int + ret.append(RewriteStep(pm_lower_index_dtype+load_store_indexing, lambda _: opts.device, name="lower all index dtypes")) + # optional pre matcher if opts.pre_matcher is not None: ret.append(RewriteStep(opts.pre_matcher, name="pre_matcher")) @@ -97,8 +111,8 @@ def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVEC # return the list (with optional linearizer) return ret + (rewrites_for_linearizer if linearizer else []) -def full_rewrite_to_sink(sink:UOp, opts:Renderer|None=None, linearizer:bool=False) -> UOp: - return apply_rewrites(sink, get_rewrites_for_renderer(opts if opts is not None else Renderer(), linearizer)) +def full_rewrite_to_sink(sink:UOp, opts:Renderer|None=None, optimize:bool=True, linearizer:bool=False) -> UOp: + return apply_rewrites(sink, get_rewrites_for_renderer(opts if opts is not None else Renderer(), optimize, linearizer)) def full_rewrite(sink:UOp, opts:Renderer|None=None) -> list[UOp]: """ @@ -112,6 +126,6 @@ def full_rewrite(sink:UOp, opts:Renderer|None=None) -> list[UOp]: Linear program in UOps. """ - lst = list(full_rewrite_to_sink(sink, opts, linearizer=True).arg.lst) + lst = list(full_rewrite_to_sink(sink, opts, optimize=sink.tag is None, linearizer=True).arg.lst) if __debug__: type_verify(lst) return lst diff --git a/tinygrad_repo/tinygrad/codegen/gpudims.py b/tinygrad_repo/tinygrad/codegen/gpudims.py index fd4e1d84..3bc67e51 100644 --- a/tinygrad_repo/tinygrad/codegen/gpudims.py +++ b/tinygrad_repo/tinygrad/codegen/gpudims.py @@ -1,6 +1,6 @@ import math -from tinygrad.uop.ops import UOp, Ops, sint, PatternMatcher, UPat, KernelInfo, ssimplify, AxisType -from tinygrad.helpers import all_int, partition, flatten, prod, dedup +from tinygrad.uop.ops import UOp, Ops, sint, PatternMatcher, UPat, KernelInfo, ssimplify, AxisType, sint_to_uop +from tinygrad.helpers import all_int, dedup from tinygrad.dtype import dtypes from tinygrad.shape.view import get_contraction from tinygrad.renderer import Renderer @@ -34,7 +34,7 @@ def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|No if max_sizes is not None and len(limited) > len(max_sizes): raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}") # try to split up dims: (a,) -> (b, c) if limited == dims: limited = _split_dims(dims, max_sizes) if max_sizes is not None else dims - ret = raw_idxs = [UOp(Ops.SPECIAL, dtypes.int, (), (f"{prefix}{i}", s)) for i,s in enumerate(limited)] + ret = raw_idxs = [UOp(Ops.SPECIAL, dtypes.index, (sint_to_uop(s),), (f"{prefix}{i}")) for i,s in enumerate(limited)] if len(limited) < len(dims): ret = [] if (contraction:=get_contraction(dims, limited)) is None: raise AssertionError(f"get_contraction should not be None {dims=} {limited=}") @@ -56,17 +56,17 @@ def add_gpudims(ctx:Renderer, s:UOp): if any(x.op is Ops.SPECIAL for x in s_topo): return None # get ranges - all_ranges = {x.arg[0]%1000:x for x in s_topo if x.op is Ops.RANGE} + all_ranges = {x.arg[0:-1]:x for x in s_topo if x.op is Ops.RANGE} # extract global/local dims - global_dims = sorted(dedup([x.arg[0]%1000 for x in all_ranges.values() if x.arg[1] is AxisType.GLOBAL])) - local_dims = sorted(dedup([x.arg[0]%1000 for x in all_ranges.values() if x.arg[1] in (AxisType.LOCAL, AxisType.GROUP_REDUCE)])) + global_dims = sorted(dedup([x.arg[0:-1] for x in all_ranges.values() if x.arg[-1] in (AxisType.GLOBAL, AxisType.THREAD)])) + local_dims = sorted(dedup([x.arg[0:-1] for x in all_ranges.values() if x.arg[-1] in (AxisType.WARP, AxisType.LOCAL, AxisType.GROUP_REDUCE)])) if not global_dims and not local_dims: return None # get global and local shape ranges = [all_ranges[r] for r in global_dims+local_dims if r in all_ranges] - global_shape = tuple([ssimplify(r.src[0]) for r in ranges if r.arg[0]%1000 in global_dims]) - local_shape = tuple([ssimplify(r.src[0]) for r in ranges if r.arg[0]%1000 in local_dims]) + global_shape = tuple([ssimplify(r.src[0]) for r in ranges if r.arg[0:-1] in global_dims]) + local_shape = tuple([ssimplify(r.src[0]) for r in ranges if r.arg[0:-1] in local_dims]) # get the idxs ki: KernelInfo = s.arg @@ -82,35 +82,13 @@ def add_gpudims(ctx:Renderer, s:UOp): for r in s_topo: if r.op is not Ops.RANGE: continue try: - ii = (global_dims+local_dims).index(r.arg[0]%1000) - if r.arg[0] < 2000 and r.arg[1] == AxisType.GROUP_REDUCE: continue + ii = (global_dims+local_dims).index(r.arg[0:-1]) + if r.arg[1] == AxisType.REDUCE: continue subs[r] = idxs[ii] except ValueError: continue return s.substitute(subs) -def fix_reduce_unroll(x:UOp): - reduce_range, reduce_expand = partition(x.src[1:], lambda y: y.op is Ops.RANGE) - if len(reduce_expand) == 0: return None - reduce_expand = [x for x in reduce_expand if x.op is not Ops.CONST] - assert all(x.op is Ops.UNROLL for x in reduce_expand), f"not all UNROLLS in {reduce_expand}" - ret = x.src[0] - if len(contract_axis:=flatten(x.arg for x in reduce_expand)): - ret = UOp(Ops.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis), tag=1) - # REDUCE supports both "horizontal" reduction and range reduction. the horizontal elements are taken in the nearest group - return x.replace(src=(ret,)+tuple(reduce_range)) - -def fix_store_unroll(x:UOp): - store_expand, store_range = partition(x.src[2:], lambda y: y.op is Ops.UNROLL) - if len(store_expand) == 0: return None - return UOp(Ops.CONTRACT, dtypes.void, (x.replace(src=x.src[:2]+tuple(store_range)),), tuple(flatten(x.arg for x in store_expand)), tag=1) - pm_add_gpudims = PatternMatcher([ + # add gpudims must be last (UPat(Ops.SINK, name="s"), add_gpudims), - # rewrite UPCAST/UNROLL range to something to be expanded - (UPat(Ops.RANGE, name="r"), - lambda r: UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(s:=r.vmax+1), tuple(range(s))),), ((r.arg[0],s),)) \ - if r.arg[1] in {AxisType.UNROLL, AxisType.UPCAST} else None), - # fix REDUCEs with UNROLLs - (UPat(Ops.REDUCE, name="x"), fix_reduce_unroll), - (UPat(Ops.STORE, name="x"), fix_store_unroll), ]) diff --git a/tinygrad_repo/tinygrad/codegen/late/devectorizer.py b/tinygrad_repo/tinygrad/codegen/late/devectorizer.py index 418301e7..50133226 100644 --- a/tinygrad_repo/tinygrad/codegen/late/devectorizer.py +++ b/tinygrad_repo/tinygrad/codegen/late/devectorizer.py @@ -2,16 +2,16 @@ from typing import Any, cast import functools, operator, itertools from collections import defaultdict from dataclasses import dataclass -from tinygrad.dtype import dtypes, ImageDType, PtrDType, DType, AddrSpace +from tinygrad.dtype import dtypes, ImageDType, DType, AddrSpace, Invalid from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, graph_rewrite, GroupOp, identity_element -from tinygrad.uop.symbolic import split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat -from tinygrad.helpers import getenv, flatten, AMX, prod, partition +from tinygrad.uop.symbolic import uop_given_valid, parse_valid, sym, symbolic_flat, invalid_gate +from tinygrad.helpers import getenv, flatten, AMX, prod from tinygrad.renderer import Renderer # ***** image load valid simplification ***** def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None: - if (idx:=uop_given_valid(valid, start_idx)) is None: return buf.const_like(0) + if (idx:=uop_given_valid(valid, start_idx)) is None: return buf.index(UOp.invalid()) if not isinstance(buf.dtype, ImageDType): return None if idx is start_idx else buf.index(idx, valid) # wait for it to be image indexed before running simplification @@ -19,13 +19,13 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None: # can drop valid if idx is out of bound when valid is False drop_stmt = [] - for stmt in split_uop(valid, Ops.AND): + for stmt in valid.split_uop(Ops.AND): try: X, is_upper_bound, c = parse_valid(stmt) except ValueError: return None # for X0 + X1 + ... >= 1, check if it's out of bound when Xi = 0 for all i - if not is_upper_bound and c == 1 and all(u.op in GroupOp.Irreducible and u.vmin == 0 for u in split_uop(X, Ops.ADD)): - testidx = functools.reduce(lambda nowidx,u: nowidx.substitute({u:u.const_like(0)}), split_uop(X, Ops.ADD), idx) + if not is_upper_bound and c == 1 and all(u.op in GroupOp.Irreducible and u.vmin == 0 for u in X.split_uop(Ops.ADD)): + testidx = functools.reduce(lambda nowidx,u: nowidx.substitute({u:u.const_like(0)}), X.split_uop(Ops.ADD), idx) testidx = testidx.simplify() if testidx.gep(0).vmax < 0 or testidx.gep(1).vmax < 0: drop_stmt.append(stmt) @@ -42,7 +42,7 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None: break if not drop_stmt and idx is start_idx: return None - new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in split_uop(valid, Ops.AND) if s not in drop_stmt]) else None + new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in valid.split_uop(Ops.AND) if s not in drop_stmt]) else None return buf.index(idx, new_valid) def delete_redundant_gates(store:UOp, buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:UOp|None=None) -> UOp|None: @@ -51,12 +51,15 @@ def delete_redundant_gates(store:UOp, buf:UOp, idx:UOp, val:UOp, store_gate:UOp, return UOp.store(buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx), val, *store.src[2:]) load_store_indexing = PatternMatcher([ - # simplify valid - (UPat(Ops.AND, name="valid"), simplify_valid), # image load valid idx simplification (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load), - # index True is just Index - (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat(Ops.CONST, arg=True))), lambda buf,start_idx: buf.index(start_idx)), + # lower turn the invalid into a gate, must come before index dtype lowering + (UPat(Ops.INDEX, src=(UPat.var("buf"), invalid_gate,),), lambda buf,x,cond,i: buf.index(x, cond)), + # drop true gate + (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("x"), UPat.const(dtypes.bool, True)),), lambda buf,x: buf.index(x)), + # remove hanging cast + (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.int).cast()),), lambda buf,idx: buf.index(idx)), + (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.int).cast(), UPat.var("valid"))), lambda buf,idx,valid: buf.index(idx, valid)), # delete_redundant_gates (after expand) (UPat(Ops.STORE, src=(UPat.any(stidx:=UPat.var("buf").index(UPat.var("idx"), UPat.var("store_gate")), stidx.cast().named("cast")), UPat.var("val")), name="store", allow_any_len=True), delete_redundant_gates), @@ -75,14 +78,12 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None): idx: Any = midx.src[i].src[1] if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg + elif idx.op is Ops.CONST and idx.arg is Invalid: root_src, arg = "INVALID", 0 elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg else: root_src, arg = idx, 0 if len(midx.src[i].src) == 3: root_src = (midx.src[i].src[2], root_src) offsets_rootsrc[root_src].setdefault(arg, []).append(i) - # the buf.dtype is always a pointer - ptrdtype = cast(PtrDType, buf.dtype) - # then rewrite everything we can into groups ret = [] idxs: list[int|None] = [None]*vec.dtype.count @@ -92,7 +93,7 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None): for grp in grouped_offsets: # get the index offset for this element. using [0] is okay, because they are the same lidx = midx.src[offsets[grp[0]][0]] - if len(grp) > 1: lidx = lidx.cast(ptrdtype.base.vec(len(grp)).ptr(size=ptrdtype.size, addrspace=ptrdtype.addrspace)) + if len(grp) > 1: lidx = lidx.cast(buf.ptrdtype.base.vec(len(grp)).ptr(size=buf.ptrdtype.size, addrspace=buf.ptrdtype.addrspace)) # set the idxs of the output for i,g in enumerate(grp): for oo in offsets[g]: idxs[oo] = global_offset+i @@ -101,7 +102,7 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None): global_offset += len(grp) assert None not in idxs, f"some idxs are missing {idxs}" # this base thing is for image, we want the CAT to be a normal pointer - post_cat = UOp(Ops.PTRCAT, ptrdtype.base.ptr(size=ptrdtype.size, addrspace=ptrdtype.addrspace).vec(vec.dtype.count), tuple(ret)) + post_cat = UOp(Ops.PTRCAT, buf.ptrdtype.base.ptr(size=buf.ptrdtype.size, addrspace=buf.ptrdtype.addrspace).vec(vec.dtype.count), tuple(ret)) return post_cat.gep(tuple(cast(list[int], idxs))) def cat_after_store(cat:UOp, data:UOp, sto:UOp): @@ -154,7 +155,7 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp): must_divide = False elif buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType): pass - elif cast(PtrDType, buf.dtype).addrspace == AddrSpace.REG: + elif buf.ptrdtype.addrspace == AddrSpace.REG: pass elif isinstance(buf.dtype, ImageDType): lengths = [4] @@ -169,13 +170,12 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp): # split based on the fold lengths global_offset = 0 ret = [] - ptrdtype = cast(PtrDType, buf.dtype) while global_offset < sz: # with 1 at the end of the lengths list, this will always hit for fold_length in lengths: if global_offset+fold_length > sz: continue lidx = buf.index(idx.src[1] + global_offset, idx.src[2] if len(idx.src) > 2 else None) - if fold_length > 1: lidx = lidx.cast(ptrdtype.base.vec(fold_length).ptr(size=ptrdtype.size, addrspace=ptrdtype.addrspace)) + if fold_length > 1: lidx = lidx.cast(buf.ptrdtype.base.vec(fold_length).ptr(size=buf.ptrdtype.size, addrspace=buf.ptrdtype.addrspace)) if ls.op is Ops.STORE: ret.append(ls.replace(src=(lidx,ls.src[1].gep(tuple(range(global_offset, global_offset+fold_length))))+ls.src[2:])) else: ret.append(ls.replace(src=(lidx,)+ls.src[1:], dtype=ls.dtype.scalar().vec(fold_length))) global_offset += fold_length @@ -232,17 +232,20 @@ def no_vectorized_alu(alu:UOp): alus = tuple(UOp(alu.op, alu.dtype.scalar(), tuple(s.gep(i) for s in alu.src), alu.arg) for i in range(alu.dtype.vcount)) return UOp(Ops.VECTORIZE, alu.dtype, alus) -def no_vectorized_acc(acc:UOp, c:UOp): - if acc.dtype.count == 1: return None - assert c.arg == 0, "this only supports index 0" - new_acc = acc.replace(dtype=acc.dtype.base.scalar().ptr(acc.dtype.count, cast(PtrDType, acc.dtype).addrspace)) - return UOp(Ops.PTRCAT, acc.dtype, tuple([new_acc.index(UOp.const(dtypes.int, i)) for i in range(acc.dtype.count)])) +def no_vectorized_buf(buf:UOp): + return buf.replace(dtype=buf.ptrdtype.base.scalar().ptr(buf.ptrdtype.size*buf.ptrdtype.count, buf.ptrdtype.addrspace)).cast(buf.dtype) + +def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp): + cnt = cast.dtype.count + assert idx.dtype.count == 1, f"idx dtype must be 1 {idx.dtype}" + return buf.broadcast(cnt).index(idx.broadcast(cnt)*cnt+UOp.const(dtypes.int.vec(cnt), tuple(range(cnt)))) devectorize = PatternMatcher([ # no ALU on vectorized dtypes (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name="alu"), no_vectorized_alu), (UPat(Ops.WMMA, name="wmma"), no_vectorized_wmma), - (UPat(Ops.DEFINE_REG, name="acc").index(UPat.cvar("c")), no_vectorized_acc), + (UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), no_vectorized_buf), + (UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf").cast(name="cast").index(UPat.var("idx")), no_vectorized_index), ]) pm_render = PatternMatcher([ @@ -255,7 +258,8 @@ pm_render = PatternMatcher([ (UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x), # give any loads that are masked an alt value (UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat())).or_casted(),), allow_any_len=True, name="x"), - lambda x: x.replace(src=(x.src[0], x.const_like(0))+x.src[1:]) if len(x.src) == 1 or x.src[1].op is Ops.CUSTOM else None), + lambda x: x.replace(src=(x.src[0], x.const_like(0))+x.src[1:]) + if len(x.src) == 1 or x.src[1].op in (Ops.CUSTOM, Ops.STORE, Ops.BARRIER) else None), # gate any stores that aren't gated with ifs (UPat(Ops.STORE, src=(UPat(src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="idx").or_casted(), UPat()), name="store", allow_any_len=True), lambda store,idx: UOp(Ops.STORE, dtype=store.dtype, src=store.src[:2]+(UOp(Ops.IF, src=(idx.src[2],)),)+store.src[2:]) if \ @@ -293,94 +297,7 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp): ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst) return acc.load(acc.store(ret, *reduce_range)) if len(reduce_range) != 0 else ret -def no_vectorized_reduce(inp:UOp, red:UOp): - if inp.dtype != red.dtype: - red = red.replace(src=(functools.reduce(lambda x,y: x.alu(red.arg, y), horizontal_reduce(inp, red.dtype)),)+red.src[1:]) - if red.dtype.vcount == 1: return red - # no_vectorize_alu ignoring ranges - if red.dtype.vcount == 1: return None - alus = tuple(UOp(red.op, red.dtype.scalar(), (red.src[0].gep(i),)+red.src[1:], red.arg) for i in range(red.dtype.vcount)) - return UOp(Ops.VECTORIZE, red.dtype, alus) - -def reduce_rangeless(red:UOp): - # TODO: share code with reduce_unparented - if red.arg not in {Ops.ADD, Ops.MAX}: return None - if red.src[0].dtype != red.dtype: return None - if any(x.op in {Ops.RANGE} for x in red.src[0].toposort()): return None - ret = red.src[0] - if red.arg is Ops.ADD: - for r in red.src[1:]: - ret = ret * r.src[0].cast(ret.dtype.scalar()).broadcast(ret.dtype.count) - return ret - -def no_range(u:UOp) -> bool: return not any(x.op is Ops.RANGE for x in u.sparents) - -pm_reduce_collapse = PatternMatcher([ - # lift x+y out of reduce on lt - ((UPat.var("x")+UPat.var("y")) < UPat.var("c"), lambda x,y,c: (x < (c-y)) if no_range(y) and no_range(c) else None), - # lift x*y out of reduce - ((UPat.var("x")*UPat.var("y")) < UPat.var("c"), - lambda x,y,c: (x < ((c+y-1) // y)) if no_range(y) and no_range(c) and y.vmin > 0 else None), - # lift x+y out of reduce on ne - ((UPat.var("x")+UPat.var("y")) != UPat.var("c"), lambda x,y,c: (x != (c-y)) if no_range(y) and no_range(c) else None), - # fold the range - ((UPat(Ops.RANGE, name="r") < UPat.var("cut")).where(0, UPat.cvar("val")).reduce(arg=Ops.ADD, allow_any_len=True), - lambda r,cut,val: (r.src[0]-cut).maximum(0).minimum(r.src[0]).cast(val.dtype) * val), - ((UPat(Ops.RANGE, name="r") < UPat.var("cut")).where(UPat.cvar("val"), 0).reduce(arg=Ops.ADD, allow_any_len=True), - lambda r,cut,val: cut.maximum(0).minimum(r.src[0]).cast(val.dtype) * val), - # REDUCE on ADD - ((UPat.var("x")+UPat.var("y")).reduce(arg=Ops.ADD, allow_any_len=True, name="r"), - lambda x,y,r: x.reduce(*r.src[1:], arg=Ops.ADD) + y.reduce(*r.src[1:],arg=Ops.ADD)), - # MUL casted bool - ((UPat.var("x") * UPat.var("gate", dtype=dtypes.bool).cast().or_broadcasted(name="b")), - lambda x,gate,b=None: gate.broadcast(x.dtype.count).where(x, 0) if b is not None else gate.where(x, 0)), - # WHERE on LOAD (works on max too) - (UPat.var("gate").where(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load(), 0).reduce(arg=Ops.ADD, allow_any_len=True), - lambda buf,idx,gate: buf.index(idx, gate).load()), - (UPat.var("gate").where(0, UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load()).reduce(arg=Ops.ADD, allow_any_len=True), - lambda buf,idx,gate: buf.index(idx, gate.logical_not()).load()), - # INDEX on RANGE / gated RANGE - (UPat.var("buf").index(UPat.var("expr"), UPat.var("idx").eq(UPat(Ops.RANGE, name="r").or_casted())), - lambda buf,r,idx,expr: buf.index(expr.substitute({r:idx.cast(r.dtype)}), (idx.cast(r.dtype) >= 0) & (idx.cast(r.dtype) < r.src[0]))), - # AND on WHERE - ((UPat.any(UPat(Ops.DEFINE_VAR, name="x"), UPat(Ops.DEFINE_VAR).gep(name="x")) & UPat.var("y")) \ - .where(UPat.cvar("c"), 0).reduce(arg=Ops.ADD, allow_any_len=True, name="r"), - lambda x,y,c,r: y.where(c, 0).reduce(*r.src[1:], arg=Ops.ADD)*x.cast(c.dtype)), - # remove REDUCEs that no longer have a RANGE in the src - (UPat(Ops.REDUCE, name="red"), reduce_rangeless), - # devectorize REDUCE - (UPat(Ops.VECTORIZE, name="inp").reduce(name="red", allow_any_len=True), no_vectorized_reduce), - # index/load/where. TODO: this is more aggressive than needed - (UPat((Ops.INDEX, Ops.LOAD, Ops.WHERE), name="alu"), no_vectorized_alu), -])+sym - -def reduce_collapse(red:UOp): - included, not_included = partition(red.parents, lambda x: any(y in x.sparents for y in red.src[1:])) - if any(x.op in {Ops.STORE, Ops.REDUCE} for x in included): return None - replaces: dict[UOp, UOp] = {} - for u in included: - for s in u.src: - if s in not_included and s not in replaces and s.op not in {Ops.CONST, Ops.VCONST, Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR}: - replaces[s] = UOp(Ops.DEFINE_VAR, dtype=s.dtype, arg=(f'in{len(replaces)}', s.vmin, s.vmax)) - collapse_fxn = red.substitute(replaces) - sink = graph_rewrite(collapse_fxn, pm_reduce_collapse, name="reduce_collapse") - if any(x.op is Ops.RANGE for x in sink.toposort()): return None - return sink.substitute({v:k for k,v in replaces.items()}) - -def reduce_unparented(red:UOp): - if red.arg not in {Ops.ADD, Ops.MAX}: return None - reduce_parented, reduce_unparented = partition(red.src[1:], lambda x: x in red.src[0].sparents) - if len(reduce_unparented) == 0: return None - ret = red.replace(src=(red.src[0],)+tuple(reduce_parented)) if len(reduce_parented) or red.dtype != red.src[0].dtype else red.src[0] - if red.arg is Ops.ADD: - for r in reduce_unparented: ret = ret * r.src[0].cast(ret.dtype.scalar()).broadcast(ret.dtype.count) - return ret - pm_reduce = PatternMatcher([ - # remove any ranges from a REDUCE that aren't referenced in the reduce source - (UPat(Ops.REDUCE, name="red"), reduce_unparented), - # remove REDUCE without loads (generic arange opt / indexing). TODO: support multi range - (UPat(Ops.REDUCE, src=(UPat(), UPat()), name="red"), reduce_collapse), # REDUCE -> DEFINE_ACC+ASSIGN (UPat(Ops.REDUCE, name="red"), reduce_to_acc), # tensor core built in accumulate diff --git a/tinygrad_repo/tinygrad/codegen/late/expander.py b/tinygrad_repo/tinygrad/codegen/late/expander.py index 2d036add..d2d0a411 100644 --- a/tinygrad_repo/tinygrad/codegen/late/expander.py +++ b/tinygrad_repo/tinygrad/codegen/late/expander.py @@ -1,9 +1,9 @@ # this converts a lowerer program into a vectorized program - import functools, itertools, operator -from tinygrad.dtype import dtypes -from tinygrad.helpers import AMX, dedup, flatten, all_same, prod -from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, GroupOp +from tinygrad.dtype import dtypes, PtrDType, AddrSpace +from tinygrad.helpers import AMX, dedup, flatten, all_same, prod, partition +from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, GroupOp, AxisType, range_start +from tinygrad.schedule.rangeify import BufferizeOpts def _expand_arg_to_idx(args:tuple[tuple[int, int], ...], rpk:dict[int, int]) -> int: idx, mul = 0, 1 @@ -50,9 +50,11 @@ def do_expand(root:UOp): if root.op is Ops.IF or src.op is Ops.IF: # for the first arg of IF, just pass them through ignoring UNROLLS new_srcs.append(src) - elif (root.op is Ops.STORE and i >= 2) or (root.op is Ops.REDUCE and i >= 1): + elif root.op in range_start and i >= range_start[root.op]: # for any range args of STORE/REDUCE, pass them through new_srcs.append(src) + elif root.op is Ops.INDEX and i >= 1 and not isinstance(root.dtype, PtrDType): + new_srcs.append(src) elif src.dtype.count > 1: # put any input dtype > 1 grouped together new_srcs.append(UOp(Ops.CAT, src.dtype.scalar().vec(expand_sz*src.dtype.count), (src,)*expand_sz)) @@ -84,7 +86,7 @@ expander = PatternMatcher([ (UPat(Ops.UNROLL, name="outer", src=(UPat(Ops.UNROLL, name="inner"),)), lambda outer, inner: UOp(Ops.UNROLL, outer.dtype, (inner.src[0],), inner.arg+outer.arg)), # do expansion - (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, + (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.BUFFERIZE, Ops.VECTORIZE, Ops.IF, Ops.REDUCE), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand), (UPat(Ops.CONTRACT, name="con"), do_contract), # BARRIERs aren't actually expanded @@ -112,3 +114,49 @@ migrate_indexing = PatternMatcher([ # create gate MUST BE BEFORE expander (UPat(Ops.STORE, name="root"), create_gate), ]) + +# **** + +def fix_reduce_unroll(x:UOp): + reduce_range, reduce_expand = partition(x.src[1:], lambda y: y.op is Ops.RANGE) + if len(reduce_expand) == 0: return None + reduce_expand = [x for x in reduce_expand if x.op is not Ops.CONST] + assert all(x.op is Ops.UNROLL for x in reduce_expand), f"not all UNROLLS in {reduce_expand}" + ret = x.src[0] + if len(contract_axis:=flatten(x.arg for x in reduce_expand)): + ret = UOp(Ops.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis), tag=1) + # REDUCE supports both "horizontal" reduction and range reduction. the horizontal elements are taken in the nearest group + return x.replace(src=(ret,)+tuple(reduce_range)) + +def fix_store_unroll(x:UOp): + store_expand, store_range = partition(x.src[2:], lambda y: y.op is Ops.UNROLL) + if len(store_expand) == 0: return None + return UOp(Ops.CONTRACT, dtypes.void, (x.replace(src=x.src[:2]+tuple(store_range)),), tuple(flatten(x.arg for x in store_expand)), tag=1) + +def fix_group_for_reduce(x:UOp): + reduce_gfr, reduce_r = partition(x.src[1:], lambda u: u.op is Ops.RANGE and u.arg[1] == AxisType.GROUP_REDUCE) + if len(reduce_gfr) == 0: return None + + # NOTE: if there's other locals here, we need them in the buffer too + upstream_locals = [u for u in x.toposort() if u.op is Ops.RANGE and u.arg[1] == AxisType.LOCAL] + + # do only the non grouped reduces early + ret = x.replace(src=(x.src[0],)+tuple(reduce_r)) + reduce_loop = [x.replace(arg=(x.arg[0]+100, AxisType.REDUCE)) for x in reduce_gfr] + buf = ret.bufferize(*upstream_locals, *reduce_gfr, arg=BufferizeOpts(reduce_gfr[0].arg[0], AddrSpace.LOCAL)).index(*upstream_locals, *reduce_loop) + + # gate with an if on the store + do the final reduce + buf = UOp(Ops.IF, dtype=buf.dtype, src=(functools.reduce(operator.and_, [x.eq(0) for x in reduce_gfr]), buf)) + return buf.reduce(*reduce_loop, arg=x.arg) + +pm_pre_expander = PatternMatcher([ + # rewrite UPCAST/UNROLL range to something to be expanded + (UPat(Ops.RANGE, name="r"), + lambda r: UOp(Ops.UNROLL, r.dtype, (UOp.const(r.dtype.vec(s:=r.vmax+1), tuple(range(s))),), ((r.arg[0],s),)) \ + if r.arg[1] in {AxisType.UNROLL, AxisType.UPCAST} else None), + # fix REDUCEs with UNROLLs + (UPat(Ops.REDUCE, name="x"), fix_reduce_unroll), + (UPat(Ops.STORE, name="x"), fix_store_unroll), + # fix group for reduce + (UPat(Ops.REDUCE, name="x"), fix_group_for_reduce), +]) diff --git a/tinygrad_repo/tinygrad/codegen/late/linearize.py b/tinygrad_repo/tinygrad/codegen/late/linearize.py index 54ff2cbf..d860125a 100644 --- a/tinygrad_repo/tinygrad/codegen/late/linearize.py +++ b/tinygrad_repo/tinygrad/codegen/late/linearize.py @@ -2,7 +2,7 @@ from __future__ import annotations import heapq from collections import defaultdict from dataclasses import dataclass, replace -from tinygrad.uop.ops import UOp, Ops, PatternMatcher, UPat, GroupOp +from tinygrad.uop.ops import UOp, Ops, PatternMatcher, UPat, GroupOp, BottomUpGate from tinygrad.helpers import dedup, all_same, flatten, BLOCK_REORDER # NOTE: any toposort should be valid here, unlike last time this isn't required, it's just for speed @@ -76,12 +76,13 @@ class BlockContext: def from_sink(sink:UOp) -> BlockContext: # get children and all block contexts ctx = BlockContext({}, {}, {}) - for u in sink.toposort(): + for u in sink.toposort(gate=lambda u:u.op is not Ops.SPECIAL): this_block_ctx: list[UOp] = [] ctx.child_count[u] = 0 # get children and accumulate the last_ctx for s in u.src: + if s.op is Ops.SPECIAL: continue # NOTE: if a parent appears multiple times in the src, it counts multiple times as a child ctx.child_count[s] += 1 this_block_ctx += ctx.last_ctx(s) @@ -142,7 +143,7 @@ def make_block_bottom_up(ctx:BlockContext, x:UOp): # add unmergables to sources srcs = [] - for u,cnt in unmergable.items(): srcs += [add_blockends(u, ctx.block_ctxs[u], current_ctx, cnt=cnt)]*cnt + for u,cnt in unmergable.items(): srcs += [add_blockends(u, ctx.block_ctxs.get(u,()), current_ctx, cnt=cnt)]*cnt # add blockseeds, with blockends as needed for (new_ctx, new_child_ctx), v in blockseeds.items(): @@ -154,8 +155,12 @@ def make_block_bottom_up(ctx:BlockContext, x:UOp): bb = BasicBlock(tuple(lst), ctx=current_ctx, cnt=child_count, child_ctx=child_ctx) return UOp(Ops.BLOCK, src=tuple(srcs), arg=bb) +# we prevent the source of the SPECIAL from being linearized since its not part of the kernel +def raise_bottom_up_gate(): raise BottomUpGate() + block_create = PatternMatcher([ (UPat(GroupOp.All-DONT_PLACE_IN_BLOCK.union({Ops.BLOCK, Ops.BLOCKEND}), name="x"), make_block_bottom_up), + (UPat(Ops.SPECIAL), raise_bottom_up_gate) ]) # ***** blockend merging **** @@ -217,6 +222,8 @@ def remove_blockend(x:UOp): if late_ops[i].op is Ops.BARRIER and late_ops[i+1].op is Ops.BARRIER: late_ops[i+1] = UOp(Ops.NOOP) arg = BasicBlock(parent_block.arg.lst+tuple(late_ops), tuple([y for y in x.arg.ctx if y is not x.arg.end]), cnt=x.arg.cnt) return UOp(Ops.BLOCK, src=tuple(y for y in x.src if y is not parent_block)+parent_block.src, arg=arg) + # else the whole context ended by the blockend is already in this block and we can safely turn it into a block + return UOp(Ops.BLOCK, src=x.src, arg=BasicBlock(x.arg.lst, tuple([y for y in x.arg.ctx if y is not x.arg.end]), cnt=x.arg.cnt)) block_merge = PatternMatcher([ (UPat((Ops.BLOCK, Ops.BLOCKEND), name="x"), merge_block), diff --git a/tinygrad_repo/tinygrad/codegen/lowerer.py b/tinygrad_repo/tinygrad/codegen/lowerer.py index 1da53e98..236aff36 100644 --- a/tinygrad_repo/tinygrad/codegen/lowerer.py +++ b/tinygrad_repo/tinygrad/codegen/lowerer.py @@ -1,9 +1,6 @@ # the job of the lowerer is to do indexing -import functools, operator -from typing import cast from dataclasses import dataclass -from tinygrad.dtype import dtypes, AddrSpace, PtrDType -from tinygrad.uop.ops import KernelInfo, UOp, Ops, PatternMatcher, UPat, sint_to_uop, AxisType, graph_rewrite +from tinygrad.uop.ops import KernelInfo, UOp, Ops, PatternMatcher, UPat, sint_to_uop, AxisType, graph_rewrite, resolve # ***** indexing ***** @@ -14,12 +11,12 @@ class IndexContext: start: int = 0 def shape_to_idx(s, axis_types, start=0): - return [UOp.range(dtypes.int, sint_to_uop(s), start+i, axistype=at) for i, (s, at) in enumerate(zip(s, axis_types))] + return [UOp.range(sint_to_uop(s), start+i, at) for i, (s, at) in enumerate(zip(s, axis_types))] def get_index(ast:UOp) -> IndexContext: axis_types = ast.arg.axis_types if isinstance(ast.arg, KernelInfo) else () - if len(ast.full_shape) != len(axis_types): - axis_types = tuple([AxisType.REDUCE if s is not fs else AxisType.LOOP for s,fs in zip(ast.shape, ast.full_shape)]) + if len(ast.full_shape) != len(axis_types) and ast.st is not None: + axis_types = tuple([AxisType.REDUCE if resolve(s != fs) else AxisType.LOOP for s,fs in zip(ast.shape, ast.full_shape)]) return IndexContext(axis_types, [], 0) # ***** lowering (given index) ***** @@ -41,8 +38,8 @@ def lower_store(ctx: IndexContext, x: UOp, buf: UOp): #assert x.src[1].shape == x.src[0].shape, f"shape mismatch on store {x.src[1].shape} != {x.src[0].shape}" new_idxs = shape_to_idx(x.src[0].shape, ctx.axis_types, ctx.start) - idx, valid = x.st_arg.to_indexed_uops(new_idxs) - used_idxs = [x for x in UOp.sink(idx, valid).toposort() if x in new_idxs] + idx = x.st_arg.to_valid_uop(new_idxs) + used_idxs = [x for x in idx.toposort() if x in new_idxs] real_new_idxs = [] for i in range(len(x.src[0].shape)): if new_idxs[i] in used_idxs or len(ctx.idxs) <= i: real_new_idxs.append(new_idxs[i]) @@ -50,15 +47,7 @@ def lower_store(ctx: IndexContext, x: UOp, buf: UOp): stored = subblock(ctx, real_new_idxs, x.src[1]) used_ranges = [x for x in used_idxs if x.op is Ops.RANGE] - ret = buf.index(idx, valid).store(stored, *used_ranges) - - # insert BARRIER if we are ending a LOCAL, IF if we are ending a GROUP_REDUCE - if cast(PtrDType, buf.dtype).addrspace == AddrSpace.LOCAL and \ - any(ctx.axis_types[x.arg[0]%1000] in {AxisType.GROUP_REDUCE, AxisType.LOCAL} for x in used_ranges): - ret = ret.barrier() - range_gates = [x.eq(0) for x in used_ranges if ctx.axis_types[x.arg[0]%1000] == AxisType.GROUP_REDUCE] - if len(range_gates): ret = UOp(Ops.IF, src=(functools.reduce(operator.and_, range_gates), ret)) - return ret + return buf.index(idx).store(stored, *used_ranges) def fixup_wmma(ctx:IndexContext, x:UOp): if x.tag is not None: return None @@ -82,9 +71,9 @@ pm_lowerer = PatternMatcher([ # consts and loads (UPat(Ops.VIEW, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"),), name="view"), - lambda ctx,view,c: c if all(x.mask is None for x in view.arg.views) else view.arg.to_indexed_uops(ctx.idxs)[1].where(c, c.const_like(0))), + lambda ctx,view,c: c if all(x.mask is None for x in view.arg.views) else view.arg.to_valid_uop(ctx.idxs).get_valid().where(c, c.const_like(0))), (UPat(Ops.LOAD, src=(UPat.var("buf").view(),), allow_any_len=True, name="x"), - lambda ctx,buf,x: UOp(Ops.LOAD, x.dtype, (buf.index(*x.st_arg.to_indexed_uops(ctx.idxs)),)+x.src[1:])), + lambda ctx,buf,x: UOp(Ops.LOAD, x.dtype, (buf.index(x.st_arg.to_valid_uop(ctx.idxs)),)+x.src[1:])), # reduce/view_const (UPat(Ops.REDUCE_AXIS, name="x"), lower_reduce_axis), diff --git a/tinygrad_repo/tinygrad/codegen/opt/__init__.py b/tinygrad_repo/tinygrad/codegen/opt/__init__.py index 0c507c56..a47d785b 100644 --- a/tinygrad_repo/tinygrad/codegen/opt/__init__.py +++ b/tinygrad_repo/tinygrad/codegen/opt/__init__.py @@ -1,46 +1,26 @@ # opt opinionatedly transforms an ast into an optimized ast using either heuristics or beam search +from __future__ import annotations +from enum import Enum, auto +from dataclasses import dataclass +from tinygrad.uop.ops import AxisType -from tinygrad.codegen.opt.kernel import Kernel -from tinygrad.codegen.opt.heuristic import hand_coded_optimizations -from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, KernelInfo -from tinygrad.helpers import NOOPT, BEAM, USE_TC, getenv -from tinygrad.renderer import Renderer -from tinygrad.uop.spec import type_verify +class OptOps(Enum): + TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto(); THREAD = auto() # noqa: E702 + GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); SWAP = auto() # noqa: E702 + def __lt__(self, x:OptOps): return self.value < x.value -def get_optimized_ast(ast:UOp, renderer:Renderer) -> UOp: - """ - Optimize an AST based on heuristics or BEAM search. +@dataclass(frozen=True, order=True) +class Opt: + op: OptOps + axis: int|None = None + arg: int|tuple|None = None + def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, arg={self.arg})" - Args: - ast: The Ops.SINK rooted AST - renderer: The renderer used to generate the code +axis_letters = {AxisType.GLOBAL: "g", AxisType.THREAD: "t", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u", + AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r"} +axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN", AxisType.LOOP: "WHITE", + AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "green", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta"} - Returns: - The Ops.SINK rooted AST transformed to apply the opts and with a KernelInfo in the arg. - """ - - assert ast.arg is None, "no opt if there's an arg" - k = Kernel(ast, opts=renderer) - if not NOOPT: - if not k.apply_tensor_cores(USE_TC.value): k.apply_opts(hand_coded_optimizations(k)) - if BEAM >= 1: - from tinygrad.codegen.opt.search import beam_search, bufs_from_lin - kb = Kernel(ast, opts=renderer) - rawbufs = bufs_from_lin(kb, allocate=False) - k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1))) - return ast.replace(arg=KernelInfo(opts_to_apply=tuple(k.applied_opts))) - -pm_get_optimization = PatternMatcher([ - (UPat(Ops.SINK, name="ast"), lambda ctx,ast: get_optimized_ast(ast, ctx) if ast.arg is None and ast.src[0].st is not None else None), -]) - -def apply_opt(ast:UOp, renderer:Renderer): - k = Kernel(ast, opts=renderer) - k.apply_opts(ast.arg.opts_to_apply) - ret = k.get_optimized_ast() - if __debug__: type_verify(list(ret.toposort())) - return ret - -pm_do_optimize = PatternMatcher([ - (UPat(Ops.SINK, name="ast"), lambda ctx,ast: apply_opt(ast, ctx) if ast.arg is not None and ast.arg.opts_to_apply is not None else None), -]) +class KernelOptError(Exception): pass +def check(cond:bool, msg:str=""): + if not cond: raise KernelOptError(msg) diff --git a/tinygrad_repo/tinygrad/codegen/opt/heuristic.py b/tinygrad_repo/tinygrad/codegen/opt/heuristic.py index 40b3a9d3..a73234c4 100644 --- a/tinygrad_repo/tinygrad/codegen/opt/heuristic.py +++ b/tinygrad_repo/tinygrad/codegen/opt/heuristic.py @@ -1,10 +1,50 @@ import itertools -from tinygrad.codegen.opt.kernel import Kernel, Opt, OptOps, KernelOptError, AxisType -from tinygrad.helpers import getenv, DEBUG, prod, NOLOCALS +from tinygrad.codegen.opt import Opt, OptOps, KernelOptError +from tinygrad.helpers import getenv, DEBUG, prod, NOLOCALS, TC_OPT, TC_SELECT, USE_TC, AMX from tinygrad.dtype import ImageDType -from tinygrad.uop.ops import Ops, resolve +from tinygrad.uop.ops import Ops, resolve, AxisType +from tinygrad.codegen.opt.postrange import Scheduler + +def hand_coded_optimizations(k:Scheduler) -> Scheduler: + # first try the tensor cores + """ Attempts to apply a tensor core optimization to the kernel. If one exists and applies properly, return true, otherwise return false. + Tensor cores are optimized instructions that matrix multiply-accumulate across a wave of threads: D(M, N) = A(M, K) * B(K, N) + C(M, N). + + Keyword arguments: + use_tensor_cores -- controls how tensor cores are applied (default 1) + 0: will disable any tensor core matching + 1: enable tensor cores + 2: apply tensor core shape but don't use UOp.WMMA + extra_opts -- additional Opt's to apply after the tensor core instead of the hand-coded additional Opt's (default None) + tc_select -- specifies which tensor core(s) to use for optimization (default -1) + -1: iterates through all available tensor cores in order and uses the first one that matches the requirements (dims and dtypes) + [0-N]: uses only the n'th tensor core available; useful for search + tc_opt -- controls which kinds of kernels may be eligible for tensor cores application (default 2 during BEAM, 0 otherwise) + 0: applies to only kernels with a single reduce axis and direct Ops.LOAD into Ops.MUL + 1: allows kernels with multiple reduce axes and also multiplication of Ops.CAST'd buffers + 2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed + """ + # NOTE: unless TC_OPT is > 0, we only trigger tensor cores if there's only one reduce axis + if USE_TC > 0 and (len(k.axes_of(AxisType.GROUP_REDUCE, AxisType.REDUCE)) == 1 or (TC_OPT.value >= 1)): + good_tc_opt = False + try: # check TC first and apply hand-coded opts if successful + tk = k.copy() + rngs = tk.apply_opt(Opt(OptOps.TC, 0, (TC_SELECT.value, TC_OPT.value, USE_TC.value))) + good_tc_opt = True + except KernelOptError: + pass + if good_tc_opt: + # skip hand-coded TC opts if AMX, upcasting will make kernel slower + if rngs is not None and not AMX: + for tc_dim in [1,0]: # attempt to upcast M and N + szs = [sz for sz in [5,4,3,2] if rngs[tc_dim].src[0].divides(sz) is not None] + if szs: + # set it to the replaced range + rngs[tc_dim] = tk.apply_opt(Opt(OptOps.UPCAST, tk.rngs.index(rngs[tc_dim]), szs[0]))[0] + if (szs := [sz for sz in [4,2] if rngs[0].src[0].divides(sz) is not None]): # attempt to local N + tk.apply_opt(Opt(OptOps.LOCAL, tk.rngs.index(rngs[0]), szs[0])) + return tk -def hand_coded_optimizations(k:Kernel) -> list[Opt]: # make a copy so it does not mutate the input k = k.copy() @@ -13,19 +53,17 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]: if k.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \ k.reduceop is not None and k.reduceop.arg[0] is Ops.ADD and len(k.full_shape) >= 2 and k.opts.has_shared and \ (mulop:=k.reduceop.src[0]).op is Ops.MUL and mulop.src[0].op is Ops.LOAD and mulop.src[1].op is Ops.LOAD: - st0, st1 = k.sts[k.bufs.index(mulop.src[0])], k.sts[k.bufs.index(mulop.src[1])] - strides0, strides1 = st0.real_strides(), st1.real_strides() - def has_expanded_axis(shape, strides): return any(resolve(s > 1) and not resolve(st != 0) for s,st in zip(shape,strides)) - if strides0[first_reduce:=(k.axes_of(AxisType.REDUCE)[0])] == 1 and \ - not (has_expanded_axis(st0.shape, strides0) and has_expanded_axis(st1.shape, strides1)): + idx0, idx1 = mulop.src[0].src[0].src[1].get_idx(), mulop.src[1].src[0].src[1].get_idx() + first_reduce_rng = k.ranges_of(AxisType.REDUCE)[0] + if any(u is first_reduce_rng for u in idx0.split_uop(Ops.ADD)) and all(r in idx1.ranges for r in idx0.ranges): for global_idx in k.axes_of(AxisType.GLOBAL): - if k.full_shape[first_reduce]%MV_THREADS_PER_ROW == 0 and k.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0: + if first_reduce_rng.src[0].divides(MV_THREADS_PER_ROW) is not None and k.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0: if DEBUG >= 3: - print(f"MATVEC: {k.full_shape=} {first_reduce=} {strides0=} {MV_BLOCKSIZE=} {MV_THREADS_PER_ROW=} {MV_ROWS_PER_THREAD=}") + print(f"MATVEC: {k.full_shape=} {first_reduce_rng.render()} {MV_BLOCKSIZE=} {MV_THREADS_PER_ROW=} {MV_ROWS_PER_THREAD=}") if MV_THREADS_PER_ROW > 1: k.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW)) if MV_BLOCKSIZE > 1: k.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE)) if MV_ROWS_PER_THREAD > 1: k.apply_opt(Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD)) - return k.applied_opts + return k # are we grouping? (requires local shape support) if resolve(prod(k.output_shape[i] for i in k.upcastable_dims) <= 2048, False): @@ -38,14 +76,17 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]: # upcast float4 images for buf_index,buf in enumerate(k.bufs): if isinstance(buf.src[0].dtype, ImageDType): - if (unit_stride_axes_mul_4 := [i for i in k.sts[buf_index].unit_stride_axes(ignore_valid=True) if k.sts[buf_index].shape[i]%4 == 0]): + # part of real_strides + unit_stride_axes_mul_4 = [k.rngs.index(c) for c in k.bufs[buf_index].src[1].get_idx().split_uop(Ops.ADD) if + c.op is Ops.RANGE and (c.vmax+1)%4 == 0] + if len(unit_stride_axes_mul_4): if (axis:=unit_stride_axes_mul_4[0]) in k.upcastable_dims: k.apply_opt(Opt(OptOps.UPCAST, axis, 4)) elif axis in k.unrollable_dims: k.apply_opt(Opt(OptOps.UNROLL, k.unrollable_dims.index(axis), 4)) # no more opt if we are grouping - if k.group_for_reduces: return k.applied_opts + if k.group_for_reduces: return k # **** below this line need to be optional and benchmarked **** @@ -53,8 +94,9 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]: to_upcast: list[int] = [] # upcast leading axes first (hack-ish for winograd; we actually want to upcast masked axes with low stride first) for axis in k.upcastable_dims: - if k.full_shape[axis] <= 7 and any(st.axis_is_masked(axis) for st in k.sts) and \ - prod(k.full_shape[j] for j in to_upcast) * k.full_shape[axis] <= 7 * 7: + # for Schedule, we check if the range is used in INDEX gates or WHERE gates + is_masked = any(any(o is k.rngs[axis] for o in u.src[0].parents) for u in k.ast.parents if u.op is Ops.WHERE) + if k.full_shape[axis] <= 7 and is_masked and prod(k.full_shape[j] for j in to_upcast) * k.full_shape[axis] <= 7 * 7: if DEBUG >= 4: print(f"upcasting masked axis : {axis}") to_upcast.append(axis) for axis in to_upcast[::-1]: k.apply_opt(Opt(OptOps.UPCAST, axis, 0)) @@ -68,10 +110,18 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]: for axis, upcast_amount in itertools.product(k.upcastable_dims, ([128] if not len(upcasted_axis) else []) if is_dsp else [3,4]): # if we haven't upcasted it, it mods, and buffer has stride 0 on axis while having no stride 0 in the upcasted axis already if axis in upcasted_axis or k.full_shape[axis]%upcast_amount != 0: continue - if any(st.views[-1].strides[axis] == 0 and \ - all(x != 0 for t,x in zip(k.axis_types, st.real_strides()) if t in (AxisType.UPCAST, AxisType.UNROLL)) for st in k.sts): - xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in k.sts), - sum(st.views[-1].strides[axis] for st in k.sts), axis, upcast_amount)) + rng = k.rngs[axis] + if any(rng not in b.src[1].get_idx().parents and all(r2 in b.src[1].get_idx().parents + for r2 in k.ranges_of(AxisType.UPCAST, AxisType.UNROLL)) for b in k.bufs): + num_strides, sum_strides = 0, 0 + for b in k.bufs: + idx = b.src[1].get_idx() + if rng in idx.parents: num_strides += 1 + for c in idx.split_uop(Ops.ADD): + if c is rng: sum_strides += 1 + if c.op is Ops.MUL and c.src[0] is rng and c.src[1].op is Ops.CONST: sum_strides += c.src[1].arg + if c.op is Ops.MUL and c.src[1] is rng and c.src[0].op is Ops.CONST: sum_strides += c.src[0].arg + xb_choices.append((num_strides, sum_strides, axis, upcast_amount)) if xb_choices: xb_choices = sorted(xb_choices) if DEBUG >= 4: print(f"more upcast axis : {xb_choices}") @@ -109,7 +159,8 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]: k.apply_opt(Opt(OptOps.NOLOCALS)) else: # prioritize making expand axes local - local_axis_ranking = [(any(st.views[-1].strides[axis] == 0 for st in k.sts), axis) for axis in k.axes_of(AxisType.GLOBAL, AxisType.LOOP)] + local_axis_ranking = [(any(k.rngs[axis] not in b.src[1].get_idx().parents for b in k.bufs), axis) \ + for axis in k.axes_of(AxisType.GLOBAL, AxisType.LOOP) if k.rngs[axis].src[0].op is Ops.CONST] to_local: list[tuple[int, int]] = [] for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])): local_size = prod(sz for _, sz in to_local) @@ -122,4 +173,16 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]: k.apply_opt(Opt(OptOps.LOCAL, axis, local_sz)) if will_delete_shape: deleted_shape += 1 - return k.applied_opts + # **** threading **** + + if k.opts.has_threads and k.opts.global_max is not None: + for threads in [32,16,12,8,6,5,4,3,2]: + # Skip is too many threads. Heuristic: use about 128K ops per thread + if threads > k.opts.global_max[0] or resolve(prod(k.full_shape) // (128 << 10) < threads): continue + for axis in k.axes_of(AxisType.LOOP): + if k.full_shape[axis] % threads == 0: + k.apply_opt(Opt(OptOps.THREAD, axis, threads)) + break + if k.applied_opts and k.applied_opts[-1].op is OptOps.THREAD: break + + return k diff --git a/tinygrad_repo/tinygrad/codegen/opt/kernel.py b/tinygrad_repo/tinygrad/codegen/opt/kernel.py deleted file mode 100644 index 570ad17c..00000000 --- a/tinygrad_repo/tinygrad/codegen/opt/kernel.py +++ /dev/null @@ -1,514 +0,0 @@ -from __future__ import annotations -import itertools, functools, math -from dataclasses import dataclass -from collections import defaultdict -from typing import cast, Final, Callable, Sequence -from enum import Enum, auto - -from tinygrad.uop.ops import GroupOp, KernelInfo, UOp, Ops, can_pad, resolve, Variable, sint, graph_rewrite, AxisType -from tinygrad.uop.spec import type_verify, ast_spec -from tinygrad.device import Device -from tinygrad.codegen.opt.tc import TensorCore -from tinygrad.renderer import Renderer -from tinygrad.dtype import ImageDType, AddrSpace -from tinygrad.helpers import all_same, colored, ansilen, dedup, prod, round_up, to_function_name, unwrap, argfix, DEBUG, TC_SELECT, TC_OPT, AMX -from tinygrad.shape.shapetracker import ShapeTracker -from tinygrad.shape.view import strides_for_shape, get_contraction -from tinygrad.codegen.opt.swizzler import view_left, view_left_through_load - -class OptOps(Enum): - TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702 - GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); SWAP = auto() # noqa: E702 - def __lt__(self, x:OptOps): return self.value < x.value - -@dataclass(frozen=True, order=True) -class Opt: - op: OptOps - axis: int|None = None - arg: int|tuple|None = None - def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, arg={self.arg})" - -axis_letters = {AxisType.GLOBAL: "g", AxisType.LOCAL: "l", AxisType.LOOP: "L", AxisType.UPCAST: "u", - AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r"} -axis_colors = {AxisType.GLOBAL: "blue", AxisType.LOCAL: "cyan", AxisType.LOOP: "WHITE", AxisType.UPCAST: "yellow", - AxisType.GROUP_REDUCE: "green", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta"} - -class KernelOptError(Exception): pass -def check(cond:bool, msg:str=""): - if not cond: raise KernelOptError(msg) - -@dataclass -class TensorCoreOptions: - axes: tuple[int, ...] # the location of the original N and M axes if still in the shape - axes_exist: tuple[bool, ...] # true if the original N and M axes are still in the shape - axis_pads: tuple[tuple[int, int], ...] - def fix_axes(self, removed_axis:int): # adjust the TC axes if necessary when a dimension is removed - axes, axes_exist = list(self.axes), list(self.axes_exist) - for tc_dim in [i for i in range(2) if axes_exist[i]]: - if removed_axis < axes[tc_dim]: axes[tc_dim] -= 1 - elif removed_axis == axes[tc_dim]: axes_exist[tc_dim] = False - self.axes, self.axes_exist = tuple(axes), tuple(axes_exist) - -class Kernel: - def __init__(self, ast:UOp, opts:Renderer|None=None): - assert ast.op is Ops.SINK, ast.op - self.ast = ast - - self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer - # verify AST matches the spec - if __debug__: type_verify(list(self.ast.toposort()), ast_spec) - - self.vars: list[Variable] = self.ast.variables() - # NOTE: this requires a specific order with the [::-1], this is likely a bug - self.bufs: list[UOp] = [x for x in self.ast.toposort() if x.op in GroupOp.Buffer and x.st is not None][::-1] - - # create new shapetrackers inside this kernel, we will permute them - self.sts: list[ShapeTracker] = [x.st_arg for x in self.bufs] - - # add the shapetrackers for each reduce - # we use this to track which axes are reduced in each reduce - self.reduceops = [x for x in self.ast.toposort() if x.op is Ops.REDUCE_AXIS] - for x in self.reduceops: - self.sts.append(unwrap(x.st)) - self.sts.append(unwrap(x.src[0].st)) - - # add a shapetracker to the end to track the full shape, with 0 strides so it can merge - full_shape = ast.full_shape - self.sts.append(ShapeTracker.from_shape(full_shape, (0,)*len(full_shape))) - - # parameters for optimization - self.tensor_core: TensorCore|None = None - self.tensor_core_opts: TensorCoreOptions|None = None - self.use_tensor_cores: int = 0 - self.applied_opts: list[Opt] = [] - self.dont_use_locals = False - self.finalized: bool = False - - # group simplifies - self.simplify_ones() - self.simplify_merge_adjacent() - - # axis types - global_loops = AxisType.GLOBAL if self.opts.has_local else AxisType.LOOP - self.axis_types: list[AxisType] = [AxisType.REDUCE if resolve(x!=y) else global_loops for x,y in zip(self.output_shape, self.full_shape)] - - # confirm all reduce axes are at the end - if (final_reduces := [x for x in self.axis_types if x == AxisType.REDUCE]) and final_reduces != self.axis_types[-len(final_reduces):]: - raise RuntimeError(f"reduces are not at the end of the shape {self.full_shape} -> {self.output_shape}") - - def copy(self): - ret = type(self).__new__(type(self)) - - # base linearizer params - ret.opts, ret.ast = self.opts, self.ast - - # things downstream of the AST - ret.reduceops, ret.vars, ret.bufs = self.reduceops, self.vars, self.bufs - ret.sts = self.sts[:] - ret.axis_types = self.axis_types[:] - - # parameters for optimizations - ret.applied_opts, ret.dont_use_locals = self.applied_opts[:], self.dont_use_locals - ret.tensor_core, ret.tensor_core_opts, ret.use_tensor_cores = self.tensor_core, self.tensor_core_opts, self.use_tensor_cores - ret.finalized = self.finalized - - return ret - - @property - def reduceop(self) -> UOp|None: return self.reduceops[0] if len(self.reduceops) > 0 else None - @property - def full_shape(self) -> tuple[sint, ...]: return self.sts[-1].shape - - @property - def output_shape(self) -> tuple[sint, ...]: return self.sts[0].shape - @property - def shape_len(self) -> int: return len(self.full_shape) - - def axes_of(self, *axis_type:AxisType) -> list[int]: return [i for i,t in enumerate(self.axis_types) if t in argfix(axis_type)] - @property - def upcasted(self) -> int: return len(self.axes_of(AxisType.UPCAST, AxisType.UNROLL)) - @property - def group_for_reduces(self) -> int: return len(self.axes_of(AxisType.GROUP_REDUCE)) - - # heuristic helpers - @property - def upcastable_dims(self) -> list[int]: return [i for i in self.axes_of(AxisType.GLOBAL, AxisType.LOCAL, AxisType.LOOP) \ - if isinstance(s:=self.full_shape[i], int) and s > 1] - @property - def unrollable_dims(self) -> list[int]: return [i for i in self.axes_of(AxisType.GROUP_REDUCE, AxisType.REDUCE) \ - if isinstance(s:=self.full_shape[i], int) and s > 1] - - # ******************** colors and names ******************** - - def colors(self) -> list[str]: - assert len(self.axis_types) == self.shape_len, "colors size mismatch" - return [axis_colors[x] if not self.dont_use_locals or not x == AxisType.GLOBAL else "BLUE" for x in self.axis_types] - - def colored_shape(self, pad:int|None=None, dense=False) -> str: - shape_strs = [(s if dense else f"{s:4d}") if isinstance(s, int) else s.render() for s in self.full_shape] - ret = ' '.join(colored(s, color) for s,color in zip(shape_strs, self.colors())) - if pad: ret += ' '*(pad-ansilen(ret)) - return ret - - kernel_cnt: Final[defaultdict[str, int]] = defaultdict(int) - @functools.cached_property - def name(self) -> str: - # kernel name (before late upcast) - kernel_type = "r" if self.reduceop is not None else ("C" if all(x.op is Ops.SINK or x.op in GroupOp.Buffer for x in self.ast.toposort()) else "E") - suffix = colored('_', 'BLACK').join([colored(x.render() if isinstance(x, UOp) else str(x), c) for x,c in zip(self.full_shape, self.colors())]) - name = kernel_type + (f"{len(self.ast.src)}" if len(self.ast.src) > 1 else "") + "_" + suffix - - # name the function something unique - Kernel.kernel_cnt[(function_name := to_function_name(name))] += 1 - num = f"n{Kernel.kernel_cnt[function_name]-1}" if Kernel.kernel_cnt[function_name] > 1 else "" - return name + colored(num, 'BLACK') - - # ******************** base simplifiers ******************** - - # apply reshape and permute to all shapetrackers - def reshape(self, new_shape_fxn:Callable[[tuple[sint, ...]], Sequence[sint]]): - self.sts = [st.reshape(tuple(new_shape_fxn(st.shape))) for st in self.sts] - def permute(self, new_axes:Sequence[int]): self.sts = [st.permute(tuple(new_axes)) for st in self.sts] - - # axis : the axis to pull from - # amount : the amount to take - # top : if you want to pull that amount from the top - # insert_at : place to insert the new stuff - def shift_to(self, axis:int, amount:int, new_type:AxisType, top:bool=False, insert_at:int|None=None) -> int: - if insert_at is None: insert_at = self.shape_len - self.axis_types.insert(insert_at, new_type) - move_axis = axis if top else axis+1 - if move_axis < insert_at: insert_at += 1 - def new_shape_fxn(x): return x[0:axis] + (((amount,x[axis]//amount) if top else (x[axis]//amount,amount)) if x[axis] > 1 else (1,1)) + x[axis+1:] - new_axes = [i for i in range(insert_at) if i != move_axis]+[move_axis]+[i for i in range(insert_at, self.shape_len+1) if i != move_axis] - self.reshape(new_shape_fxn) - self.permute(new_axes) - return insert_at - - # ******************** complex simplifiers ******************** - - def simplify_ones(self) -> bool: - # remove places where the shape is all ones - if any(all_ones:=[s==1 for s in self.full_shape]): - if hasattr(self, 'axis_types'): - self.axis_types = [x for i,x in enumerate(self.axis_types) if not all_ones[i]] - self.reshape(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]]) - return True - return False - - def simplify_merge_adjacent(self): - assert not hasattr(self, 'axis_types'), "don't call this after init" - if self.shape_len == 0: return - shapes, strides = [x.shape for x in self.sts], [x.real_strides() for x in self.sts] - # NOTE: we can't use self.first_reduce yet - first_reduce = [resolve(x!=y) for x,y in zip(self.output_shape+(0,), self.full_shape+(1,))].index(True) - - # if it's an image, insert fake strides such that this fusion doesn't happen across image axes - # TODO: remove membufs - membufs = dedup([x.src[0].base for x in self.bufs if x.op in {Ops.LOAD, Ops.STORE}]) - if isinstance(membufs[0].base.dtype, ImageDType): - base_shape = membufs[0].base.dtype.shape - if shape_idx_groups := get_contraction(self.output_shape, base_shape): - special_strides: tuple[sint, ...] = tuple() - for i,g in enumerate(shape_idx_groups): - shape_piece = tuple(self.output_shape[x] for x in g) - assert prod(shape_piece) == base_shape[i], f"get_contraction was wrong? {shape_piece} != {base_shape[i]}" - special_strides += strides_for_shape(shape_piece) - # adding the fake image shape - shapes.append(self.output_shape) - strides.append(special_strides) - - # merge dimensions if we can, multi _merge_dims - # NOTE: this does not always preserve the reduce dimension - # TODO: move this into shapetracker, with tests! - # TODO: how does this work with multi-reduce? - rets = [[(s[0], st[0])] for s,st in zip(shapes, strides)] - for i in range(1, len(shapes[0])): - can_merge = [] - for s,st,ret in zip(shapes, strides, rets): - # TODO: added the always mergeability of 1s, is this right? if so, add to shapetracker in the 1 case - si, sti, last_st = s[i], st[i], ret[-1][1] - can_merge.append((sti is not None) and ((sti != 0 and last_st == si*sti) or (sti == 0 and last_st == 0))) - # more can merge than this - mergeable = all(can_merge) and i != first_reduce - for j,(s,st) in enumerate(zip(shapes, strides)): - if mergeable: rets[j][-1] = (rets[j][-1][0] * s[i], st[i]) - else: rets[j].append((s[i], st[i])) - - # do the reshapes - for i,x in enumerate(rets[:len(self.sts)]): self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x])) - - # ******************** apply optimizations ******************** - - def real_axis(self, op:OptOps, axis:int|None): - try: - if axis is None: return -1 - if op is OptOps.UNROLL: return self.unrollable_dims[axis] - if op in {OptOps.GROUP, OptOps.GROUPTOP}: return self.axes_of(AxisType.REDUCE)[axis] - check(axis < self.shape_len, "invalid axis") - return axis - except IndexError as e: raise KernelOptError from e - - def apply_opt(self, opt:Opt, append_opt:bool=True) -> int|None: - if self.finalized: raise RuntimeError("can't optimize Kernel after it's finalized") - if self.dont_use_locals: check(opt.op not in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP}, "not using locals") - - if opt.op is OptOps.TC: - check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: things like PADTO might be fine - check(len(self.opts.tensor_cores) > 0, "must have tensor cores") - check(opt.axis is not None, "tensor core opts must have an axis") - check(opt.arg is not None and isinstance(opt.arg, tuple) and len(opt.arg) == 3, "tensor core opts must have valid arg") - check(-1 <= (tc_select:=cast(tuple, opt.arg)[0]) < len(self.opts.tensor_cores), "tensor core opts must have valid tc_select") - check(0 <= (tc_opt:=cast(tuple, opt.arg)[1]) <= 2, "tensor core opts must have valid tc_opt") - check(0 < (use_tensor_cores:=cast(tuple, opt.arg)[2]) <= 2, "use_tensor_cores value is not valid") - check(self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), tc_select, tc_opt), "no tensor core available") - self.applied_opts.append(opt) - return None - - axis = self.real_axis(opt.op, opt.axis) - - if opt.op is OptOps.SWAP: amt = self.real_axis(opt.op, cast(int, opt.arg)) # arg is an axis in the SWAPs - elif opt.arg is not None: - check(isinstance(opt.arg, int), "arg should be int") - amt = arg if (arg:=cast(int, opt.arg)) != 0 else self.full_shape[axis] - check(isinstance(amt, int) and amt != 1, f"shift/padto of {amt=}, 1 or symbolic amount is meaningless") - if opt.op is not OptOps.PADTO: - # we check both the full_shape and each shape - check(self.full_shape[axis] % amt == 0, f"no longer valid shift {self.full_shape[axis]=}, {amt=}") - for st in self.sts: check(st.shape[axis] == 1 or st.shape[axis] % amt == 0, f"no longer valid shift {st.shape[axis]=}, {amt=}") - else: amt = -1 - - if self.reduceop is not None and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or \ - (self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})): - acc_sz = self.reduceop.dtype.itemsize - upcast_sz = prod([self.full_shape[a] for a in self.axes_of(AxisType.UPCAST)]) - local_sz = prod([self.full_shape[a] for a in self.axes_of(AxisType.LOCAL)]) - smem_sz = amt*acc_sz*upcast_sz*local_sz - check(smem_sz <= self.opts.shared_max, f"exceeds maximum shared memory size: needs {smem_sz}, max {self.opts.shared_max}") - - new_axis = None - if opt.op is OptOps.LOCAL: # cyan - # NOTE: LLVM/CPU can use locals too, but they are treated the same as globals (still helpful for L1 cache) - # it's disabled for now since it makes BEAM slow for little gain - check(self.opts.has_local, "target does not support local") - check(self.axis_types[axis] is AxisType.GLOBAL, "local is for globals") - new_axis = self.shift_to(axis, amt, AxisType.LOCAL, insert_at=max(self.axes_of(AxisType.GLOBAL, AxisType.LOCAL))+1) - elif opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: # green - check(self.opts.has_local and self.opts.has_shared, "target does not support local or shared mem") - check(self.axis_types[axis] is AxisType.REDUCE, "must be reduce axis to group") - check(not self.tensor_core, "can't group with tensor cores") - check(len(reduce_axes:=[i for r in self.reduceops for i in r.axis_arg]) == len(set(reduce_axes)), "can't group with parallel reduces") - new_axis = self.shift_to(axis, amt, AxisType.GROUP_REDUCE, top=(opt.op is OptOps.GROUPTOP), insert_at=min(self.axes_of(AxisType.REDUCE))) - elif opt.op is OptOps.UNROLL: # purple - check(self.axis_types[axis] not in (AxisType.UPCAST, AxisType.UNROLL), "can't upcasted already upcasted") - check(amt <= 32, "don't unroll more than 32") - new_axis = self.shift_to(axis, amt, AxisType.UNROLL, insert_at=None) - elif opt.op is OptOps.UPCAST: # yellow - check(axis in self.upcastable_dims, f"{axis=} not in {self.upcastable_dims=}") - # NOTE: assume the first get_local_axes() LOCAL are for TC - check(not (self.tensor_core and axis in self.axes_of(AxisType.LOCAL)[:len(self.tensor_core.get_local_axes())]), "can't upcast TC locals") - check((self.opts is not None and self.opts.device == "DSP") or amt <= 16, "don't upcast more than 16") - new_axis = self.shift_to(axis, amt, AxisType.UPCAST, - insert_at=max(self.axes_of(AxisType.GLOBAL, AxisType.LOCAL, AxisType.LOOP, AxisType.UPCAST))+1) - elif opt.op is OptOps.NOLOCALS: - check(self.opts.has_local and not self.dont_use_locals, "NOLOCALS is meaningless if target does not support local or already not using locals") - check(AxisType.LOCAL not in self.axis_types and self.group_for_reduces == 0, "can't have no locals with locals") - self.dont_use_locals = True - elif opt.op is OptOps.SWAP: - check(axis < amt, f"swap is only for axis < amt, getting {amt=}, {axis=}") - check(self.axis_types[axis]==self.axis_types[amt]==AxisType.GLOBAL, f"swap is for globals {self.axis_types[axis]=}, {self.axis_types[amt]=}") - permute = list(range(self.shape_len)) - permute[axis], permute[amt] = permute[amt], permute[axis] - self.permute(tuple(permute)) - elif opt.op is OptOps.PADTO: - check(not self.vars, "does not work with symbolic shape") - check(self.axis_types[axis] not in (AxisType.UPCAST, AxisType.UNROLL), "cannot pad upcasted") - # ok to pad SUM if all parent ALU ops have f(0) = 0 - if (r:=self.reduceop) is not None and self.axis_types[axis] in (AxisType.GROUP_REDUCE, AxisType.REDUCE): - check(r.arg[0] is Ops.ADD and can_pad(r, {}), f"cannot pad {r}") - padded = False - for i,st in enumerate(self.sts): - if (s:=st.shape[axis]) == 1: continue # reduced - check(s > amt//4, f"pad adds more than quadruple the work {st.shape[axis]=} > {amt//4=}") - if (ru := round_up(cast(int, s), amt) - s): - # pad right seems to be faster - self.sts[i] = st.pad(((0,0),) * axis + ((0,ru),) + ((0,0),) * (len(st.shape)-axis-1)) - padded = True - check(padded, "nothing was padded") - - if append_opt: self.applied_opts.append(opt) - if self.simplify_ones() and self.tensor_core_opts: - self.tensor_core_opts.fix_axes(axis) # fix up axes in TC opts if required after simplify_ones() - return new_axis - - def apply_opts(self, opts:Sequence[Opt]) -> Kernel: - for opt in opts: self.apply_opt(opt) - return self - - # **** kernel outputs, mostly tensor cores **** - - def _create_tc_opts(self, reduceop:UOp, tc:TensorCore, axis:int, opt_level:int) -> TensorCoreOptions|None: - has_cast = tc.dtype_in != tc.dtype_out - if has_cast and not (reduceop.src[0].op is Ops.CAST and reduceop.src[0].dtype == tc.dtype_out): return None - - mul_op = reduceop.src[0].src[0] if has_cast else reduceop.src[0] - if mul_op.op is not Ops.MUL: return None - - def buf_index(src:UOp) -> int|None: - # TODO: apply tc even if the sources are not from LOAD - if src.op is Ops.LOAD and src.dtype == tc.dtype_in: return self.bufs.index(src) - try: - if opt_level >= 1 and src.op is Ops.CAST and src.dtype == tc.dtype_in: return self.bufs.index(src.src[0]) - except ValueError: return None - return None - if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: return None - - buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides() - axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i in self.upcastable_dims if buf0_strides[i] == 0] - axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i in self.upcastable_dims if buf1_strides[i] == 0] - if not (axis_buf0 and axis_buf1 and (len(self.axes_of(AxisType.GROUP_REDUCE, AxisType.REDUCE)) == 1 or (opt_level >= 1))): return None - - axis_choices = list(itertools.product(axis_buf0, axis_buf1, self.axes_of(AxisType.GROUP_REDUCE, AxisType.REDUCE))) - if not (axis < len(axis_choices)): return None - - s0, s1, s2 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0], axis_choices[-(axis+1)][2] # s0 is n, s1 is m, s2 is k - axis_pads = tuple((x, tc.dims[i]) for i, x in enumerate([s0, s1, s2]) if resolve(self.full_shape[x]%tc.dims[i] != 0)) - if axis_pads and (opt_level < 2): return None - if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc) - return TensorCoreOptions(axes=(s0, s1, s2), axes_exist=(True, True), axis_pads=axis_pads) - - def _apply_tc_opt(self, use_tensor_cores:int, axis:int, tc_select:int, opt_level:int) -> bool: - if use_tensor_cores and self.reduceop is not None and self.reduceop.arg[0] is Ops.ADD: - tensor_cores = self.opts.tensor_cores if tc_select == -1 else [self.opts.tensor_cores[tc_select]] - for tc in tensor_cores: - tensor_core_opts = [self._create_tc_opts(reduceop, tc, axis, opt_level) for reduceop in self.reduceops] - if tensor_core_opts[0] is None: continue - # can only fuse reduces with the same tc options - assert all_same(tensor_core_opts) - self.tensor_core_opts = tc_opts = tensor_core_opts[0] - - # attempt to pad the tensor axes that require it - try: - for axis, dim in tc_opts.axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail - except KernelOptError: continue - # tensor core -- unroll the reduce dim (K), upcast and local the inner and outer dims (N, M) - for opt in tc.opts: self.apply_opt(Opt({"u":OptOps.UPCAST, "l":OptOps.LOCAL}[opt[0]], tc_opts.axes[int(opt[1])], 2), append_opt=False) - for dim, amt in tc.get_reduce_axes(): self.apply_opt(Opt(OptOps.UNROLL, 0, amt), append_opt=False) # TODO: this should be the reduce, not 0 - self.tensor_core = tc - self.use_tensor_cores = use_tensor_cores # TC=2 will do the shape ops without the WMMA - return True - return False - - def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:list[Opt]|None=None, axis:int=0, tc_select:int|None=None, tc_opt:int|None=None) -> bool: - """ Attempts to apply a tensor core optimization to the kernel. If one exists and applies properly, return true, otherwise return false. - Tensor cores are optimized instructions that matrix multiply-accumulate across a wave of threads: D(M, N) = A(M, K) * B(K, N) + C(M, N). - - Keyword arguments: - use_tensor_cores -- controls how tensor cores are applied (default 1) - 0: will disable any tensor core matching - 1: enable tensor cores - 2: apply tensor core shape but don't use UOp.WMMA - extra_opts -- additional Opt's to apply after the tensor core instead of the hand-coded additional Opt's (default None) - tc_select -- specifies which tensor core(s) to use for optimization (default -1) - -1: iterates through all available tensor cores in order and uses the first one that matches the requirements (dims and dtypes) - [0-N]: uses only the n'th tensor core available; useful for search - tc_opt -- controls which kinds of kernels may be eligible for tensor cores application (default 2 during BEAM, 0 otherwise) - 0: applies to only kernels with a single reduce axis and direct Ops.LOAD into Ops.MUL - 1: allows kernels with multiple reduce axes and also multiplication of Ops.CAST'd buffers - 2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed - """ - if tc_select is None: tc_select = TC_SELECT.value - if tc_opt is None: tc_opt = TC_OPT.value - if not self.opts.tensor_cores: return False - try: # check TC first and apply hand-coded opts if successful - self.apply_opt(Opt(OptOps.TC, axis, (tc_select, tc_opt, use_tensor_cores))) - - if (tc_opts:=self.tensor_core_opts) is not None: - if extra_opts is not None: self.apply_opts(extra_opts) - else: - if AMX: return True # skip hand-coded TC opts if AMX, upcasting will make kernel slower - # hand-coded TC opts - for tc_dim in [tc_dim for tc_dim in [1,0] if tc_opts.axes_exist[tc_dim]]: # attempt to upcast M and N - szs = [sz for sz in [5,4,3,2] if self.full_shape[tc_opts.axes[tc_dim]] % sz == 0] - if szs: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[tc_dim], szs[0])) - - if tc_opts.axes_exist[0] and (szs := [sz for sz in [4,2] if self.full_shape[tc_opts.axes[0]] % sz == 0]): # attempt to local N - self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[0], szs[0])) - return True - except KernelOptError: - return False - - # strings like ['g0', 'g1', 'l0', 'l1', 'l2', 'l3', 'l4', 'l5', 'R0', 'r0', 'r1', 'r2', 'u0', 'u1', 'u2'] - def shape_str(self) -> list[str]: - ret: list[str] = [] - cnt: dict[AxisType, int] = {} - for x in self.axis_types: - cnt[x] = (cnt[x] + 1) if x in cnt else 0 - ret.append(f"{axis_letters[x]}{cnt[x]}") - return ret - def shape_str_to_axis(self, nms:list[str]) -> tuple[int, ...]: return tuple([self.shape_str().index(x) for x in nms]) - - def get_optimized_ast(self, name_override:str|None=None) -> UOp: - @functools.cache - def fixup_ast(op:UOp) -> UOp: - ret = op.replace(src=tuple(fixup_ast(x) for x in op.src)) # noqa: F821 - if op.op in GroupOp.Buffer and op in self.bufs: - st = self.sts[self.bufs.index(op)] - # replace the VIEW source - return ret.replace(src=(ret.src[0].replace(arg=st),)+ret.src[1:]) - if op.op is Ops.SINK: - # NOTE: should group_for_reduces be added to the local_dims? - # TODO: arg.name should be able to be None - kernel_name = ret.arg.name if ret.arg is not None and ret.arg.name != "test" else self.name if name_override is None else name_override - return ret.replace(arg=KernelInfo(kernel_name, tuple(self.axis_types), self.dont_use_locals, tuple(self.applied_opts))) - if op.op is Ops.REDUCE_AXIS: - reduce_idx = len(self.bufs) + self.reduceops.index(op) * 2 - changed = tuple(i for i in range(self.shape_len) if resolve(self.sts[reduce_idx].shape[i] != self.sts[reduce_idx + 1].shape[i])) - axes = tuple(i for i in self.axes_of(AxisType.REDUCE, AxisType.UNROLL) if i in changed) - grouped_axes = tuple(i for i in self.axes_of(AxisType.GROUP_REDUCE) if i in changed) - if (tc := self.tensor_core) and self.use_tensor_cores == 1: - # get reduce/upcast axes for the tensor cores - tc_reduce_axes = self.shape_str_to_axis([f"r{i}" for i in range(len(tc.get_reduce_axes()))]) - base_upcast_axes = tuple([(s,2) for s in self.shape_str_to_axis(tc.base_upcast_axes())]) - tc_upcast_axes = tuple([base_upcast_axes[:int(math.log2(tc.elements_per_thread[i]))] for i in range(3)]) - - # permute the srcs - srcs = list((ret.src[0] if ret.src[0].op is not Ops.CAST else ret.src[0].src[0]).src) - for i, (src, permaxis) in enumerate(zip(srcs, tc.permutes_for_shape_str(self.shape_str()))): - src_st = (src if src.op is Ops.LOAD else src.src[0]).st_arg - srcs[i] = src.view(ShapeTracker.from_shape(src_st.shape).permute(permaxis)) - - # construct the op - wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device, tc.threads, tc_upcast_axes, tc_reduce_axes) - wmma = UOp(Ops.WMMA, dtype=tc.dtype_out.vec(tc.elements_per_thread[2]), src=( - UOp(Ops.CONTRACT, dtype=srcs[0].dtype.vec(tc.elements_per_thread[0]), src=(srcs[0],), arg=tc_upcast_axes[0]), - UOp(Ops.CONTRACT, dtype=srcs[1].dtype.vec(tc.elements_per_thread[1]), src=(srcs[1],), arg=tc_upcast_axes[1]), - UOp.const(tc.dtype_out.vec(tc.elements_per_thread[2]), 0.0)), arg=wmma_arg) - tc_uop = UOp(Ops.UNROLL, tc.dtype_out, (wmma,), arg=tc_upcast_axes[2]) - - # preserve any other reduce - return ret.replace(src=(tc_uop,), arg=(Ops.ADD, new_axes)) if (new_axes := tuple(i for i in axes if i not in tc_reduce_axes)) else tc_uop - - ret = ret.replace(arg = (op.arg[0], axes)) - if self.group_for_reduces and grouped_axes: - local_axes = tuple([i for i,t in enumerate(self.axis_types) if t in (AxisType.LOCAL, AxisType.UPCAST) or i in grouped_axes]) - slocal, supcast, sgroup = sorted(self.axes_of(AxisType.LOCAL)), sorted(self.axes_of(AxisType.UPCAST)), sorted(grouped_axes) - # NOTE: start with UPCAST at the end so it has stride 1 and can merge - base_shape = tuple([self.full_shape[i] for i in slocal] + [self.full_shape[i] for i in sgroup] + [self.full_shape[i] for i in supcast]) - permute_axes = tuple([local_axes.index(i) for i in slocal+sgroup+supcast]) - local_shape = tuple([s if i in local_axes else 1 for i,s in enumerate(self.full_shape)]) - local_src_shape = tuple([self.full_shape[i] if i in self.axes_of(AxisType.GLOBAL) else s for i,s in enumerate(local_shape)]) - st = ShapeTracker.from_shape(base_shape).permute(permute_axes).reshape(local_shape).expand(local_src_shape) - local_size = st.real_size() - local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local_size, addrspace=AddrSpace.LOCAL), (), f"temp{self.reduceops.index(op)}") - local_load = local_buffer.view(st).load(local_buffer.view(st).store(ret)) - grouped_reduce = UOp(Ops.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], grouped_axes)) - if op is self.reduceops[-1]: return grouped_reduce - st = ShapeTracker.from_shape(tuple([1 if i in grouped_axes else s for i,s in enumerate(local_shape)])) - return local_buffer.view(st).load(local_buffer.view(st).store(grouped_reduce)) - - return ret - self.finalized = True - fixed_ast = fixup_ast(self.ast) - del fixup_ast - return graph_rewrite(fixed_ast, view_left+view_left_through_load, name="fixup optimized AST") diff --git a/tinygrad_repo/tinygrad/codegen/opt/postrange.py b/tinygrad_repo/tinygrad/codegen/opt/postrange.py new file mode 100644 index 00000000..20193a05 --- /dev/null +++ b/tinygrad_repo/tinygrad/codegen/opt/postrange.py @@ -0,0 +1,334 @@ +from __future__ import annotations +import math, itertools +from collections import defaultdict +from typing import cast, Final +from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, can_pad, GroupOp +from tinygrad.device import Buffer +from tinygrad.dtype import AddrSpace, dtypes, ImageDType +from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod +from tinygrad.codegen.opt import axis_colors, Opt, OptOps, KernelOptError, check, axis_letters +from tinygrad.codegen.simplify import pm_flatten_range +from tinygrad.renderer import Renderer + +remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)]) + +# NOTE: LOCAL and GROUP_REDUCE have the same priority. the order here matters +axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3, + AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5} + +class Scheduler: + def __init__(self, ast:UOp, opts:Renderer): + self.ast, self.opts = ast, opts + self.dont_use_locals = self.ast.arg.dont_use_locals if self.ast.arg is not None else False + self.applied_opts = list(self.ast.arg.applied_opts) if self.ast.arg is not None else [] + + @property + def rngs(self): + # always in order by axistype + return sorted([u for u in self.ast.parents if u.op is Ops.RANGE and u.vmax > 0], key=lambda x: (axis_to_pos[x.arg[-1]],) + x.arg[0:-1]) + @property + def shape_len(self): return len(self.rngs) + @property + def full_shape(self): return [ssimplify(x.src[0]) for x in self.rngs] + @property + def axis_types(self): return [x.arg[-1] for x in self.rngs] + @property + def maxarg(self): return max([x.arg[0] for x in self.rngs], default=0) + + # strings like ['g0', 'g1', 'l0', 'l1', 'l2', 'l3', 'l4', 'l5', 'R0', 'r0', 'r1', 'r2', 'u0', 'u1', 'u2'] + def shape_str(self) -> list[str]: + ret: list[str] = [] + cnt: dict[AxisType, int] = {} + for x in self.axis_types: + cnt[x] = (cnt[x] + 1) if x in cnt else 0 + ret.append(f"{axis_letters[x]}{cnt[x]}") + return ret + def shape_str_to_axis(self, nms:list[str]) -> tuple[int, ...]: return tuple([self.shape_str().index(x) for x in nms]) + + def copy(self): + ret = Scheduler(self.ast, self.opts) + ret.dont_use_locals = self.dont_use_locals + ret.applied_opts = self.applied_opts[:] + return ret + + kernel_cnt: Final[defaultdict[str, int]] = defaultdict(int) + def get_optimized_ast(self, name_override:str|None=None): + if name_override is not None: name = name_override + else: + kernel_type = "r" if self.reduceop is not None else "E" + name = kernel_type + colored('_', 'BLACK').join(['']+[colored(x.src[0].render(), color) for x,color in zip(self.rngs, self.colors())]) + Scheduler.kernel_cnt[(function_name := to_function_name(name))] += 1 + num = f"n{Scheduler.kernel_cnt[function_name]-1}" if Scheduler.kernel_cnt[function_name] > 1 else "" + name += colored(num, 'BLACK') + self.ast = graph_rewrite(self.ast, pm_flatten_range, name="flatten range") + return self.ast.replace(arg=KernelInfo(name=name, applied_opts=tuple(self.applied_opts), dont_use_locals=self.dont_use_locals), tag=1) + + def _globalizable_rngs(self) -> list[UOp]: + store_rngs = self.ast.src[0].src[2:] + + # filter any not in local stores + local_store_rngs = [x.ranges for x in self.ast.toposort() if (x.op is Ops.STORE and x.src[0].ptrdtype.addrspace == AddrSpace.LOCAL) \ + or (x.op is Ops.BUFFERIZE and x.arg == AddrSpace.LOCAL)] + for ls in local_store_rngs: store_rngs = tuple([x for x in store_rngs if x in ls]) + + return [x for x in UOp.sink(*store_rngs).toposort() if x.op is Ops.RANGE and x.arg[1] == AxisType.LOOP] if store_rngs else [] + + def convert_loop_to_global(self): + if not self.opts.has_local: return None + + globalizible_rngs = self._globalizable_rngs() + rng = [x.replace(arg=(x.arg[0], AxisType.GLOBAL)) if x in globalizible_rngs else x for x in self.rngs] + + self.ast = self.ast.substitute(dict(zip(self.rngs, rng))) + + def colors(self) -> list[str]: return [axis_colors[x] if not self.dont_use_locals or not x == AxisType.GLOBAL else "BLUE" for x in self.axis_types] + def colored_shape(self) -> str: return ' '.join([colored(f'{x.src[0].render():>4s}', color) for x,color in zip(self.rngs, self.colors())]) + + def shift_to(self, rng:UOp, amount:int, new_type:AxisType, top:bool=False, input_new_rng=None): + if (old_sz:=rng.src[0].divides(amount)) is None: + raise KernelOptError(f"{amount} can't divide {rng.src[0]} in {self.colored_shape()}") + new_rng = UOp.range(amount, self.maxarg+1, new_type) if input_new_rng is None else input_new_rng + replaced_rng = rng.replace(src=(UOp.const(dtypes.int, old_sz),)) + sub_axis = (new_rng * old_sz + replaced_rng) if top else (replaced_rng * amount + new_rng) + self.ast = self.ast.substitute({rng:sub_axis}, name=f"shift {rng.arg[0]} {amount} {str(new_type).split('.')[1].lower()}") + return replaced_rng, new_rng + + def ranges_of(self, *axis_type:AxisType) -> list[UOp]: return [r for r in self.rngs if r.arg[-1] in axis_type] + def axes_of(self, *axis_type:AxisType) -> list[int]: return [i for i,t in enumerate(self.axis_types) if t in axis_type] + + # copied from kernel.py + @property + def upcastable_dims(self) -> list[int]: return [i for i in self.axes_of(AxisType.GLOBAL, AxisType.LOCAL, AxisType.LOOP) \ + if isinstance(s:=self.full_shape[i], int) and s > 1] + @property + def unrollable_dims(self) -> list[int]: return [i for i in self.axes_of(AxisType.GROUP_REDUCE, AxisType.REDUCE) \ + if isinstance(s:=self.full_shape[i], int) and s > 1] + + def real_axis(self, op:OptOps, axis:int|None): + try: + if axis is None or op is OptOps.TC: return -1 + if op is OptOps.UNROLL: return self.unrollable_dims[axis] + if op in {OptOps.GROUP, OptOps.GROUPTOP}: return self.axes_of(AxisType.REDUCE)[axis] + check(axis < self.shape_len, f"invalid axis on {axis=} {op=} {self.shape_len=}") + return axis + except IndexError as e: raise KernelOptError from e + + def apply_opt(self, opt:Opt, append_opt:bool=True): + if opt.op is OptOps.NOLOCALS: + check(all(x not in {AxisType.WARP, AxisType.LOCAL, AxisType.GROUP_REDUCE} for x in self.axis_types), "no locals can't have locals") + if append_opt: self.applied_opts.append(opt) + self.dont_use_locals = True + return + + if opt.op in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP}: + check(self.opts.has_local, "locals needed for opt") + + rng = self.rngs[real_axis] if (real_axis:=self.real_axis(opt.op, opt.axis)) >= 0 else UOp(Ops.NOOP) + + opt_to_at = { + OptOps.LOCAL: AxisType.LOCAL, OptOps.UPCAST: AxisType.UPCAST, + OptOps.UNROLL: AxisType.UNROLL, OptOps.GROUP: AxisType.GROUP_REDUCE, + OptOps.GROUPTOP: AxisType.GROUP_REDUCE, OptOps.THREAD: AxisType.THREAD} + + ret = None + if opt.op in opt_to_at: + amt:int = int(rng.vmax+1) if opt.arg == 0 else cast(int, opt.arg) + + # copied from kernel.py. prevents METAL compiler hangs + if self.reduceop is not None and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or \ + (self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})): + upcast_local_sz = prod([self.full_shape[a] for a in self.axes_of(AxisType.UPCAST, AxisType.WARP, AxisType.LOCAL, AxisType.GROUP_REDUCE)]) + smem_sz = amt*upcast_local_sz*self.reduceop.dtype.itemsize + check(smem_sz <= self.opts.shared_max, f"exceeds maximum shared memory size: needs {smem_sz}, max {self.opts.shared_max}") + + if opt.op is OptOps.UNROLL: + check(amt <= 32, "don't unroll more than 32") + check(rng.arg[-1] in {AxisType.GROUP_REDUCE, AxisType.REDUCE}, "unroll is for GROUP_REDUCE/REDUCE") + if opt.op is OptOps.UPCAST: + check((self.opts is not None and self.opts.device == "DSP") or amt <= 16, "don't upcast more than 16") + check(rng.arg[-1] in {AxisType.GLOBAL, AxisType.LOCAL, AxisType.LOOP}, f"upcast is for GLOBAL/LOCAL/LOOP, not {rng.arg[-1]}") + if opt.op is OptOps.LOCAL: + check(not self.dont_use_locals, "can't use locals") + check(rng.arg[-1] in {AxisType.GLOBAL, AxisType.LOOP}, "local is for globals") + if opt.op is OptOps.THREAD: + check(self.opts is not None and self.opts.has_threads, "target does not support threads") + check(self.opts is not None and self.opts.global_max is not None and amt <= self.opts.global_max[0], "too many threads") + check(all(x is not AxisType.THREAD for x in self.axis_types), "already threaded") + check(rng in self._globalizable_rngs(), "can't apply range to this dim") + if opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: + check(all(x.op is not OptOps.TC for x in self.applied_opts), "no grouping with tensor cores") # TODO: why is this wrong? + check(not self.dont_use_locals, "can't use locals") + check(rng.arg[-1] == AxisType.REDUCE, "group is for reduce") + ret = self.shift_to(rng, amt, opt_to_at[opt.op], top=opt.op in {OptOps.GROUPTOP, OptOps.THREAD}) + elif opt.op is OptOps.TC: + check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: remove the need for this by having warps + check(opt.axis is not None, "tensor core opts must have an axis") + check(opt.arg is not None and isinstance(opt.arg, tuple) and len(opt.arg) == 3, "tensor core opts must have valid arg") + check(-1 <= (tc_select:=cast(tuple, opt.arg)[0]) < len(self.opts.tensor_cores), "tensor core opts must have valid tc_select") + check(0 <= (tc_opt:=cast(tuple, opt.arg)[1]) <= 2, "tensor core opts must have valid tc_opt") + check(0 < (use_tensor_cores:=cast(tuple, opt.arg)[2]) <= 2, "use_tensor_cores value is not valid") + try: ret = self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), tc_select, tc_opt) + except ValueError as e: raise KernelOptError(str(e)) + check(ret is not None, "no tensor core available") + elif opt.op is OptOps.PADTO: + check(rng.src[0].op is Ops.CONST, "only pad const axes") + check(rng.arg[-1] not in {AxisType.UPCAST, AxisType.UNROLL}, "cannot pad upcasted") # TODO: why is this wrong? + check(rng.arg[-1] is not AxisType.THREAD, "cannot pad thread") + # ok to pad SUM if all parent ALU ops have f(0) = 0 + if (r:=self.reduceop) is not None and rng.arg[-1] in (AxisType.GROUP_REDUCE, AxisType.REDUCE): + check(r.arg[0] is Ops.ADD and can_pad(r, {}), f"cannot pad {r}") + new_sz = round_up(int(rng.vmax+1), cast(int, opt.arg)) + check(rng.vmax+1 > new_sz//4, "pad adds more than quadruple the work") + replaced_rng = UOp.range(new_sz, *rng.arg) + replaces = {rng:replaced_rng} + valid = replaced_rng < rng.vmax+1 + for b in self.bufs: + if rng in (i:=b.src[1].get_idx()).sparents: + replaces[b] = b.replace(src=(b.src[0],(valid&b.src[1].get_valid()).where(i, UOp.invalid()))) + self.ast = self.ast.substitute(replaces, f"padto {rng.arg[:-1]} {opt.arg}") + elif opt.op is OptOps.SWAP: + try: + altrng = self.rngs[opt.arg] + except IndexError: + raise KernelOptError + check(rng.arg[-1] == AxisType.GLOBAL and altrng.arg[-1] == AxisType.GLOBAL, "swap only for globals") + self.ast = self.ast.substitute({rng:rng.replace(arg=(*altrng.arg[0:-1], rng.arg[-1]), tag=1), + altrng:altrng.replace(arg=(*rng.arg[0:-1], altrng.arg[-1]), tag=1)}) + self.ast = graph_rewrite(self.ast, remove_tags) + else: + raise KernelOptError(f"unsupported opt {opt.op}") + + if append_opt: self.applied_opts.append(opt) + return ret + + def _apply_tc_opt(self, use_tensor_cores:int, axis:int, tc_select:int, opt_level:int) -> None|list[UOp]: + reduceops = [x for x in self.ast.toposort() if x.op is Ops.REDUCE] + if not len(reduceops): raise KernelOptError("no reduce ops for TensorCore") + reduceop = reduceops[0] + if use_tensor_cores and reduceop is not None and reduceop.arg is Ops.ADD: + mul = reduceop.src[0] if reduceop.src[0].op is not Ops.CAST else reduceop.src[0].src[0] + if mul.op is not Ops.MUL: return None + in0, in1 = mul.src + try: + tensor_cores = self.opts.tensor_cores if tc_select == -1 else [self.opts.tensor_cores[tc_select]] + except IndexError: + raise KernelOptError(f"invalid tensor core choice {tc_select}") + for tc in tensor_cores: + if tc.dtype_in == in0.dtype.scalar() and tc.dtype_in == in1.dtype.scalar() and tc.dtype_out == reduceop.dtype.scalar(): + # tensor cores have three ranges. X, Y, and REDUCE + in0_ranges = sorted([u for u in in0.ranges if u not in in1.ranges], key=lambda x: -x.arg[0]) + in1_ranges = sorted([u for u in in1.ranges if u not in in0.ranges], key=lambda x: -x.arg[0]) + red_ranges = sorted(reduceop.src[1:], key=lambda x: -x.arg[0]) + if DEBUG >= 3: + print(f"TC({axis}): {[(x.arg[0],x.vmax+1) for x in in0_ranges]}", + f"{[(x.arg[0],x.vmax+1) for x in in1_ranges]} {[(x.arg[0],x.vmax+1) for x in red_ranges]}") + if not len(in0_ranges) or not len(in1_ranges) or not len(red_ranges): continue + + # pick ranges + # NOTE: why are in1 and in0 switched? + axis_choices = list(itertools.product(in1_ranges, in0_ranges, red_ranges)) + if not (axis < len(axis_choices)): continue + axes = list(axis_choices[axis]) + + # do optimizations and save the ranges + try: + for i,a in enumerate(axes): + idx = self.rngs.index(a) + if (a.vmax+1) % tc.dims[i] != 0: + if opt_level < 2: raise KernelOptError("tc padding requires opt_level >= 2") + # apply_opt should return the updated range? + self.apply_opt(Opt(OptOps.PADTO, idx, tc.dims[i]), append_opt=False) # PADTO might fail + axes[i] = self.rngs[idx] + except KernelOptError: continue + + # we create the warp as a whole thing, in case some of these ranges are moved/removed later + warp = UOp.range(tc.threads, -1, AxisType.WARP) + ne: list[UOp] = [] + for opt in tc.opts: + if opt[0] == "l": + axes[int(opt[1])], new_range = self.shift_to(axes[int(opt[1])], 2, AxisType.LOCAL, input_new_rng=warp%2) + warp //= 2 + elif opt[0] == "u": + axes[int(opt[1])], new_range = self.shift_to(axes[int(opt[1])], 2, AxisType.UPCAST) + else: raise RuntimeError(f"unsupported opt {opt[0]} in tensor cores") + ne.append(new_range) + + for _, amt in tc.get_reduce_axes(): + axes[2], new_range = self.shift_to(axes[2], amt, AxisType.UNROLL) + ne.append(new_range) + + if use_tensor_cores != 2: + # fix the srcs + reduceop = [x for x in self.ast.toposort() if x.op is Ops.REDUCE][0] + tne = [x.replace(tag=1) for x in ne] + ret = reduceop.substitute(dict(zip(ne, tne))) + srcs = list((ret.src[0] if ret.src[0].op is not Ops.CAST else ret.src[0].src[0]).src) + srcs = [x.substitute(dict(zip(tne, [ne[i] for i in argsort(p)]))) for x,p in zip(srcs, tc.permutes_for_shape_str(tc.base_shape_str()))] + + # get reduce/upcast axes for the tensor cores + tc_reduce_axes = self.shape_str_to_axis([f"r{i}" for i in range(len(tc.get_reduce_axes()))]) + base_upcast_axes = tuple([(s,2) for s in self.shape_str_to_axis(tc.base_upcast_axes())]) + tc_upcast_axes = tuple([base_upcast_axes[:int(math.log2(tc.elements_per_thread[i]))] for i in range(3)]) + + # axes to range number (was done in lowerer) + tc_upcast_axes = tuple([tuple([(self.rngs[a].arg[0], sz) for a,sz in v]) for v in tc_upcast_axes]) + tc_reduce_axes = tuple([self.rngs[a].arg[0] for a in tc_reduce_axes]) + + # construct the op + # TODO: remove tc_upcast_axes from the arg + # do the reduce_axes always disappear? i think they don't + # they need to be moved into the WMMA srcs + wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device, tc.threads, tc_upcast_axes, ()) #, tc_reduce_axes) + wmma = UOp(Ops.WMMA, dtype=tc.dtype_out.vec(tc.elements_per_thread[2]), src=( + UOp(Ops.CONTRACT, dtype=srcs[0].dtype.vec(tc.elements_per_thread[0]), src=(srcs[0],), arg=tc_upcast_axes[0], tag=1), + UOp(Ops.CONTRACT, dtype=srcs[1].dtype.vec(tc.elements_per_thread[1]), src=(srcs[1],), arg=tc_upcast_axes[1], tag=1), + UOp.const(tc.dtype_out.vec(tc.elements_per_thread[2]), 0.0)), arg=wmma_arg, tag=1) + tc_uop = UOp(Ops.UNROLL, tc.dtype_out, (wmma,), arg=tc_upcast_axes[2], tag=1) + + # preserve extra reduces + reduce_ranges = [x for x in UOp.sink(*reduceop.src[1:]).toposort() if x.op is Ops.RANGE and x.arg[0] not in tc_reduce_axes] + if len(reduce_ranges): tc_uop = UOp(Ops.REDUCE, tc_uop.dtype, (tc_uop,)+tuple(reduce_ranges), Ops.ADD) + self.ast = self.ast.substitute({reduceop: tc_uop}) + return axes + return None + + # helpers for hand_coded_optimizations + @property + def reduceop(self) -> UOp|None: + red = [x for x in self.ast.parents if x.op is Ops.REDUCE] + if not len(red): return None + return UOp(Ops.REDUCE_AXIS, red[0].dtype, red[0].src, (red[0].arg, ())) + @property + def bufs(self) -> list[UOp]: return [x for x in self.ast.toposort() if x.op is Ops.INDEX][::-1] + @property + def output_shape(self): + return [s if at not in {AxisType.REDUCE, AxisType.UNROLL, AxisType.GROUP_REDUCE} else 1 for s,at in zip(self.full_shape, self.axis_types)] + @property + def upcasted(self) -> int: return len(self.axes_of(AxisType.UPCAST, AxisType.UNROLL)) + @property + def group_for_reduces(self) -> int: return len(self.axes_of(AxisType.GROUP_REDUCE)) + +def bufs_from_ast(ast:UOp, dname:str) -> list[Buffer]: + glbls = sorted([x for x in ast.parents if x.op is Ops.DEFINE_GLOBAL], key=lambda x: x.arg) + return [Buffer(dname, x.ptrdtype.size, x.dtype.base if not isinstance(x.dtype, ImageDType) else x.dtype) for x in glbls] + +def apply_opts(ctx:Renderer, ast:UOp): + if ast.tag is not None: return None + k = Scheduler(ast, ctx) + k.convert_loop_to_global() + if ast.arg is not None and ast.arg.opts_to_apply is not None: + for opt in ast.arg.opts_to_apply: k.apply_opt(opt) + elif BEAM >= 1: + from tinygrad.codegen.opt.search import beam_search + rawbufs = bufs_from_ast(ast, ctx.device) + k = beam_search(k, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1))) + elif not NOOPT and (ast.arg is None or ast.arg.applied_opts == ()): + from tinygrad.codegen.opt.heuristic import hand_coded_optimizations + # NOTE: hand_coded_optimizations doesn't support multiblock opts yet + if all(len(u.src) == 1 for u in ast.parents if u.op is Ops.LOAD): + k = hand_coded_optimizations(k) + return k.get_optimized_ast(name_override=ast.arg.name if ast.arg is not None and ast.arg.name != "test" else None) + +pm_postrange_opt = PatternMatcher([ + (UPat(Ops.SINK, name="ast"), apply_opts), +]) diff --git a/tinygrad_repo/tinygrad/codegen/opt/search.py b/tinygrad_repo/tinygrad/codegen/opt/search.py index 8db9c8c9..b5c3a1eb 100644 --- a/tinygrad_repo/tinygrad/codegen/opt/search.py +++ b/tinygrad_repo/tinygrad/codegen/opt/search.py @@ -1,28 +1,28 @@ -from typing import cast, Callable -import itertools, functools, random, math, time, multiprocessing, traceback, signal, atexit -from collections import defaultdict +from typing import cast +import functools, math, time, multiprocessing, traceback, signal, atexit from dataclasses import replace -from tinygrad.uop.ops import UOp, Ops, Variable, sym_infer, AxisType +from tinygrad.uop.ops import sym_infer, AxisType, pyrender from tinygrad.device import Device, Buffer, Compiler from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, time_to_str -from tinygrad.helpers import IGNORE_BEAM_CACHE, TC_SEARCH_OVER_SHAPE -from tinygrad.dtype import ImageDType, PtrDType -from tinygrad.codegen.opt.kernel import Kernel, Opt, OptOps, KernelOptError +from tinygrad.helpers import IGNORE_BEAM_CACHE +from tinygrad.codegen.opt import Opt, OptOps, KernelOptError from tinygrad.tensor import Tensor from tinygrad.engine.realize import CompiledRunner, get_program from tinygrad.renderer import ProgramSpec +from tinygrad.codegen.opt.postrange import Scheduler actions = [Opt(op=OptOps.UPCAST, axis=axis, arg=amt) for amt in [0,2,3,4,5,7] for axis in range(8)] actions += [Opt(op=OptOps.UNROLL, axis=axis, arg=amt) for amt in [0,4,7] for axis in range(5)] actions += [Opt(op=OptOps.LOCAL, axis=axis, arg=amt) for amt in [2,3,4,8,13,16,29] for axis in range(6)] actions += [Opt(op=OptOps.GROUPTOP, axis=axis, arg=amt) for amt in [13,16,28,29,32,49,64,256] for axis in range(3)] actions += [Opt(op=OptOps.GROUP, axis=axis, arg=amt) for amt in [0,4,8,16] for axis in range(3)] -if getenv("BEAM_PADTO", 1): actions += [Opt(op=OptOps.PADTO, axis=axis, arg=amt) for amt in [32] for axis in range(7)] +if getenv("BEAM_PADTO", 0): actions += [Opt(op=OptOps.PADTO, axis=axis, arg=amt) for amt in [32] for axis in range(7)] actions += [Opt(op=OptOps.LOCAL, axis=0, arg=32), Opt(op=OptOps.LOCAL, axis=6, arg=2)] actions += [Opt(op=OptOps.TC, axis=0, arg=(-1, 0, getenv("TC", 1)))] # covers resnet kernels (3 global * 3 reduce) actions += [Opt(op=OptOps.TC, axis=axis, arg=(-1, getenv("TC_OPT", 2), getenv("TC", 1))) for axis in range(9)] actions += [Opt(op=OptOps.SWAP, axis=axis_0, arg=axis_1) for axis_0 in range(5) for axis_1 in range(axis_0+1, 5)] +actions += [Opt(op=OptOps.THREAD, axis=axis, arg=amt) for amt in [2,3,4,5,8,12,16,24,32,64] for axis in range(3)] if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)] def get_test_global_size(global_size, max_global_size, var_vals): @@ -35,7 +35,7 @@ def get_test_global_size(global_size, max_global_size, var_vals): break return test_global_size, input_size / prod(test_global_size) -def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[Variable, int], rawbufs:list[Buffer], early_stop:float|None=None, +def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[str, int], rawbufs:list[Buffer], early_stop:float|None=None, allow_test_size:int=True, max_global_size:int|None=65536, clear_l2=False, cnt=3, name="test") -> list[float]: factor = 1 if allow_test_size and p.global_size is not None and max_global_size is not None: @@ -55,9 +55,11 @@ def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[Variable, int], rawbuf return tms class TimeoutException(Exception): pass -def timeout_handler(signum, frame): raise TimeoutException() +def timeout_handler(signum, frame): + if DEBUG >= 2: print("*** BEAM COMPILE TIMEOUT") + raise TimeoutException() -def _try_compile_linearized_w_idx(x:tuple[int,Kernel], compiler:Compiler) -> tuple[int, tuple[ProgramSpec, bytes, float]|None]: +def _try_compile_linearized_w_idx(x:tuple[int,Scheduler], compiler:Compiler) -> tuple[int, tuple[ProgramSpec, bytes, float]|None]: if hasattr(signal, "alarm"): signal.signal(getattr(signal, 'SIGALRM'), timeout_handler) # set timeout @@ -90,35 +92,11 @@ def _ensure_buffer_alloc(bufs:list[Buffer]) -> list[Buffer]: return [buf.ensure_ # *** external API *** -# get (scrap) buffers for timing the linearizer -def bufs_from_lin(lin:Kernel, allocate:bool=True) -> list[Buffer]: - bufsts: defaultdict[int, list[UOp]] = defaultdict(list) - for x in lin.bufs: - if x.src[0].base.op is Ops.DEFINE_GLOBAL: bufsts[x.src[0].base.arg].append(x) - # TODO: Nones are staying in here if buffers are optimized out! - # TODO: add a test for this - rawbufs: list[Buffer|None] = [None]*(max(bufsts)+1) - for k,lx in bufsts.items(): - buf_size = prod(dtype.shape) if isinstance(dtype:=lx[0].src[0].dtype, ImageDType) else max(y.st_arg.real_size() for y in lx) - assert isinstance(dtype, (PtrDType, ImageDType)) - if buf_size == 0: buf_size = 1 # create a size 1 buffer if no cell is accessed in kernel. # TODO: remove from kernel input in this case. - buf_dtype = dtype if isinstance(dtype, ImageDType) else dtype.base - rawbufs[k] = Buffer(lin.opts.device, buf_size, buf_dtype).allocate() if allocate else Buffer(lin.opts.device, buf_size, buf_dtype) - #assert all(r is not None for r in rawbufs) - return cast(list[Buffer], rawbufs) - # get dictionary of all possible actions -def get_kernel_actions(lin:Kernel, include_0=True, candidates:list[Opt]|None=None) -> dict[int, Kernel]: +def get_kernel_actions(lin:Scheduler, include_0=True, candidates:list[Opt]|None=None) -> dict[int, Scheduler]: acted_lins, max_up, max_lcl = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024) kernel_actions = (actions if candidates is None else candidates).copy() - if TC_SEARCH_OVER_SHAPE and len(lin.applied_opts) == 0: # tensor core opts must be first - for i, action in enumerate(kernel_actions): - if action.op == OptOps.TC and (tc_arg := cast(tuple, action.arg))[0] == -1: - # replace every tc_action with default tc with one tc_action for each available tc - kernel_actions[i:i+1] = \ - [Opt(op=OptOps.TC, axis=action.axis, arg=(tc_select, tc_arg[1], tc_arg[2])) for tc_select,_ in enumerate(lin.opts.tensor_cores)] - for i,a in enumerate(kernel_actions): if a.axis is not None and a.op is not OptOps.TC: try: ax = lin.real_axis(a.op, a.axis) @@ -127,10 +105,10 @@ def get_kernel_actions(lin:Kernel, include_0=True, candidates:list[Opt]|None=Non lin2 = lin.copy() try: lin2.apply_opt(a) - up, lcl, tc_up = 1, 1, prod(tc.dims)//tc.threads if (tc:=lin2.tensor_core) else 1 + up, lcl, tc_up = 1, 1, prod(tc.dims)//tc.threads if hasattr(lin2, 'tensor_core') and (tc:=lin2.tensor_core) else 1 for s,c in zip(lin2.full_shape, lin2.axis_types): if c in (AxisType.UPCAST, AxisType.UNROLL): up *= s - elif c in (AxisType.LOCAL, AxisType.GROUP_REDUCE): lcl *= s + elif c in (AxisType.WARP, AxisType.LOCAL, AxisType.GROUP_REDUCE): lcl *= s if up//tc_up > max_up or lcl > max_lcl: if getenv("BEAM_LOG_SURPASS_MAX"): print(f"too many upcast/local. {up//tc_up=}, {max_up=}, {lcl=}, {max_lcl=}") continue @@ -139,7 +117,7 @@ def get_kernel_actions(lin:Kernel, include_0=True, candidates:list[Opt]|None=Non return acted_lins beam_pool, BEAM_DEBUG = None, getenv("BEAM_DEBUG") -def beam_search(lin:Kernel, rawbufs:list[Buffer], amt:int, allow_test_size=True, disable_cache=IGNORE_BEAM_CACHE.value) -> Kernel: +def beam_search(lin:Scheduler, rawbufs:list[Buffer], amt:int, allow_test_size=True, disable_cache=IGNORE_BEAM_CACHE.value): global beam_pool key = {"ast": lin.ast.key, "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device, "suffix": lin.opts.suffix} if not disable_cache and CACHELEVEL >= 1 and (val:=diskcache_get("beam_search", key)) is not None: @@ -147,7 +125,7 @@ def beam_search(lin:Kernel, rawbufs:list[Buffer], amt:int, allow_test_size=True, for o in val[len(lin.applied_opts):]: ret.apply_opt(o) return ret - beam: list[tuple[Kernel, float]] = [(lin, float("inf"))] + beam: list[tuple[Scheduler, float]] = [(lin, float("inf"))] seen_libs = set() default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "AMD", "NV", "METAL", "HIP"} else 0 @@ -157,17 +135,19 @@ def beam_search(lin:Kernel, rawbufs:list[Buffer], amt:int, allow_test_size=True, def close_pool(): beam_pool.close() min_progress = getenv("BEAM_MIN_PROGRESS", 0.01)/1e6 - if BEAM_DEBUG: print(f"BEAM_SEARCH:\n{lin.ast}") + if BEAM_DEBUG: + print("BEAM_SEARCH:") + print('\n'.join(pyrender(lin.ast.replace(arg=None)))) if DEBUG >= 2: print(f" 0.00s: from 1 -> 1 actions {lin.colored_shape()}") try: rawbufs = _ensure_buffer_alloc(rawbufs) - var_vals: dict[Variable, int] = {k:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()} + var_vals: dict[str, int] = {k.expr:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()} exiting, st = False, time.perf_counter() dev = Device[lin.opts.device] while not exiting: - acted_lins: list[Kernel] = flatten([get_kernel_actions(lin, include_0=False).values() for lin,_ in beam]) - timed_lins: list[tuple[Kernel, float]] = [] + acted_lins: list[Scheduler] = flatten([get_kernel_actions(lin, include_0=False).values() for lin,_ in beam]) + timed_lins: list[tuple[Scheduler, float]] = [] _compile_fn = functools.partial(_try_compile_linearized_w_idx, compiler=dev.compiler) least_compute_ops = math.inf for i,proc in (map(_compile_fn, enumerate(acted_lins)) if beam_pool is None else beam_pool.imap_unordered(_compile_fn, enumerate(acted_lins))): @@ -201,15 +181,3 @@ def beam_search(lin:Kernel, rawbufs:list[Buffer], amt:int, allow_test_size=True, if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts) if BEAM_DEBUG: print(f"BEAM_SEARCH: final tm={time_to_str(beam[0][1], w=0)}, applied_opts={beam[0][0].applied_opts}") return beam[0][0] - -def optimize_local_size(_prg:Callable, global_size:list[int], rawbufs:list[Buffer]) -> list[int]: - test_rawbuffers = [Buffer(rawbufs[0].device, rawbufs[0].size, rawbufs[0].dtype).allocate(), *rawbufs[1:]] if rawbufs[0] in rawbufs[1:] else rawbufs - MAX_WORKGROUP = 1024 - local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in global_size] - local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice - def try_exec(local_size): - try: return _prg(*[x._buf for x in test_rawbuffers], global_size=[g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)], local_size=local_size, wait=True) # noqa: E501 - except Exception: return float('inf') - ret = min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))]) - assert not math.isinf(ret[0]), "all optimize_local_size exec failed" - return ret[1] diff --git a/tinygrad_repo/tinygrad/codegen/opt/swizzler.py b/tinygrad_repo/tinygrad/codegen/opt/swizzler.py index 4e4b6bce..75521b83 100644 --- a/tinygrad_repo/tinygrad/codegen/opt/swizzler.py +++ b/tinygrad_repo/tinygrad/codegen/opt/swizzler.py @@ -17,7 +17,7 @@ merge_views = PatternMatcher([ lambda view: view.const_like(0) if (mask:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in mask) else None), # only unmaksed VIEW on CONST replaces the ShapeTracker (UPat(Ops.VIEW, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),), name="view"), - lambda x,view: x.replace(src=(x.src[0].replace(arg=x.st+view.st),)) if all(v.mask is None for v in (x.st+view.st).views) else None), + lambda x,view: x.replace(src=(UOp(Ops.VIEW, x.dtype, x.src, view.arg),)) if all(v.mask is None for v in view.st.views) else None), ]) def reduce_push_add_ones(src:UOp, r:UOp, view:UOp): @@ -128,7 +128,8 @@ fix_kernel_ops = view_left_through_load+PatternMatcher([ (UPat(Ops.VIEW, src=(UPat.cvar(),), name="self"), lambda self: UOp.where(UOp(Ops.VALID, dtypes.bool, (UOp(Ops.VIEW, arg=self.st),)), self.const_like(self.base.arg), 0)), # no ImageDType after index - (UPat(GroupOp.All-{Ops.DEFINE_GLOBAL, Ops.VIEW}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None), + (UPat(GroupOp.All-{Ops.DEFINE_GLOBAL, Ops.VIEW, Ops.INDEX}, name="x"), + lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None), # if this kernel also assigns to the loaded buffer, ensure we can index it correctly (UPat(Ops.LOAD, src=(UPat.var("glbl").view(name="view"),)), check_load_st), ]) diff --git a/tinygrad_repo/tinygrad/codegen/simplify.py b/tinygrad_repo/tinygrad/codegen/simplify.py new file mode 100644 index 00000000..77c4d49b --- /dev/null +++ b/tinygrad_repo/tinygrad/codegen/simplify.py @@ -0,0 +1,120 @@ +from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, graph_rewrite, _substitute, range_start +from tinygrad.uop.symbolic import symbolic_flat, sym +from tinygrad.helpers import partition +from tinygrad.dtype import dtypes + +def flatten_range(r:UOp): + off = range_start[r.op] + rngs = r.src[off:] + if not len(rngs): return None + new_rngs = [x for x in UOp.sink(*rngs).toposort() if x.op is Ops.RANGE] + return r.replace(src=r.src[:off]+tuple(new_rngs)) + +pm_flatten_range = PatternMatcher([ + # real ranges only + (UPat((Ops.REDUCE, Ops.STORE), name="r"), flatten_range), +]) + +def count_divmod(x:UOp): return len([u for u in x.toposort() if u.op in {Ops.IDIV, Ops.MOD}]) +def simplify_merge_adjacent(u:UOp) -> UOp|None: + i = range_start[u.op] + while i < len(u.src)-1: + r0, r1 = u.src[i], u.src[i+1] + # check same type + if r0.arg[-1] == r1.arg[-1]: + s0, s1 = r0.src[0], r1.src[0] + # do the merge + new_range = r0.replace(src=(s0*s1,)) + nidx = graph_rewrite(u, _substitute+symbolic_flat+pm_flatten_range, ctx={r0:new_range//s1, r1:new_range%s1}, + name=f"check_merge_{r0.arg[0]}_{r1.arg[0]}") + # check if it simplifies + if count_divmod(nidx) <= count_divmod(u): + u = nidx + continue + i += 1 + return u + +pm_simplify_ranges = PatternMatcher([ + (UPat((Ops.STORE, Ops.REDUCE), name="u"), simplify_merge_adjacent), +]) + +# **** reduce simplification **** + +def reduce_rangeless(red:UOp): + # TODO: share code with reduce_unparented + if red.arg not in {Ops.ADD, Ops.MAX}: return None + if red.src[0].dtype != red.dtype: return None + if any(x.op in {Ops.RANGE} for x in red.src[0].toposort()): return None + ret = red.src[0] + if red.arg is Ops.ADD: + for r in red.src[1:]: + ret = ret * r.src[0].cast(ret.dtype.scalar()).broadcast(ret.dtype.count) + return ret + +def no_range(u:UOp) -> bool: return not any(x.op is Ops.RANGE for x in u.sparents) + +pm_reduce_collapse = PatternMatcher([ + # lift x+y out of reduce on lt + ((UPat.var("x")+UPat.var("y")).or_casted() < UPat.var("c"), lambda x,y,c: (x < (c.cast(y.dtype)-y)) if no_range(y) and no_range(c) else None), + # lift x*y out of reduce + ((UPat.var("x")*UPat.var("y")) < UPat.var("c"), + lambda x,y,c: (x < ((c+y-1) // y)) if no_range(y) and no_range(c) and y.vmin > 0 else None), + # lift x+y out of reduce on ne + ((UPat.var("x")+UPat.var("y")).or_casted() != UPat.var("c"), lambda x,y,c: (x != (c.cast(y.dtype)-y)) if no_range(y) and no_range(c) else None), + # fold the range + ((UPat(Ops.RANGE, name="r") < UPat.var("cut")).where(0, UPat.cvar("val")).reduce(arg=Ops.ADD, allow_any_len=True), + lambda r,cut,val: (r.src[0]-cut).maximum(0).minimum(r.src[0]).cast(val.dtype) * val), + ((UPat(Ops.RANGE, name="r") < UPat.var("cut")).where(UPat.cvar("val"), 0).reduce(arg=Ops.ADD, allow_any_len=True), + lambda r,cut,val: cut.maximum(0).minimum(r.src[0]).cast(val.dtype) * val), + # REDUCE on ADD + ((UPat.var("x")+UPat.var("y")).reduce(arg=Ops.ADD, allow_any_len=True, name="r"), + lambda x,y,r: x.reduce(*r.src[1:], arg=Ops.ADD) + y.reduce(*r.src[1:],arg=Ops.ADD)), + # MUL casted bool + ((UPat.var("x") * UPat.var("gate", dtype=dtypes.bool).cast().or_broadcasted(name="b")), + lambda x,gate,b=None: gate.broadcast(x.dtype.count).where(x, 0) if b is not None else gate.where(x, 0)), + # WHERE on LOAD (works on max too) + (UPat.var("gate").where(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load(), 0).reduce(arg=Ops.ADD, allow_any_len=True), + lambda buf,idx,gate: buf.index(idx, gate).load()), + (UPat.var("gate").where(0, UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load()).reduce(arg=Ops.ADD, allow_any_len=True), + lambda buf,idx,gate: buf.index(idx, gate.logical_not()).load()), + # INDEX on RANGE / gated RANGE + (UPat.var("buf").index(UPat.var("expr"), UPat.var("idx").eq(UPat(Ops.RANGE, name="r").or_casted())), + lambda buf,r,idx,expr: buf.index(expr.substitute({r:idx.cast(r.dtype)}), (idx.cast(r.dtype) >= 0) & (idx.cast(r.dtype) < r.src[0]))), + # AND on WHERE + ((UPat.any(UPat(Ops.DEFINE_VAR, name="x"), UPat(Ops.DEFINE_VAR).gep(name="x")) & UPat.var("y")) \ + .where(UPat.cvar("c"), 0).reduce(arg=Ops.ADD, allow_any_len=True, name="r"), + lambda x,y,c,r: y.where(c, 0).reduce(*r.src[1:], arg=Ops.ADD)*x.cast(c.dtype)), + # remove REDUCEs that no longer have a RANGE in the src + (UPat(Ops.REDUCE, name="red"), reduce_rangeless), +])+sym + +def reduce_collapse(red:UOp): + included, not_included = partition(red.parents, lambda x: any(y in x.sparents for y in red.src[1:])) + if any(x.op in {Ops.STORE, Ops.REDUCE} for x in included): return None + replaces: dict[UOp, UOp] = {} + for u in included: + for s in u.src: + if s in not_included and s not in replaces and s.op not in {Ops.CONST, Ops.VCONST, Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR}: + replaces[s] = UOp(Ops.DEFINE_VAR, dtype=s.dtype, arg=(f'in{len(replaces)}', s.vmin, s.vmax)) + collapse_fxn = red.substitute(replaces) + sink = graph_rewrite(collapse_fxn, pm_reduce_collapse, name="reduce_collapse") + if any(x.op is Ops.RANGE for x in sink.toposort()): return None + return sink.substitute({v:k for k,v in replaces.items()}) + +def reduce_unparented(red:UOp): + if red.arg not in {Ops.ADD, Ops.MAX, Ops.MUL}: return None + reduce_parented, reduce_unparented = partition(red.src[1:], lambda x: x in red.src[0].sparents) + if len(reduce_unparented) == 0: return None + ret = red.replace(src=(red.src[0],)+tuple(reduce_parented)) if len(reduce_parented) or red.dtype != red.src[0].dtype else red.src[0] + if red.arg is Ops.ADD: + for r in reduce_unparented: ret = ret * r.src[0].cast(ret.dtype.scalar()).broadcast(ret.dtype.count) + if red.arg is Ops.MUL: + for r in reduce_unparented: ret = ret ** r.src[0].cast(ret.dtype.scalar()).broadcast(ret.dtype.count) + return ret + +pm_reduce_simplify = PatternMatcher([ + # remove any ranges from a REDUCE that aren't referenced in the reduce source + (UPat(Ops.REDUCE, name="red"), reduce_unparented), + # remove REDUCE without loads (generic arange opt / indexing). TODO: support multi range + (UPat(Ops.REDUCE, src=(UPat(), UPat()), name="red"), reduce_collapse), +]) diff --git a/tinygrad_repo/tinygrad/device.py b/tinygrad_repo/tinygrad/device.py index c7e5cabd..64fb3bb0 100644 --- a/tinygrad_repo/tinygrad/device.py +++ b/tinygrad_repo/tinygrad/device.py @@ -1,16 +1,17 @@ from __future__ import annotations from dataclasses import dataclass, replace from collections import defaultdict -from typing import Any, Generic, TypeVar, Iterator -import importlib, inspect, functools, pathlib, os, platform, contextlib, sys, re, atexit, pickle, decimal, time -from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, colored, \ - Context, DISABLE_COMPILER_CACHE, ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, cpu_events, ProfileEvent, ProfilePointEvent, dedup +from typing import Any, Generic, TypeVar, Iterator, Sequence, cast +import importlib, inspect, functools, pathlib, os, platform, contextlib, sys, re, atexit, pickle, decimal +from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, colored, CPU_LLVM +from tinygrad.helpers import Context, DISABLE_COMPILER_CACHE, ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, cpu_events, ProfileEvent, ProfilePointEvent, dedup +from tinygrad.helpers import unwrap_class_type from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype from tinygrad.renderer import Renderer # **************** Device **************** -ALL_DEVICES = ["METAL", "AMD", "NV", "CUDA", "QCOM", "GPU", "CPU", "LLVM", "DSP", "WEBGPU"] +ALL_DEVICES = ["METAL", "AMD", "NV", "CUDA", "QCOM", "CL", "CPU", "DSP", "WEBGPU"] class _Device: def __init__(self) -> None: self._devices = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")] @@ -138,15 +139,14 @@ class Buffer: if not self.device.startswith("DISK"): GlobalCounters.mem_used += self.nbytes if PROFILE: self._prof_num = num = len(Buffer.profile_events) - ts = decimal.Decimal(time.perf_counter_ns())/1000 - Buffer.profile_events.append(ProfilePointEvent(self.device, "alloc", ts, num, {"dtype":self.dtype, "sz":self.size})) + Buffer.profile_events.append(ProfilePointEvent(self.device, "alloc", num, {"dtype":self.dtype, "sz":self.size})) return self def deallocate(self): assert hasattr(self, '_buf'), "buffer must be allocated to deallocate" if DEBUG is not None and DEBUG >= 7: print(f"buffer: deallocate {self.nbytes} bytes on {self.device}") if self._base is None and (self.options is None or self.options.external_ptr is None): if GlobalCounters is not None and not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes - if PROFILE: Buffer.profile_events.append(ProfilePointEvent(self.device, "free", decimal.Decimal(time.perf_counter_ns())/1000, self._prof_num)) + if PROFILE: Buffer.profile_events.append(ProfilePointEvent(self.device, "free", self._prof_num)) self.allocator.free(self._buf, self.nbytes, self.options) elif self._base is not None: self._base.allocated_views -= 1 del self._buf @@ -273,12 +273,34 @@ class Compiler: return lib def disassemble(self, lib:bytes): pass +CompilerPairT = tuple[functools.partial|type[Renderer], functools.partial|type[Compiler]] class Compiled: profile_events:list[ProfileEvent] = [ProfileDeviceEvent("CPU")] # NOTE: CPU is the default device. - def __init__(self, device:str, allocator:Allocator, renderer:Renderer|None, compiler:Compiler|None, runtime, graph=None, group_id=None): - self.device, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler or Compiler(), runtime, graph - self.renderer, self.group_id = renderer or Renderer(), group_id + def __init__(self, device:str, allocator:Allocator, compilers:Sequence[CompilerPairT]|None, runtime, graph=None, group_id=None): + self.device, self.allocator, self.runtime, self.graph, self.group_id = device, allocator, runtime, graph, group_id + self.compilers = cast(list[CompilerPairT], compilers or [(Renderer, Compiler)]) + + envnames = [self._get_compiler_envvar(c) for r,c in self.compilers] + enable_comps = set((en, comp_pair) for en, comp_pair in zip(envnames, self.compilers) if en is not None and getenv(en, -1) == 1) + disable_comps = set((en, comp_pair) for en, comp_pair in zip(envnames, self.compilers) if en is not None and getenv(en, -1) == 0) + + if len(enable_comps) > 1: raise RuntimeError(f"{self.device}: multiple compilers set in env {enable_comps}") + for _, comp_pair in disable_comps: self.compilers.remove(comp_pair) + + try: self.renderer, self.compiler = next(self._get_available_compilers([list(enable_comps)[0][1]] if len(enable_comps) == 1 else self.compilers)) + except StopIteration as exc: raise RuntimeError(f"no usable compilers for {self.device}") from exc + + if DEBUG >= 1: print(f"{self.device}: using {self.compiler.__class__.__name__}") + + def _get_compiler_envvar(self, c): + compiler_name = f"{unwrap_class_type(c).__name__.upper().removesuffix('COMPILER').removeprefix(devname:=self.device.split(':')[0].upper())}" + return f"{devname}_{compiler_name if len(compiler_name) > 0 else unwrap_class_type(c).__name__.upper()}" + + def _get_available_compilers(self, compilers) -> Iterator[tuple[Renderer, Compiler]]: + for renderer, compiler in compilers: + with contextlib.suppress(Exception): yield renderer(), compiler() + def synchronize(self): """ Synchronize all pending operations on the device. @@ -299,15 +321,14 @@ class Compiled: # TODO: move this to each Device def is_dtype_supported(dtype:DType, device:str|None=None) -> bool: + if dtype == dtypes.index: return False if device is None: device = Device.DEFAULT if dtype == dtypes.bfloat16: if device == "METAL": return not CI - if device in {"CUDA", "NV"}: return not CI and not getenv("PTX") - if device in {"CPU", "LLVM"}: return not CI and platform.machine() in {"arm", "arm64", "aarch64", "x86_64", "amd64"} - return device == "AMD" - if dtype in dtypes.fp8s: - # not supported yet - in progress - return False + if device in {"CUDA", "NV"}: return not CI and not getenv(f"{device}_PTX") + if device in {"CPU"}: return not CI and platform.machine() in {"arm", "arm64", "aarch64", "x86_64", "amd64"} + return device in {"AMD", "PYTHON"} + if dtype in dtypes.fp8s: return device == "PYTHON" if device == "WEBGPU": return dtype in [dtypes.bool, dtypes.char, dtypes.uchar, dtypes.short, dtypes.ushort, dtypes.float, dtypes.int32, dtypes.uint32, dtypes.half] # for CI GPU and OSX, cl_khr_fp16 isn't supported @@ -315,11 +336,11 @@ def is_dtype_supported(dtype:DType, device:str|None=None) -> bool: # CI CUDA architecture is sm_35 but we need at least sm_70 to run fp16 ALUs # PYTHON supports half memoryview in 3.12+ https://github.com/python/cpython/issues/90751 if dtype == dtypes.half: - if device == "GPU": return not CI and not OSX + if device == "CL": return not CI and not OSX if device in ["CUDA", "NV"]: return not CI - if device == "LLVM": return OSX + if device == "CPU" and CPU_LLVM: return OSX if device == "PYTHON": return sys.version_info >= (3, 12) - if dtype == dtypes.float64: return device != "METAL" and not (OSX and device == "GPU") + if dtype == dtypes.float64: return device != "METAL" and not (OSX and device == "CL") return True if PROFILE: @@ -331,21 +352,26 @@ if PROFILE: with open(fn:=temp("profile.pkl", append_user=True), "wb") as f: pickle.dump(cpu_events+Compiled.profile_events+Buffer.profile_events, f) - if not getenv("SQTT", 0): - from tinygrad.uop.ops import launch_viz - launch_viz(PROFILE, fn) + from tinygrad.uop.ops import launch_viz + launch_viz("PROFILE", fn) if __name__ == "__main__": + from tinygrad import Tensor, Device + for device in ALL_DEVICES: + compilers_results, any_works = [], False try: - _ = Device[device].device - try: - from tinygrad import Tensor - with Context(CACHELEVEL=0): test = (Tensor([1,2,3], device=device) * 2).tolist() - if test != [2,4,6]: raise ValueError(f"got {test} instead of [2, 4, 6]") - result = colored("PASS", "green") - except Exception as e: - result = f"{colored('FAIL', 'yellow')} {e}" + default_compiler = (d:=Device[device]).compiler + for i,(r,c) in enumerate(d.compilers): + try: + d.renderer, d.compiler = r(), c() + with Context(CACHELEVEL=0): test = (Tensor([1,2,3], device=device) * 2).tolist() + if test != [2,4,6]: raise ValueError(f"got {test} instead of [2, 4, 6]") + default_text = '(default)' if type(default_compiler) is type(d.compiler) else f'({d._get_compiler_envvar(c)}=1 to make default)' + compilers_results.append(f"{colored('+', 'green')} {unwrap_class_type(c).__name__} {default_text}") + any_works = True + except Exception as e: compilers_results.append(f"{colored('-', 'yellow')} {unwrap_class_type(c).__name__}: {e}") + result = (colored('PASS', 'green') if any_works else f"{colored('FAIL', 'yellow')}") + ''.join([f'\n{" "*16} {x}' for x in compilers_results]) except Exception as e: result = f"{colored('FAIL', 'red')} {e}" print(f"{'*' if device == Device.DEFAULT else ' '} {device:10s}: {result}") diff --git a/tinygrad_repo/tinygrad/dtype.py b/tinygrad_repo/tinygrad/dtype.py index 2fb65176..11373bb3 100644 --- a/tinygrad_repo/tinygrad/dtype.py +++ b/tinygrad_repo/tinygrad/dtype.py @@ -5,6 +5,21 @@ from dataclasses import dataclass, fields from tinygrad.helpers import getenv, prod from enum import Enum, auto +class InvalidTypeMetaClass(type): + instance:None|InvalidType = None + def __call__(cls, *args, **kwargs): + if (ret:=InvalidTypeMetaClass.instance) is not None: return ret + InvalidTypeMetaClass.instance = ret = super().__call__() + return ret + +class InvalidType(metaclass=InvalidTypeMetaClass): + def __eq__(self, other): return self is other + def __hash__(self): return id(self) + def __repr__(self): return "Invalid" + def __reduce__(self): return (InvalidType, ()) # Return the global Invalid instance + +Invalid = InvalidType() + ConstType = float|int|bool FmtStr = Literal['?', 'b', 'B', 'h', 'H', 'i', 'I', 'q', 'Q', 'e', 'f', 'd'] @@ -17,7 +32,9 @@ class DTypeMetaClass(type): DTypeMetaClass.dcache[args] = ret = super().__call__(*args) return ret -class AddrSpace(Enum): GLOBAL = auto(); LOCAL = auto(); REG = auto() # noqa: E702 +class AddrSpace(Enum): + def __repr__(self): return str(self) + GLOBAL = auto(); LOCAL = auto(); REG = auto() # noqa: E702 @dataclass(frozen=True, eq=False) class DType(metaclass=DTypeMetaClass): @@ -67,7 +84,7 @@ class PtrDType(DType): return type(self)(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.addrspace, sz, self.size) def ptr(self, size=-1, addrspace=AddrSpace.GLOBAL): raise RuntimeError("can't make a pointer from a pointer") def nbytes(self) -> int: - if self.size == -1: return 0 # TODO: this should be an exception + if self.size == -1: raise RuntimeError("can't get nbytes of a pointer with unlimited size") return self.size*self.itemsize @property def vcount(self): return self.v @@ -89,7 +106,7 @@ class dtypes: def is_float(x: DType) -> bool: return x.scalar() in dtypes.floats or isinstance(x, ImageDType) @staticmethod # static methods on top, or bool in the type info will refer to dtypes.bool @functools.cache - def is_int(x: DType) -> bool: return x.scalar() in dtypes.ints + def is_int(x: DType) -> bool: return x.scalar() in dtypes.ints + (dtypes.index,) @staticmethod @functools.cache def is_unsigned(x: DType) -> bool: return x.scalar() in dtypes.uints @@ -104,21 +121,21 @@ class dtypes: if x.__class__ is list or x.__class__ is tuple: return max(dtypes.from_py(xi) for xi in x) if x else dtypes.default_float raise RuntimeError(f"Could not infer dtype of {x} with type {type(x)}") @staticmethod - def as_const(val: tuple[ConstType, ...]|ConstType, dtype:DType): + def as_const(val: tuple[ConstType|InvalidType, ...]|ConstType|InvalidType, dtype:DType): if isinstance(val, tuple): assert len(val) == dtype.count, f"mismatch {val} {dtype}" return tuple(dtypes.as_const(x, dtype) for x in val) - # TODO: should truncate here + if isinstance(val, InvalidType): return val return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val) @staticmethod @functools.cache def min(dtype:DType): - if dtypes.is_int(dtype): return 0 if dtypes.is_unsigned(dtype) else -2**(dtype.itemsize*8-1) + if dtypes.is_int(dtype): return 0 if dtypes.is_unsigned(dtype) else -2**(dtype.scalar().itemsize*8-1) return -float("inf") if dtypes.is_float(dtype) else False @staticmethod @functools.cache def max(dtype:DType): - if dtypes.is_int(dtype): return 2**(dtype.itemsize*8)-1+dtypes.min(dtype) + if dtypes.is_int(dtype): return 2**(dtype.scalar().itemsize*8)-1+dtypes.min(dtype) return float("inf") if dtypes.is_float(dtype) else True @staticmethod def finfo(dtype:DType) -> tuple[int, int]: @@ -129,6 +146,7 @@ class dtypes: @staticmethod def fields() -> dict[str, DType]: return DTYPES_DICT void: Final[DType] = DType.new(-1, 0, "void", None) + index: Final[DType] = DType.new(-1,100, "index", None) bool: Final[DType] = DType.new(0, 1, "bool", '?') int8: Final[DType] = DType.new(1, 1, "signed char", 'b') uint8: Final[DType] = DType.new(2, 1, "unsigned char", 'B') @@ -165,7 +183,7 @@ class dtypes: uints = (uint8, uint16, uint32, uint64) sints = (int8, int16, int32, int64) ints = uints + sints - all = floats + ints + (bool,) + all = floats + ints + (bool, index) if (env_default_float := getenv("DEFAULT_FLOAT", "")): dtypes.default_float = getattr(dtypes, env_default_float.lower()) @@ -177,8 +195,8 @@ def to_dtype(dtype:DTypeLike) -> DType: return dtype if isinstance(dtype, DType) # https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html # we don't support weak type and complex type promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8], dtypes.int8: [dtypes.int16], dtypes.int16: [dtypes.int32], dtypes.int32: [dtypes.int64], - dtypes.int64: [dtypes.float16, dtypes.bfloat16], dtypes.uint8: [dtypes.int16, dtypes.uint16], dtypes.uint16: [dtypes.int32, dtypes.uint32], - dtypes.uint32: [dtypes.int64, dtypes.uint64], dtypes.uint64: [dtypes.float16, dtypes.bfloat16], + dtypes.int64: [dtypes.fp8e4m3, dtypes.fp8e5m2], dtypes.uint8: [dtypes.int16, dtypes.uint16], dtypes.uint16: [dtypes.int32, dtypes.uint32], + dtypes.uint32: [dtypes.int64, dtypes.uint64], dtypes.uint64: [dtypes.fp8e4m3, dtypes.fp8e5m2], dtypes.fp8e5m2: [dtypes.float16, dtypes.bfloat16], dtypes.fp8e4m3: [dtypes.float16, dtypes.bfloat16], dtypes.float16: [dtypes.float32], dtypes.bfloat16: [dtypes.float32], dtypes.float32: [dtypes.float64], } @@ -187,11 +205,12 @@ def _get_recursive_parents(dtype:DType) -> set[DType]: return set.union(*[_get_recursive_parents(d) for d in promo_lattice[dtype]], {dtype}) if dtype != dtypes.float64 else {dtypes.float64} @functools.cache def least_upper_dtype(*ds:DType) -> DType: - return min(set.intersection(*[_get_recursive_parents(d) for d in ds])) if not (images:=[d for d in ds if isinstance(d, ImageDType)]) else images[0] + return min(set.intersection(*[_get_recursive_parents(d.scalar()) for d in ds])) \ + if not (images:=[d for d in ds if isinstance(d, ImageDType)]) else images[0] def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.default_float) -DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if isinstance(v, DType) and not k.startswith(("default", "void"))} -INVERSE_DTYPES_DICT = {**{v.name:k for k,v in DTYPES_DICT.items()}, "void": "void"} +DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if isinstance(v, DType) and not k.startswith(("default", "void", "index"))} +INVERSE_DTYPES_DICT = {**{v.name:k for k,v in DTYPES_DICT.items()}, "void": "void", "index":"index"} @functools.cache def can_safe_cast(dt0:DType, dt1:DType) -> bool: @@ -199,8 +218,10 @@ def can_safe_cast(dt0:DType, dt1:DType) -> bool: # https://numpy.org/doc/stable/reference/generated/numpy.can_cast.html if dt0 == dt1 or dt0 == dtypes.bool: return True match dt1: - case dtypes.double: return dt0 in (dtypes.float, dtypes.half, dtypes.bfloat16) - case dtypes.float: return dt0 in (dtypes.half, dtypes.bfloat16) + case dtypes.index: return dt0 in dtypes.ints + case dtypes.double: return dt0 in (dtypes.float, dtypes.half, dtypes.bfloat16, + dtypes.uint32, dtypes.uint16, dtypes.uint8, dtypes.int32, dtypes.int16, dtypes.int8) + case dtypes.float: return dt0 in (dtypes.half, dtypes.bfloat16, dtypes.uint16, dtypes.uint8, dtypes.int16, dtypes.int8) case dtypes.uint64: return dt0 in (dtypes.uint32, dtypes.uint16, dtypes.uint8) case dtypes.uint32: return dt0 in (dtypes.uint16, dtypes.uint8) case dtypes.int64: return dt0 in (dtypes.uint32, dtypes.uint16, dtypes.uint8, dtypes.int32, dtypes.int16, dtypes.int8) @@ -214,20 +235,23 @@ def sum_acc_dtype(dt:DType): if dtypes.is_int(dt) or dt == dtypes.bool: return least_upper_dtype(dt, dtypes.int) return least_upper_dtype(dt, to_dtype(getenv("SUM_DTYPE", "float32"))) -def truncate_fp16(x): - try: return struct.unpack("@e", struct.pack("@e", float(x)))[0] +def float_to_fp16(x): + try: return struct.unpack('e', struct.pack('e', float(x)))[0] except OverflowError: return math.copysign(math.inf, x) -def truncate_bf16(x): - max_bf16 = struct.unpack('f', struct.pack('I', 0x7f7f0000))[0] - if abs(x) > max_bf16: return math.copysign(math.inf, x) - f32_int = struct.unpack('I', struct.pack('f', x))[0] - bf = struct.unpack('f', struct.pack('I', f32_int & 0xFFFF0000))[0] - return bf +def float_to_bf16(x): + if not math.isfinite(x): return x + u = struct.unpack('I', struct.pack('f', x))[0] + u = (u + 0x7FFF + ((u >> 16) & 1)) & 0xFFFF0000 + return struct.unpack('f', struct.pack('I', u))[0] # fp8-float conversions based on https://gitlab.com/nvidia/headers/cuda-individual/cudart/-/blob/main/cuda_fp8.hpp def float_to_fp8(x: float, dtype: DType) -> int: assert dtype in dtypes.fp8s, "Only for fp8s" + # e4m3 don't support inf, return 0x7f(+NaN) and 0xff(-NaN) to match jax + # NaN is unordered, can't compare with zero, use math.copysign to get sign + if dtype == dtypes.fp8e4m3 and not math.isfinite(x): return 0x7f if math.copysign(1, x) > 0 else 0xff + if dtype == dtypes.fp8e5m2 and math.isinf(x): return 0x7c if math.copysign(1, x) > 0 else 0xfc config = { dtypes.fp8e4m3: {"EXP_BIAS": 7, "SIGNIFICAND_BITS": 4, "MANTISSA_MASK": 0x7, "MINDENORM_O2": 0x3F50000000000000, "OVERFLOW_THRESHOLD": 0x407D000000000000, "MAXNORM": 0x7E, "MINNORM": 0x3F90000000000000, "INF_VALUE": 0x7F}, @@ -288,7 +312,7 @@ def fp8_to_float(x: int, dtype: DType) -> float: return float(float32_val) truncate: dict[DType, Callable] = {dtypes.bool: bool, - dtypes.float16: truncate_fp16, dtypes.bfloat16: truncate_bf16, + dtypes.float16: float_to_fp16, dtypes.bfloat16: lambda x: float_to_bf16(float(x)), **{fp8: (lambda x, dtype=fp8: fp8_to_float(float_to_fp8(x, dtype), dtype)) for fp8 in dtypes.fp8s}, dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value, dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value, @@ -300,6 +324,7 @@ truncate: dict[DType, Callable] = {dtypes.bool: bool, def _to_np_dtype(dtype:DType) -> type|None: import numpy as np + if dtype in { dtypes.bfloat16, *dtypes.fp8s }: return np.float32 return np.dtype(dtype.fmt).type if dtype.fmt is not None else None def _from_np_dtype(npdtype:'np.dtype') -> DType: # type: ignore [name-defined] # noqa: F821 import numpy as np @@ -308,9 +333,12 @@ def _from_np_dtype(npdtype:'np.dtype') -> DType: # type: ignore [name-defined] # @functools.cache def _to_torch_dtype(dtype:DType) -> 'torch.dtype'|None: # type: ignore [name-defined] # noqa: F821 import numpy as np, torch + if dtype == dtypes.uint64: return torch.uint64 + if dtype == dtypes.bfloat16: return torch.bfloat16 + if dtype in dtypes.fp8s: return torch.uint8 # NOTE: torch doesn't expose this mapping with a stable API try: return torch.from_numpy(np.array([], dtype=_to_np_dtype(dtype))).dtype except TypeError: return None @functools.cache def _from_torch_dtype(torchdtype:'torch.dtype') -> DType: # type: ignore [name-defined] # noqa: F821 - return {v:k for k in dtypes.all if (v:=_to_torch_dtype(k)) is not None}[torchdtype] \ No newline at end of file + return {v:k for k in DTYPES_DICT.values() if (v:=_to_torch_dtype(k)) is not None}[torchdtype] diff --git a/tinygrad_repo/tinygrad/engine/jit.py b/tinygrad_repo/tinygrad/engine/jit.py index 2254c606..6cc2612d 100644 --- a/tinygrad_repo/tinygrad/engine/jit.py +++ b/tinygrad_repo/tinygrad/engine/jit.py @@ -16,7 +16,7 @@ class GraphException(Exception): pass def graph_class(dev): return dev.graph.func if isinstance(dev.graph, functools.partial) else dev.graph -def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], max_batch_size=0) -> list[ExecItem]: +def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int], max_batch_size=0) -> list[ExecItem]: # Split JIT cache into batches for faster graph execution. # This allows the accelerator to run some batches while subsequent graphs are still being updated. graphed_jit_cache: list[ExecItem] = [] @@ -44,7 +44,7 @@ def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer] match ji.prg: case CompiledRunner(): ji_graph_dev = ji.prg.dev case BufferXfer(): ji_graph_dev = Device[unwrap(ji.bufs[0]).device] - case BufferCopy(): ji_graph_dev = next((Device[unwrap(b).device] for b in ji.bufs if unwrap(b).device not in {"CPU", "LLVM"}), None) + case BufferCopy(): ji_graph_dev = next((Device[unwrap(b).device] for b in ji.bufs if unwrap(b).device != "CPU"), None) case ViewOp(): continue # ViewOps are just ignored case _: ji_graph_dev = None # Everything else is not graphed and flushes existing graph if it's being constructed @@ -73,7 +73,7 @@ def get_input_replace(jit_cache: list[ExecItem], input_rawbuffers:list[Buffer]) return input_replace class GraphRunner(Runner): - def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]): + def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int]): self.jit_cache = jit_cache # NOTE: this is not used, but you have to keep these objects alive for the Graph self.input_replace:dict[tuple[int, int], int] = get_input_replace(jit_cache, input_rawbuffers) self.var_vals_replace:dict[int, list[tuple[int, int]]] = {} @@ -82,7 +82,7 @@ class GraphRunner(Runner): def is_sym_dim(dim) -> bool: return not all(isinstance(d, (int, float)) for d in dim) - self.vars = sorted(var_vals.keys(), key=lambda v: v.expr) + self.vars = sorted(var_vals.keys()) self.symbolic_dims = dedup([tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.local_size) and is_sym_dim(d)] + [tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.global_size) and is_sym_dim(d)]) def find_symbolic_dim(dim): return self.symbolic_dims.index(tuple(dim)) if dim is not None and tuple(dim) in self.symbolic_dims else None @@ -91,7 +91,7 @@ class GraphRunner(Runner): for j,ji in enumerate(jit_cache): estimates += ji.prg.estimates if isinstance(ji.prg, CompiledRunner): - if ji.prg.p.vars: self.var_vals_replace[j] = [(i, self.vars.index(v)) for i, v in enumerate(ji.prg.p.vars) if v not in ji.fixedvars] + if ji.prg.p.vars: self.var_vals_replace[j] = [(i, self.vars.index(v.expr)) for i, v in enumerate(ji.prg.p.vars) if v.expr not in ji.fixedvars] global_dim_idx, local_dim_idx = find_symbolic_dim(ji.prg.p.global_size), find_symbolic_dim(ji.prg.p.local_size) if global_dim_idx is not None or local_dim_idx is not None: @@ -105,12 +105,12 @@ class GraphRunner(Runner): super().__init__(colored(f"", "cyan"), jit_cache[0].prg.device.split(":")[0], estimates.simplify()) - def updated_vars(self, var_vals: dict[Variable, int]): + def updated_vars(self, var_vals: dict[str, int]): vals = [var_vals[v] for v in self.vars] for j, vidxs in self.var_vals_replace.items(): for i, v in vidxs: yield j, i, vals[v] - def updated_launch_dims(self, var_vals: dict[Variable, int]): + def updated_launch_dims(self, var_vals: dict[str, int]): dims = [tuple(sym_infer(s, var_vals) for s in dim) for dim in self.symbolic_dims] for j, (gl, lc) in self.launch_dims_replace.items(): yield j, (dims[gl] if gl is not None else self.launch_dims_base[j][0]), (dims[lc] if lc is not None else self.launch_dims_base[j][1]) @@ -192,7 +192,7 @@ class CapturedJit(Generic[ReturnType]): self.__post_init__() # jit exec - def __call__(self, input_buffers:list[Buffer], var_vals:dict[Variable, int]) -> ReturnType: + def __call__(self, input_buffers:list[Buffer], var_vals:dict[str, int]) -> ReturnType: # assign inputs for idx, offset, device, size, dtype in self.extra_view_inputs: input_buffers.append(Buffer(device, size, dtype, base=input_buffers[idx], offset=offset).ensure_allocated()) @@ -225,7 +225,8 @@ def _prepare_jit_inputs(args, kwargs): for lb in lbs if lb.base.realized is not None]) assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT" st_varval_dtype_device = [(*unwrap(lb.st).unbind(), lb.dtype, lb.device) for lb in lbs] - var_vals = merge_dicts([x[1] for x in st_varval_dtype_device] + [dict(v.unbind() for v in (args + tuple(kwargs.values())) if isinstance(v, UOp))]) + _var_vals = merge_dicts([x[1] for x in st_varval_dtype_device] + [dict(v.unbind() for v in (args + tuple(kwargs.values())) if isinstance(v, UOp))]) + var_vals = {k.expr:v for k,v in _var_vals.items()} st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varval_dtype_device] return input_buffers, var_vals, names, st_vars_dtype_device @@ -317,7 +318,7 @@ class TinyJit(Generic[ReturnType]): # memory planning (optional) # Exclude buffers involved in transfer ops to preserve parallelism. - noopt_buffers = {b for ji in jit_cache if isinstance(ji.prg, BufferXfer) for b in ji.bufs} + noopt_buffers = {b for ji in jit_cache if isinstance(ji.prg, (BufferXfer, BufferCopy)) for b in ji.bufs} assigned = _internal_memory_planner([cast(list[Buffer], item.bufs) for item in jit_cache], noopt_buffers, debug_prefix="JIT ") jit_cache = [ExecItem(item.prg, [assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None], item.metadata, item.fixedvars) for item in jit_cache] diff --git a/tinygrad_repo/tinygrad/engine/memory.py b/tinygrad_repo/tinygrad/engine/memory.py index 2c62b3f6..36a4e3b0 100644 --- a/tinygrad_repo/tinygrad/engine/memory.py +++ b/tinygrad_repo/tinygrad/engine/memory.py @@ -23,12 +23,13 @@ def _internal_memory_planner(buffers:list[list[Buffer]], noopt_buffers=None, ign # Sort buffer operations in timeline order. Two events: buffer is allocated or buffer is freed. buffer_requests = sorted([((first_appearance[buf], True), buf) for buf in first_appearance.keys()] + \ [((last_appearance[buf] + 1, False), buf) for buf in first_appearance.keys()], key=lambda x: x[0]) + total_memory = sum(round_up(buf.nbytes, min_block_size:=0x1000) for buf in first_appearance.keys()) * 2 # *2 for fragmentation (which is about 15%) # Try to suballocate from a shared buffer managed by global_planner using TLSFAllocator. # Also track buffer replacements for buffers that do not support suballocation. buffer_replace:dict[Buffer, tuple[Buffer|None, int|None]] = {} reuse_buffers:dict[tuple, list[Buffer]] = defaultdict(list) - global_planner:dict[str, tuple[int, TLSFAllocator]] = defaultdict(lambda: (0, TLSFAllocator(1 << 44, block_size=0x1000, lv2_cnt=32))) + global_planner:dict[str, tuple[int, TLSFAllocator]] = defaultdict(lambda: (0, TLSFAllocator(total_memory, block_size=min_block_size, lv2_cnt=32))) for (_, is_open_ev), buf in buffer_requests: # Check if suballocation is possible for the given buffer and device. if hasattr(Device[buf.device].allocator, "_offset") and not isinstance(buf.dtype, ImageDType): diff --git a/tinygrad_repo/tinygrad/engine/realize.py b/tinygrad_repo/tinygrad/engine/realize.py index 8b100008..50474a62 100644 --- a/tinygrad_repo/tinygrad/engine/realize.py +++ b/tinygrad_repo/tinygrad/engine/realize.py @@ -1,14 +1,14 @@ -from typing import cast, Generator -import time, pprint, decimal +from typing import cast, Generator, Callable +import time, pprint, random, itertools, math from dataclasses import dataclass, replace, field from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA, TracingKey -from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, getenv, cpu_profile, PROFILE, ProfilePointEvent, cpu_events -from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer, graph_rewrite, print_uops, track_rewrites, KernelInfo +from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, getenv, cpu_profile, PROFILE, ProfilePointEvent, cpu_events, prod, Context +from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, graph_rewrite, print_uops, track_rewrites, KernelInfo, pyrender from tinygrad.device import Device, Buffer from tinygrad.renderer import Renderer, ProgramSpec, Estimates from tinygrad.engine.schedule import ScheduleItem from tinygrad.codegen import full_rewrite -from tinygrad.codegen.opt.kernel import Opt +from tinygrad.codegen.opt import Opt # **************** Program Creation **************** @@ -26,6 +26,7 @@ def get_program(ast:UOp, renderer:Renderer|None=None, opts:list[Opt]|None=None) """ if getenv("VIZ"): graph_rewrite(ast, PatternMatcher([]), name="View Base AST") + if DEBUG >= 5: print('\n'.join(pyrender(ast))) # linearize if renderer is None: renderer = Device.default.renderer @@ -34,9 +35,10 @@ def get_program(ast:UOp, renderer:Renderer|None=None, opts:list[Opt]|None=None) ast = ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts))) try: uops = full_rewrite(ast, renderer) - except RuntimeError: + except RuntimeError as e: print("***** LINEARIZE FAILURE *****") - print(f"ast = {ast}") + print(e) + print('\n'.join(pyrender(ast))) raise assert uops[-1].op is Ops.SINK, "last uop must be sink" @@ -45,7 +47,8 @@ def get_program(ast:UOp, renderer:Renderer|None=None, opts:list[Opt]|None=None) src = renderer.render(uops) return ProgramSpec(uops[-1].arg.name if uops[-1].arg is not None else "test", src, renderer.device, ast, uops, - global_size=[1,1,1] if renderer.has_local else None, local_size=[1,1,1] if renderer.has_local else None) + global_size=[1,1,1] if renderer.has_local or renderer.has_threads else None, + local_size=[1,1,1] if renderer.has_local else None) # **************** Runners **************** @@ -54,18 +57,32 @@ class Runner: self.first_run, self.display_name, self.device, self.estimates = True, display_name, device, estimates @property def dev(self): return Device[self.device] - def exec(self, rawbufs:list[Buffer], var_vals:dict[Variable, int]|None=None) -> float|None: + def exec(self, rawbufs:list[Buffer], var_vals:dict[str, int]|None=None) -> float|None: return self(rawbufs, {} if var_vals is None else var_vals) - def __call__(self, rawbufs:list[Buffer], var_vals:dict[Variable, int], wait=False) -> float|None: + def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int], wait=False) -> float|None: raise NotImplementedError("override this") +def optimize_local_size(_prg:Callable, global_size:list[int], rawbufs:list[Buffer]) -> list[int]: + test_rawbuffers = [Buffer(rawbufs[0].device, rawbufs[0].size, rawbufs[0].dtype).allocate(), *rawbufs[1:]] if rawbufs[0] in rawbufs[1:] else rawbufs + MAX_WORKGROUP = 1024 + local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in global_size] + local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice + def try_exec(local_size): + try: + return _prg(*[x._buf for x in test_rawbuffers],global_size=[g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)], + local_size=local_size, wait=True) + except Exception: return float('inf') + ret = min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))]) + assert not math.isinf(ret[0]), "all optimize_local_size exec failed" + return ret[1] + class CompiledRunner(Runner): def __init__(self, p:ProgramSpec, precompiled:bytes|None=None, prg=None): if DEBUG >= 4: print(p.src) self.p:ProgramSpec = p if precompiled is not None: self.lib = precompiled else: - with cpu_profile(TracingKey(f"compile {p.name}", (p.function_name,), cat="compiler"), "TINY"): + with cpu_profile(TracingKey(f"compile {p.name}", (p.function_name,)), "TINY"): self.lib = Device[p.device].compiler.compile_cached(p.src) if DEBUG >= 7: Device[p.device].compiler.disassemble(self.lib) self._prg = Device[p.device].runtime(p.function_name, self.lib) if prg is None else prg @@ -73,11 +90,10 @@ class CompiledRunner(Runner): def __reduce__(self): return self.__class__, (self.p, self.lib) - def __call__(self, rawbufs:list[Buffer], var_vals:dict[Variable, int], wait=False) -> float|None: + def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int], wait=False) -> float|None: + has_local = Device[self.p.device].renderer.has_local global_size, local_size = self.p.launch_dims(var_vals) - if global_size is not None and local_size is None and all_int(self.p.global_size): # type: ignore[arg-type] - # TODO: this is copied from get_program - from tinygrad.codegen.opt.search import optimize_local_size + if has_local and global_size is not None and local_size is None and all_int(self.p.global_size): # type: ignore[arg-type] local_size = optimize_local_size(self._prg, global_size, rawbufs) global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)] self.p = replace(self.p, global_size=global_size, local_size=local_size) @@ -88,11 +104,11 @@ class CompiledRunner(Runner): if local_size: lra['local_size'] = tuple(local_size) assert len(local_size) == 3, "local size must have len 3" - return self._prg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k] for k in self.p.vars), wait=wait) + return self._prg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k.expr] for k in self.p.vars), wait=wait) class ViewOp(Runner): def __init__(self, buf:Buffer): super().__init__(colored(f"view {buf.nbytes:8d} @ {buf.offset:<10d}", "yellow"), buf.device) - def __call__(self, rawbufs:list[Buffer], var_vals:dict[Variable, int], wait=False): + def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int], wait=False): assert rawbufs[0]._base is not None and rawbufs[0]._base == rawbufs[1].base, f"must be base {rawbufs}" class BufferCopy(Runner): @@ -110,7 +126,7 @@ class BufferCopy(Runner): src.allocator._copyout(dest.allocator._as_buffer(dest._buf), src._buf) else: dest.copyin(src.as_buffer(allow_zero_copy=True)) # may allocate a CPU buffer depending on allow_zero_copy - def __call__(self, rawbufs:list[Buffer], var_vals:dict[Variable, int], wait=False): + def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int], wait=False): dest, src = rawbufs[0:2] assert dest.size == src.size and dest.dtype == src.dtype, f"buffer copy mismatch, {dest.size} != {src.size}, {dest.dtype} != {src.dtype}" st = time.perf_counter() @@ -124,13 +140,13 @@ class BufferXfer(BufferCopy): # **************** method cache **************** -method_cache: dict[tuple[str, bytes, tuple[int, ...], bool], CompiledRunner] = {} +method_cache: dict[tuple[str, type, bytes, tuple[int, ...], bool], CompiledRunner] = {} def get_runner(device:str, ast:UOp) -> CompiledRunner: # TODO: this should be all context relevant to rendering context = (BEAM.value, NOOPT.value, DEVECTORIZE.value) - ckey = (device, ast.key, context, False) + ckey = (device, type(Device[device].compiler), ast.key, context, False) if cret:=method_cache.get(ckey): return cret - bkey = (device.split(":")[0], ast.key, context, True) + bkey = (device.split(":")[0], type(Device[device].compiler), ast.key, context, True) if bret:=method_cache.get(bkey): method_cache[ckey] = ret = CompiledRunner(replace(bret.p, device=device), bret.lib) else: @@ -145,12 +161,11 @@ class ExecItem: prg: Runner bufs: list[Buffer|None] metadata: tuple[Metadata, ...]|None = None - fixedvars: dict[Variable, int] = field(default_factory=dict) - def run(self, _var_vals:dict[Variable, int]|None=None, wait=False, jit=False, do_update_stats=True) -> float|None: + fixedvars: dict[str, int] = field(default_factory=dict) + def run(self, _var_vals:dict[str, int]|None=None, wait=False, jit=False, do_update_stats=True) -> float|None: var_vals = self.fixedvars if _var_vals is None else (_var_vals|self.fixedvars) bufs = [cast(Buffer, x) for x in self.bufs] if jit else [cast(Buffer, x).ensure_allocated() for x in self.bufs] - if PROFILE: cpu_events.append(ProfilePointEvent(self.prg.device, "exec", decimal.Decimal(time.perf_counter_ns())/1000, self.prg.display_name, - {"metadata":self.metadata, "var_vals":var_vals})) + if PROFILE: cpu_events.append(ProfilePointEvent(self.prg.device, "exec", self.prg.display_name, {"metadata":self.metadata, "var_vals":var_vals})) et = self.prg(bufs, var_vals, wait=wait or DEBUG >= 2) if do_update_stats: GlobalCounters.kernel_count += 1 @@ -160,10 +175,15 @@ class ExecItem: if DEBUG >= 2: lds_est = sym_infer(self.prg.estimates.lds, var_vals) mem_est = min(mem_est, lds_est) # there can't be more memory accessed than loads/stores. remove this when symbolic is fixed + header_color = 'magenta' if jit else ('green' if self.prg.first_run else None) ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None) if et is not None else "" - print(f"{colored(f'*** {self.prg.device[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(44-ansilen(self.prg.display_name))} arg {len(bufs):2d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501 - (str() if et is None else f"tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_est/((et or 1e-20)*1e9):9.2f} GFLOPS {mem_est/((et or 1e-20)*1e9):6.1f}|{lds_est/((et or 1e-20)*1e9):<7.1f} GB/s)" + # noqa: E501 - f" {[repr(m) if TRACEMETA >= 2 else str(m) for m in self.metadata] if self.metadata else ''}")) + flops, membw, ldsbw = op_est/(et or 1e-20), mem_est/(et or 1e-20), lds_est/(et or 1e-20) + flops_str = f"{flops*1e-9:9.2f} GFLOPS" if flops < 1e14 else colored(f"{flops*1e-12:9.2f} TFLOPS", 'green') + mem_str = f"{membw*1e-9:6.1f}|{ldsbw*1e-9:<7.1f} GB/s" if membw < 1e13 else colored(f"{membw*1e-12:6.1f}|{ldsbw*1e-12:<7.1f} TB/s", 'green') + print(f"{colored(f'*** {self.prg.device[:7]:7s} {GlobalCounters.kernel_count:4d}', header_color)}"+ + f" {self.prg.display_name+' '*(44-ansilen(self.prg.display_name))} arg {len(bufs):2d} mem {GlobalCounters.mem_used/1e9:5.2f} GB"+ + ("" if et is None else f" tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({flops_str} {mem_str})")+ + f" {[repr(m) if TRACEMETA >= 2 else str(m) for m in self.metadata] if self.metadata else ''}") self.prg.first_run = False return et @@ -193,7 +213,7 @@ def lower_schedule(schedule:list[ScheduleItem]) -> Generator[tuple[ScheduleItem, capturing: list = [] # put classes with an add method in here -def run_schedule(schedule:list[ScheduleItem], var_vals:dict[Variable, int]|None=None, do_update_stats=True): +def run_schedule(schedule:list[ScheduleItem], var_vals:dict[str, int]|None=None, do_update_stats=True): for si, ei in lower_schedule(schedule): if len(capturing) and CAPTURING: capturing[0].add(ei) if VALIDATE_WITH_CPU and si.ast.op is Ops.SINK: @@ -206,9 +226,8 @@ def run_schedule(schedule:list[ScheduleItem], var_vals:dict[Variable, int]|None= ei.run(var_vals, do_update_stats=do_update_stats) # validate the output buffers match (NOTE: this is assuming the output is buffer 0) - lower_schedule_item(ScheduleItem(si.ast, nb, si.metadata, si.fixedvars)).run(var_vals, do_update_stats=do_update_stats) + with Context(BEAM=0): lower_schedule_item(ScheduleItem(si.ast, nb, si.metadata, si.fixedvars)).run(var_vals, do_update_stats=do_update_stats) import numpy as np np.testing.assert_allclose(si.bufs[0].numpy(), nb[0].numpy(), rtol=1e-3, atol=1e-3) else: ei.run(var_vals, do_update_stats=do_update_stats) - diff --git a/tinygrad_repo/tinygrad/engine/schedule.py b/tinygrad_repo/tinygrad/engine/schedule.py index 1e6b5259..e56c9083 100644 --- a/tinygrad_repo/tinygrad/engine/schedule.py +++ b/tinygrad_repo/tinygrad/engine/schedule.py @@ -1,7 +1,7 @@ from typing import cast from dataclasses import dataclass, field from collections import deque, defaultdict -from tinygrad.uop.ops import UOp, Variable, Ops, buffers +from tinygrad.uop.ops import UOp, Ops, buffers from tinygrad.device import Device, Buffer, MultiBuffer from tinygrad.helpers import Metadata, all_same @@ -12,15 +12,15 @@ class ScheduleItem: ast: UOp bufs: tuple[Buffer, ...] metadata: tuple[Metadata, ...] = () - fixedvars: dict[Variable, int] = field(default_factory=dict) + fixedvars: dict[str, int] = field(default_factory=dict) # **** schedule linearizer -def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int]]: +def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[str, int]]: # construct the KERNEL children graph based on assigns children: defaultdict[UOp, list[UOp]] = defaultdict(list) in_degree: dict[UOp, int] = {} - var_vals: dict[Variable, int] = {} + var_vals: dict[str, int] = {} for u in sched_sink.toposort(): if u.op is not Ops.ASSIGN: continue # anything that's not an ASSIGN doesn't write a kernel, so we can skip k = u.src[1] @@ -40,8 +40,8 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[ pass # a BUFFER is already realized, nothing to do here elif s.op is Ops.BIND: var, val = s.unbind() - assert var not in var_vals or var_vals[var] == val, f"bind mismatch on {var}, {var_vals[var]} != {val}" - var_vals[var] = val + assert var.expr not in var_vals or var_vals[var.expr] == val, f"bind mismatch on {var}, {var_vals[var.expr]} != {val}" + var_vals[var.expr] = val else: raise RuntimeError(f"input to kernel must be ASSIGN or BUFFER, not {s.op}") @@ -72,7 +72,7 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[ assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer" dnums = [x for x in ast.variables() if x.arg[0] == '_device_num'] for i,bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])): - schedule.append(ScheduleItem(ast, bufs, k.arg.metadata, {dnums[0]:i} if len(dnums) else {})) + schedule.append(ScheduleItem(ast, bufs, k.arg.metadata, {dnums[0].expr:i} if len(dnums) else {})) else: # ONE -> ONE schedule.append(ScheduleItem(ast, cast(tuple[Buffer, ...], ubufs), k.arg.metadata)) diff --git a/tinygrad_repo/tinygrad/frontend/onnx.py b/tinygrad_repo/tinygrad/frontend/onnx.py index 25b377d9..33d54086 100644 --- a/tinygrad_repo/tinygrad/frontend/onnx.py +++ b/tinygrad_repo/tinygrad/frontend/onnx.py @@ -21,9 +21,9 @@ class AttributeType(enum.IntEnum): ONNX attribute type identifiers. Reference: https://github.com/onnx/onnx/blob/rel-1.18.0/onnx/onnx.proto3#L128-L145 """ - FLOAT = 1; INT = 2; STRING = 3; TENSOR = 4; FLOATS = 6; INTS = 7; STRINGS = 8 # noqa: E702 + FLOAT = 1; INT = 2; STRING = 3; TENSOR = 4; GRAPH = 5; FLOATS = 6; INTS = 7; STRINGS = 8 # noqa: E702 - def to_field_name(self) -> str: return {1: "f", 2: "i", 3: "s", 4: "t", 6: "floats", 7: "ints", 8: "strings"}[self.value] + def to_field_name(self) -> str: return {1: "f", 2: "i", 3: "s", 4: "t", 5: "g", 6: "floats", 7: "ints", 8: "strings"}[self.value] class OnnxDataType(enum.IntEnum): """ @@ -266,6 +266,7 @@ class OnnxPBParser: case 3: obj["i"] = self.reader.read_int64() case 4: obj["s"] = self.reader.read_bytes().data().tobytes().decode("utf8") case 5: obj["t"] = self._parse_TensorProto()['parsed_tensor'] + case 6: obj["g"] = OnnxRunner._from_subgraph(self._parse_GraphProto()) case 7: obj["floats"].append(self.reader.read_float()) case 8: obj["ints"].append(self.reader.read_int64()) case 9: obj["strings"].append(self.reader.read_bytes().data().tobytes().decode("utf8")) @@ -401,8 +402,11 @@ class OnnxRunner: """ def __init__(self, model_path: Tensor | str | pathlib.Path): model = OnnxPBParser(model_path, load_external_data=True).parse() - graph = model["graph"] + self._init_from_graph(model["graph"]) + + def _init_from_graph(self, graph: dict, is_subgraph: bool = False): self.is_training = any(n['parsed_node'].opset_id.domain in {Domain.AI_ONNX_TRAINING, Domain.AI_ONNX_PREVIEW_TRAINING} for n in graph["node"]) + self.graph_name = graph["name"] if is_subgraph else "" self.graph_values = {"": None, **{i["name"]: i["parsed_tensor"] for i in graph["initializer"]}} self.graph_inputs = {i["name"]: i["parsed_type"] for i in graph["input"] if i["name"] not in self.graph_values} self.graph_outputs = tuple(o["name"] for o in graph["output"]) @@ -414,6 +418,12 @@ class OnnxRunner: self.variable_dims: dict[str, int] = {} self.onnx_ops = onnx_ops + @classmethod + def _from_subgraph(cls, graph: dict) -> "OnnxRunner": + subgraph = cls.__new__(cls) + subgraph._init_from_graph(graph, is_subgraph=True) + return subgraph + def _parse_input(self, name: str, value: Any, spec: OnnxValue): if spec.is_optional and value is None: return None if spec.is_sequence: @@ -445,9 +455,10 @@ class OnnxRunner: return {name:Tensor.empty(*spec.shape, device=device, dtype=dtype or spec.dtype) for name, spec in self.graph_inputs.items()} def to(self, device:str|None): - self.graph_values = {k:v.to(device) if isinstance(v, Tensor) else v for k,v in self.graph_values.items()} + self.graph_values = {k: (v.to(device) if isinstance(v, Tensor) else v) for k,v in self.graph_values.items()} self.graph_nodes = tuple(OnnxNode(n.op, n.opset_id, tuple(n.inputs), tuple(n.outputs), - {k:v.to(device) if isinstance(v, Tensor) else v for k,v in n.opts.items()}) for n in self.graph_nodes) + {k: (v.to(device) if isinstance(v, (Tensor, OnnxRunner)) else v) for k,v in n.opts.items()}) + for n in self.graph_nodes) return self def __call__(self, inputs:dict[str, Any], debug=debug): @@ -461,9 +472,9 @@ class OnnxRunner: # provide additional opts if node.op == "Split" and 'num_outputs' not in opts: opts['num_outputs'] = len(node.outputs) - if node.op == "Gradient": opts['intermediate_tensors'] = self.graph_values + if node.op in {"Gradient", "If"}: opts['intermediate_tensors'] = self.graph_values - if debug >= 1: print(f"{num}: op '{node.op}' opt {opts}") + if debug >= 1: print((f"[{self.graph_name}] " if self.graph_name else "") + f"{num}: op '{node.op}' opt {opts}") if debug >= 2 and node.inputs: print("\tinputs:\n" + "\n".join(f"\t\t{x} - {i!r}" for x,i in zip(node.inputs, inps))) ret = self._select_op(node.op, node.opset_id)(*inps, **opts) ret = ret if isinstance(ret, tuple) else (ret,) @@ -543,6 +554,23 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT return __decorator # ***** Property/Graph Ops ***** + def If(condition:Tensor, else_branch:OnnxRunner, then_branch:OnnxRunner, intermediate_tensors:dict[str, Tensor]): + def run_branch(branch:OnnxRunner): + branch.graph_values.update(intermediate_tensors) + out = branch({k:intermediate_tensors[k] for k in branch.graph_inputs.keys()}) + # dereference intermediate tensors so Buffer can be deallocated + for k in intermediate_tensors: del branch.graph_values[k] + return out + # both branch must be ran before the condition can be evaluated + else_out, then_out = run_branch(else_branch), run_branch(then_branch) + assert len(else_out) == len(then_out), f"else_out and then_out must have the same number of outputs: {len(else_out)} != {len(then_out)}" + # can use where op when output shape is the same + if all(t.shape == e.shape for t,e in zip(then_out.values(), else_out.values())): + return tuple(condition.where(t,e) for t,e in zip(then_out.values(), else_out.values())) + # otherwise, use condition to select the output in python + cond = _resolve_const(_cached_to_python_const(condition)) + return tuple(t if cond else e for t,e in zip(then_out.values(), else_out.values())) + def Identity(x:Tensor): return x def Constant(sparse_value:Tensor|None=None, value:Tensor|None=None, value_float:float|None=None, value_floats:list[float]|None=None, value_int:int|None=None, value_ints:list[int]|None=None, value_string:str|None=None, value_strings:list[str]|None=None): diff --git a/tinygrad_repo/tinygrad/gradient.py b/tinygrad_repo/tinygrad/gradient.py index 77979148..c538555a 100644 --- a/tinygrad_repo/tinygrad/gradient.py +++ b/tinygrad_repo/tinygrad/gradient.py @@ -22,11 +22,10 @@ pm_gradient = PatternMatcher([ (UPat(Ops.SQRT, name="ret"), lambda ctx, ret: (ctx / (ret*2),)), (UPat((Ops.CMPLT, Ops.CMPNE)), lambda: (None, None)), (UPat(Ops.ADD), lambda ctx: (ctx, ctx)), - (UPat(Ops.POW, name="ret"), lambda ctx, ret: - (ctx*(ret.src[0].eq(0) & ret.src[1].eq(0)).where(ret.src[1], ret.src[1]*ret.src[0].pow(ret.src[1]-1)), - ctx*ret.src[0].eq(0).where((ret.src[1]<0).where(ret.const_like(-math.inf), ret.const_like(0)), ret*ret.src[0].log2()*math.log(2.0)))), - (UPat(Ops.MAX, name="ret"), lambda ctx, ret: ((ret.src[0]>ret.src[1]).where(ctx, (ret.src[0]!=ret.src[1]).where(ctx.const_like(0), ctx * 0.5)), - (ret.src[0]y).where(ctx, (x.eq(y)).where(ctx * 0.5, 0)), (x T|int: return functools.reduce(operator.mul, x, 1) # NOTE: helpers is not allowed to import from anything else in tinygrad -OSX = platform.system() == "Darwin" +OSX, WIN = platform.system() == "Darwin", sys.platform == "win32" CI = os.getenv("CI", "") != "" +ARCH_X86 = any(x in platform.processor() for x in ("Intel", "i386", "x86_64")) # fix colors on Windows, https://stackoverflow.com/questions/12492810/python-how-can-i-make-the-ansi-escape-codes-to-work-also-in-windows -if sys.platform == "win32": os.system("") +if WIN: os.system("") def dedup(x:Iterable[T]): return list(dict.fromkeys(x)) # retains list order def argfix(*x): @@ -56,7 +57,7 @@ def i2u(bits: int, value: int): return value if value >= 0 else (1< bool: return str(type(x)) == "" def merge_dicts(ds:Iterable[dict[T,U]]) -> dict[T,U]: kvs = set([(k,v) for d in ds for k,v in d.items()]) - assert len(kvs) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key" + if len(kvs) != len(set(kv[0] for kv in kvs)): raise RuntimeError(f"{kvs} contains different values for the same key") return {k:v for d in ds for k,v in d.items()} def partition(itr:Iterable[T], fxn:Callable[[T],bool]) -> tuple[list[T], list[T]]: ret:tuple[list[T], list[T]] = ([], []) @@ -88,6 +89,8 @@ def suppress_finalizing(func): if not getattr(sys, 'is_finalizing', lambda: True)(): raise # re-raise if not finalizing return wrapper +def unwrap_class_type(cls_t:T): return cls_t.func if isinstance(cls_t, functools.partial) else cls_t + def pluralize(st:str, cnt:int): return f"{cnt} {st}"+('' if cnt == 1 else 's') class LazySeq(Generic[T]): # NOTE: Mapping requires __iter__ and __len__, Sequence requires supporting __len__ and slicing in __getitem__ @@ -126,21 +129,24 @@ class ContextVar: def __lt__(self, x): return self.value < x DEBUG, IMAGE, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0) -JIT = ContextVar("JIT", 2 if platform.system() == 'Darwin' and ('Intel' in platform.processor() or 'i386' in platform.processor()) else 1) -JIT_BATCH_SIZE = ContextVar("JIT_BATCH_SIZE", 32) +JIT, JIT_BATCH_SIZE = ContextVar("JIT", 2 if OSX and ARCH_X86 else 1), ContextVar("JIT_BATCH_SIZE", 32) WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1) USE_TC, TC_SELECT, TC_OPT, AMX = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0) -TRANSCENDENTAL, TC_SEARCH_OVER_SHAPE, NOLOCALS = ContextVar("TRANSCENDENTAL", 1), ContextVar("TC_SEARCH_OVER_SHAPE", 1), ContextVar("NOLOCALS", 0) +TRANSCENDENTAL, NOLOCALS = ContextVar("TRANSCENDENTAL", 1), ContextVar("NOLOCALS", 0) FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 1), ContextVar("FUSE_CONV_BW", 0) SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1) -PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1) +PICKLE_BUFFERS, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("LRU", 1) CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1) DISABLE_COMPILER_CACHE, BLOCK_REORDER = ContextVar("DISABLE_COMPILER_CACHE", 0), ContextVar("BLOCK_REORDER", 1) DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES = ContextVar("DONT_REALIZE_EXPAND", 0), ContextVar("DONT_GROUP_REDUCES", 0) QUANTIZE, VALIDATE_WITH_CPU, DISABLE_FAST_IDIV = ContextVar("QUANTIZE", 0), ContextVar("VALIDATE_WITH_CPU", 0), ContextVar("DISABLE_FAST_IDIV", 0) CORRECT_DIVMOD_FOLDING, FUSE_OPTIM = ContextVar("CORRECT_DIVMOD_FOLDING", 0), ContextVar("FUSE_OPTIM", 0) -ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, AMD_LLVM = ContextVar("ALLOW_DEVICE_USAGE", 1), ContextVar("MAX_BUFFER_SIZE", 0), ContextVar("AMD_LLVM", 1) -RANGEIFY = ContextVar("RANGEIFY", 0) +ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE = ContextVar("ALLOW_DEVICE_USAGE", 1), ContextVar("MAX_BUFFER_SIZE", 0) +RANGEIFY, FUSE_ATTENTION = ContextVar("RANGEIFY", 0), ContextVar("FUSE_ATTENTION", 0) +EMULATE = ContextVar("EMULATE", "") +CPU_COUNT = ContextVar("CPU_COUNT", max(1, (os.cpu_count() or 1) // (4 if ARCH_X86 else 2))) # take 1/2 of the cores, accounting HT +CPU_LLVM, AMD_LLVM = ContextVar("CPU_LLVM", 0), ContextVar("AMD_LLVM", 1) +VIZ = PROFILE = ContextVar("VIZ", 0) @dataclass(frozen=True) class Metadata: @@ -192,12 +198,12 @@ class Profiling(contextlib.ContextDecorator): colored(_format_fcn(fcn).ljust(50), "yellow"), colored(f"<- {(scallers[0][1][2]/tottime)*100:3.0f}% {_format_fcn(scallers[0][0])}", "BLACK") if scallers else '') +def perf_counter_us() -> decimal.Decimal: return decimal.Decimal(time.perf_counter_ns())/1000 @dataclass(frozen=True) class TracingKey: display_name:str # display name of this trace event - keys:tuple[str, ...]=() # optional keys to search for related traces - cat:str|None=None # optional category to color this by + keys:tuple[Any, ...]=() # optional keys to search for related traces ret:Any=None class ProfileEvent: pass @@ -206,17 +212,21 @@ class ProfileEvent: pass class ProfileRangeEvent(ProfileEvent): device:str; name:str|TracingKey; st:decimal.Decimal; en:decimal.Decimal|None=None; is_copy:bool=False # noqa: E702 @dataclass(frozen=True) -class ProfilePointEvent(ProfileEvent): device:str; name:str; ts:decimal.Decimal; key:Any; arg:dict=field(default_factory=dict) # noqa: E702 +class ProfilePointEvent(ProfileEvent): device:str; name:str; key:Any; arg:dict=field(default_factory=dict); \ + ts:decimal.Decimal=field(default_factory=perf_counter_us) # noqa: E702 cpu_events:list[ProfileEvent] = [] @contextlib.contextmanager def cpu_profile(name:str|TracingKey, device="CPU", is_copy=False, display=True) -> Generator[ProfileRangeEvent, None, None]: - res = ProfileRangeEvent(device, name, decimal.Decimal(time.perf_counter_ns()) / 1000, is_copy=is_copy) + res = ProfileRangeEvent(device, name, perf_counter_us(), is_copy=is_copy) try: yield res finally: - res.en = decimal.Decimal(time.perf_counter_ns()) / 1000 + res.en = perf_counter_us() if PROFILE and display: cpu_events.append(res) +def profile_marker(name:str, color="gray") -> None: + cpu_events.append(ProfilePointEvent("TINY", "marker", None, {"name":name, "color":color})) + # *** universal database cache *** cache_dir: str = os.path.join(getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches" if OSX else "~/.cache")), "tinygrad") @@ -315,7 +325,10 @@ def cpu_objdump(lib, objdump_tool='objdump'): print(subprocess.check_output([objdump_tool, '-d', f.name]).decode('utf-8')) def capstone_flatdump(lib: bytes): - import capstone + try: import capstone + except ImportError: + print("Disassembler Error: Capstone not installed.") + return match platform.machine(): case 'x86_64' | 'AMD64': cs = capstone.Cs(capstone.CS_ARCH_X86, capstone.CS_MODE_64) case 'aarch64' | 'arm64': cs = capstone.Cs(capstone.CS_ARCH_ARM64, capstone.CS_MODE_ARM) diff --git a/tinygrad_repo/tinygrad/nn/__init__.py b/tinygrad_repo/tinygrad/nn/__init__.py index bf6dca75..d32a3d5e 100644 --- a/tinygrad_repo/tinygrad/nn/__init__.py +++ b/tinygrad_repo/tinygrad/nn/__init__.py @@ -320,6 +320,7 @@ class Embedding: def __call__(self, idx:Tensor) -> Tensor: if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).unsqueeze(-1) + if not dtypes.is_int(idx.dtype): raise TypeError(f"Expected integer dtype for index in embedding, got {idx.dtype}") big_shp = idx.shape+(self.vocab_sz, self.embed_sz) arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1)).expand(big_shp), self.weight.expand(big_shp) return (arange == idx).mul(vals).sum(-2, dtype=vals.dtype) diff --git a/tinygrad_repo/tinygrad/nn/state.py b/tinygrad_repo/tinygrad/nn/state.py index aa8c1b1a..110da5ec 100644 --- a/tinygrad_repo/tinygrad/nn/state.py +++ b/tinygrad_repo/tinygrad/nn/state.py @@ -274,9 +274,9 @@ def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor: Converts ggml tensor data to a tinygrad tensor. Supported native types: float32 (id: 0), float16 (id: 1), int8 (id: 16), int16 (id: 17), int32 (id: 18) - Supported quantized types: Q4_0 (id: 2), Q4_1 (id: 3), Q8_0 (id: 8), Q6_K (id: 14) + Supported quantized types: Q4_0 (id: 2), Q4_1 (id: 3), Q8_0 (id: 8), Q6_K (id: 14), MXFP4 (id: 39) """ - # https://github.com/ggerganov/ggml/blob/6dccc647264f5429df2624f36138f601e7ce23e5/include/ggml.h#L356 + # https://github.com/ggerganov/ggml/blob/323951f1bdcdfbd5b5ff3a9a7c3770e63b1a560e/include/ggml.h#L356 # native types if (dtype := { 0: dtypes.float32, 1: dtypes.float16, 16: dtypes.int8, 17: dtypes.int16, 18: dtypes.int32 }.get(ggml_type)) is not None: @@ -288,7 +288,7 @@ def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor: return t.unsqueeze(-1).expand((*t.shape,8//b)).idiv(shift_tensor).bitwise_and(bitmask).transpose(-1, -2).flatten(-2) # map to (number of elements, number of bytes) - if (nelements_nbytes := { 2: (32, 18), 3: (32, 20), 14: (256, 210), 8: (32, 34) }.get(ggml_type)) is not None: + if (nelements_nbytes := { 2: (32, 18), 3: (32, 20), 14: (256, 210), 8: (32, 34), 39: (32, 17) }.get(ggml_type)) is not None: blocks = t[:(n//nelements_nbytes[0])*nelements_nbytes[1]].reshape((-1, nelements_nbytes[1])) if ggml_type == 2: return (q_to_uint8(blocks[:,2:], 4).bitcast(dtypes.int8) - 8) * blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32) if ggml_type == 3: @@ -300,6 +300,17 @@ def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor: scales = blocks[:,192:208].bitcast(dtypes.int8).unsqueeze(-1).expand((-1, 16, 16)).reshape((-1, 256)) d = blocks[:,-2:].bitcast(dtypes.float16).cast(dtypes.float32).expand((-1, 256)) return d * (xl.bitwise_or(xh).bitcast(dtypes.int8) - 32).flatten(-2) * scales + if ggml_type == 39: + e_int = blocks[:, 0].cast(dtypes.int32) + d = ((e_int >= 2).cast(dtypes.float32) * (e_int.cast(dtypes.float32) - 128).exp2() + + (e_int == 1).cast(dtypes.float32) * 2.0**(-127) + + (e_int == 0).cast(dtypes.float32) * 2.0**(-128)).unsqueeze(-1) + codes = q_to_uint8(blocks[:, 1:17], 4) + sign = 1.0 - codes.rshift(3).cast(dtypes.float32) * 2.0 + exp, mant = codes.rshift(1).bitwise_and(0x3).cast(dtypes.float32), codes.bitwise_and(0x1).cast(dtypes.float32) + fp4_val = sign * ((exp != 0).cast(dtypes.float32) * (1.0 + 0.5 * mant) * (exp - 1.0).exp2() + + (exp == 0).cast(dtypes.float32) * 0.5 * mant) + return (fp4_val * d).flatten(-2)[:n] raise ValueError(f"GGML type '{ggml_type}' is not supported!") @accept_filename diff --git a/tinygrad_repo/tinygrad/renderer/__init__.py b/tinygrad_repo/tinygrad/renderer/__init__.py index ed53f039..068e8afb 100644 --- a/tinygrad_repo/tinygrad/renderer/__init__.py +++ b/tinygrad_repo/tinygrad/renderer/__init__.py @@ -39,14 +39,14 @@ class Estimates: buf = u while len(buf.src): buf = buf.src[0] if buf.op is Ops.DEFINE_GLOBAL: # assume all DEFINE_GLOBAL memory is accessed - mem[(buf, u.op)] = cast(PtrDType, buf.dtype).size * buf.dtype.itemsize + mem[(buf, u.op)] = buf.ptrdtype.size * buf.dtype.itemsize if u.op is Ops.RANGE: mult_stack.append(mults) mults *= cast(sint, u.src[0].ssimplify()) # SPECIAL are already counted in mults mults = mults.substitute({x:x.const_like(0) for x in mults.toposort() if x.op is Ops.SPECIAL}) if isinstance(mults, UOp) else mults elif u.op is Ops.ENDRANGE: mults = mult_stack.pop(-1) - elif u.op is Ops.SPECIAL: mults *= u.arg[1] # NOTE: we don't push to the mult_stack here, you can't end these + elif u.op is Ops.SPECIAL: mults *= cast(sint, u.src[0].ssimplify()) # NOTE: we don't push to the mult_stack here, you can't end these elif u.op is Ops.LOAD and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG): lds += u.dtype.itemsize * mults elif u.op is Ops.STORE and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG): @@ -82,9 +82,9 @@ class ProgramSpec: if u.op is Ops.LOAD: self.ins.extend([x.arg for x in u.src[0].toposort() if x.op is Ops.DEFINE_GLOBAL]) if u.op is Ops.SPECIAL: # NOTE: you have to set local_size and global_size to the base [1,1,1] outside this - if u.arg[0][0] == 'i': self.local_size = None - special_size = self.local_size if u.arg[0][0] == 'l' else self.global_size - if special_size is not None: special_size[int(u.arg[0][-1])] = u.arg[1] + if u.arg[0] == 'i': self.local_size = None + special_size = self.local_size if u.arg[0] == 'l' else self.global_size + if special_size is not None: special_size[int(u.arg[-1])] = cast(int, u.src[0].ssimplify()) self.vars = sorted(self.vars, key=lambda v: v.arg) self.outs = sorted(dedup(self.outs)) self.ins = sorted(dedup(self.ins)) @@ -101,7 +101,7 @@ class ProgramSpec: def applied_opts(self) -> tuple[Opt, ...]|None: return self.uops[-1].arg.applied_opts if \ self.uops is not None and self.uops[-1].op is Ops.SINK and self.uops[-1].arg is not None else None - def launch_dims(self, var_vals:dict[Variable, int]): + def launch_dims(self, var_vals:dict[str, int]): global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else None local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else None return global_size, local_size @@ -112,6 +112,7 @@ class Renderer: # TODO: make this generic with a list of supported types supports_float4: bool = True has_local: bool = True + has_threads: bool = False has_shared: bool = True # NOTE: these two should be in (x,y,z) order to match the max_sizes argument in get_grouped_dims global_max: tuple[int, ...]|None = (0x8FFFFFFF,) * (3) # TODO: Ops.SPECIAL int32 indexes right now diff --git a/tinygrad_repo/tinygrad/renderer/cstyle.py b/tinygrad_repo/tinygrad/renderer/cstyle.py index c1993544..09d25c08 100644 --- a/tinygrad_repo/tinygrad/renderer/cstyle.py +++ b/tinygrad_repo/tinygrad/renderer/cstyle.py @@ -2,8 +2,8 @@ from typing import Literal, Callable, cast import os, math, sys from collections import defaultdict, Counter from tinygrad.codegen.opt import tc -from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat -from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX +from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat, sint_to_uop +from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX, CPU_COUNT from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, AddrSpace, truncate from tinygrad.renderer import Renderer from tinygrad.codegen.late.devectorizer import no_vectorized_alu @@ -26,7 +26,7 @@ base_rewrite = PatternMatcher([ (UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.dtype.size}];"), (UPat(Ops.BARRIER), lambda ctx: ctx.barrier), (UPat(Ops.PRECAST, name="x"), lambda ctx,x: ctx[x.src[0]]), - (UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"{ctx.code_for_workitem[x.arg[0][0]](x.arg[0][-1])}; /* {x.arg[1]} */"), + (UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"{ctx.code_for_workitem[x.arg[0]](x.arg[-1])}; /* {(x.src[0]).render()} */"), # const (UPat(Ops.CONST, arg=math.inf, name="x"), lambda ctx, x: f"({ctx.render_cast(x.dtype, ctx.infinity)})"), (UPat(Ops.CONST, arg=-math.inf, name="x"), lambda ctx, x: f"({ctx.render_cast(x.dtype, f'-{ctx.infinity}')})"), @@ -111,7 +111,8 @@ class CStyleLanguage(Renderer): tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs) else "" # noqa: E501 buftypes = [(name, self.render_dtype(dtype, mutable)+self.buffer_suffix if isinstance(dtype, (ImageDType, PtrDType)) else self.arg_int_prefix if dtype == dtypes.int else None) for name,(dtype,mutable) in bufs] - launch_bounds = prod(u.arg[1] for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == "l") + local_dims = [u.src[0] for u in uops if u.op is Ops.SPECIAL and u.arg[0] == "l"] + launch_bounds = sint_to_uop(prod(local_dims)).vmax prg = ''.join([f"{self.kernel_typedef.format(launch_bounds=launch_bounds)} {function_name}(",] + [', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] + [") {\n" + tmp] + ['\n'.join(kernel), "\n}"]) @@ -145,7 +146,7 @@ class CStyleLanguage(Renderer): if u.arg is not None: name = u.arg.function_name continue if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR): - r[u] = (f"data{u.arg}_{sz}" if (sz:=cast(PtrDType, u.dtype).size) > 0 else f"data{u.arg}") if u.op is Ops.DEFINE_GLOBAL else u.arg[0] + r[u] = (f"data{u.arg}_{sz}" if (sz:=u.ptrdtype.size) > 0 else f"data{u.arg}") if u.op is Ops.DEFINE_GLOBAL else u.arg[0] bufs[u] = (r[u], (u.dtype, False)) continue @@ -156,8 +157,8 @@ class CStyleLanguage(Renderer): # naming prefix = None - if u.op is Ops.SPECIAL: r[u] = u.arg[0] - elif u.op is Ops.RANGE: r[u] = f"ridx{u.arg[0]}" if u.arg[0] >= 0 else f"ridxm{-u.arg[0]}" + if u.op is Ops.SPECIAL: r[u] = u.arg + elif u.op is Ops.RANGE: r[u] = "ridx"+'_'.join([str(x) if x >= 0 else "m"+str(-x) for x in u.arg[0:-1]]) else: prefix = {Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const", Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.VECTORIZE: "cast", Ops.PRECAST: "precast", @@ -169,7 +170,7 @@ class CStyleLanguage(Renderer): if u.op in {Ops.ENDIF, Ops.ENDRANGE}: depth -= 1 if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.CUSTOMI} or \ - (u.op is Ops.LOAD and cast(PtrDType, u.src[0].dtype).addrspace == AddrSpace.REG) or \ + (u.op is Ops.LOAD and u.src[0].ptrdtype.addrspace == AddrSpace.REG) or \ (u.op is Ops.CAST and isinstance(u.dtype, PtrDType)) or \ (u.op in {Ops.VECTORIZE, *(GroupOp.ALU-{Ops.WHERE}), Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA"))): r[u] = l @@ -191,9 +192,12 @@ class ClangRenderer(CStyleLanguage): float4_style = ('{', '}') gep_arr_threshold = 0 has_local = False - global_max = None + has_threads = bool(getenv("THREADS", 1)) + global_max = (CPU_COUNT.value, 0, 0) infinity = "__builtin_inff()" nan = '__builtin_nanf("")' + code_for_workitem = {"g": lambda _: "core_id"} + extra_args = ['int core_id'] if AMX: tensor_cores = tc.amx # language options @@ -238,7 +242,7 @@ class ClangRenderer(CStyleLanguage): return defines + "\n" + self._render_body(function_name, kernel, bufs, uops, prefix) + "\n" + self._render_entry(function_name, bufs) class OpenCLRenderer(CStyleLanguage): - device = "GPU" + device = "CL" # language options kernel_typedef = "__kernel void" @@ -267,7 +271,7 @@ class OpenCLRenderer(CStyleLanguage): return super().render_kernel(function_name, kernel, bufs, uops, prefix) class IntelRenderer(OpenCLRenderer): - device, suffix, kernel_typedef = "GPU", "INTEL", "__attribute__((intel_reqd_sub_group_size(8)))\n" + "__kernel void" + device, suffix, kernel_typedef = "CL", "INTEL", "__attribute__((intel_reqd_sub_group_size(8)))\n" + "__kernel void" tensor_cores = tc.intel string_rewrite = PatternMatcher([ diff --git a/tinygrad_repo/tinygrad/renderer/llvmir.py b/tinygrad_repo/tinygrad/renderer/llvmir.py index a7580871..f19f4dc2 100644 --- a/tinygrad_repo/tinygrad/renderer/llvmir.py +++ b/tinygrad_repo/tinygrad/renderer/llvmir.py @@ -3,7 +3,8 @@ import math, struct, sys from tinygrad.codegen.opt import tc from tinygrad.renderer import Renderer from tinygrad.renderer.cstyle import AMDRenderer -from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, GroupOp +from tinygrad.uop.decompositions import xexp2, xlog2 +from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, GroupOp, sint_to_uop from tinygrad.dtype import dtypes, DType, PtrDType, truncate from tinygrad.helpers import prod, AMX @@ -117,7 +118,7 @@ base_rewrite = PatternMatcher([ ]) class LLVMRenderer(Renderer): - device = "LLVM" + device = "CPU" abi = 'win64cc' if sys.platform == 'win32' else None supports_float4 = True has_local = False @@ -173,7 +174,7 @@ class LLVMRenderer(Renderer): elif u.op in (Ops.DEFINE_LOCAL, Ops.DEFINE_REG): r[u] = f"%{'local' if u.op is Ops.DEFINE_LOCAL else 'reg'}_{str(u.arg).replace('(', '').replace(')', '').replace(',', '_').replace(' ', '')}" assert isinstance(u.dtype, PtrDType) - if self.device == "LLVM" or u.op is Ops.DEFINE_REG: + if self.device == "CPU" or u.op is Ops.DEFINE_REG: kernel.append(f" {r[u]} = alloca [{u.dtype.size} x {ldt(u.dtype.base)}]") else: local_args.append(f"@{r[u][1:]} = internal unnamed_addr addrspace(3) global [{u.dtype.size} x {ldt(u.dtype)}] undef, align 16") @@ -196,14 +197,19 @@ class LLVMRenderer(Renderer): barrier = 'fence syncscope("workgroup") release\ntail call void @llvm.amdgcn.s.barrier()\nfence syncscope("workgroup") acquire\n' code_for_workitem = {"g": lambda x: f"tail call i32 @llvm.amdgcn.workgroup.id.{chr(120+int(x))}()", "l": lambda x: f"tail call i32 @llvm.amdgcn.workitem.id.{chr(120+int(x))}()"} +# https://rocm.docs.amd.com/projects/llvm-project/en/latest/LLVM/llvm/html/AMDGPUUsage.html#llvm-ir-intrinsics +llvm_intrinsics = {Ops.SQRT: "sqrt", Ops.LOG2: "log2", Ops.EXP2: "exp2"} class AMDLLVMRenderer(LLVMRenderer): device = "AMD" has_local = True shared_max = AMDRenderer.shared_max global_max = AMDRenderer.global_max abi = "amdgpu_kernel" + code_for_op = {**LLVMRenderer.code_for_op, **{op: lambda: None for op in llvm_intrinsics}} string_rewrite = PatternMatcher([ - (UPat(Ops.SPECIAL, name="x"), lambda ctx, x: f" {ctx[x]} = " + f"{ code_for_workitem[x.arg[0][0]](x.arg[0][-1])}; "), + (UPat(Ops.SPECIAL, name="x"), lambda ctx, x: f" {ctx[x]} = " + f"{ code_for_workitem[x.arg[0]](x.arg[-1])}; "), + (UPat(tuple(llvm_intrinsics), name="x"), + lambda ctx, x: f" {ctx[x]} = call {ldt(x.dtype)} @llvm.{llvm_intrinsics[x.op]}.{ldt(x.dtype.scalar())}({ldt(x.src[0].dtype)} {ctx[x.src[0]]})"), (UPat(Ops.BARRIER), lambda ctx: barrier), ]) + base_rewrite extra_matcher = LLVMRenderer.extra_matcher + PatternMatcher([ @@ -211,10 +217,14 @@ class AMDLLVMRenderer(LLVMRenderer): lambda x, y: UOp(Ops.VECTORIZE, dtypes.half.vec(16), tuple(y.gep(i // 2) if i % 2 == 0 else UOp.const(dtypes.half, 0.0) for i in range(16)))), (UPat(Ops.CAST, name="x", dtype=dtypes.half.vec(8), src=UPat.var("y", dtypes.half.vec(16))), lambda x, y: UOp(Ops.VECTORIZE, dtypes.half.vec(8), tuple(y.gep(i * 2) for i in range(8)))), + # amd llvm intrinsics llvm.log2/llvm.exp2 don't support double + (UPat(Ops.LOG2, dtype=dtypes.double, src=(UPat.var("d"),)), xlog2), + (UPat(Ops.EXP2, dtype=dtypes.double, src=(UPat.var("d"),)), xexp2), ]) def _render_footer(self, uops: list[UOp]) -> str: # TODO: this is copied from cstyle - requiredMaxThreadsPerBlock = prod(u.arg[1] for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == "l") + local_dims = [u.src[0] for u in uops if u.op is Ops.SPECIAL and u.arg[0] == "l"] + requiredMaxThreadsPerBlock = sint_to_uop(prod(local_dims)).vmax attributes = ["alwaysinline", "nounwind", '"no-builtins"', f'"amdgpu-flat-work-group-size"="1,{requiredMaxThreadsPerBlock}"', '"no-trapping-math"="true"'] return 'attributes #0 = { ' + ' '.join(attributes) + ' }' diff --git a/tinygrad_repo/tinygrad/renderer/ptx.py b/tinygrad_repo/tinygrad/renderer/ptx.py index 926bcc3f..cc764d1b 100644 --- a/tinygrad_repo/tinygrad/renderer/ptx.py +++ b/tinygrad_repo/tinygrad/renderer/ptx.py @@ -2,7 +2,7 @@ from typing import cast, Callable import struct from collections import defaultdict from tinygrad.codegen.opt import tc -from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp +from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp, sint_to_uop from tinygrad.dtype import dtypes, DType, PtrDType, AddrSpace from tinygrad.renderer import Renderer from tinygrad.renderer.cstyle import CUDARenderer @@ -91,7 +91,7 @@ string_rewrite = PatternMatcher([ (UPat(Ops.STORE, name="x", src=(UPat.var('bidx'), UPat.var("var")), allow_any_len=True), lambda ctx, x, bidx, var: f"st.{mem_type(bidx)}" + \ f"{f'.v{cnt}' if ((cnt:=var.dtype.count)>1) else ''}.{ctx.mem_types[var.dtype.scalar()]} " + \ f"[{ctx.r[bidx]}+0], {('{' + ', '.join(ctx.r[var]) + '}') if var.dtype.count > 1 else ctx.r[var]};"), - (UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"mov.u32 %{x.arg[0]}, %{'ctaid' if x.arg[0][0] == 'g' else 'tid'}.{chr(120+int(x.arg[0][-1]))};"), + (UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"mov.u32 %{x.arg}, %{'ctaid' if x.arg[0] == 'g' else 'tid'}.{chr(120+int(x.arg[-1]))};"), (UPat(Ops.DEFINE_GLOBAL, name="x"), lambda ctx, x: f"ld.param.{ctx.types[dtypes.ulong]} {ctx.r[x]}, [data{x.arg}+0];"), (UPat((Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ), name="x", allow_any_len=True, src=(UPat.var("src0"),)), lambda ctx, x, src0: ctx.code_for_op[x.op](ctx.r[x], *[ctx.r[v] for v in x.src], src0.dtype, ctx.types[src0.dtype])), @@ -119,7 +119,7 @@ string_rewrite = PatternMatcher([ ctx.code_for_op[Ops.CMPLT](ctx.r[x], ctx.r[x.src[0]], ctx.r[src0.src[0]], dtypes.int, ctx.types[dtypes.int]), f"@{ctx.r[x]} bra LOOP_{ctx.r[src0][1:]};"]), (UPat(Ops.DEFINE_LOCAL, name="x"), - lambda ctx, x: [f".shared .align 16 .b8 {x.arg}[{x.dtype.size*x.dtype.itemsize}];", f"mov.u64 {ctx.r[x]}, {x.arg}[0];"]), + lambda ctx, x: [f".shared .align 16 .b8 local{x.arg}[{x.dtype.size*x.dtype.itemsize}];", f"mov.u64 {ctx.r[x]}, local{x.arg}[0];"]), (UPat(Ops.IF, name="x"), lambda ctx, x: f"@!{ctx.r[x.src[0]]} bra IF_{ctx.r[x.src[0]][1:]}_{ctx.uops.index(x)};"), (UPat(Ops.ENDIF, name="x"), lambda ctx, x: f"IF_{ctx.r[x.src[0].src[0]][1:]}_{ctx.uops.index(x.src[0])}:"), (UPat(Ops.WMMA, name="x"), lambda ctx, x: list(render_wmma(ctx, x))), @@ -155,7 +155,8 @@ class PTXRenderer(Renderer): def render_kernel(self, kernel, function_name, bufs, regs, uops) -> str: def fmt(line): return line if line[0]=="$" else "\t" + line.replace(" ", "\t" if len(line.split(" ")[0]) > 7 else "\t\t", 1) kernel = '\n'.join(map(fmt, [f".reg .{reg.split('_')[-2]} %{reg}<{cnt}>;" for reg,cnt in regs] + kernel + ["ret;"])) - launch_bounds = prod(u.arg[1] for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == "l") + local_dims = [u.src[0] for u in uops if u.op is Ops.SPECIAL and u.arg[0] == "l"] + launch_bounds = sint_to_uop(prod(local_dims)).vmax params = ',\n\t'.join([f".param .{'u64' if dtype.__class__ == PtrDType else self.types[dtype]} {name}" for name,dtype in bufs]) return f"{self.kernel_prefix.format(launch_bounds=launch_bounds)} {function_name} (\n\t{params}\n)\n.maxntid {launch_bounds}\n{{\n{kernel}\n}}" @@ -190,7 +191,7 @@ class PTXRenderer(Renderer): r[u] = r[u.src[0]] continue if u.op is Ops.DEFINE_REG: - r[u] = [ssa("reg", u, self.types[u.dtype.base.scalar()]) for _ in range(cast(PtrDType, u.dtype).size)] + r[u] = [ssa("reg", u, self.types[u.dtype.base.scalar()]) for _ in range(u.ptrdtype.size)] continue if u.op in {Ops.INDEX, Ops.LOAD, Ops.STORE} and isinstance(u.src[0].dtype, PtrDType) and u.src[0].dtype.addrspace == AddrSpace.REG: if u.op is Ops.INDEX: @@ -202,7 +203,7 @@ class PTXRenderer(Renderer): typ = "pred" if u.src[1].dtype == dtypes.bool else ("b"+self.types[u.src[1].dtype][1:]) kernel.append(f"mov.{typ} {self.r[u.src[0]]}, {self.r[u.src[1]]};") continue - if u.op is Ops.SPECIAL: r[u] = "%" + u.arg[0] + if u.op is Ops.SPECIAL: r[u] = "%" + u.arg elif u.op is Ops.DEFINE_VAR: bufs.append((u.arg[0], u.dtype)) elif u.op is Ops.LOAD: assert u.src[0].dtype == dtypes.int64, "load isn't int64" @@ -215,7 +216,7 @@ class PTXRenderer(Renderer): [ssa("wmma_acc", dtype="b32") for _ in range(0, len(r[u.src[2]]), 4 // u.dtype.scalar().itemsize)]] r[u] = [ssa("wmma", dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] prefix, dtype = {Ops.CAST: ("cast", None), Ops.BITCAST: ("cast", None), Ops.ENDRANGE: ("pred", "pred"), Ops.RANGE: ("ridx", None), - Ops.DEFINE_VAR: ("dat", None), Ops.CONST: ("const", None), Ops.DEFINE_LOCAL:("local",self.types[dtypes.ulong]), + Ops.DEFINE_VAR: ("dat", None), Ops.CONST: ("const", None), Ops.DEFINE_LOCAL: ("local",self.types[dtypes.ulong]), Ops.DEFINE_GLOBAL: ("dat", self.types[dtypes.ulong]), **{op: ("alu", None) for op in GroupOp.ALU}}.get(u.op, (None, None)) if prefix: r[u] = ssa(prefix, u, dtype) @@ -223,5 +224,5 @@ class PTXRenderer(Renderer): raise RuntimeError(f"failed to render {u.op} with {u.dtype} srcs {[x.dtype for x in u.src]}") kernel.extend([l] if isinstance(l, str) else l) - if u.op is Ops.SPECIAL: kernel = [f".reg .u32 %{u.arg[0]};"] + kernel + if u.op is Ops.SPECIAL: kernel = [f".reg .u32 %{u.arg};"] + kernel return self.render_kernel(kernel, name, bufs, c.items(), uops) diff --git a/tinygrad_repo/tinygrad/renderer/wgsl.py b/tinygrad_repo/tinygrad/renderer/wgsl.py index 4063a23c..49a952ba 100644 --- a/tinygrad_repo/tinygrad/renderer/wgsl.py +++ b/tinygrad_repo/tinygrad/renderer/wgsl.py @@ -84,7 +84,7 @@ class WGSLRenderer(CStyleLanguage): def render_load(self, x:str, dt:DType) -> str: return f"atomicLoad(&{x})" if is_packed(dt) else x def buf_map(self, dt:DType) -> str: return "atomic" if is_packed(dt) else self.type_map[dt.base] def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[DType,bool]]], uops:list[UOp], prefix=None) -> str: - local_size = [num for _, num in sorted([u.arg for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == 'l'], key=lambda x: x[0])] + local_size = [u.src[0].ssimplify() for u in sorted([u for u in uops if u.op is Ops.SPECIAL and u.arg[0] == 'l'], key=lambda u: u.arg)] if not local_size: local_size = [1] bind_it = iter(range(len(bufs))) external_local_bufs = [line.lstrip() for line in kernel if "var" in line] diff --git a/tinygrad_repo/tinygrad/runtime/graph/cuda.py b/tinygrad_repo/tinygrad/runtime/graph/cuda.py index 566b3005..056b0b67 100644 --- a/tinygrad_repo/tinygrad/runtime/graph/cuda.py +++ b/tinygrad_repo/tinygrad/runtime/graph/cuda.py @@ -4,12 +4,11 @@ import tinygrad.runtime.autogen.cuda as cuda from tinygrad.helpers import init_c_var, dedup from tinygrad.device import Buffer, Device from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution -from tinygrad.uop.ops import Variable from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner from tinygrad.engine.jit import MultiGraphRunner, GraphException class CUDAGraph(MultiGraphRunner): - def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]): + def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int]): super().__init__(jit_cache, input_rawbuffers, var_vals) # Check all jit items are compatible. @@ -28,7 +27,7 @@ class CUDAGraph(MultiGraphRunner): deps = self._access_resources([x.base for x in ji.bufs if x is not None], ji.prg.p.outs, new_dependency=new_node) c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None - c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals.get(x, ji.fixedvars.get(x)) for x in ji.prg.p.vars]) + c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals.get(x.expr, ji.fixedvars.get(x.expr)) for x in ji.prg.p.vars]) kern_params = cuda.CUDA_KERNEL_NODE_PARAMS_v1(ji.prg._prg.prg, *global_size, *local_size, 0, None, vargs) check(cuda.cuGraphAddKernelNode(ctypes.byref(new_node), self.graph, c_deps, len(deps), ctypes.byref(kern_params))) @@ -48,7 +47,7 @@ class CUDAGraph(MultiGraphRunner): self.instance = init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), self.graph, None, None, 0))) - def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], wait=False) -> float|None: + def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None: # Update rawbuffers in the c_args struct. for (j,i),input_idx in self.input_replace.items(): if not self.updatable_nodes[j][3]: setattr(self.updatable_nodes[j][2], f'f{i}', input_rawbuffers[input_idx]._buf) diff --git a/tinygrad_repo/tinygrad/runtime/graph/hcq.py b/tinygrad_repo/tinygrad/runtime/graph/hcq.py index d41c56fe..a025cf46 100644 --- a/tinygrad_repo/tinygrad/runtime/graph/hcq.py +++ b/tinygrad_repo/tinygrad/runtime/graph/hcq.py @@ -9,7 +9,7 @@ from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner, Buffer from tinygrad.engine.jit import MultiGraphRunner class HCQGraph(MultiGraphRunner): - def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]): + def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int]): super().__init__(jit_cache, input_rawbuffers, var_vals) self.devices = list(set(cast(HCQCompiled, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs])) @@ -69,7 +69,7 @@ class HCQGraph(MultiGraphRunner): for dev, queue in self.comp_queues.items(): dev_access[queue].add(dev) self.input_replace_map: dict[HCQCompiled, set[int]] = collections.defaultdict(set) - self.fixedvars: dict[HCQCompiled, dict[Variable, int]] = {} + self.fixedvars: dict[HCQCompiled, dict[str, int]] = {} for j,ji in enumerate(jit_cache): if is_exec_prg:=isinstance(ji.prg, CompiledRunner): enqueue_dev: HCQCompiled = ji.prg.dev @@ -183,7 +183,7 @@ class HCQGraph(MultiGraphRunner): self.last_timeline: dict[HCQCompiled, tuple[HCQSignal, int]] = {dev: (dev.timeline_signal, 0) for dev in self.devices} self.queue_signals_to_reset = [self.signals[q] for q in list(self.comp_queues.values()) + list(self.copy_queues.values()) if q in self.signals] - def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], wait=False) -> float|None: + def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None: # Wait and restore signals self.kickoff_value += 1 for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1]) @@ -195,12 +195,13 @@ class HCQGraph(MultiGraphRunner): if PROFILE and self.kickoff_value > 1: self.collect_timestamps() - hcq_var_vals = {self.kickoff_var: self.kickoff_value, **var_vals, - **{var: dev.timeline_value - 1 for dev, var in self.virt_timeline_vals.items()}, - **{sig.base_buf.va_addr: dev.timeline_signal.base_buf.va_addr for dev, sig in self.virt_timeline_signals.items()}} + hcq_var_vals = {self.kickoff_var.expr: self.kickoff_value, **var_vals, + **{var.expr: dev.timeline_value - 1 for dev, var in self.virt_timeline_vals.items()}, + **{sig.base_buf.va_addr.expr: dev.timeline_signal.base_buf.va_addr for dev, sig in self.virt_timeline_signals.items()}} # Update rawbuffers - for (j,i),input_idx in self.input_replace.items(): hcq_var_vals[self.input_replace_to_var.get((j,i))] = input_rawbuffers[input_idx]._buf.va_addr + for (j,i),input_idx in self.input_replace.items(): + hcq_var_vals[self.input_replace_to_var[(j,i)].expr] = input_rawbuffers[input_idx]._buf.va_addr for dev in self.devices: self.comp_queues[dev].submit(dev, hcq_var_vals_local:=hcq_var_vals|self.fixedvars.get(dev, {})) diff --git a/tinygrad_repo/tinygrad/runtime/graph/metal.py b/tinygrad_repo/tinygrad/runtime/graph/metal.py index 112f7005..b181f105 100644 --- a/tinygrad_repo/tinygrad/runtime/graph/metal.py +++ b/tinygrad_repo/tinygrad/runtime/graph/metal.py @@ -5,7 +5,6 @@ from tinygrad.helpers import dedup, getenv, merge_dicts, PROFILE from tinygrad.device import Buffer, ProfileGraphEntry, ProfileGraphEvent from tinygrad.engine.realize import ExecItem, CompiledRunner from tinygrad.engine.jit import GraphRunner, GraphException -from tinygrad.uop.ops import Variable from tinygrad.runtime.ops_metal import wait_check, msg, libobjc, to_struct, objc_instance,\ MTLResourceOptions, cmdbuf_st_time, cmdbuf_en_time, objc_id, to_ns_str @@ -17,7 +16,7 @@ class MTLResourceUsage: MTLResourceUsageWrite = 0b10 class MetalGraph(GraphRunner): - def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]): + def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int]): super().__init__(jit_cache, input_rawbuffers, var_vals) if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException @@ -48,7 +47,8 @@ class MetalGraph(GraphRunner): if b is not None and b not in input_rawbuffers: msg("setKernelBuffer:offset:atIndex:")(icb_command, b._buf.buf, b._buf.offset, i) all_resources.append(b._buf.buf) - for i,v in enumerate(prg.p.vars): msg("setKernelBuffer:offset:atIndex:")(icb_command, self.int_buf.buf, self.varlist.index(v)*4, len(ji.bufs)+i) + for i,v in enumerate(prg.p.vars): + msg("setKernelBuffer:offset:atIndex:")(icb_command, self.int_buf.buf, self.varlist.index(v.expr)*4, len(ji.bufs)+i) global_size, local_size = prg.p.launch_dims(var_vals) msg("concurrentDispatchThreadgroups:threadsPerThreadgroup:")(icb_command, to_struct(*global_size), to_struct(*local_size)) @@ -61,7 +61,7 @@ class MetalGraph(GraphRunner): for var in self.fixedvars: self.int_buf_view[self.varlist.index(var)] = self.fixedvars[var] self.range = to_struct(0, len(jit_cache)) - def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], wait=False) -> float|None: + def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None: if self.command_buffer is not None and self.command_buffer in self.dev.mtl_buffers_in_flight: wait_check(self.command_buffer) # NOTE: old command buffer may not be inflight anymore if self.command_buffer is not None and PROFILE: self.collect_timestamps() diff --git a/tinygrad_repo/tinygrad/runtime/graph/remote.py b/tinygrad_repo/tinygrad/runtime/graph/remote.py index a0d9e20e..8a9c5516 100644 --- a/tinygrad_repo/tinygrad/runtime/graph/remote.py +++ b/tinygrad_repo/tinygrad/runtime/graph/remote.py @@ -1,5 +1,4 @@ import time, itertools -from tinygrad.uop.ops import Variable from tinygrad.engine.jit import MultiGraphRunner from tinygrad.engine.realize import CompiledRunner, BufferXfer, ExecItem from tinygrad.device import Device, Compiled, Buffer @@ -18,7 +17,7 @@ def dev_key(dev:RemoteDevice): return dev.conn if dev.properties.graph_supports_ def map_rawbuf(rawbuf:Buffer): return (cast(RemoteDevice, Device[rawbuf.device]).session, rawbuf._buf) class RemoteGraph(MultiGraphRunner): - def __init__(self, jit_cache: list[ExecItem], rawbufs: list[Buffer], var_vals: dict[Variable, int]): + def __init__(self, jit_cache: list[ExecItem], rawbufs: list[Buffer], var_vals: dict[str, int]): super().__init__(jit_cache, rawbufs, var_vals) devices = dedup(flatten([[Device[unwrap(buf).device] for buf in ji.bufs] for ji in jit_cache])) c2d = {device.conn: device for device in devices} @@ -93,7 +92,7 @@ class RemoteGraph(MultiGraphRunner): for req in self.template: match req: case GraphExec(): RemoteConnection(unwrap(req.session).host).q(GraphFree(req.graph_num, session=req.session)) - def __call__(self, rawbufs: list[Buffer], var_vals: dict[Variable, int], wait=False): + def __call__(self, rawbufs: list[Buffer], var_vals: dict[str, int], wait=False): if wait: st = time.perf_counter() rmap = {orig: map_rawbuf(rawbufs[replace_idx]) for orig,replace_idx in self.handle_indexes.items()} for req in self.template: diff --git a/tinygrad_repo/tinygrad/runtime/ops_amd.py b/tinygrad_repo/tinygrad/runtime/ops_amd.py index a862f359..b6b17767 100644 --- a/tinygrad_repo/tinygrad/runtime/ops_amd.py +++ b/tinygrad_repo/tinygrad/runtime/ops_amd.py @@ -6,9 +6,8 @@ from dataclasses import dataclass from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWQueue, CLikeArgsState, HCQSignal, HCQProgram, FileIOInterface from tinygrad.runtime.support.hcq import MMIOInterface, BumpAllocator from tinygrad.uop.ops import sint -from tinygrad.device import Compiled, DMAFdRef, BufferSpec -from tinygrad.helpers import getenv, to_mv, round_up, data64_le, all_same, flatten, DEBUG, AMD_LLVM, PROFILE, ProfileEvent, suppress_finalizing -from tinygrad.helpers import lo32, hi32 +from tinygrad.device import Compiled, DMAFdRef, BufferSpec, CompilerPairT +from tinygrad.helpers import getenv, to_mv, round_up, data64_le, DEBUG, PROFILE, ProfileEvent, suppress_finalizing, lo32, hi32 from tinygrad.renderer.cstyle import AMDRenderer from tinygrad.renderer.llvmir import AMDLLVMRenderer from tinygrad.runtime.autogen import kfd, hsa, pci, sqtt @@ -109,17 +108,6 @@ class AMDComputeQueue(HWQueue): self.pkt3(self.pm4.PACKET3_RELEASE_MEM, event_dw | cache_flags_dw, memsel_dw, *data64_le(address), *data64_le(value), ctxid) - def xcc_barrier(self): - if self.dev.xcc_sync is None: return self - assert self.dev.xccs == 8, 'only 8 XCCs supported' - a, b = self.dev.xcc_sync - mem_eq = self.pm4.WAIT_REG_MEM_FUNCTION(WAIT_REG_MEM_FUNCTION_EQ) | self.pm4.WAIT_REG_MEM_MEM_SPACE(1) - self.pkt3(self.pm4.PACKET3_ATOMIC_MEM, self.soc.TC_OP_ATOMIC_ADD_RTN_32, *data64_le(a.value_addr), *data64_le(1), *data64_le(0), 0x10) # a += 1 - self.pkt3(self.pm4.PACKET3_WAIT_REG_MEM, mem_eq, *data64_le(a.value_addr), 0, 0b111, 0x80) # a == 0 (mod 8) via bitmask - self.pkt3(self.pm4.PACKET3_ATOMIC_MEM, self.soc.TC_OP_ATOMIC_ADD_RTN_32, *data64_le(b.value_addr), *data64_le(1), *data64_le(0), 0x10) # b += 1 - self.pkt3(self.pm4.PACKET3_WAIT_REG_MEM, mem_eq, *data64_le(b.value_addr), 0, 0b111, 0x80) # b == 0 (mod 8) via bitmask - return self - def memory_barrier(self): pf = '' if self.nbio.version[0] == 2 else '0' if self.nbio.version[:2] != (7, 11) else '1' self.wait_reg_mem(reg_req=getattr(self.nbio, f'regBIF_BX_PF{pf}_GPU_HDP_FLUSH_REQ').addr[0], @@ -127,13 +115,6 @@ class AMDComputeQueue(HWQueue): self.acquire_mem() return self - def xcc_config(self): - self.wreg(self.gc.regCOMPUTE_TG_CHUNK_SIZE, 1) - for xcc_id in range(self.dev.xccs): - with self.pred_exec(xcc_mask=1 << xcc_id): - self.wreg(self.dev.regCOMPUTE_CURRENT_LOGIC_XCC_ID, xcc_id) - return self - def spi_config(self, tracing:bool): self.wreg(self.gc.regSPI_CONFIG_CNTL, ps_pkr_priority_cntl=3, exp_priority_order=3, gpr_write_priority=0x2c688, enable_sqg_bop_events=int(tracing), enable_sqg_top_events=int(tracing)) @@ -278,16 +259,10 @@ class AMDComputeQueue(HWQueue): if prg.dev.sqtt_enabled: self.pkt3(self.pm4.PACKET3_EVENT_WRITE, self.pm4.EVENT_TYPE(self.soc.THREAD_TRACE_MARKER) | self.pm4.EVENT_INDEX(0)) self.pkt3(self.pm4.PACKET3_EVENT_WRITE, self.pm4.EVENT_TYPE(self.soc.CS_PARTIAL_FLUSH) | self.pm4.EVENT_INDEX(EVENT_INDEX_PARTIAL_FLUSH)) - - if self.dev.xccs > 1: - self.release_mem(cache_flush=True) - self.acquire_mem(gli=0) - self.xcc_barrier() return self def wait(self, signal:AMDSignal, value:sint=0): self.wait_reg_mem(mem=signal.value_addr, value=value, mask=0xffffffff) - if self.dev.xccs > 1 and not self.dev.is_aql: self.xcc_barrier() return self def timestamp(self, signal:AMDSignal): @@ -464,14 +439,14 @@ class AMDProgram(HCQProgram): # TODO; this API needs the type signature of the function and global_size/local_size self.dev, self.name, self.lib = dev, name, lib - image, sections, _ = elf_loader(self.lib) + image, sections, relocs = elf_loader(self.lib) rodata_entry = next((sh.header.sh_addr for sh in sections if sh.name == ".rodata"), -1) - text_entry = next((sh.header.sh_addr for sh in sections if sh.name == ".text"), -1) - assert rodata_entry >= 0 and text_entry >= 0, ".text or .rodata section not found" + assert rodata_entry >= 0, ".rodata section not found" - # Relo for kernel_code_entry_byte_offset for AMD_LLVM. Comgr doesn't need that, but keep shared code path. - image[rodata_entry+0x10:rodata_entry+0x10+8] = struct.pack(' 1)) if self.is_aql: self.pm4_ibs = self.iface.alloc(0x2000 if self.is_usb() else (16 << 20), uncached=True, cpu_access=True) self.pm4_ib_alloc = BumpAllocator(self.pm4_ibs.size, wrap=True) @@ -819,9 +788,11 @@ class AMDDevice(HCQCompiled): max_copy_size = 0x40000000 if self.iface.ip_versions[am.SDMA0_HWIP][0] >= 5 else 0x400000 self.sdma_queue = self.create_queue(kfd.KFD_IOC_QUEUE_TYPE_SDMA, 0x200 if self.is_usb() else (16 << 20)) - super().__init__(device, AMDAllocator(self), AMDLLVMRenderer(self.arch) if AMD_LLVM else AMDRenderer(self.arch), - AMDLLVMCompiler(self.arch) if AMD_LLVM else HIPCompiler(self.arch), functools.partial(AMDProgram, self), - AMDSignal, functools.partial(AMDComputeAQLQueue if self.is_aql else AMDComputeQueue, self), + compilers:list[CompilerPairT] = [(functools.partial(AMDLLVMRenderer, self.arch), functools.partial(AMDLLVMCompiler, self.arch)), + (functools.partial(AMDRenderer, self.arch), functools.partial(HIPCompiler, self.arch))] + + super().__init__(device, AMDAllocator(self), compilers, functools.partial(AMDProgram, self), AMDSignal, + functools.partial(AMDComputeAQLQueue if self.is_aql else AMDComputeQueue, self), functools.partial(AMDCopyQueue, self, max_copy_size=max_copy_size), kernargs_size=(8 << 10) if self.is_usb() else (16 << 20), sigalloc_size=0x100 if self.is_usb() else 0x1000) @@ -829,13 +800,6 @@ class AMDDevice(HCQCompiled): self.max_private_segment_size = 0 self._ensure_has_local_memory(128) # set default scratch size to 128 bytes per thread - # XCC setup - self.xcc_sync: tuple[AMDSignal, AMDSignal]|None = None - if self.xccs > 1 and not self.is_aql: - self.xcc_sync_area = self.allocator.alloc(0x1000, BufferSpec(nolru=True, cpu_access=True)) - self.xcc_sync = (AMDSignal(base_buf=self.xcc_sync_area), AMDSignal(base_buf=self.xcc_sync_area.offset(256))) - cast(AMDComputeQueue, self.hw_compute_queue_t()).xcc_config().submit(self) - # SQTT is disabled by default because of runtime overhead and big file sizes (~200mb to Tensor.full() two 4096x4096 tensors and matmul them) self.sqtt_enabled = PROFILE and bool(getenv("SQTT", 0)) if self.sqtt_enabled: @@ -866,10 +830,9 @@ class AMDDevice(HCQCompiled): cwsr_buffer = self.iface.alloc(cwsr_buffer_size) if ctx_save_restore_size else None eop_buffer = self.iface.alloc(eop_buffer_size) if eop_buffer_size else None - return AMDQueueDesc.multi(*(self.iface.create_queue(queue_type, ring, gart, rptr=getattr(hsa.amd_queue_t, 'read_dispatch_id').offset, - wptr=getattr(hsa.amd_queue_t, 'write_dispatch_id').offset, eop_buffer=eop_buffer, cwsr_buffer=cwsr_buffer, - xcc_id=xcc_id, ctx_save_restore_size=ctx_save_restore_size, ctl_stack_size=ctl_stack_size) - for xcc_id in range(self.xccs if queue_type == kfd.KFD_IOC_QUEUE_TYPE_COMPUTE else 1))) + return (self.iface.create_queue(queue_type, ring, gart, rptr=getattr(hsa.amd_queue_t, 'read_dispatch_id').offset, + wptr=getattr(hsa.amd_queue_t, 'write_dispatch_id').offset, eop_buffer=eop_buffer, cwsr_buffer=cwsr_buffer, + ctx_save_restore_size=ctx_save_restore_size, ctl_stack_size=ctl_stack_size)) def _ensure_has_local_memory(self, required): if self.max_private_segment_size >= required: return diff --git a/tinygrad_repo/tinygrad/runtime/ops_gpu.py b/tinygrad_repo/tinygrad/runtime/ops_cl.py similarity index 92% rename from tinygrad_repo/tinygrad/runtime/ops_gpu.py rename to tinygrad_repo/tinygrad/runtime/ops_cl.py index c9ebb338..8887c97f 100644 --- a/tinygrad_repo/tinygrad/runtime/ops_gpu.py +++ b/tinygrad_repo/tinygrad/runtime/ops_cl.py @@ -2,7 +2,7 @@ from __future__ import annotations from typing import cast import ctypes, functools, hashlib from tinygrad.runtime.autogen import opencl as cl -from tinygrad.helpers import init_c_var, to_char_p_p, from_mv, OSX, DEBUG, getenv, mv_address, suppress_finalizing +from tinygrad.helpers import init_c_var, to_char_p_p, from_mv, OSX, DEBUG, mv_address, suppress_finalizing from tinygrad.renderer.cstyle import OpenCLRenderer, IntelRenderer from tinygrad.device import BufferSpec, LRUAllocator, Compiled, Compiler, CompileError @@ -48,8 +48,8 @@ class CLProgram: def __call__(self, *bufs:tuple[ctypes._CData, BufferSpec], global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]|None=None, vals:tuple[int, ...]=(), wait=False) -> float|None: - for i,(b,_) in enumerate(bufs): cl.clSetKernelArg(self.kernel, i, ctypes.sizeof(b), ctypes.byref(b)) - for i,v in enumerate(vals,start=len(bufs)): cl.clSetKernelArg(self.kernel, i, 4, ctypes.byref(ctypes.c_int32(v))) + for i,(b,_) in enumerate(bufs): check(cl.clSetKernelArg(self.kernel, i, ctypes.sizeof(b), ctypes.byref(b))) + for i,v in enumerate(vals,start=len(bufs)): check(cl.clSetKernelArg(self.kernel, i, 4, ctypes.byref(ctypes.c_int32(v)))) if local_size is not None: global_size = cast(tuple[int,int,int], tuple(int(g*l) for g,l in zip(global_size, local_size))) event = cl.cl_event() if wait else None check(cl.clEnqueueNDRangeKernel(self.dev.queue, self.kernel, len(global_size), None, (ctypes.c_size_t * len(global_size))(*global_size), @@ -108,11 +108,9 @@ class CLDevice(Compiled): self.pending_copyin: list[memoryview] = [] self.device_exts = (cl.clGetDeviceInfo(self.device_id, cl.CL_DEVICE_EXTENSIONS, 4096, ctypes.byref(buf := ctypes.create_string_buffer(4096)), ctypes.byref(total := ctypes.c_size_t())), ctypes.string_at(buf, size=total.value).decode())[1] # noqa: E501 - compile_key = hashlib.md5(self.device_name.encode() + self.driver_version.encode()).hexdigest() - renderer = IntelRenderer() if "cl_intel_subgroup_matrix_multiply_accumulate" in self.device_exts and getenv("INTEL") else OpenCLRenderer() - super().__init__(device, CLAllocator(self), renderer, CLCompiler(self, f"compile_cl_{compile_key}"), functools.partial(CLProgram, self)) + compilers = [(IntelRenderer if "cl_intel_subgroup_matrix_multiply_accumulate" in self.device_exts else OpenCLRenderer, + functools.partial(CLCompiler, self, f"compile_cl_{hashlib.md5(self.device_name.encode() + self.driver_version.encode()).hexdigest()}"))] + super().__init__(device, CLAllocator(self), compilers, functools.partial(CLProgram, self)) def synchronize(self): check(cl.clFinish(self.queue)) self.pending_copyin.clear() - -GPUDevice = CLDevice # for legacy reasons diff --git a/tinygrad_repo/tinygrad/runtime/ops_cpu.py b/tinygrad_repo/tinygrad/runtime/ops_cpu.py index 5ebf1f8c..012a9e72 100644 --- a/tinygrad_repo/tinygrad/runtime/ops_cpu.py +++ b/tinygrad_repo/tinygrad/runtime/ops_cpu.py @@ -1,58 +1,52 @@ from __future__ import annotations -import platform, subprocess, sys, ctypes, functools, time, mmap, threading, queue -from tinygrad.helpers import capstone_flatdump, getenv, from_mv, to_mv, OSX, mv_address, wait_cond, cpu_profile -from tinygrad.device import Compiler, BufferSpec, DMACPURef +import platform, sys, ctypes, functools, time, mmap, threading, queue +from tinygrad.helpers import from_mv, to_mv, OSX, WIN, mv_address, wait_cond, cpu_profile, suppress_finalizing, unwrap +from tinygrad.device import BufferSpec, DMACPURef from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocatorBase, HCQBuffer, HWQueue, HCQArgsState, HCQSignal, HCQProgram, MMIOInterface -from tinygrad.runtime.support.elf import jit_loader from tinygrad.renderer.cstyle import ClangRenderer +from tinygrad.renderer.llvmir import LLVMRenderer +from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler, ClangJITCompiler from tinygrad.uop.ops import sint class CPUSignal(HCQSignal): def _sleep(self, time_spent_waiting_ms:int): if self.is_timeline and self.owner is not None: self.owner.tasks.join() -class ClangJITCompiler(Compiler): - def __init__(self, cachekey="compile_clang_jit"): super().__init__(cachekey) - - def compile(self, src:str) -> bytes: - # -fno-math-errno is required for __builtin_sqrt to become an instruction instead of a function call - # x18 is a reserved platform register. It is clobbered on context switch in macos and is used to store TEB pointer in windows on arm, don't use it - target = 'x86_64' if sys.platform == 'win32' else platform.machine() - # on arm march means "runs on this arch and superset" instead of "optimize for this arch". x86 march == arm mcpu - arch = '-march=native' if platform.machine() in ('x86_64', 'AMD64') else '-mcpu=native' - args = [arch, f'--target={target}-none-unknown-elf', '-O2', '-fPIC', '-ffreestanding', '-fno-math-errno', '-nostdlib', '-fno-ident'] - arch_args = ['-ffixed-x18'] if target == 'arm64' else [] - obj = subprocess.check_output([getenv("CC", 'clang'), '-c', '-x', 'c', *args, *arch_args, '-', '-o', '-'], input=src.encode('utf-8')) - return jit_loader(obj) - - def disassemble(self, lib:bytes): return capstone_flatdump(lib) - class CPUWorker(threading.Thread): - def __init__(self, dev): + def __init__(self, dev, tasks, thread_id): super().__init__() - self.dev, self.tasks, self.daemon = dev, dev.tasks, True + self.dev, self.tasks, self.thread_id, self.pool, self.daemon = dev, tasks, thread_id, [], True + + def push_task(self, tid, cmd, args): + if len(self.pool) <= tid: + self.pool.append(queue.Queue()) + CPUWorker(self, self.pool[tid], thread_id=tid+1).start() + self.pool[tid].put([cmd, 1, len(args)] + args) def run(self): while True: cmd_iter = iter(self.tasks.get()) for cmd in cmd_iter: - args_cnt = next(cmd_iter) - cmd(*[next(cmd_iter) for _ in range(args_cnt)]) + threads, args_cnt = next(cmd_iter), next(cmd_iter) + args = [next(cmd_iter) for _ in range(args_cnt)] + for th in range(threads - 1): self.push_task(th, cmd, args) + cmd(self.thread_id, *args) + for th in range(threads - 1): self.pool[th].join() self.tasks.task_done() class CPUComputeQueue(HWQueue): - def _exec(self, prg, bufs, *args): - prg.fxn(*map(ctypes.c_uint64, args[:bufs]), *map(ctypes.c_int64 if platform.machine() == "arm64" else ctypes.c_int32, args[bufs:])) - def _signal(self, signal_addr, value): to_mv(signal_addr, 4).cast('I')[0] = value - def _wait(self, signal_addr, value): wait_cond(lambda: to_mv(signal_addr, 4).cast('I')[0] >= value, timeout_ms=60000) - def _timestamp(self, timestamp_addr): to_mv(timestamp_addr, 8).cast('Q')[0] = time.perf_counter_ns() - def cmd(self, cmd, *args): - self.q(cmd, len(args), *args) + def _exec(self, tid, prg, bufs, *args): + prg.fxn(*map(ctypes.c_uint64, args[:bufs]), *map(ctypes.c_int64 if platform.machine() == "arm64" else ctypes.c_int32, args[bufs:]), tid) + def _signal(self, tid, signal_addr, value): to_mv(signal_addr, 4).cast('I')[0] = value + def _wait(self, tid, signal_addr, value): wait_cond(lambda: to_mv(signal_addr, 4).cast('I')[0] >= value, timeout_ms=60000) + def _timestamp(self, tid, timestamp_addr): to_mv(timestamp_addr, 8).cast('Q')[0] = time.perf_counter_ns() + def cmd(self, cmd, *args, threads=1): + self.q(cmd, threads, len(args), *args) return self def memory_barrier(self): return self def exec(self, prg:CPUProgram, args_state:HCQArgsState, global_size, local_size): - return self.cmd(self._exec, prg, len(args_state.bufs), *[x.va_addr for x in args_state.bufs], *args_state.vals) + return self.cmd(self._exec, prg, len(args_state.bufs), *[x.va_addr for x in args_state.bufs], *args_state.vals, threads=(global_size or (1,))[0]) def wait(self, signal, value=0): return self.cmd(self._wait, signal.value_addr, value) def timestamp(self, signal): return self.cmd(self._timestamp, signal.timestamp_addr) def signal(self, signal, value:sint=0): return self.cmd(self._signal, signal.value_addr, value) @@ -62,10 +56,12 @@ class CPUComputeQueue(HWQueue): MAP_JIT = 0x0800 class CPUProgram(HCQProgram): - rt_lib = ctypes.CDLL(ctypes.util.find_library('System' if OSX else 'kernel32') if OSX or sys.platform == "win32" else 'libgcc_s.so.1') + rt_lib = None + try: rt_lib = ctypes.CDLL(ctypes.util.find_library('System' if OSX else 'kernel32') if OSX or WIN else 'libgcc_s.so.1') + except OSError: pass def __init__(self, dev, name:str, lib:bytes): - if sys.platform == "win32": + if sys.platform == "win32": # mypy doesn't understand when WIN is used here PAGE_EXECUTE_READWRITE, MEM_COMMIT, MEM_RESERVE = 0x40, 0x1000, 0x2000 ctypes.windll.kernel32.VirtualAlloc.restype = ctypes.c_void_p self.mem = ctypes.windll.kernel32.VirtualAlloc(ctypes.c_void_p(0), ctypes.c_size_t(len(lib)), MEM_COMMIT | MEM_RESERVE, PAGE_EXECUTE_READWRITE) @@ -79,28 +75,33 @@ class CPUProgram(HCQProgram): # MAP_JIT allows us to easily flip pages from RW- to R-X and vice versa. It is a noop on intel cpus. (man pthread_jit_write_protect_np) self.mem = mmap.mmap(-1, len(lib), mmap.MAP_ANON|mmap.MAP_PRIVATE|(MAP_JIT if OSX else 0), mmap.PROT_READ|mmap.PROT_WRITE|mmap.PROT_EXEC) - if OSX: CPUProgram.rt_lib.pthread_jit_write_protect_np(False) + if OSX: unwrap(CPUProgram.rt_lib).pthread_jit_write_protect_np(False) self.mem.write(lib) - if OSX: CPUProgram.rt_lib.pthread_jit_write_protect_np(True) + if OSX: unwrap(CPUProgram.rt_lib).pthread_jit_write_protect_np(True) # __clear_cache isn't a normal libc function, but a compiler support routine found in libgcc_s for gcc and compiler-rt for clang. # libgcc_s comes as shared library but compiler-rt is only a bunch of static library archives which we can't directly load, but fortunately # it somehow found its way into libSystem on macos (likely because it used __builtin_clear_cache) and libgcc_s is ~always present on linux # Using ["name"] instead of .name because otherwise name is getting mangled: https://docs.python.org/3.12/reference/expressions.html#index-5 - CPUProgram.rt_lib["__clear_cache"](ctypes.c_void_p(mv_address(self.mem)), ctypes.c_void_p(mv_address(self.mem) + len(lib))) + if CPUProgram.rt_lib is not None: + CPUProgram.rt_lib["__clear_cache"](ctypes.c_void_p(mv_address(self.mem)), ctypes.c_void_p(mv_address(self.mem) + len(lib))) + else: + # msync should be a universal POSIX way to do this + from tinygrad.runtime.autogen import libc + libc.msync(ctypes.c_void_p(mv_address(self.mem)), len(lib), libc.MS_SYNC | libc.MS_INVALIDATE) self.fxn = ctypes.CFUNCTYPE(None)(mv_address(self.mem)) super().__init__(HCQArgsState, dev, name, kernargs_alloc_size=0) + @suppress_finalizing def __del__(self): - if getattr(sys, 'is_finalizing', lambda: True)(): return if sys.platform == 'win32': ctypes.windll.kernel32.VirtualFree(ctypes.c_void_p(self.mem), ctypes.c_size_t(0), 0x8000) #0x8000 - MEM_RELEASE class CPUAllocator(HCQAllocatorBase): def _alloc(self, size:int, options:BufferSpec) -> HCQBuffer: if options.external_ptr: addr, buf = options.external_ptr, None - elif sys.platform == "win32": addr = mv_address(buf:=mmap.mmap(-1, size, access=mmap.ACCESS_WRITE)) + elif WIN: addr = mv_address(buf:=mmap.mmap(-1, size, access=mmap.ACCESS_WRITE)) else: addr = mv_address(buf:=mmap.mmap(-1, size, mmap.MAP_ANON | mmap.MAP_PRIVATE, mmap.PROT_READ | mmap.PROT_WRITE)) return HCQBuffer(va:=addr, sz:=size, meta=buf, view=MMIOInterface(va, sz, fmt='B'), owner=self.dev) def _as_buffer(self, src) -> memoryview: @@ -121,5 +122,6 @@ class CPUAllocator(HCQAllocatorBase): class CPUDevice(HCQCompiled): def __init__(self, device:str=""): self.tasks:queue.Queue = queue.Queue() - CPUWorker(self).start() - super().__init__(device, CPUAllocator(self), ClangRenderer(), ClangJITCompiler(), functools.partial(CPUProgram, self), CPUSignal, CPUComputeQueue) + CPUWorker(self, self.tasks, thread_id=0).start() + compilers = [(ClangRenderer, ClangJITCompiler), (LLVMRenderer, CPULLVMCompiler)] + super().__init__(device, CPUAllocator(self), compilers, functools.partial(CPUProgram, self), CPUSignal, CPUComputeQueue) diff --git a/tinygrad_repo/tinygrad/runtime/ops_cuda.py b/tinygrad_repo/tinygrad/runtime/ops_cuda.py index 326ae84f..440f68b5 100644 --- a/tinygrad_repo/tinygrad/runtime/ops_cuda.py +++ b/tinygrad_repo/tinygrad/runtime/ops_cuda.py @@ -1,11 +1,11 @@ from __future__ import annotations import ctypes, ctypes.util, functools from tinygrad.helpers import DEBUG, getenv, mv_address, init_c_var, init_c_struct_t, suppress_finalizing -from tinygrad.device import Compiled, BufferSpec, LRUAllocator +from tinygrad.device import Compiled, BufferSpec, LRUAllocator, CompilerPairT from tinygrad.renderer.cstyle import CUDARenderer from tinygrad.renderer.ptx import PTXRenderer from tinygrad.runtime.autogen import cuda -from tinygrad.runtime.support.compiler_cuda import pretty_ptx, CUDACompiler, PTXCompiler, PTX +from tinygrad.runtime.support.compiler_cuda import pretty_ptx, CUDACompiler, PTXCompiler if getenv("IOCTL"): import extra.nv_gpu_driver.nv_ioctl # noqa: F401 # pylint: disable=unused-import if MOCKGPU:=getenv("MOCKGPU"): from test.mockgpu.cuda import cuda # type: ignore # pylint: disable=reimported @@ -115,8 +115,9 @@ class CUDADevice(Compiled): CUDADevice.devices.append(self) from tinygrad.runtime.graph.cuda import CUDAGraph - super().__init__(device, CUDAAllocator(self), PTXRenderer(self.arch) if PTX else CUDARenderer(self.arch), - PTXCompiler(self.arch) if PTX else CUDACompiler(self.arch), functools.partial(CUDAProgram, self), None if MOCKGPU else CUDAGraph) + compilers:list[CompilerPairT] = [(functools.partial(CUDARenderer, self.arch), functools.partial(CUDACompiler, self.arch)), + (functools.partial(PTXRenderer, self.arch), functools.partial(PTXCompiler, self.arch))] + super().__init__(device, CUDAAllocator(self), compilers, functools.partial(CUDAProgram, self), None if MOCKGPU else CUDAGraph) def synchronize(self): check(cuda.cuCtxSetCurrent(self.context)) diff --git a/tinygrad_repo/tinygrad/runtime/ops_disk.py b/tinygrad_repo/tinygrad/runtime/ops_disk.py index 1f810ec9..56809daf 100644 --- a/tinygrad_repo/tinygrad/runtime/ops_disk.py +++ b/tinygrad_repo/tinygrad/runtime/ops_disk.py @@ -15,7 +15,7 @@ class DiskDevice(Compiled): self.size: int|None = None self.fd: int|None = None self.count = 0 - super().__init__(device, DiskAllocator(self), None, None, None) + super().__init__(device, DiskAllocator(self), None, None) def _might_open(self, size:int): assert self.size is None or size <= self.size, f"can't reopen Disk tensor with larger size, opened with {self.size}, tried to open with {size}" if self.size is not None and hasattr(self.device, "mem"): @@ -39,7 +39,9 @@ class DiskDevice(Compiled): def _might_close(self): self.count -= 1 if self.count == 0: - if self.fd is not None: os.close(self.fd) + if self.fd is not None: + os.close(self.fd) + if hasattr(self, "mem"): self.mem.close() self.size = None def _iouring_setup(self): DiskDevice._tried_io_uring_init = True diff --git a/tinygrad_repo/tinygrad/runtime/ops_dsp.py b/tinygrad_repo/tinygrad/runtime/ops_dsp.py index 640d992e..d93f14af 100644 --- a/tinygrad_repo/tinygrad/runtime/ops_dsp.py +++ b/tinygrad_repo/tinygrad/runtime/ops_dsp.py @@ -36,8 +36,10 @@ dsp_string = PatternMatcher([ class DSPRenderer(ClangRenderer): device = "DSP" supports_float4 = True + has_threads = False buffer_suffix = " restrict __attribute__((align_value(128)))" kernel_typedef = "__attribute__((noinline)) void" + extra_args = [] pre_matcher = dsp_pm extra_matcher = dsp_pm_late+ClangRenderer.extra_matcher string_rewrite = dsp_string+ClangRenderer.string_rewrite @@ -132,8 +134,8 @@ class DSPDevice(Compiled): def __init__(self, device:str=""): compiler_args = ["--target=hexagon", "-mcpu=hexagonv65", "-fuse-ld=lld", "-nostdlib", "-mhvx=v65", "-mhvx-length=128b"] if getenv("MOCKDSP"): - super().__init__(device, CPUAllocator(self), MockDSPRenderer(), - ClangCompiler(None, ["-static"] + compiler_args, 'llvm-objdump'), MockDSPProgram) + mock_compilers = [(MockDSPRenderer, functools.partial(ClangCompiler, None, ["-static"] + compiler_args, 'llvm-objdump'))] + super().__init__(device, CPUAllocator(self), mock_compilers, MockDSPProgram) else: self.ion_fd = os.open('/dev/ion', os.O_RDONLY) # Generate link script to pass into clang. Aligning all used sections to 4k fixes invoke problem. @@ -144,8 +146,9 @@ class DSPDevice(Compiled): self.link_ld.write(f"SECTIONS {{ . = 0x0; {sections_link}\n /DISCARD/ : {{ *(.note .note.* .gnu.hash .comment) }} }}".encode()) self.link_ld.flush() - super().__init__(device, DSPAllocator(self), DSPRenderer(), - ClangCompiler("compile_dsp", ["-shared"] + compiler_args + [f"-T{self.link_ld.name}"], 'llvm-objdump'), functools.partial(DSPProgram, self)) + compilers = [(DSPRenderer, functools.partial(ClangCompiler, "compile_dsp", ["-shared"] + compiler_args + [f"-T{self.link_ld.name}"], + 'llvm-objdump'))] + super().__init__(device, DSPAllocator(self), compilers, functools.partial(DSPProgram, self)) fastrpc_shell = memoryview(bytearray(pathlib.Path('/dsp/cdsp/fastrpc_shell_3').read_bytes())) self.shell_buf = self.allocator.alloc(round_up(fastrpc_shell.nbytes, 0x1000), BufferSpec(nolru=True)) ctypes.memmove(self.shell_buf.va_addr, mv_address(fastrpc_shell), fastrpc_shell.nbytes) diff --git a/tinygrad_repo/tinygrad/runtime/ops_hip.py b/tinygrad_repo/tinygrad/runtime/ops_hip.py index a15e3668..6bd97600 100644 --- a/tinygrad_repo/tinygrad/runtime/ops_hip.py +++ b/tinygrad_repo/tinygrad/runtime/ops_hip.py @@ -14,7 +14,9 @@ class HIPDevice(Compiled): self.device_id = int(device.split(":")[1]) if ":" in device else 0 self.arch = init_c_var(hip.hipDeviceProp_t(), lambda x: check(hip.hipGetDeviceProperties(x, self.device_id))).gcnArchName.decode() self.time_event_st, self.time_event_en = [init_c_var(hip.hipEvent_t(), lambda x: hip.hipEventCreate(ctypes.byref(x), 0)) for _ in range(2)] - super().__init__(device, HIPAllocator(self), HIPRenderer(self.arch), HIPCompiler(self.arch), functools.partial(HIPProgram, self)) + + compilers = [(functools.partial(HIPRenderer, self.arch), functools.partial(HIPCompiler, self.arch))] + super().__init__(device, HIPAllocator(self), compilers, functools.partial(HIPProgram, self)) def synchronize(self): check(hip.hipSetDevice(self.device_id)) check(hip.hipDeviceSynchronize()) diff --git a/tinygrad_repo/tinygrad/runtime/ops_metal.py b/tinygrad_repo/tinygrad/runtime/ops_metal.py index 9e453eb7..df9d136d 100644 --- a/tinygrad_repo/tinygrad/runtime/ops_metal.py +++ b/tinygrad_repo/tinygrad/runtime/ops_metal.py @@ -76,7 +76,7 @@ class MetalDevice(Compiled): from tinygrad.runtime.graph.metal import MetalGraph # NOTE: GitHub CI macOS runners use paravirtualized metal which is broken with graph. # This can be reproduced locally with any virtualization software (like utm) that can create macOS VMs with apple's own virtualization framework. - super().__init__(device, MetalAllocator(self), MetalRenderer(), MetalCompiler() if getenv("METAL_DIRECT", 1) else Compiler(), + super().__init__(device, MetalAllocator(self), [(MetalRenderer, MetalCompiler), (MetalRenderer, Compiler)], functools.partial(MetalProgram, self), MetalGraph if 'virtual' not in from_ns_str(msg('name')(self.sysdevice)).lower() else None) def synchronize(self): diff --git a/tinygrad_repo/tinygrad/runtime/ops_npy.py b/tinygrad_repo/tinygrad/runtime/ops_npy.py index f40bfcde..d92309ba 100644 --- a/tinygrad_repo/tinygrad/runtime/ops_npy.py +++ b/tinygrad_repo/tinygrad/runtime/ops_npy.py @@ -8,4 +8,4 @@ class NpyAllocator(Allocator['NpyDevice']): def _copyout(self, dest:memoryview, src:np.ndarray): dest[:] = self._as_buffer(src) class NpyDevice(Compiled): - def __init__(self, device:str): super().__init__(device, NpyAllocator(self), None, None, None) + def __init__(self, device:str): super().__init__(device, NpyAllocator(self), None, None) diff --git a/tinygrad_repo/tinygrad/runtime/ops_null.py b/tinygrad_repo/tinygrad/runtime/ops_null.py index 10ae40f4..c8f5a6b5 100644 --- a/tinygrad_repo/tinygrad/runtime/ops_null.py +++ b/tinygrad_repo/tinygrad/runtime/ops_null.py @@ -1,28 +1,33 @@ +import functools from tinygrad.device import Compiled, Compiler, Allocator from tinygrad.engine.jit import MultiGraphRunner from tinygrad.renderer.cstyle import CStyleLanguage from tinygrad.uop.ops import Ops +from tinygrad.helpers import cpu_profile class NullRenderer(CStyleLanguage): device = "NULL" has_local = False float4 = "float4" + barrier = "// BARRIER" code_for_op = {**CStyleLanguage.code_for_op, Ops.THREEFRY: lambda a,b,dtype: f"threefry({a},{b})", Ops.MAX: lambda a,b,dtype: f"max({a},{b})"} class NullProgram: - def __init__(self, name:str, lib:bytes): pass + def __init__(self, device:str, name:str, lib:bytes): self.device, self.name = device, name def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False): - return 1e-4 + with cpu_profile(self.name, self.device): return 1e-4 class NullAllocator(Allocator['NullDevice']): def _alloc(self, size, options): pass def _copyin(self, dest, src:memoryview): pass def _copyout(self, dest:memoryview, src): pass - def _transfer(self, dest, src, sz:int, src_dev, dest_dev): pass + def _transfer(self, dest, src, sz:int, src_dev, dest_dev): + with cpu_profile(f"{src_dev.device} -> {dest_dev.device}", self.dev.device): pass def _offset(self, buf, offset:int, size:int): pass class NullGraph(MultiGraphRunner): def __call__(self, input_rawbuffers, var_vals, wait=False) -> float|None: return 1e-3 class NullDevice(Compiled): - def __init__(self, device:str): super().__init__(device, NullAllocator(self), NullRenderer(), Compiler(), NullProgram, NullGraph) + def __init__(self, device:str): super().__init__(device, NullAllocator(self), [(NullRenderer, Compiler)], functools.partial(NullProgram, device), + NullGraph) diff --git a/tinygrad_repo/tinygrad/runtime/ops_nv.py b/tinygrad_repo/tinygrad/runtime/ops_nv.py index 81dbbd37..38dbb450 100644 --- a/tinygrad_repo/tinygrad/runtime/ops_nv.py +++ b/tinygrad_repo/tinygrad/runtime/ops_nv.py @@ -6,11 +6,11 @@ from dataclasses import dataclass from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWQueue, CLikeArgsState, HCQProgram, HCQSignal, BumpAllocator from tinygrad.runtime.support.hcq import MMIOInterface, FileIOInterface, MOCKGPU from tinygrad.uop.ops import sint -from tinygrad.device import BufferSpec +from tinygrad.device import BufferSpec, CompilerPairT from tinygrad.helpers import getenv, mv_address, round_up, data64, data64_le, prod, OSX, to_mv, hi32, lo32, suppress_finalizing from tinygrad.renderer.ptx import PTXRenderer from tinygrad.renderer.cstyle import NVRenderer -from tinygrad.runtime.support.compiler_cuda import CUDACompiler, PTXCompiler, PTX, NVPTXCompiler, NVCompiler +from tinygrad.runtime.support.compiler_cuda import CUDACompiler, PTXCompiler, NVPTXCompiler, NVCompiler from tinygrad.runtime.autogen import nv_gpu, pci from tinygrad.runtime.support.elf import elf_loader from tinygrad.runtime.support.nv.nvdev import NVDev, NVMemoryManager @@ -525,9 +525,9 @@ class NVDevice(HCQCompiled[HCQSignal]): self.arch: str = "sm_120" if self.sm_version==0xa04 else f"sm_{(self.sm_version>>8)&0xff}{(val>>4) if (val:=self.sm_version&0xff) > 0xf else val}" self.sass_version = ((self.sm_version & 0xf00) >> 4) | (self.sm_version & 0xf) - compiler_t = (PTXCompiler if PTX else CUDACompiler) if MOCKGPU else (NVPTXCompiler if PTX else NVCompiler) - super().__init__(device, NVAllocator(self), PTXRenderer(self.arch, device="NV") if PTX else NVRenderer(self.arch), compiler_t(self.arch), - functools.partial(NVProgram, self), HCQSignal, NVComputeQueue, NVCopyQueue) + compilers:list[CompilerPairT] = [(functools.partial(NVRenderer, self.arch),functools.partial(CUDACompiler if MOCKGPU else NVCompiler, self.arch)), + (functools.partial(PTXRenderer, self.arch, device="NV"), functools.partial(PTXCompiler if MOCKGPU else NVPTXCompiler, self.arch))] + super().__init__(device, NVAllocator(self), compilers, functools.partial(NVProgram, self), HCQSignal, NVComputeQueue, NVCopyQueue) self._setup_gpfifos() diff --git a/tinygrad_repo/tinygrad/runtime/ops_python.py b/tinygrad_repo/tinygrad/runtime/ops_python.py index 991a3387..9dd145d2 100644 --- a/tinygrad_repo/tinygrad/runtime/ops_python.py +++ b/tinygrad_repo/tinygrad/runtime/ops_python.py @@ -2,27 +2,39 @@ # a python uops emulator # works to test the tensor cores, and all the uops in general # this is the (living) definition of uops -from typing import Any, TYPE_CHECKING +from typing import Any, TYPE_CHECKING, cast import pickle, base64, itertools, time, struct, sys -from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate -from tinygrad.helpers import all_same, getenv, flatten, get_single_element +from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate, float_to_bf16, float_to_fp8, fp8_to_float +from tinygrad.helpers import all_same, getenv, flatten, get_single_element, EMULATE from tinygrad.device import Compiled, Compiler, Allocator from tinygrad.codegen.opt import tc -from tinygrad.uop.ops import exec_alu, Ops, UOp, GroupOp +from tinygrad.uop.ops import exec_alu, python_alu, Ops, UOp, GroupOp from tinygrad.renderer import Renderer -def _load(m, i): +def storage_fmt_for_dtype(dtype: DType): return 'H' if dtype == dtypes.bfloat16 else 'B' if dtype in dtypes.fp8s else dtype.fmt + +def to_storage_scalar(x, dtype: DType): + if dtype == dtypes.bfloat16: return (struct.unpack('I', struct.pack('f', float_to_bf16(x)))[0] >> 16) & 0xFFFF + if dtype in dtypes.fp8s: return float_to_fp8(float(x), dtype) + return x + +def from_storage_scalar(x, dtype: DType): + if dtype == dtypes.bfloat16: return struct.unpack('f', struct.pack('I', (x & 0xFFFF) << 16))[0] + if dtype in dtypes.fp8s: return fp8_to_float(int(x), dtype) + return x + +def _load(m, i, dtype: DType): if i is None: return 0.0 if i < 0 or i >= len(m): raise IndexError(f"load out of bounds, size is {len(m)} and access is {i}") - return m[i] + return from_storage_scalar(m[i], dtype) -def load(inp, j=0): - if len(inp) == 2: return [_load(m, x+j if x is not None else None) if gate else default for (m,x,gate),default in zip(*inp)] - return [_load(m, x+j if x is not None else None) for m,x,_ in inp[0]] +def load(inp, j, dtype: DType): + if len(inp) == 2: return [_load(m, x+j if x is not None else None, dtype) if gate else default for (m,x,gate),default in zip(*inp)] + return [_load(m, x+j if x is not None else None, dtype) for m,x,_ in inp[0]] -def _store(m, i, v): +def _store(m, i, v, dtype: DType): if i < 0 or i >= len(m): raise IndexError(f"store out of bounds, size is {len(m)}, access is {i}, value is {v}") - m[i] = v + m[i] = to_storage_scalar(v, dtype) class PythonProgram: def __init__(self, name:str, lib:bytes): @@ -57,24 +69,25 @@ class PythonProgram: if uop is Ops.STORE: for j,val in enumerate(inp[1] if dtp[1].count > 1 else [inp[1]]): for (m,o,g),v in zip(inp[0], val): - if g: _store(m, o+j, v) + if g: _store(m, o+j, v, dtp[1].scalar()) i += 1 continue if uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}: assert isinstance(dtype, PtrDType), dtype - if dtype.fmt is None: raise RuntimeError(f"{dtype=} is not supported") - if TYPE_CHECKING or sys.version_info < (3, 12): assert dtype.fmt != "e" + storage_fmt = storage_fmt_for_dtype(dtype.base.scalar()) + if storage_fmt is None: raise RuntimeError(f"{dtype=} is not supported") + if TYPE_CHECKING or sys.version_info < (3, 12): assert storage_fmt != "e" if uop is Ops.DEFINE_REG: # REGs are per thread - ul[i] = [memoryview(bytearray(dtype.size*dtype.itemsize)).cast(dtype.fmt) for _ in range(warp_size)] + ul[i] = [memoryview(bytearray(dtype.size*dtype.itemsize)).cast(storage_fmt) for _ in range(warp_size)] else: buf = memoryview(bytearray(dtype.size*dtype.itemsize)) if uop is not Ops.DEFINE_GLOBAL else pbufs.pop(0) - ul[i] = [buf.cast(dtype.fmt)] * warp_size + ul[i] = [buf.cast(storage_fmt)] * warp_size elif uop is Ops.DEFINE_VAR: ul[i] = [pvals.pop(0)] * warp_size elif uop is Ops.SPECIAL: - if arg[0][0] == 'g': ul[i] = [idxs[2-int(arg[0][-1])]] * warp_size - elif arg[0][0] == 'l': ul[i] = [x[2-int(arg[0][-1])] for x in warp] + if arg[0] == 'g': ul[i] = [idxs[2-int(arg[-1])]] * warp_size + elif arg[0] == 'l': ul[i] = [x[2-int(arg[-1])] for x in warp] elif uop is Ops.CONST: ul[i] = [arg] * warp_size elif uop is Ops.INDEX: ret:list = [] @@ -98,16 +111,17 @@ class PythonProgram: continue elif uop is Ops.VECTORIZE: ul[i] = inp elif uop is Ops.BITCAST: - assert dtp[0].fmt and dtype.fmt - pack_format, unpack_format = str(warp_size) + dtp[0].fmt, str(warp_size) + dtype.fmt - ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0]))) + packed = struct.pack(str(warp_size) + storage_fmt_for_dtype(dtp[0].scalar()), *[to_storage_scalar(x, dtp[0].scalar()) for x in inp[0]]) + ul[i] = list(struct.unpack(str(warp_size) + storage_fmt_for_dtype(dtype.scalar()), packed)) + ul[i] = [from_storage_scalar(x, dtype.scalar()) for x in ul[i]] elif uop is Ops.CAST: ul[i] = [truncate.get(dtype, lambda dt: dt)(dtypes.as_const(x, dtype)) for x in inp[0]] elif uop is Ops.LOAD: if dtype.count > 1: - ul[i] = [load([inp[i][j] if i != 0 and dtp[i].count > 1 else inp[i] for i in range(len(inp))], j) for j in range(dtype.count)] + ul[i] = [load([inp[i][j] if i != 0 and dtp[i].count > 1 else inp[i] for i in range(len(inp))], j, dtype.scalar()) \ + for j in range(dtype.count)] else: - ul[i] = load(inp) + ul[i] = load(inp, 0, dtype) elif uop is Ops.GEP: ul[i] = inp[0][get_single_element(arg)] elif uop is Ops.WMMA: # here are the models for the WMMA instruction on the different hardware @@ -188,7 +202,7 @@ class PythonProgram: else: raise NotImplementedError(f"unimplemented tensor core {arg}") elif uop in GroupOp.ALU: assert all_same([len(x) for x in inp]), f"{[len(x) for x in inp]} doesn't match on {uop}" - assert all_same([dtype] + dtp) or uop in {Ops.CMPNE, Ops.CMPLT, Ops.WHERE}, f"dtype mismatch on {uop}" + assert all_same([dtype] + dtp) or uop in {*GroupOp.Comparison, Ops.WHERE}, f"dtype mismatch on {uop}" ul[i] = [exec_alu(uop, dtype, p) for p in zip(*inp)] assert i in ul, (uop, dtype, idp, arg) i += 1 @@ -196,18 +210,23 @@ class PythonProgram: class PythonRenderer(Renderer): device = "PYTHON" + code_for_op = python_alu def __init__(self): - if getenv("EMULATE_METAL"): self.device, self.tensor_cores = "METAL", tc.metal - if getenv("EMULATE_AMD"): self.device, self.tensor_cores = "AMD", tc.amd_rdna3 - if getenv("EMULATE_AMD_MFMA"): self.device, self.tensor_cores = "AMD", tc.amd_cdna - if getenv("EMULATE_AMD_RDNA4"): self.device, self.tensor_cores = "AMD", tc.amd_rdna4 - if getenv("EMULATE_CUDA"): self.device, self.tensor_cores = "CUDA", tc.cuda_sm80 - if getenv("EMULATE_CUDA_SM75"): self.device, self.tensor_cores = "CUDA", tc.cuda_sm75 - if getenv("EMULATE_INTEL"): self.device, self.suffix, self.tensor_cores = "INTEL", "INTEL", tc.intel - if getenv("EMULATE_AMX"): self.device, self.tensor_cores = "CPU", tc.amx + match cast(str, EMULATE.value): + case "METAL": self.device, self.tensor_cores = "METAL", tc.metal + case "AMD": self.device, self.tensor_cores = "AMD", tc.amd_rdna3 + case "AMD_MFMA": self.device, self.tensor_cores = "AMD", tc.amd_cdna + case "AMD_RDNA4": self.device, self.tensor_cores = "AMD", tc.amd_rdna4 + case "CUDA": self.device, self.tensor_cores = "CUDA", tc.cuda_sm80 + case "CUDA_SM75": self.device, self.tensor_cores = "CUDA", tc.cuda_sm75 + case "INTEL": self.device, self.suffix, self.tensor_cores = "INTEL", "INTEL", tc.intel + case "AMX": self.device, self.tensor_cores = "CPU", tc.amx + case "": pass + case _: raise RuntimeError(f"can't EMULATE device: {EMULATE.value}") def render(self, uops:list[UOp]) -> str: - lops = [(u.op, u.dtype, [uops.index(v) for v in u.src], u.arg) for u in uops] + # the value of SPECIAL comes from local/global_size, not form its source + lops = [(u.op, u.dtype, [uops.index(v) for v in u.src if u.op is not Ops.SPECIAL], u.arg) for u in uops] return base64.b64encode(pickle.dumps(lops)).decode() class PythonCompiler(Compiler): @@ -219,4 +238,4 @@ class PythonAllocator(Allocator['PythonDevice']): def _copyout(self, dest:memoryview, src): dest[:] = src class PythonDevice(Compiled): - def __init__(self, device:str): super().__init__(device, PythonAllocator(self), PythonRenderer(), PythonCompiler(), PythonProgram) + def __init__(self, device:str): super().__init__(device, PythonAllocator(self), [(PythonRenderer, PythonCompiler)], PythonProgram) diff --git a/tinygrad_repo/tinygrad/runtime/ops_qcom.py b/tinygrad_repo/tinygrad/runtime/ops_qcom.py index bc0857a6..39312cfa 100644 --- a/tinygrad_repo/tinygrad/runtime/ops_qcom.py +++ b/tinygrad_repo/tinygrad/runtime/ops_qcom.py @@ -7,7 +7,7 @@ from tinygrad.device import BufferSpec from tinygrad.runtime.support.hcq import HCQBuffer, HWQueue, HCQProgram, HCQCompiled, HCQAllocatorBase, HCQSignal, HCQArgsState, BumpAllocator from tinygrad.runtime.support.hcq import FileIOInterface, MMIOInterface from tinygrad.runtime.autogen import kgsl, adreno -from tinygrad.runtime.ops_gpu import CLCompiler, CLDevice +from tinygrad.runtime.ops_cl import CLCompiler, CLDevice from tinygrad.renderer.cstyle import QCOMRenderer from tinygrad.helpers import getenv, mv_address, to_mv, round_up, data64_le, prod, fromimport if getenv("IOCTL"): import extra.qcom_gpu_driver.opencl_ioctl # noqa: F401 # pylint: disable=unused-import @@ -321,7 +321,7 @@ class QCOMDevice(HCQCompiled): QCOMDevice.dummy_addr = cast(int, self._gpu_alloc(0x1000).va_addr) flags = kgsl.KGSL_CONTEXT_PREAMBLE | kgsl.KGSL_CONTEXT_PWR_CONSTRAINT | kgsl.KGSL_CONTEXT_NO_FAULT_TOLERANCE | kgsl.KGSL_CONTEXT_NO_GMEM_ALLOC \ - | kgsl.KGSL_CONTEXT_PRIORITY(8) | kgsl.KGSL_CONTEXT_PREEMPT_STYLE(kgsl.KGSL_CONTEXT_PREEMPT_STYLE_FINEGRAIN) + | kgsl.KGSL_CONTEXT_PRIORITY(getenv("QCOM_PRIORITY", 8)) | kgsl.KGSL_CONTEXT_PREEMPT_STYLE(kgsl.KGSL_CONTEXT_PREEMPT_STYLE_FINEGRAIN) self.ctx = kgsl.IOCTL_KGSL_DRAWCTXT_CREATE(self.fd, flags=flags).drawctxt_id self.cmd_buf = self._gpu_alloc(16 << 20) @@ -341,8 +341,8 @@ class QCOMDevice(HCQCompiled): QCOMDevice.gpu_id = ((info.chip_id >> 24) & 0xFF) * 100 + ((info.chip_id >> 16) & 0xFF) * 10 + ((info.chip_id >> 8) & 0xFF) if QCOMDevice.gpu_id >= 700: raise RuntimeError(f"Unsupported GPU: {QCOMDevice.gpu_id}") - super().__init__(device, QCOMAllocator(self), QCOMRenderer(), QCOMCompiler(device), functools.partial(QCOMProgram, self), - QCOMSignal, QCOMComputeQueue, None) + compilers = [(QCOMRenderer, functools.partial(QCOMCompiler, device))] + super().__init__(device, QCOMAllocator(self), compilers, functools.partial(QCOMProgram, self), QCOMSignal, QCOMComputeQueue, None) def _gpu_alloc(self, size:int, flags:int=0, uncached=False, fill_zeroes=False) -> HCQBuffer: flags |= kgsl.KGSL_MEMALIGN(alignment_hint:=12) | kgsl.KGSL_MEMFLAGS_USE_CPU_MAP diff --git a/tinygrad_repo/tinygrad/runtime/ops_remote.py b/tinygrad_repo/tinygrad/runtime/ops_remote.py index 27ca50d9..147063f3 100644 --- a/tinygrad_repo/tinygrad/runtime/ops_remote.py +++ b/tinygrad_repo/tinygrad/runtime/ops_remote.py @@ -100,7 +100,7 @@ class GraphComputeItem: datahash: str bufs: tuple[int, ...] vars: tuple[Variable, ...] - fixedvars: dict[Variable, int] + fixedvars: dict[str, int] ins: tuple[int, ...] outs: tuple[int, ...] global_size: tuple[sint, ...]|None @@ -111,7 +111,7 @@ class GraphAlloc(RemoteRequest): graph_num: int jit_cache: tuple[GraphComputeItem|Transfer, ...] bufs: tuple[tuple[SessionKey, int], ...] - var_vals: dict[Variable, int] + var_vals: dict[str, int] @dataclass(frozen=True) class GraphFree(RemoteRequest): @@ -121,7 +121,7 @@ class GraphFree(RemoteRequest): class GraphExec(RemoteRequest): graph_num: int bufs: tuple[tuple[SessionKey, int], ...] - var_vals: dict[Variable, int] + var_vals: dict[str, int] wait: bool # for safe deserialization @@ -471,10 +471,11 @@ class RemoteDevice(Compiled): if not renderer[0].startswith("tinygrad.") or not renderer[1].endswith("Renderer"): raise RuntimeError(f"bad renderer {renderer}") renderer_class = fromimport(renderer[0], renderer[1]) # TODO: is this secure? if not issubclass(renderer_class, Renderer): raise RuntimeError(f"renderer isn't a Renderer {renderer}") - renderer_instance = renderer_class(*renderer[2]) - renderer_instance.device = device + graph = fromimport('tinygrad.runtime.graph.remote', "RemoteGraph") if self.properties.graph_supported else None - super().__init__(device, RemoteAllocator(self), renderer_instance, Compiler(), functools.partial(RemoteProgram, self), graph, id(self.conn)) + compilers = [(functools.partial(renderer_class, *renderer[2]), Compiler)] + super().__init__(device, RemoteAllocator(self), compilers, functools.partial(RemoteProgram, self), graph, id(self.conn)) + self.renderer.device = device def finalize(self): with contextlib.suppress(ConnectionError, http.client.HTTPException): self.q(SessionFree(), wait=True) diff --git a/tinygrad_repo/tinygrad/runtime/ops_webgpu.py b/tinygrad_repo/tinygrad/runtime/ops_webgpu.py index eef0a5cc..cc070fd2 100644 --- a/tinygrad_repo/tinygrad/runtime/ops_webgpu.py +++ b/tinygrad_repo/tinygrad/runtime/ops_webgpu.py @@ -217,7 +217,7 @@ class WebGpuDevice(Compiled): device_res = _run(webgpu.wgpuAdapterRequestDeviceF, webgpu.WGPURequestDeviceCallbackInfo, webgpu.WGPURequestDeviceCallback, webgpu.WGPURequestDeviceStatus__enumvalues, 1, 2, adapter_res, dev_desc) - super().__init__(device, WebGpuAllocator(device_res), WGSLRenderer(), Compiler(), + super().__init__(device, WebGpuAllocator(device_res), [(WGSLRenderer, Compiler)], functools.partial(WebGPUProgram, (device_res, webgpu.WGPUFeatureName_TimestampQuery in supported))) def synchronize(self): diff --git a/tinygrad_repo/tinygrad/runtime/support/am/amdev.py b/tinygrad_repo/tinygrad/runtime/support/am/amdev.py index fa9ba793..7ec9b4ae 100644 --- a/tinygrad_repo/tinygrad/runtime/support/am/amdev.py +++ b/tinygrad_repo/tinygrad/runtime/support/am/amdev.py @@ -84,7 +84,7 @@ class AMFirmware: self.descs += [self.desc(blob, hdr0.header.ucode_array_offset_bytes, hdr0.header.ucode_size_bytes, am.GFX_FW_TYPE_RLC_G)] def load_fw(self, fname:str, *headers, versioned_header:str|None=None): - fpath = fetch(f"https://gitlab.com/kernel-firmware/linux-firmware/-/raw/45f59212aebd226c7630aff4b58598967c0c8c91/amdgpu/{fname}", subdir="fw") + fpath = fetch(f"https://gitlab.com/kernel-firmware/linux-firmware/-/raw/a9f26799247aa60fbaa3b64267a18f20b72b5235/amdgpu/{fname}", subdir="fw") blob = memoryview(bytearray(fpath.read_bytes())) if AM_DEBUG >= 1: print(f"am {self.adev.devfmt}: loading firmware {fname}: {hashlib.sha256(blob).hexdigest()}") if versioned_header: diff --git a/tinygrad_repo/tinygrad/runtime/support/am/ip.py b/tinygrad_repo/tinygrad/runtime/support/am/ip.py index e702e3e2..e6ff7a24 100644 --- a/tinygrad_repo/tinygrad/runtime/support/am/ip.py +++ b/tinygrad_repo/tinygrad/runtime/support/am/ip.py @@ -224,7 +224,8 @@ class AM_GFX(AM_IP): self._grbm_select() self.adev.regGCVM_CONTEXT0_CNTL.write(0) - def setup_ring(self, ring_addr:int, ring_size:int, rptr_addr:int, wptr_addr:int, eop_addr:int, eop_size:int, doorbell:int, pipe:int, queue:int): + def setup_ring(self, ring_addr:int, ring_size:int, rptr_addr:int, wptr_addr:int, eop_addr:int, eop_size:int, doorbell:int, pipe:int, queue:int, + aql:bool): mqd = self.adev.mm.valloc(0x1000, uncached=True, contiguous=True) struct_t = getattr(am, f"struct_v{self.adev.ip_ver[am.GC_HWIP][0]}_compute_mqd") @@ -235,9 +236,10 @@ class AM_GFX(AM_IP): cp_hqd_pq_rptr_report_addr_lo=lo32(rptr_addr), cp_hqd_pq_rptr_report_addr_hi=hi32(rptr_addr), cp_hqd_pq_wptr_poll_addr_lo=lo32(wptr_addr), cp_hqd_pq_wptr_poll_addr_hi=hi32(wptr_addr), cp_hqd_pq_doorbell_control=self.adev.regCP_HQD_PQ_DOORBELL_CONTROL.encode(doorbell_offset=doorbell*2, doorbell_en=1), - cp_hqd_pq_control=self.adev.regCP_HQD_PQ_CONTROL.encode(rptr_block_size=5, unord_dispatch=0, queue_size=(ring_size//4).bit_length()-2), + cp_hqd_pq_control=self.adev.regCP_HQD_PQ_CONTROL.encode(rptr_block_size=5, unord_dispatch=0, queue_size=(ring_size//4).bit_length()-2, + **({'queue_full_en':1, 'slot_based_wptr':2, 'no_update_rptr':1} if aql else {})), cp_hqd_ib_control=self.adev.regCP_HQD_IB_CONTROL.encode(min_ib_avail_size=0x3), cp_hqd_hq_status0=0x20004000, - cp_mqd_control=self.adev.regCP_MQD_CONTROL.encode(priv_state=1), cp_hqd_vmid=0, + cp_mqd_control=self.adev.regCP_MQD_CONTROL.encode(priv_state=1), cp_hqd_vmid=0, cp_hqd_aql_control=int(aql), cp_hqd_eop_base_addr_lo=lo32(eop_addr>>8), cp_hqd_eop_base_addr_hi=hi32(eop_addr>>8), cp_hqd_eop_control=self.adev.regCP_HQD_EOP_CONTROL.encode(eop_size=(eop_size//4).bit_length()-2)) diff --git a/tinygrad_repo/tinygrad/runtime/support/compiler_amd.py b/tinygrad_repo/tinygrad/runtime/support/compiler_amd.py index fd00ea03..88608c71 100644 --- a/tinygrad_repo/tinygrad/runtime/support/compiler_amd.py +++ b/tinygrad_repo/tinygrad/runtime/support/compiler_amd.py @@ -9,7 +9,7 @@ try: assert comgr.AMD_COMGR_LANGUAGE_HIP == 3 except AttributeError: pass # ignore if ROCm isn't installed from tinygrad.device import Compiler, CompileError -from tinygrad.runtime.ops_llvm import LLVMCompiler +from tinygrad.runtime.support.compiler_cpu import LLVMCompiler from tinygrad.helpers import OSX, to_char_p_p def amdgpu_disassemble(lib:bytes): diff --git a/tinygrad_repo/tinygrad/runtime/ops_llvm.py b/tinygrad_repo/tinygrad/runtime/support/compiler_cpu.py similarity index 73% rename from tinygrad_repo/tinygrad/runtime/ops_llvm.py rename to tinygrad_repo/tinygrad/runtime/support/compiler_cpu.py index e5fabb84..f9ec8d10 100644 --- a/tinygrad_repo/tinygrad/runtime/ops_llvm.py +++ b/tinygrad_repo/tinygrad/runtime/support/compiler_cpu.py @@ -1,11 +1,25 @@ -import ctypes, platform, functools, queue +import ctypes, platform, sys, subprocess from tinygrad.device import Compiler -from tinygrad.runtime.support.hcq import HCQCompiled, HCQSignal -from tinygrad.runtime.ops_cpu import CPUAllocator, CPUProgram, CPUComputeQueue, CPUWorker from tinygrad.helpers import OSX, getenv, capstone_flatdump, DEBUG -from tinygrad.renderer.llvmir import LLVMRenderer -import tinygrad.runtime.autogen.llvm as llvm from tinygrad.runtime.support.elf import jit_loader +try: import tinygrad.runtime.autogen.llvm as llvm +except (ImportError, FileNotFoundError): llvm = None #type:ignore[assignment] + +class ClangJITCompiler(Compiler): + def __init__(self, cachekey="compile_clang_jit"): super().__init__(cachekey) + + def compile(self, src:str) -> bytes: + # -fno-math-errno is required for __builtin_sqrt to become an instruction instead of a function call + # x18 is a reserved platform register. It is clobbered on context switch in macos and is used to store TEB pointer in windows on arm, don't use it + target = 'x86_64' if sys.platform == 'win32' else platform.machine() + # on arm march means "runs on this arch and superset" instead of "optimize for this arch". x86 march == arm mcpu + arch = {'x86_64': '-march=native', 'AMD64': '-march=native', 'riscv64': '-march=rv64g'}.get(platform.machine(), "-mcpu=native") + args = [arch, f'--target={target}-none-unknown-elf', '-O2', '-fPIC', '-ffreestanding', '-fno-math-errno', '-nostdlib', '-fno-ident'] + arch_args = ['-ffixed-x18'] if target == 'arm64' else [] + obj = subprocess.check_output([getenv("CC", 'clang'), '-c', '-x', 'c', *args, *arch_args, '-', '-o', '-'], input=src.encode('utf-8')) + return jit_loader(obj) + + def disassemble(self, lib:bytes): return capstone_flatdump(lib) def cerr(): return ctypes.pointer(ctypes.pointer(ctypes.c_char())) @@ -15,7 +29,7 @@ def expect(x, err, ret=None): class LLVMCompiler(Compiler): jit = True - target_arch = {'arm64': 'AArch64', 'aarch64': 'AArch64', 'x86_64': 'X86', 'AMD64': 'X86'}[platform.machine()] + target_arch = {'arm64': 'AArch64', 'aarch64': 'AArch64', 'x86_64': 'X86', 'AMD64': 'X86', 'riscv64': 'riscv64'}[platform.machine()] def __init__(self, processor:str, feats:str): for component in ['Target', 'TargetInfo', 'TargetMC', 'AsmParser', 'AsmPrinter']: getattr(llvm, f'LLVMInitialize{self.target_arch}{component}')() @@ -65,14 +79,8 @@ class LLVMCompiler(Compiler): def disassemble(self, lib:bytes): capstone_flatdump(lib) -class HostLLVMCompiler(LLVMCompiler): +class CPULLVMCompiler(LLVMCompiler): def __init__(self): # +reserve-x18 here does the same thing as -ffixed-x18 in ops_cpu.py, see comments there for why it's needed on arm osx cpu, feats = ctypes.string_at(llvm.LLVMGetHostCPUName()), (b'+reserve-x18,' if OSX else b'') + ctypes.string_at(llvm.LLVMGetHostCPUFeatures()) super().__init__(cpu.decode(), feats.decode()) - -class LLVMDevice(HCQCompiled): - def __init__(self, device:str=""): - self.tasks:queue.Queue = queue.Queue() - CPUWorker(self).start() - super().__init__(device, CPUAllocator(self), LLVMRenderer(), HostLLVMCompiler(), functools.partial(CPUProgram, self), HCQSignal, CPUComputeQueue) diff --git a/tinygrad_repo/tinygrad/runtime/support/compiler_cuda.py b/tinygrad_repo/tinygrad/runtime/support/compiler_cuda.py index 1b27d026..e10249ed 100644 --- a/tinygrad_repo/tinygrad/runtime/support/compiler_cuda.py +++ b/tinygrad_repo/tinygrad/runtime/support/compiler_cuda.py @@ -4,7 +4,7 @@ from tinygrad.helpers import to_char_p_p, colored, init_c_var, getenv import tinygrad.runtime.autogen.nvrtc as nvrtc from tinygrad.device import Compiler, CompileError -PTX, CUDA_PATH = getenv("PTX"), getenv("CUDA_PATH", "") # PTX shouldn't be here, in fact, it shouldn't exist +CUDA_PATH = getenv("CUDA_PATH", "") def _get_bytes(arg, get_str, get_sz, check) -> bytes: sz = init_c_var(ctypes.c_size_t(), lambda x: check(get_sz(arg, ctypes.byref(x)))) diff --git a/tinygrad_repo/tinygrad/runtime/support/hcq.py b/tinygrad_repo/tinygrad/runtime/support/hcq.py index 5759081d..82823be6 100644 --- a/tinygrad_repo/tinygrad/runtime/support/hcq.py +++ b/tinygrad_repo/tinygrad/runtime/support/hcq.py @@ -1,12 +1,11 @@ from __future__ import annotations -from typing import cast, Callable, Type, TypeVar, Generic, Any +from typing import cast, Callable, Type, TypeVar, Generic, Any, Sequence import contextlib, decimal, statistics, time, ctypes, array, os, struct, traceback, collections try: import fcntl # windows misses that except ImportError: fcntl = None #type:ignore[assignment] from tinygrad.helpers import PROFILE, getenv, to_mv, round_up, ProfileRangeEvent -from tinygrad.renderer import Renderer -from tinygrad.device import BufferSpec, Compiler, Compiled, LRUAllocator, ProfileDeviceEvent, ProfileProgramEvent -from tinygrad.uop.ops import sym_infer, sint, Variable, UOp +from tinygrad.device import BufferSpec, Compiled, LRUAllocator, ProfileDeviceEvent, ProfileProgramEvent, CompilerPairT +from tinygrad.uop.ops import sym_infer, sint, UOp from tinygrad.runtime.autogen import libc class MMIOInterface: @@ -28,10 +27,7 @@ class FileIOInterface: def __del__(self): if hasattr(self, 'fd'): os.close(self.fd) def ioctl(self, request, arg): return fcntl.ioctl(self.fd, request, arg) - def mmap(self, start, sz, prot, flags, offset): - x = libc.mmap(start, sz, prot, flags, self.fd, offset) - if x == 0xffffffffffffffff: raise OSError(f"Failed to mmap {sz} bytes at {hex(start)}: {os.strerror(ctypes.get_errno())}") - return x + def mmap(self, start, sz, prot, flags, offset): return FileIOInterface._mmap(start, sz, prot, flags, self.fd, offset) def read(self, size=None, binary=False, offset=None): if offset is not None: self.seek(offset) with open(self.fd, "rb" if binary else "r", closefd=False) as file: return file.read(size) @@ -41,11 +37,13 @@ class FileIOInterface: def listdir(self): return os.listdir(self.path) def seek(self, offset): os.lseek(self.fd, offset, os.SEEK_SET) @staticmethod - def anon_mmap(start, sz, prot, flags, offset): - x = libc.mmap(start, sz, prot, flags, -1, offset) + def _mmap(start, sz, prot, flags, fd, offset): + x = libc.mmap(start, sz, prot, flags, fd, offset) if x == 0xffffffffffffffff: raise OSError(f"Failed to mmap {sz} bytes at {hex(start)}: {os.strerror(ctypes.get_errno())}") return x @staticmethod + def anon_mmap(start, sz, prot, flags, offset): return FileIOInterface._mmap(start, sz, prot, flags, -1, offset) + @staticmethod def munmap(buf, sz): return libc.munmap(buf, sz) @staticmethod def exists(path): return os.path.exists(path) @@ -192,7 +190,7 @@ class HWQueue(Generic[SignalType, HCQDeviceType, ProgramType, ArgsStateType]): if isinstance(val, int): mv[i] = val if mask is None else ((mv[i] & ~mask) | val) else: self.mv_sints.append((mv, i, self._new_sym(val), mask)) - def _apply_var_vals(self, var_vals:dict[Variable, int]): + def _apply_var_vals(self, var_vals:dict[str, int]): resolved_syms = [sym_infer(sym, var_vals) for sym in self.syms] for off, sym_idx in self.q_sints: @@ -205,7 +203,7 @@ class HWQueue(Generic[SignalType, HCQDeviceType, ProgramType, ArgsStateType]): self._prev_resolved_syms = cast(list[int|None], resolved_syms) - def submit(self, dev:HCQDeviceType, var_vals:dict[Variable, int]|None=None): + def submit(self, dev:HCQDeviceType, var_vals:dict[str, int]|None=None): """ Submits the command queue to a specific device for execution. @@ -360,12 +358,12 @@ class HCQCompiled(Compiled, Generic[SignalType]): signal_pool: dict[str, list[HCQBuffer]] = collections.defaultdict(list) # per peer group cpu_devices: list[HCQCompiled] = [] - def __init__(self, device:str, allocator:HCQAllocatorBase, renderer:Renderer, compiler:Compiler, runtime, signal_t:Type[SignalType], + def __init__(self, device:str, allocator:HCQAllocatorBase, compilers:Sequence[CompilerPairT], runtime, signal_t:Type[SignalType], comp_queue_t:Callable[[], HWQueue], copy_queue_t:Callable[[], HWQueue]|None=None, kernargs_size=(16 << 20), sigalloc_size=0x1000): self.device_id:int = int(device.split(":")[1]) if ":" in device else 0 from tinygrad.runtime.graph.hcq import HCQGraph - super().__init__(device, allocator, renderer, compiler, runtime, HCQGraph) + super().__init__(device, allocator, compilers, runtime, HCQGraph) # TODO: peer logic is determined based on device name. self.peer_group = device.split(":")[0] @@ -383,15 +381,20 @@ class HCQCompiled(Compiled, Generic[SignalType]): self.kernargs_buf:HCQBuffer = self.allocator.alloc(kernargs_size, BufferSpec(cpu_access=True)) self.kernargs_offset_allocator:BumpAllocator = BumpAllocator(self.kernargs_buf.size, wrap=True) + self.error_state:Exception|None = None # Exception if error is unrecoverable and sync will always fail + if self._is_cpu(): HCQCompiled.cpu_devices.append(self) def synchronize(self): + if self.error_state is not None: raise self.error_state + # If we have any work on CPU devices, need to synchronize them. This is just an optimization to release GIL allowing to finish faster. if not self._is_cpu(): for dev in HCQCompiled.cpu_devices: dev.synchronize() try: self.timeline_signal.wait(self.timeline_value - 1) except RuntimeError as e: + self.error_state = e if hasattr(self, 'on_device_hang'): self.on_device_hang() else: raise e @@ -437,16 +440,21 @@ class HCQCompiled(Compiled, Generic[SignalType]): except MemoryError: buf, realloced = self.allocator.alloc(oldbuf.size if oldbuf is not None else new_size, options=options), False return buf, realloced + def _make_no_iface_error(self, errs:str, err_short:str) -> RuntimeError: + # Keep it in a separate function to avoid creating a traceback <-> locals ref cycle + e = RuntimeError(f"No interface for {type(self).__name__[:-6]}:{self.device_id} is available") + if hasattr(e, "add_note"): e.add_note(errs + err_short) + return e + def _select_iface(self, *ifaces:Type): errs, err_short = "", "" if val:=getenv(f'{type(self).__name__[:-6].upper()}_IFACE', ""): ifaces = tuple(x for x in ifaces if x.__name__.startswith(val.upper())) for iface_t in ifaces: try: return iface_t(self, self.device_id) - except Exception as e: errs, err_short = errs + f"\n{iface_t.__name__}: {traceback.format_exc()}", err_short + f"\n{iface_t.__name__}: {e}" - raise RuntimeError(f"{errs}\nNo interface for {type(self).__name__[:-6]}:{self.device_id} is available:{err_short}\n" \ - f"\nForce an interface with {type(self).__name__[:-6].upper()}_IFACE={('|'.join(x.__name__[:-5] for x in ifaces))}.") + except Exception as e: errs, err_short = errs + f"\n{iface_t.__name__}: {traceback.format_exc()}", err_short + f"\n{iface_t.__name__}: {e}." + raise self._make_no_iface_error(errs, err_short) - def _is_cpu(self) -> bool: return hasattr(self, 'device') and self.device.split(":")[0] in ("CPU", "LLVM") + def _is_cpu(self) -> bool: return hasattr(self, 'device') and self.device.split(":")[0] == "CPU" def finalize(self): try: self.synchronize() # Try to finalize device in any case. diff --git a/tinygrad_repo/tinygrad/runtime/support/memory.py b/tinygrad_repo/tinygrad/runtime/support/memory.py index 96db4525..0c74ea41 100644 --- a/tinygrad_repo/tinygrad/runtime/support/memory.py +++ b/tinygrad_repo/tinygrad/runtime/support/memory.py @@ -77,11 +77,10 @@ class TLSFAllocator: if self.lv1_entries[l1] == 0: continue for l2 in range(self.lv2(size) if l1 == size.bit_length() else 0, (1 << self.l2_cnt)): if len(self.storage[l1][l2]) > 0: - nsize = self.blocks[self.storage[l1][l2][0]][0] - assert nsize >= size, "block must be larger" - # Block start address. start = self.storage[l1][l2][0] + nsize = self.blocks[start][0] + assert nsize >= size, "block must be larger" # If request contains alignment, split the block into two parts. if (new_start:=round_up(start, align)) != start: diff --git a/tinygrad_repo/tinygrad/runtime/support/nv/nvdev.py b/tinygrad_repo/tinygrad/runtime/support/nv/nvdev.py index 455e4fe1..6831b5e8 100644 --- a/tinygrad_repo/tinygrad/runtime/support/nv/nvdev.py +++ b/tinygrad_repo/tinygrad/runtime/support/nv/nvdev.py @@ -118,7 +118,7 @@ class NVDev(PCIDevImplBase): self.include("src/common/inc/swref/published/turing/tu102/dev_fb.h") if self.reg("NV_PFB_PRI_MMU_WPR2_ADDR_HI").read() != 0: - if DEBUG >= 2: print(f"nv {self.devfmt}: WPR2 is up. Issuing a full reset.") + if DEBUG >= 2: print(f"nv {self.devfmt}: WPR2 is up. Issuing a full reset.", flush=True) System.pci_reset(self.devfmt) time.sleep(0.5) diff --git a/tinygrad_repo/tinygrad/runtime/support/system.py b/tinygrad_repo/tinygrad/runtime/support/system.py index 62388661..66b2f786 100644 --- a/tinygrad_repo/tinygrad/runtime/support/system.py +++ b/tinygrad_repo/tinygrad/runtime/support/system.py @@ -165,7 +165,7 @@ class PCIIfaceBase: def map(self, b:HCQBuffer): if b.owner is not None and b.owner._is_cpu(): System.lock_memory(cast(int, b.va_addr), b.size) - paddrs, snooped, uncached = [(x, 0x1000) for x in System.system_paddrs(cast(int, b.va_addr), round_up(b.size, 0x1000))], True, False + paddrs, snooped, uncached = [(x, 0x1000) for x in System.system_paddrs(cast(int, b.va_addr), round_up(b.size, 0x1000))], True, True elif (ifa:=getattr(b.owner, "iface", None)) is not None and isinstance(ifa, PCIIfaceBase): paddrs = [(paddr if b.meta.mapping.system else (paddr + ifa.p2p_base_addr), size) for paddr,size in b.meta.mapping.paddrs] snooped, uncached = b.meta.mapping.snooped, b.meta.mapping.uncached diff --git a/tinygrad_repo/tinygrad/schedule/kernelize.py b/tinygrad_repo/tinygrad/schedule/kernelize.py index d02ec8fc..ec2f7c04 100644 --- a/tinygrad_repo/tinygrad/schedule/kernelize.py +++ b/tinygrad_repo/tinygrad/schedule/kernelize.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, identity_element, resolve -from tinygrad.uop.ops import track_rewrites, _substitute +from tinygrad.uop.ops import track_rewrites, _substitute, KernelInfo from tinygrad.uop.spec import type_verify, tensor_uop_spec from tinygrad.uop.symbolic import symbolic_simple from tinygrad.helpers import Metadata, all_int, all_same, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP @@ -8,6 +8,7 @@ from tinygrad.dtype import ImageDType from tinygrad.schedule.multi import multi_pm from tinygrad.schedule.grouper import group_realizes, ALWAYS_CONTIGUOUS from tinygrad.codegen.opt.swizzler import merge_views, apply_swizzle, swizzle_reduceop +from tinygrad.codegen.opt import Opt # creation can recurse a lot import sys @@ -119,7 +120,8 @@ def create_kernel(x:UOp, b:UOp|None=None): if b is None: b = UOp.new_buffer(x.device, x.size, x.dtype) kernel = UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x.sink(), m if (m:=x.metadata) else ())) buffer = b.base if b.size == b.base.size else UOp(Ops.BUFFER_VIEW, b.dtype, (b.base,), (b.size, b.arg.views[0].offset)) - return buffer.assign(kernel).reshape(x.shape) + # we have to shrink the buffer back to the symbolic shape + return buffer.assign(kernel).reshape(tuple(d.vmax if isinstance(d, UOp) else d for d in x.shape)).shrink(tuple((0, d) for d in x.shape)) DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER, Ops.MSELECT, Ops.MSTACK, Ops.MULTI, Ops.BIND} def append_to_kernel(x:UOp): @@ -147,6 +149,16 @@ create_kernels = PatternMatcher([ lambda ms: UOp(Ops.MSTACK, ms.dtype, tuple(x.src[0] for x in ms.src)).reshape(ms.src[0].arg)), ]) +def add_stores(ctx, sink: UOp): + stores = [] + for i,x in enumerate(sink.src): + gbl = UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(ctx[i].size), (), i) + # if this is an assign then we already have a buffer with a view that should be the target of the store + if x.op is Ops.ASSIGN: stores.append(UOp.store(gbl.view(unwrap(s.st)), s)) + # otherwise we have to create the shapetracker and shrink it to the correct symbolic shape + else: stores.append( + UOp.store(gbl.reshape(tuple(int(d.vmax) if isinstance(d,UOp) else d for d in s.shape)).shrink(tuple((0,d) for d in s.shape)),s)) + return UOp.sink(*stores, arg=sink.arg) # **** fix kernel AST def unbind_view(x:UOp): @@ -154,6 +166,10 @@ def unbind_view(x:UOp): return None replace_buffers = PatternMatcher([ + # sink on contig creates a KernelInfo + (UPat(Ops.CONTIGUOUS, name="c").sink(name="s"), + lambda s,c: s.replace(src=(c.replace(arg=None),), arg=KernelInfo(opts_to_apply=c.arg)) \ + if s.arg is None and c.arg is not None and isinstance(c.arg[0], Opt) else None), # replace ASSIGN with the target BUFFER (UPat(Ops.ASSIGN, src=(UPat((Ops.BUFFER, Ops.LOAD)), UPat(Ops.KERNEL)), name="assign", allow_any_len=True), lambda assign: assign.src[0]), # HACK: select the 0 branch of MSTACK (the device is wrong after this, is that okay?) @@ -163,9 +179,7 @@ replace_buffers = PatternMatcher([ # no SINK for meta ops (UPat(Ops.SINK, src=(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Meta, name="x"),),))), lambda x:x), # STORE (except for meta ops) - (UPat(Ops.SINK, src=UPat(GroupOp.All-{Ops.STORE}), name="sink"), lambda ctx,sink: - UOp.sink(*[UOp.store(UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(ctx[i].size), (), i).view(s.st), s) for i,x in enumerate(sink.src)], - arg=sink.arg)), + (UPat(Ops.SINK, src=UPat(GroupOp.All-{Ops.STORE}), name="sink"), add_stores), # remove CONTIGUOUS/DEVICE from kernel AST (UPat((Ops.CONTIGUOUS, Ops.MSELECT), src=(UPat.var("x"),)), lambda x: x), (UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())), diff --git a/tinygrad_repo/tinygrad/schedule/rangeify.py b/tinygrad_repo/tinygrad/schedule/rangeify.py index aa3f10a4..995c5477 100644 --- a/tinygrad_repo/tinygrad/schedule/rangeify.py +++ b/tinygrad_repo/tinygrad/schedule/rangeify.py @@ -1,13 +1,16 @@ -from typing import Any +from typing import Any, cast +import functools, operator from dataclasses import dataclass, field -from tinygrad.dtype import dtypes, PtrDType, AddrSpace -from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute -from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, colored, RANGEIFY +from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace +from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute, ssimplify, graph_rewrite_map +from tinygrad.uop.symbolic import sym, symbolic_simple +from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, RANGEIFY, Context, flatten, dedup from tinygrad.schedule.multi import multi_pm from tinygrad.schedule.kernelize import Kernel -from tinygrad.uop.ops import track_rewrites, graph_rewrite_map, graph_rewrite, KernelInfo, identity_element, sint, AxisType +from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType +# ***************** # 0. do some cleanup rewrites, mostly copied from the old stuff double_reshape = PatternMatcher([ @@ -16,31 +19,40 @@ double_reshape = PatternMatcher([ ]) earliest_rewrites = double_reshape+PatternMatcher([ - # UOp with size 0 is zero - (UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 else None), + # non shape changing RESHAPE is NOOP + #(UPat(Ops.RESHAPE, name="x"), lambda x: x.src[0] if x.src[0].shape == x.arg else None), # DETACH and CONTIGUOUS_BACKWARD are NOOPs here, so is FUSE + #(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE), name="x"), lambda x: x.src[0].f(Ops.NOOP, tag=x.tag)), + + # just removing it works... (UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE), name="x"), lambda x: x.src[0]), + + # preserve tags? # reduce of size 0 is the identity element (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None), - # non shape changing RESHAPE is NOOP - (UPat(Ops.RESHAPE, name="x"), lambda x: x.src[0] if x.src[0].shape == x.arg else None), - # RESHAPE after COPY - (UPat(Ops.COPY, src=(UPat(Ops.RESHAPE, name="r"),UPat(name="d")), name="c"), lambda c,r,d: c.replace(src=(r.src[0],d)).reshape(r.arg)), - # TODO: this should be BUFFER_VIEW - (UPat(Ops.COPY, src=(UPat(Ops.SHRINK, name="r"),UPat(name="d")), name="c"), lambda c,r,d: c.replace(src=(r.src[0],d)).shrink(r.arg)), - # const hacks - (UPat(Ops.CONST, name="x"), lambda x: - x.replace(src=(x.src[0].src[0],)).reshape((1,)*len(x.shape)).expand(x.shape) if \ - len(x.src) and x.src[0].op is Ops.VIEW and not any(s == 0 for s in x.shape) else None), + + # COPY and source size need to match + # TODO: expand after copy creates issues with tagging + (UPat(Ops.COPY, src=(UPat(GroupOp.Movement, name="r"), UPat(name="d")), name="c"), + lambda c,r,d: c.replace(src=(r.contiguous(), d)) if r.size != r.base.size else None), + # assign only to buffer - (UPat(Ops.ASSIGN, src=(UPat(GroupOp.All-{Ops.BUFFER}, name="target"), UPat(name="x"))), - lambda x,target: x if target.base.op is not Ops.BUFFER else None), + (UPat(Ops.ASSIGN, src=(UPat(GroupOp.All-{Ops.BUFFER}, name="target"), UPat(name="x")), name="assign"), + lambda x,target,assign: x.f(Ops.NOOP, tag=assign.tag) if target.base.op is not Ops.BUFFER else None), + + # handle disk + # TODO: this doesn't need to use st.views + (UPat.var("x").f((Ops.BITCAST, Ops.CONTIGUOUS), name="t"), + lambda x,t: UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,), (t.size, x.st.views[0].offset), tag=t.tag).reshape(t.shape) if isinstance(x.device, str) \ + and x.device.startswith("DISK") else None), + # contiguous/buffer/copy/assign is already contiguous - (UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER, Ops.COPY, Ops.ASSIGN)),)), lambda root: root.src[0]), + #(UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER, Ops.COPY, Ops.ASSIGN)),)), lambda root: root.src[0]), ]) -# 1. add contiguous where we have to +# ***************** +# 1. add realize where we have to ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW, Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_GLOBAL, @@ -50,7 +62,7 @@ def realize(ctx:dict[UOp, None], tr:UOp) -> None: ctx[tr] = None def realize_parents(ctx:dict[UOp, None], rb:UOp) -> None: for s in rb.src: - if s.op not in ALWAYS_CONTIGUOUS: ctx[s] = None + if s.base.op not in ALWAYS_CONTIGUOUS: ctx[s] = None def realize_assign(ctx:dict[UOp, None], a:UOp) -> None: if a.src[1].op not in ALWAYS_CONTIGUOUS: ctx[a.src[1]] = None @@ -58,19 +70,25 @@ def realize_assign(ctx:dict[UOp, None], a:UOp) -> None: do_realize = PatternMatcher([ # always realize SINK parents (UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)), - # always realize ASSIGN/COPY/BUFFER_VIEW - (UPat({Ops.ASSIGN, Ops.COPY, Ops.BUFFER_VIEW}, name="tr"), realize), + # always realize ASSIGN/COPY/BUFFER_VIEW/CONTIGUOUS + (UPat({Ops.ASSIGN, Ops.COPY, Ops.BUFFER_VIEW, Ops.CONTIGUOUS}, name="tr"), realize), # realize parents of COPY, MSELECT, MSTACK (UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_parents), # realize input to assign (might be optimized out) (UPat(Ops.ASSIGN, name="a"), realize_assign), ]) -add_contiguous = PatternMatcher([ - (UPat(GroupOp.All-{Ops.CONTIGUOUS}, name="x"), lambda ctx,x: x.replace(tag=1).contiguous() if x in ctx and x.tag is None else None), -]) -remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)]) +class WrappedContig: + def __init__(self, x): self.x = x + def __repr__(self): return f"C({self.x})" +add_contiguous = PatternMatcher([ + (UPat(GroupOp.All, name="x"), + lambda ctx,x: x.replace(tag=WrappedContig(x.tag)).realize() if x in ctx and not isinstance(x.tag, WrappedContig) else None), +]) +remove_contig_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=x.tag.x) if isinstance(x.tag, WrappedContig) else None)]) + +# ***************** # 2. mark all children @dataclass @@ -94,10 +112,11 @@ def mark_children(ctx:ChildrenContext, x:UOp): pm_children = PatternMatcher([ (UPat(Ops.SINK, name="x"), extract_children), - (UPat(GroupOp.All-{Ops.CHILD, Ops.CHILDREN}, name="x"), mark_children), + (UPat(GroupOp.All-{Ops.CHILD, Ops.CHILDREN, Ops.SINK}, name="x"), mark_children), ]) -# 3. rangeify +# ***************** +# 3a. rangeify (movement) @dataclass class RangeifyContext: @@ -109,7 +128,7 @@ class RangeifyContext: # create ranges range_idx: int = 0 def new_range(self, s:sint, axistype:AxisType=AxisType.LOOP): - ret = UOp.range(dtypes.int, s, self.range_idx, axistype) + ret = UOp.range(s, self.range_idx, axistype) self.range_idx += 1 return ret @@ -119,7 +138,7 @@ def map_reshape(idx:UOp, r:UOp): for s,src in list(zip(idx.shape, idx.src[1:]))[::-1]: to_sum.append(acc*src) acc *= s - mish = sum(to_sum, start=UOp.const(dtypes.int, 0)) + mish = sum(to_sum, start=UOp.const(dtypes.index, 0)) ret:list[UOp] = [] for s in r.src[0].shape[::-1]: ret.append(mish % s) # NOTE: simplify will turn this to CONST @@ -136,9 +155,8 @@ def map_pad(idx:UOp, r:UOp): if resolve(e > 0): where = where & (ret[i] < (sh-e)) if resolve(s > 0): where = where & (ret[i] >= s) bigwhere = bigwhere & where - # this is safe but dumb - # TODO (S-Lykles): switch to mixed index/valid - ret[i] = (ret[i] - s).maximum(0).minimum(r.src[0].shape[i]-1) + with Context(TRACK_MATCH_STATS=0): + ret[i] = graph_rewrite(where.where(ret[i]-s, UOp.invalid()), sym) # PAD is with 0 return bigwhere.simplify().where(r.src[0].index(*ret, dtype=idx.dtype, arg=idx.arg), UOp.const(r.dtype, 0)) @@ -148,12 +166,12 @@ def map_expand(r:UOp, idx:UOp): non_ending_ranges = [] for a,x,y in zip(idx.src[1:], r.src[0].shape, r.shape): axis_to_range = [u for u in a.toposort() if u.op is Ops.RANGE] - if resolve(x!=y, False): - ending_ranges.extend(axis_to_range) - new_rngs.append(a.const_like(0)) - else: + if resolve(x==y, False): non_ending_ranges.extend(axis_to_range) new_rngs.append(a) + else: + ending_ranges.extend(axis_to_range) + new_rngs.append(a.const_like(0)) ending_ranges = [x.arg for x in ending_ranges if x not in non_ending_ranges] if idx.arg is not None: ending_ranges.append(idx.arg) return r.src[0].index(*new_rngs, arg=min(ending_ranges) if ending_ranges else None) @@ -174,7 +192,21 @@ pm_mops = PatternMatcher([ (UPat(Ops.PAD, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"), map_pad), ]) -def map_partial_contiguous(ctx:RangeifyContext, x:UOp, idx:UOp): +# ***************** +# 3b. rangeify (ops) + +# bufferization can happen in three ways +# 1. there's an explicit REALIZE in the graph +# 2. the ranges from the children don't match and we have to create a buffer (only on children) +# 3. might_end_axis triggers because we should be closing a loop to save compute + +@dataclass(frozen=True) +class BufferizeOpts: + # on AddrSpace.LOCAL, device is the id + device: str|tuple[str, ...]|int|None + addrspace: AddrSpace = AddrSpace.GLOBAL + +def map_partial_realize(ctx:RangeifyContext, x:UOp, idx:UOp): if x.arg is None: return None # map_contiguous can handle this # NOTE: all partial contiguous can safely be replaced by full contiguous. we should be able to match old functionality like this if not (RANGEIFY > 1): return idx.replace(src=(x.replace(arg=None),)+idx.src[1:]) @@ -186,17 +218,17 @@ def map_partial_contiguous(ctx:RangeifyContext, x:UOp, idx:UOp): ranges.append(idx.src[1+i]) continue passthrough_idx.append(idx.src[1+i]) - ranges.append(ctx.new_range(s) if resolve(s!=1) else UOp.const(dtypes.int, 0)) + ranges.append(ctx.new_range(s)) new_ranges.append(ranges[-1]) - ret = x.src[0].index(*ranges).bufferize(*[x for x in new_ranges if x.op is not Ops.CONST], arg=x.device) + # TODO: this should be able to be global or local + ret = x.src[0].index(*ranges).bufferize(*[x for x in new_ranges if x.op is not Ops.CONST], + arg=BufferizeOpts(device=None, addrspace=AddrSpace.LOCAL)) return ret.index(*passthrough_idx) -def map_contiguous(ctx:RangeifyContext, x:UOp): +def map_realize(ctx:RangeifyContext, x:UOp): if x.arg is not None: return None - ranges = [] - for s in x.shape[len(x.src)-1:]: - ranges.append(ctx.new_range(s) if resolve(s!=1) else UOp.const(dtypes.int, 0)) - return x.src[0].index(*ranges).bufferize(*x.src[1:], *[x for x in ranges if x.op is not Ops.CONST], arg=x.device).forced_reshape(x.shape) + ranges = [ctx.new_range(s) for s in x.shape] + return x.src[0].index(*ranges).bufferize(*x.src[1:], *ranges, arg=BufferizeOpts(device=x.device), tag=x.src[0].tag) def map_reduce(ctx:RangeifyContext, idx:UOp, red:UOp): rngs = list(idx.src[1:]) @@ -205,7 +237,7 @@ def map_reduce(ctx:RangeifyContext, idx:UOp, red:UOp): if i in red.arg[1]: rngs[i] = ctx.new_range(s, axistype=AxisType.REDUCE) new_ranges.append(rngs[i]) - return UOp(Ops.REDUCE, red.dtype, src=(red.src[0].index(*rngs),)+tuple(new_ranges), arg=red.arg[0]) + return UOp(Ops.REDUCE, red.dtype, src=(red.src[0].index(*rngs),)+tuple(new_ranges), arg=red.arg[0], tag=red.tag) def index_child(ctx:RangeifyContext, c:UOp, x:UOp, idx:UOp): if c not in ctx.seen_children: ctx.seen_children[c] = {} @@ -219,26 +251,38 @@ def index_child(ctx:RangeifyContext, c:UOp, x:UOp, idx:UOp): ctx.progress = 0 if c not in ctx.seen_child: - all_rngs = zip(*[ch.src[1:] for ch in ctx.seen_children[c].values()]) + all_rngs = list(zip(*[ch.src[1:] for ch in ctx.seen_children[c].values()])) out_rngs = [] end_ranges = [] idx_ranges = [] - for i,r in enumerate(all_rngs): - if all_same(r): - out_rngs.append(r[0]) + # NOTE: locals aren't working, so we only fully bufferize here (unless RANGEIFY > 1) + all_all_same = all(all_same(r) for r in all_rngs) + for i,valid_rngs in enumerate(all_rngs): + rngs, valids = zip(*[(r.get_idx(), r.get_valid()) for r in valid_rngs]) + # we compare the ranges without their valids + if all_same(rngs) and (all_all_same or RANGEIFY > 1): + # the new valid is the OR of all the children valids + minimum_valid = functools.reduce(operator.or_, valids, UOp.const(dtypes.bool, False)) + out_rngs.append(minimum_valid.where(rngs[0], UOp.invalid()).simplify()) else: out_rngs.append(ctx.new_range(c.shape[i])) end_ranges.append(out_rngs[-1]) idx_ranges.append(i) - ctx.seen_child[c] = (idx_ranges, end_ranges) + ctx.seen_child[c] = (out_rngs, idx_ranges, end_ranges) else: - out_rngs = list(idx.src[1:]) - idx_ranges, end_ranges = ctx.seen_child[c] + out_rngs, idx_ranges, end_ranges = ctx.seen_child[c] for i,nr in zip(idx_ranges, end_ranges): out_rngs[i] = nr # index based on the shared ranges ret = c.index(*out_rngs) # if all ranges aren't the same between children, we have to bufferize - if len(idx_ranges) > 0: ret = ret.bufferize(*end_ranges, arg=x.device).index(*[idx.src[1+i] for i in idx_ranges]) + if len(idx_ranges) > 0: + if len(idx_ranges) == len(out_rngs): + # this is a global bufferize + ret = ret.bufferize(*end_ranges, arg=BufferizeOpts(device=x.device)) + else: + assert RANGEIFY > 1, "this isn't supported with RANGEIFY=1" + ret = ret.bufferize(*end_ranges, arg=BufferizeOpts(device=None, addrspace=AddrSpace.LOCAL)) + ret = ret.index(*[idx.src[1+i] for i in idx_ranges]) return ret def children_gate(ctx:RangeifyContext, idx:UOp, c:UOp): @@ -248,20 +292,22 @@ def children_gate(ctx:RangeifyContext, idx:UOp, c:UOp): def might_end_axis(idx:UOp): if idx.arg is None: return None # TODO: write a proper cost function here - if all(x.op not in {Ops.BUFFER, Ops.CONTIGUOUS, Ops.BUFFERIZE} for x in idx.toposort()): return None + if all(x.op not in {Ops.BUFFER, Ops.REALIZE, Ops.BUFFERIZE} for x in idx.toposort()): return None if all(x.op not in {Ops.REDUCE_AXIS} for x in idx.toposort()): return None to_end_axis = [] for i,a in enumerate(idx.src[1:]): if any(x.arg > idx.arg for x in a.toposort() if x.op is Ops.RANGE): to_end_axis.append(i) - if to_end_axis: return idx.replace(src=(idx.src[0].contiguous(arg=tuple(to_end_axis)),)+idx.src[1:], arg=None) + if to_end_axis: return idx.replace(src=(idx.src[0].realize(arg=tuple(to_end_axis)),)+idx.src[1:], arg=None) return idx.replace(arg=None) +def unprocessed_index(x:UOp): raise RuntimeError(f"unprocessed index on {x.src[0].op}") + pm_rangeify = pm_mops+PatternMatcher([ # sink contigs to kick it off - (UPat(Ops.CONTIGUOUS, src=(UPat(),), name="x", allow_any_len=True), map_contiguous), + (UPat(Ops.REALIZE, src=(UPat(),), name="x", allow_any_len=True), map_realize), # if there's an INDEX it can support partial contig - (UPat(Ops.INDEX, src=(UPat(Ops.CONTIGUOUS, src=(UPat(),), name="x"),), allow_any_len=True, name="idx"), map_partial_contiguous), + (UPat(Ops.INDEX, src=(UPat(Ops.REALIZE, src=(UPat(),), name="x"),), allow_any_len=True, name="idx"), map_partial_realize), # if there are new ended children, tag the SINK (UPat(Ops.INDEX, src=(UPat(Ops.CHILD, src=(UPat(name="c"), ), name="x"),), allow_any_len=True, name="idx"), index_child), @@ -270,29 +316,39 @@ pm_rangeify = pm_mops+PatternMatcher([ # if we come across this, remove it. it was a CHILD unused in an INDEX (UPat(Ops.CHILD, src=(UPat(Ops.CHILDREN, src=(UPat.var("x"),)),)), lambda x: x), - # CONST (or DEFINE_VAR) can't have axes. remove srcs when we INDEX it - (UPat(Ops.INDEX, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"),)), lambda c: c.replace(src=())), + # CONST (or DEFINE_VAR) can't have axes. remove INDEX when we get here + (UPat(Ops.INDEX, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"),)), lambda c: c), # handle arg on any op with weight. old endrange stuff (UPat(Ops.INDEX, src=(UPat(GroupOp.Elementwise.union({Ops.REDUCE_AXIS})),), allow_any_len=True, name="idx"), might_end_axis), + # handle size 0 + (UPat(Ops.INDEX, name="x"), lambda x: x.replace(src=(x.const_like(0),)+x.src[1:]) if x.st is not None and x.size == 0 else None), + + # handle assign + (UPat(Ops.INDEX, src=(UPat(Ops.ASSIGN, name="assign"),), allow_any_len=True, name="x"), + lambda x,assign: assign.replace(src=tuple([s.index(*x.src[1:]) for s in assign.src])+(assign.src[0],))), + # move MAP through elementwise ALU / reduce. these are the items with cost - (UPat(Ops.INDEX, src=(UPat(GroupOp.Elementwise.union({Ops.STORE, Ops.ASSIGN, Ops.COPY, Ops.DEVICE, Ops.BIND})),), allow_any_len=True, name="x"), + (UPat(Ops.INDEX, src=(UPat(GroupOp.Elementwise.union( + {Ops.STORE, Ops.COPY, Ops.BUFFER_VIEW, Ops.DEVICE, Ops.BIND, Ops.CONTIGUOUS, Ops.NOOP})),), allow_any_len=True, name="x"), lambda x: x.src[0].replace(src=tuple([s.index(*x.src[1:]) for s in x.src[0].src]))), (UPat(Ops.INDEX, src=(UPat(Ops.REDUCE_AXIS, name="red"),), allow_any_len=True, name="idx"), map_reduce), + + # assert if there's any index we didn't process + (UPat(GroupOp.All-{Ops.REALIZE, Ops.BUFFERIZE}).f(Ops.INDEX, name="x"), unprocessed_index), ]) +# ***************** # 3.5 cleanups # you don't know in the first pass if axes are going to die, this happens if there's an EXPAND to the left -# TODO: figure out how to reenable this def cleanup_dead_axes(b:UOp): - parents = b.src[0].toposort() new_rng = [] hit = False reshape: list[sint] = [] for s,rng in zip(b.shape, b.src[1:]): - if rng not in parents and rng.op is Ops.RANGE: + if rng not in b.src[0].sparents and rng.op is Ops.RANGE: reshape.append(1) hit = True else: @@ -303,25 +359,51 @@ def cleanup_dead_axes(b:UOp): # if a buffer is being stored just for permutes or something, remove it # we want to reexpress the indexes of idx2 in terms of the implied b1 -def remove_bufferize(b2:UOp, idx2:UOp): - # HACK - if len(b2.src) != len(idx2.src): return None - assert len(b2.src) == len(idx2.src) - assert all(x.op is Ops.RANGE for x in b2.src[1:]) - return b2.src[0].substitute(dict(zip(b2.src[1:], idx2.src[1:]))) +def remove_bufferize(src:UOp, buf:UOp, idx:UOp): + # see if we can't do it, should this ever hit? + assert len(buf.src) == len(idx.src), "index on wrong bufferize" + assert all(x.op is Ops.RANGE for x in buf.src[1:]) + + # if it's user contiguous, we never remove it + if src.op is Ops.CONTIGUOUS: return None + + # here is where we compute the cost + # for now just no REDUCE, COPY, or ASSIGN + ran = src.toposort(gate=lambda x: x.op not in {Ops.INDEX}) + # we don't want to bufferize threefry, also causes problems because not all platforms support long + if any(x.op in {Ops.REDUCE, Ops.COPY, Ops.BUFFER_VIEW, Ops.ASSIGN} for x in ran) and src.op is not Ops.THREEFRY: return None + + # simple, matching old behavior + #if src.op is not Ops.INDEX: return None + + # this is the ranges replaced + return src.substitute(dict(zip(buf.src[1:], idx.src[1:]))) + +def pre_bufferize(b:UOp, x:UOp, copy:UOp): + nb = b.replace(src=(b.src[0].contiguous(),)+b.src[1:]) + return copy.replace(src=(x.replace(src=(nb,)+x.src[1:]), copy.src[1])) pm_cleanups = double_reshape+pm_mops+PatternMatcher([ #(UPat(Ops.BUFFERIZE, name="b"), cleanup_dead_axes), # remove noop buffers. if we look at the next index we can remove even more of these # NOTE: this is mostly the same case as below, but if there's no INDEX this gets more - #(UPat(Ops.INDEX, name="idx").f(Ops.BUFFERIZE, allow_any_len=True, name="b2"), - # lambda idx,b2: idx.src[0] if idx.src[1:] == b2.src[1:] else None), - # remove reindexing - (UPat(Ops.INDEX).f(Ops.BUFFERIZE, allow_any_len=True, name="b2").f(Ops.INDEX, allow_any_len=True, name="idx2"), remove_bufferize), + (UPat(Ops.INDEX, name="idx").f(Ops.BUFFERIZE, allow_any_len=True, name="b2"), + lambda idx,b2: idx.src[0].replace(tag=nt if len(nt:=(idx.src[0].tag or ()) + (b2.tag or ())) else None) if idx.src[1:] == b2.src[1:] \ + and idx.src[0].op is not Ops.BUFFER_VIEW else None), + # remove reindexing with cost function + (UPat.var("src").f(Ops.BUFFERIZE, allow_any_len=True, name="buf").f(Ops.INDEX, allow_any_len=True, name="idx"), remove_bufferize), # no buffers for const - #(UPat(Ops.CONST, name='c').f(Ops.BUFFERIZE, allow_any_len=True, name="b"), lambda c,b: c.reshape((1,)*len(b.shape)).expand(b.shape)), + (UPat(Ops.CONST, name='c').f(Ops.BUFFERIZE, allow_any_len=True, name="b"), + lambda c,b: c.reshape((1,)*len(b.shape)).expand(b.shape).replace(tag=b.tag)), + # if any CONST with DEVICE make it here (symbolic/copy issue), remove it + #(UPat(Ops.DEVICE).f(Ops.CONST, name="c"), lambda c: c.replace(src=())), + # copy on CONST is CONST + (UPat(Ops.COPY, src=(UPat.cvar("x"), UPat()), name="copy"), lambda copy,x: copy.const_like(x.arg)), + (UPat(Ops.COPY, src=(UPat(GroupOp.All-{Ops.CONTIGUOUS, Ops.COPY}).f(Ops.BUFFERIZE, allow_any_len=True, name="b") + .f(Ops.INDEX, allow_any_len=True, name="x"), UPat()), name="copy"), pre_bufferize), ]) +# ***************** # 4. put in buffers for bufferize # TODO: should BUFFERIZE look a lot more like STORE # BUFFERIZE has device in arg @@ -332,17 +414,40 @@ pm_cleanups = double_reshape+pm_mops+PatternMatcher([ def bufferize_to_store(x:UOp): rngs = x.src[1:] shape = tuple([int(r.vmax+1) for r in rngs]) - sdtype = x.dtype.ptr(size=prod(shape), addrspace=AddrSpace.GLOBAL if not isinstance(x.arg, AddrSpace) else x.arg) - assert prod(shape) > 0, f"no zero sized buffers {shape}" + sym_shape = tuple([ssimplify(r.src[0]) for r in rngs]) + size = prod(shape) + assert size > 0, f"no zero sized buffers {shape}" + + sdtype = x.dtype.ptr(size=size, addrspace=x.arg.addrspace) if x.src[0].op is Ops.ASSIGN: - assign_target, assign_src = x.src[0].src + assign_target, assign_src, assign_mops = x.src[0].src assert assign_target.op is Ops.INDEX - return assign_target.replace(dtype=sdtype).store(assign_src, *rngs, dtype=sdtype) + # in assign, this is the buffer size, not the bufferize size + # TODO: assign_mops here + ret = assign_target.replace(dtype=sdtype).store(assign_src, *rngs, dtype=x.dtype) + mops = [] + walk = assign_mops + while walk is not assign_mops.base: + mops.append((walk.op, walk.arg)) + walk = walk.src[0] + for m in mops[::-1]: ret = ret._mop(*m) + return ret.forced_reshape(shape).replace(tag=x.tag) + + # NOTE: the DEFINE_LOCAL needs to be disambiguated here if sdtype.addrspace == AddrSpace.GLOBAL: - buf = UOp.new_buffer(x.arg, prod(shape), x.dtype) - else: - # TODO: how to dedup this - buf = UOp(Ops.DEFINE_LOCAL, sdtype, arg=UOp.unique().arg) + buf = UOp.new_buffer(x.arg.device, size, x.dtype) + ret = buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs, dtype=x.dtype) + ret = ret.forced_reshape(shape) + # TODO: is this right? what if it's offset + if shape is not sym_shape: ret = ret.shrink(tuple([(0,x) for x in sym_shape])) + return ret.replace(tag=x.tag) + + # handle locals + tag = x.arg.device + if tag is None: tag = UOp.unique().arg # TODO: hack + buf = UOp(Ops.DEFINE_LOCAL, sdtype, arg=tag) + # store has the other dtype here + # TODO: how is this unified? return buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs, dtype=sdtype).forced_reshape(shape, dtype=x.dtype) pm_add_buffers = pm_mops+PatternMatcher([ @@ -353,6 +458,7 @@ pm_add_buffers = pm_mops+PatternMatcher([ lambda m: m.replace(src=tuple([x.src[0] for x in m.src])).reshape(m.src[0].arg)), ]) +# ***************** # 5. split into kernels @dataclass @@ -360,6 +466,7 @@ class LocalAddBufferContext: dg:int = 0 map:dict = field(default_factory=dict) vars:dict = field(default_factory=dict) + range:int = 0 def debuf(ctx:LocalAddBufferContext, buf:UOp): ret = UOp(Ops.DEFINE_GLOBAL, buf.dtype.ptr(buf.arg), arg=ctx.dg) @@ -379,6 +486,12 @@ def handle_assign(ctx:LocalAddBufferContext, assign:UOp): ctx.map[buf] = assign return buf +def renumber_range(ctx:LocalAddBufferContext, r:UOp): + if r.tag is not None: return None + ret = r.replace(arg=(ctx.range,)+r.arg[1:], tag=()) + ctx.range += 1 + return ret + to_define_global = PatternMatcher([ (UPat(Ops.BUFFER, name="buf"), debuf), (UPat(Ops.BIND, name="b"), unbind_kernel), @@ -386,59 +499,101 @@ to_define_global = PatternMatcher([ # HACK in case any CONSTs were replaced # this is only needed if you are using symbolic - #(UPat(Ops.CONST, name="c"), lambda c: c.replace(src=()) if len(c.src) else None), + (UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"), lambda c: c.replace(src=()) if len(c.src) else None), + + # renumber the ranges starting with 0 so that kernel deduping works + (UPat(Ops.RANGE, name="r"), renumber_range), ]) rangeify_codegen = PatternMatcher([ + # no NOOP in the kernel graph + # TODO: this can be moved into codegen? + (UPat((Ops.NOOP, Ops.CONTIGUOUS), name="x"), lambda x: x.src[0]), + + # strip the arg from store + (UPat(Ops.STORE, name="x"), lambda x: x.replace(arg=None) if x.arg is not None else None), + # add loads to non ptr indexes # TODO: this can be moved into codegen? (UPat((Ops.DEFINE_GLOBAL, Ops.STORE), name="dg").f(Ops.INDEX, name="idx", allow_any_len=True), - lambda dg,idx: idx.replace(dtype=dg.dtype, arg=None).load() if not isinstance(idx.dtype, PtrDType) else None), + lambda dg,idx: None if isinstance(idx.dtype, (PtrDType, ImageDType)) else idx.replace(dtype=dg.dtype, arg=None).load()), # TODO: this can be moved into codegen (UPat(Ops.STORE, name="store").f(Ops.INDEX, allow_any_len=True, name="idx").f(Ops.LOAD), lambda store,idx: idx.replace(src=(store.as_buf(),)+idx.src[1:]).load(store if idx.dtype.addrspace != AddrSpace.LOCAL else store.barrier())), + + # TODO: hack for group for reduce + (UPat(Ops.IF, src=(UPat.var("gate"), UPat(Ops.LOAD, src=(UPat.var("src"), UPat.var("barrier"))),)), + lambda src, barrier, gate: src.load(UOp(Ops.IF, src=(gate, barrier)))), ]) -def split_store(x:UOp): +def split_store(ctx:list[UOp], x:UOp): if len(x.ranges): return None - ctx = LocalAddBufferContext() - ret = graph_rewrite(x, to_define_global+rangeify_codegen, ctx=ctx, name="kernel split", bottom_up=True) + if x.src[0].ptrdtype.addrspace is AddrSpace.LOCAL: return None - # get name - rng = sorted([u for u in ret.toposort() if u.op is Ops.RANGE], key=lambda x: x.arg) - name = "k"+colored('_', 'BLACK').join(['']+[colored(s.src[0].render(), "WHITE" if s in ret.src[2:] else "red") for s in rng]) + # local kernel rewrite + lctx = LocalAddBufferContext() + ret = graph_rewrite(x, to_define_global+rangeify_codegen, ctx=lctx, name="kernel split", bottom_up=True) + + # gather the metadata + metadatas = [ctx[y].metadata for x in ret.sparents if x.tag is not None for y in x.tag] # NOTE: the hack for COPY is here - ret = ret.sink(arg=KernelInfo(name=name)) if ret.src[1].op is not Ops.COPY else ret.src[1] - kernel = UOp(Ops.KERNEL, src=tuple(ctx.map.values())+tuple(ctx.vars.keys()), arg=Kernel(ret,())) + ret = ret.sink() if ret.src[1].op not in {Ops.COPY, Ops.BUFFER_VIEW} else ret.src[1] + kernel_arg = Kernel(ret,tuple(dedup(flatten([x for x in metadatas if x is not None])))) + kernel = UOp(Ops.KERNEL, src=tuple(lctx.map.values())+tuple(lctx.vars.keys()), arg=kernel_arg) return x.as_buf().assign(kernel) split_kernels = PatternMatcher([ (UPat(Ops.STORE, name="x"), split_store), ]) -@track_rewrites(name=lambda sink,ret: f"Schedule {pluralize('Kernel',len([u for u in ret[sink].toposort() if u.op is Ops.KERNEL]))}", replay=True) -def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]: - tensor_map = graph_rewrite_map(sink, multi_pm+earliest_rewrites, name="earliest") - realize_map: dict[UOp, UOp] = {} - graph_rewrite(tensor_map[sink], do_realize, ctx=realize_map, name="Input Graph") - tensor_map = graph_rewrite_map(tensor_map[sink], add_contiguous, ctx=realize_map, bottom_up=True, input_map=tensor_map, name="add contiguous") - tensor_map = graph_rewrite_map(tensor_map[sink], remove_tags, input_map=tensor_map, name="cleanup") - tensor_map = graph_rewrite_map(tensor_map[sink], pm_children, ctx=ChildrenContext(), bottom_up=True, input_map=tensor_map, name="children") - tensor_map = graph_rewrite_map(tensor_map[sink], pm_rangeify, ctx=RangeifyContext(), bottom_up=True, input_map=tensor_map, name="rangeify") - # NOTE: running symbolic can break the graph, leaving RANGE/INDEX/BUFFERIZE in the final graph - #tensor_map = graph_rewrite_map(tensor_map[sink], symbolic_simple, input_map=tensor_map, name="symbolic") - tensor_map = graph_rewrite_map(tensor_map[sink], pm_cleanups, bottom_up=True, input_map=tensor_map, name="cleanups") - if getenv("VIZ"): graph_rewrite(tensor_map[sink], PatternMatcher([]), name="View Rangeify Graph") +def tag_uop(ctx:list[UOp], x:UOp): + if x.tag is not None: return None + ctx.append(x) + return x.replace(tag=(len(ctx)-1,)) +add_tags = PatternMatcher([ + # don't tag BUFFERs, they are global + (UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.DEFINE_VAR, Ops.BIND}.union(GroupOp.Movement), name="x"), tag_uop), +]) - tensor_map = graph_rewrite_map(tensor_map[sink], pm_add_buffers, bottom_up=True, input_map=tensor_map, name="add buffers") - tensor_map = graph_rewrite_map(tensor_map[sink], split_kernels, input_map=tensor_map, name="split kernels") +@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len([u for u in UOp.sink(*ret.values()).toposort() if u.op is Ops.KERNEL]))}", True) +def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]: + uop_list: list[UOp] = [] + tsink = graph_rewrite(sink, add_tags, ctx=uop_list, bottom_up=True, name="number the uops") + + # HACKS: handle multi with graph_rewrite_map in order to not have to add all the tag logic to multi + msink = graph_rewrite_map(tsink, multi_pm, name="multi") + tsink = msink[tsink].substitute({v:v.rtag(k.tag) for k,v in msink.items() if v.tag is None and k.tag is not None}) + + tsink = graph_rewrite(tsink, earliest_rewrites, name="earliest rewrites") + realize_map: dict[UOp, UOp] = {} + graph_rewrite(tsink, do_realize, ctx=realize_map, name="Input Graph") + # NOTE: we don't use contiguous here, contiguous is a user op + tsink = graph_rewrite(tsink, add_contiguous, ctx=realize_map, bottom_up=True, name="add realize") + tsink = graph_rewrite(tsink, remove_contig_tags, name="remove contiguous tags") + tsink = graph_rewrite(tsink, pm_children, ctx=ChildrenContext(), bottom_up=True, name="get children") + + # rangeify + tsink = graph_rewrite(tsink, pm_rangeify, ctx=RangeifyContext(), bottom_up=True, name="rangeify") + # NOTE: sym (vs symbolic_simple) breaks things here because ranges with len 1 aren't handled right + tsink = graph_rewrite(tsink, symbolic_simple, name="symbolic") # this supports const folding + tsink = graph_rewrite(tsink, pm_cleanups, bottom_up=True, name="remove costly buffers") + + # rebuild the sink with all the BUFFERIZEs with tags, this is what's ending up in the tensor graph + # if it's not tagged by here, it's out + tsink = UOp.sink(*[x for x in tsink.parents if (x.op is Ops.BUFFERIZE or x.base.op in {Ops.CONST}) and x.tag is not None]) + + if getenv("VIZ"): graph_rewrite(tsink, PatternMatcher([]), name="View Tagged Rangeify") + + # bufferize -> store + tsink = graph_rewrite(tsink, pm_add_buffers, bottom_up=True, name="bufferize to store") + tsink = graph_rewrite(tsink, split_kernels, ctx=uop_list, name="split kernels") # if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign kernel_assign: dict[UOp, UOp] = {} assign_rep: dict[UOp, UOp] = {} - for u in tensor_map[sink].toposort(): + for u in tsink.toposort(): if u.op is not Ops.ASSIGN: continue kernel_assign[u.buf_uop] = u for s in u.src[1].src: @@ -447,8 +602,14 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]: if any(x.op is Ops.ASSIGN and x.buf_uop is s for x in u.toposort()): raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on ASSIGN or BUFFER") assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,)) - if assign_rep: - tensor_map = graph_rewrite_map(tensor_map[sink], _substitute, ctx=assign_rep, bottom_up=True, input_map=tensor_map, name="fix_assign") + if assign_rep: tsink = graph_rewrite(tsink, _substitute, ctx=assign_rep, bottom_up=True, name="fix_assign") - if getenv("VIZ"): graph_rewrite(tensor_map[sink], PatternMatcher([]), name="View Kernel Graph") - return tensor_map + if getenv("VIZ"): graph_rewrite(tsink, PatternMatcher([]), name="View Kernel Graph") + + becomes_map: dict[UOp, UOp] = {} + for s in tsink.src: + assert s.tag is not None + for a in s.tag: + if a is None: continue + becomes_map[uop_list[cast(int, a)]] = s.replace(tag=None) + return becomes_map diff --git a/tinygrad_repo/tinygrad/shape/shapetracker.py b/tinygrad_repo/tinygrad/shape/shapetracker.py index 3abddc69..dca69bbe 100644 --- a/tinygrad_repo/tinygrad/shape/shapetracker.py +++ b/tinygrad_repo/tinygrad/shape/shapetracker.py @@ -5,45 +5,25 @@ import functools from typing import Callable from tinygrad.helpers import merge_dicts, getenv from tinygrad.shape.view import View, unravel -from tinygrad.dtype import dtypes -from tinygrad.uop.ops import UOp, Ops, graph_rewrite, Variable, sint, sint_to_uop, Context, PatternMatcher, UPat, GroupOp -from tinygrad.uop.symbolic import split_uop, symbolic_flat, uop_given_valid, simplify_valid - -# If a node overflow, its srcs need to be checked to see if this overflow is the result of an ALU operation, -# or that the node simply inherits the dtype from srcs. Upcast is either `Ops.CAST`+`replace` or just `replace`. -def handle_upcast(u: UOp) -> UOp|None: - dtype = dtypes.int64.vec(u.dtype.count) if u.dtype.count > 1 else dtypes.int64 - # check for overflow, upcast this to int64 - if u.vmax > dtypes.max(dtypes.int) or u.vmin < dtypes.min(dtypes.int): - return u.replace(dtype=dtype, src=tuple([x.cast(dtype) for x in u.src])) - # if any inputs are int64 and this *doesn't* overflow, cast back to int - if any(x.dtype == dtypes.int64 for x in u.src): - return u.replace(dtype=dtype, src=tuple([x.cast(dtype) for x in u.src])).cast(u.dtype) - return None -pm_upcast = PatternMatcher([(UPat(GroupOp.ALU, dtype=dtypes.int, name="u"), handle_upcast),]) +from tinygrad.uop.symbolic import sym +from tinygrad.uop.ops import UOp, Ops, graph_rewrite, Variable, sint, sint_to_uop, Context @functools.cache -def views_to_indexed_uops(views: tuple[View, ...], _idxs:tuple[UOp, ...]|None=None) -> tuple[UOp, UOp]: - idx, valid = views[-1].to_indexed_uops(_idxs) +def views_to_valid_uop(views: tuple[View, ...], _idxs:tuple[UOp, ...]|None=None) -> UOp: + idx = views[-1].to_valid_uop(_idxs) for view in reversed(views[0:-1]): view = view.minify() - idx, valid = view.to_indexed_uops([sint_to_uop(i) for i in unravel(view.shape, idx)], valid) + idx = view.to_valid_uop([sint_to_uop(i) for i in unravel(view.shape, idx)]) with Context(TRACK_MATCH_STATS=0): - # symbolic - idx, valid = graph_rewrite(UOp.sink(idx, valid), symbolic_flat, name="indexing sym @ 1").src - # simplify - if (newvalid:=simplify_valid(valid)) is not None: valid = newvalid - if (newidx:=uop_given_valid(valid, idx)) is not None: idx = newidx - # symbolic again, upcast if needed - return graph_rewrite(UOp.sink(idx, valid), symbolic_flat+pm_upcast, name="indexing sym @ 2").src + return graph_rewrite(idx, sym, name="indexing sym @ 1") @functools.cache def views_to_real_strides(views: tuple[View, ...], ignore_valid=False) -> tuple[sint|None, ...]: # NOTE: if a stride is not always valid, it will be None if len(views) == 1 and views[-1].mask is None: return views[-1].strides ret: list[sint|None] = [None] * len(views[-1].shape) - idx, valid = views_to_indexed_uops(views) - for c in split_uop(idx, Ops.ADD): + idx, valid = (vidx:=views_to_valid_uop(views)).get_idx(), vidx.get_valid() + for c in idx.split_uop(Ops.ADD): if c.op is Ops.RANGE: ret[c.arg[0]] = 1 if c.op is Ops.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg[0]] = c.src[1].arg if c.op is Ops.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg[0]] = c.src[0].arg @@ -83,21 +63,21 @@ class ShapeTracker: def reduce(self, axis:tuple[int, ...]) -> tuple[sint, ...]: return tuple(1 if i in axis else s for i,s in enumerate(self.shape)) - def to_indexed_uops(self, _idxs:list[UOp]|tuple[UOp, ...]|None=None) -> tuple[UOp, UOp]: - return views_to_indexed_uops(self.views, tuple(_idxs) if _idxs is not None else None) + def to_valid_uop(self, _idxs:list[UOp]|tuple[UOp, ...]|None=None) -> UOp: + return views_to_valid_uop(self.views, tuple(_idxs) if _idxs is not None else None) # upper bound on buffer size required to fit this shapetracker def real_size(self) -> int: if 0 in self.shape: return 0 view = (v.shrink(v.mask) if (v:=self.views[0]).mask else v) - idx, _ = views_to_indexed_uops((view,)) + idx = views_to_valid_uop((view,)).get_idx() assert idx.vmax < 1e12, f"real_size broken for {self}" return int(idx.vmax + 1) def vars(self) -> set[Variable]: return set().union(*[v.vars() for v in self.views]) @property - def var_vals(self) -> dict[Variable, int]: return merge_dicts([dict([v.unbind()]) for v in self.vars()]) + def var_vals(self) -> dict[str, int]: return merge_dicts([{(vu:=v.unbind())[0].expr:vu[1]} for v in self.vars()]) def unbind(self) -> tuple[ShapeTracker, dict[Variable, int]]: unbound_views, var_vals = zip(*[v.unbind() for v in self.views]) @@ -109,11 +89,6 @@ class ShapeTracker: with Context(TRACK_MATCH_STATS=0): return views_to_real_strides(self.views, ignore_valid) def unit_stride_axes(self, ignore_valid=False) -> list[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1] - def axis_is_masked(self, axis:int) -> bool: - with Context(TRACK_MATCH_STATS=0): - _, valid = self.to_indexed_uops() - return axis in [x.arg[0] for x in graph_rewrite(valid, symbolic_flat).toposort() if x.op is Ops.RANGE] - def simplify(self) -> ShapeTracker: if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None: return ShapeTracker(self.views[:-2] + (new_view,)).simplify() diff --git a/tinygrad_repo/tinygrad/shape/view.py b/tinygrad_repo/tinygrad/shape/view.py index 0475fc65..37da1564 100644 --- a/tinygrad_repo/tinygrad/shape/view.py +++ b/tinygrad_repo/tinygrad/shape/view.py @@ -3,7 +3,7 @@ import functools, operator, itertools from dataclasses import dataclass from typing import cast, Sequence from tinygrad.dtype import dtypes -from tinygrad.uop.ops import resolve, UOp, Variable, sint, sym_infer, smax, smin, sint_to_uop, Ops, ssimplify +from tinygrad.uop.ops import resolve, UOp, Variable, sint, smax, smin, sint_to_uop, Ops, ssimplify from tinygrad.helpers import prod, all_int, argsort, flatten, ceildiv # returns the axes to create new_shape if new_shape can be created by combining axis from old_shape @@ -112,16 +112,17 @@ class View: mask:tuple[tuple[sint, sint], ...]|None contiguous:bool - def to_indexed_uops(self:View, idxs:Sequence[UOp]|None=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> tuple[UOp, UOp]: - """(idx, valid)""" - if idxs is None: idxs = [UOp.range(dtypes.int, s, i) for i,s in enumerate(self.shape)] + def to_valid_uop(self, idxs:Sequence[UOp]|None=None) -> UOp: + """valid.where(idx, INVALID)""" + if idxs is None: idxs = [UOp.range(s, i) for i,s in enumerate(self.shape)] iexpr = sint_to_uop(self.offset) + where = UOp.const(dtypes.bool, True) for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else itertools.repeat(None)): - if resolve(sh != 1) and resolve(st != 0): iexpr = iexpr + idx*st + iexpr = iexpr + idx*sint_to_uop(st) if m is not None: - if resolve(m[0] != 0): vexpr = vexpr * (idx >= m[0]) - if resolve(m[1] != sh): vexpr = vexpr * (idx < m[1]) - return iexpr, vexpr + if resolve(m[0] != 0): where &= (idx >= sint_to_uop(m[0])) + if resolve(m[1] != sh): where &= (idx < sint_to_uop(m[1])) + return where.where(iexpr, UOp.invalid()) @functools.cache # pylint: disable=method-cache-max-size-none def size(self) -> int: @@ -204,15 +205,15 @@ class View: # Merge dimensions in vm2 if required. # NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required. - idxs: list[UOp] = [UOp.variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)] - merged_size, merged_term = 1, UOp.const(dtypes.int, 0) + idxs: list[UOp] = [UOp.variable(f"idx{i}", 0, s-1, dtypes.index) for i,s in enumerate(vm1.shape)] + merged_size, merged_term = 1, UOp.const(dtypes.index, 0) extents: list[tuple[sint, UOp]] = [] for term, s, o in zip(reversed(terms), reversed(vm2.shape), reversed(origin)): merged_term += (sum([idxs[d1] * s1 for d1, s1 in term]) + o) * merged_size merged_size *= s if resolve(merged_term < merged_size, False) and resolve(0 <= merged_term, False): extents.append((merged_size, merged_term)) - merged_size, merged_term = 1, UOp.const(dtypes.int, 0) + merged_size, merged_term = 1, UOp.const(dtypes.index, 0) if resolve(merged_term != 0): return None if (vm2_shape := tuple(s for s,_ in reversed(extents))) != vm2.shape: if (reshaped_vm2 := vm2.reshape(vm2_shape)) is None: return None @@ -311,9 +312,7 @@ class View: if not all(x >= 0 for x in new_shape): raise ValueError(f"shape can't contain negative numbers {new_shape}") # check for the same size - if (self_all_int := all_int(self.shape)): - assert all(isinstance(s, (int, UOp)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim" - if resolve(prod(self.shape) != prod(new_shape), False): raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}") + if resolve(prod(self.shape) != prod(new_shape), True): raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}") if 0 in self.shape: return View.create(new_shape) if new_shape == () and self.mask and any(mx==my for (mx,my) in self.mask): return None @@ -321,15 +320,6 @@ class View: # after the asserts, it's okay to check contiguous if self.contiguous: return View.create(new_shape) - # if it's not contiguous and new shape is symbolic, check if it's directly replaceable - if self_all_int and not all_int(new_shape): - if len(self.shape) != len(new_shape): raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}") - for si, so in zip(self.shape, new_shape): - if not isinstance(so, int): so = sym_infer(so, dict([v.unbind() for v in so.vars()])) - if si != so: raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}") - # all dimensions matched, return the new view directly - return View(new_shape, self.strides, self.offset, self.mask, self.contiguous) - r_strides, r_new_shape = [], reversed(new_shape) for merged_size, new_stride, real_size in reversed(merge_dims(self.shape, self.strides, self.mask)): # TODO: write with get_contraction diff --git a/tinygrad_repo/tinygrad/tensor.py b/tinygrad_repo/tinygrad/tensor.py index 10dd0467..a1b0c6ac 100644 --- a/tinygrad_repo/tinygrad/tensor.py +++ b/tinygrad_repo/tinygrad/tensor.py @@ -6,9 +6,10 @@ from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, Suppor from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate from tinygrad.dtype import _from_np_dtype, _to_np_dtype from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup -from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap, DEBUG, is_numpy_ndarray, RANGEIFY +from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap, DEBUG, is_numpy_ndarray, RANGEIFY, FUSE_ATTENTION from tinygrad.gradient import compute_gradient -from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, Variable, MathTrait, identity_element, all_metadata +from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, MathTrait, identity_element, all_metadata, index_to_concrete_int, sint_to_uop, \ + srender from tinygrad.uop.spec import tensor_uop_spec, type_verify from tinygrad.device import Device, Buffer from tinygrad.engine.realize import run_schedule @@ -68,7 +69,7 @@ def _frompy(x:list|tuple|bytes, dtype:DType) -> UOp: ret = UOp.new_buffer("PYTHON", prod(shape:=get_shape(x)), dtype).reshape(shape) assert dtype.fmt is not None, f"{dtype=} has None fmt" truncate_function = truncate[dtype] - data = struct.pack(f"@{ret.size}{dtype.fmt}", *[truncate_function(xi) for xi in fully_flatten(x)]) + data = struct.pack(f"{ret.size}{dtype.fmt}", *[truncate_function(dtypes.as_const(xi, dtype)) for xi in fully_flatten(x)]) # fake realize ret.buffer.allocate(memoryview(data if Device.DEFAULT != "PYTHON" else bytearray(data))) return ret @@ -98,7 +99,8 @@ def _broadcast_shape(*shapes:tuple[sint, ...]) -> tuple[sint, ...]: def _masked_setitem(target:Tensor, values:Tensor, mask:Tensor, axes:tuple[int, ...]) -> Tensor: # reduce such that if mask contains repeated indices the last one remains - for dim in axes: mask, values = functools.reduce(lambda x,y: (x[0]|y[0], y[0].where(y[1], x[1])), zip(mask.split(1, dim), values.split(1, dim))) + for dim in reversed(axes): + mask, values = functools.reduce(lambda x,y: (x[0]|y[0], y[0].where(y[1], x[1])), zip(mask.split(1, dim), values.split(1, dim))) # remove extra dims from reduce for dim in reversed(axes): mask, values = mask.squeeze(dim), values.squeeze(dim) # select from values for each True element in mask else select from target @@ -139,11 +141,13 @@ class Tensor(MathTrait): # create a UOp from the different types of inputs if isinstance(data, UOp): assert dtype is None or dtype==data.dtype, "dtype doesn't match, and casting isn't supported" - if data.op is Ops.BIND: - var, val = data.unbind() + # if data is dtype.index that means that this is a symbolic int and we need to lower it to something we can make a Tensor out of + if data.dtype==dtypes.index: data = index_to_concrete_int(data) + if data.op is Ops.BIND: # type: ignore # mypy type narrowing is bugged here + var, val = data.unbind() # type: ignore # give the bound constant a device const = UOp.const(var.dtype, val, device, ()) - data = data.replace(src=(var.replace(src=const.src), const)) + data = data.replace(src=(var.replace(src=const.src), const)) # type: ignore elif data is None: data = UOp.const(dtype or dtypes.default_float, 0, device, ()) elif isinstance(data, get_args(ConstType)): data = UOp.const(dtype or dtypes.from_py(data), data, device, ()) elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8 if dtype is None else dtype) @@ -175,7 +179,9 @@ class Tensor(MathTrait): # add to all_tensors after construction succeeds all_tensors[weakref.ref(self)] = None - def __del__(self): all_tensors.pop(weakref.ref(self), None) + def __del__(self): + try: all_tensors.pop(weakref.ref(self), None) + except Exception: pass def _apply_uop(self, fxn:Callable, *x:Tensor, extra_args=(), **kwargs) -> Tensor: new_uop: UOp = fxn(*[t.uop for t in (self,)+x], *extra_args, **kwargs) @@ -239,7 +245,7 @@ class Tensor(MathTrait): _apply_map_to_tensors(becomes_map, name="Apply Kernelize Map") return self - def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ScheduleItem], dict[Variable, int]]: + def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ScheduleItem], dict[str, int]]: """ Creates the schedule needed to realize these Tensor(s), with Variables. @@ -364,7 +370,7 @@ class Tensor(MathTrait): """ assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}" import numpy as np - if self.dtype.base == dtypes.bfloat16: return self.float().numpy() + if self.dtype.base in { dtypes.bfloat16, *dtypes.fp8s }: return self.float().numpy() if 0 in self.shape: return np.empty(self.shape, dtype=_to_np_dtype(self.dtype.base)) return self._buffer().numpy().reshape(self.shape) @@ -442,7 +448,7 @@ class Tensor(MathTrait): if not isinstance(size:=prod([x.vmax if isinstance(x, UOp) else x for x in shape]), int): raise ValueError(f"size must be int {size}") # TODO: add test for multidevice tensor device = tuple(Device.canonicalize(d) for d in device) if isinstance(device, tuple) else Device.canonicalize(device) - return Tensor(UOp.new_buffer(device, size, dtype), device, dtype, **kwargs).reshape(shape) + return Tensor(UOp.new_buffer(device, size, dtype), device, dtype, **kwargs).shrink(((0,prod(shape)),)).reshape(shape) @staticmethod def from_blob(ptr:int, shape:tuple[int, ...], **kwargs) -> Tensor: @@ -628,6 +634,7 @@ class Tensor(MathTrait): """ if stop is None: stop, start = start, 0 dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int) + if start < (dt:=to_dtype(dtype)).min or dt.max < (stop-step): raise ValueError(f"arange [{start}, {stop}) is not representable in dtype {dtype}") # NOTE: this matches numpy, torch raises RuntimeError if stop-start and step have different signs if (output_len:=ceildiv(stop-start, step)) <= 0: return Tensor([], dtype=dtype, **kwargs) return (Tensor.full((output_len,), step, dtype=dtype, **kwargs)._cumalu(0, Ops.ADD) + (start - step)).cast(dtype) @@ -988,6 +995,8 @@ class Tensor(MathTrait): # resolve -1 if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}") if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape]) + if resolve(prod(self.shape) != prod(new_shape), True): + raise ValueError(f"size mismatch, can't reshape ({', '.join(srender(d) for d in self.shape)}) -> ({', '.join(srender(d) for d in new_shape)})") return self._apply_uop(UOp.reshape, arg=new_shape) if new_shape != self.shape else self def expand(self, shape, *args) -> Tensor: @@ -1060,6 +1069,7 @@ class Tensor(MathTrait): print(t.shrink((((0, 2), (0, 2)))).numpy()) ``` """ + if self.ndim != len(arg): raise ValueError(f"{self.ndim=} != {len(arg)=}") if (shrink_arg:=[x if x is not None else (0,s) for x,s in zip(arg, self.shape)]) == [(0,s) for s in self.shape]: return self return self._apply_uop(UOp.shrink, arg=tuple(shrink_arg)) @@ -1126,6 +1136,10 @@ class Tensor(MathTrait): X = Tensor.cat(*(X_ for X_ in (xB, X, xA) if X_ is not None), dim=d) return X.shrink(tuple((-min(pB,0), min(pA+s,s)) for (pB,pA),s in zip(pX, X.shape))) + # convenience + def pad_to(self, shape, *args): return self.pad(tuple([(0, ns-s) for s,ns in itertools.zip_longest(self.shape, argfix(shape, *args))])) + def shrink_to(self, shape, *args): return self.shrink(tuple([(0, ns) for ns in argfix(shape, *args)])) + # ***** movement high level ops ***** def _getitem(self, indices, v: Tensor|None = None) -> Tensor: @@ -1163,6 +1177,9 @@ class Tensor(MathTrait): boundary, stride = [start, stop], step if all(isinstance(s, int) for s in (start,stop,step)): # handle int slicing + # if we're slicing a symbolic dimension into a int dimension, we can slice untill the bind size + # TODO: right now this is using vmax instead of the bind size because jit doesnt update the bound value of the returned tensor + if isinstance(size, UOp): size = int(size.vmax) *boundary, stride = index.indices(cast(SupportsIndex, size)) if stride * (boundary[1] - boundary[0]) < 0: boundary = [0, 0] elif stride < 0: boundary = [boundary[1] + 1, boundary[0] + 1] @@ -1215,8 +1232,8 @@ class Tensor(MathTrait): x = (mask.where(x.reshape(reshape_arg), 0)).sum(sum_axis:=tuple(d + len(big_shape) for d in dims), dtype=x.dtype) # special permute case - if dims[0] != 0 and len(dims) != 1 and tuple(dims) != tuple(range(dims[0], dims[-1]+1)): - x = x.permute(*range(dims[0], dims[0]+len(big_shape)), *range(0, dims[0]), *range(dims[0]+len(big_shape), x.ndim)) + if (permuted := dims[0] != 0 and len(dims) != 1 and tuple(dims) != tuple(range(dims[0], dims[-1]+1))): + mask, x = (y.permute(*range(dims[0], dims[0]+len(big_shape)), *range(0, dims[0]), *range(dims[0]+len(big_shape), y.ndim)) for y in (mask, x)) # for advanced setitem, returns whole tensor with indices replaced if v is not None: @@ -1224,7 +1241,7 @@ class Tensor(MathTrait): # add back reduced dims from sum for dim in sum_axis: vb = vb.unsqueeze(dim) # run _masked_setitem on tuple of axis that is to be reduced to match self.shape - x = _masked_setitem(self, vb, mask, tuple(range(dims[0], dims[0] + len(big_shape)))) + x = _masked_setitem(self, vb, mask, tuple(range((start := dims[0] if not permuted else 0), start + len(big_shape)))) return x @@ -1725,7 +1742,7 @@ class Tensor(MathTrait): ``` """ ret = self.cast(sum_acc_dtype(self.dtype) if dtype is None else dtype)._reduce(Ops.ADD, axis, keepdim) - return ret.cast(self.dtype) if dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret + return ret.cast(self.dtype) if dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16, *dtypes.fp8s) else ret def prod(self, axis:int|Sequence[int]|None=None, keepdim=False, dtype:DTypeLike|None=None) -> Tensor: """ @@ -2255,7 +2272,7 @@ class Tensor(MathTrait): xs:tuple[Tensor, ...] = argfix(*operands) inputs_str, output = parse_formula(formula, *xs) inputs = inputs_str.split(",") - assert len(xs) == len(inputs), f"number of inputs doesn't match number of operands in formula, expected {len(inputs)}, got {len(xs)}" + if len(xs)!=len(inputs): raise ValueError(f"number of inputs doesn't match number of operands in formula, expected {len(inputs)}, got {len(xs)}") # map the value of each letter in the formula letter_val = sorted(merge_dicts([dict(zip(letters, tensor.shape)) for letters, tensor in zip(inputs, xs)]).items()) @@ -3099,6 +3116,7 @@ class Tensor(MathTrait): print(Tensor([0., math.pi/2, math.pi, 3*math.pi/2, 2*math.pi]).cos().numpy()) ``` """ + if self.is_floating_point(): return ((math.pi/2)-self.cast(least_upper_dtype(self.dtype, dtypes.float32))).sin().cast(self.dtype) return ((math.pi/2)-self).sin() def tan(self) -> Tensor: @@ -3892,7 +3910,8 @@ class Tensor(MathTrait): def _one_hot_along_dim(self:Tensor, num_classes:sint, dim:int=-1) -> Tensor: if not dtypes.is_int(self.dtype): raise RuntimeError(f"_one_hot_along_dim expects int index tensor, getting {self.dtype}") offset = self.ndim - self._resolve_dim(dim) - 1 - return self == Tensor.arange(num_classes, device=self.device, requires_grad=False).reshape((num_classes,) + (1,) * offset) + dt = dtypes.int64 if sint_to_uop(num_classes).overflows(dtypes.int32) else dtypes.int32 + return self == Tensor.arange(num_classes, dtype=dt, device=self.device, requires_grad=False).reshape((num_classes,) + (1,) * offset) def one_hot(self, num_classes:int=-1) -> Tensor: """ @@ -3930,7 +3949,11 @@ class Tensor(MathTrait): if enable_gqa: key = key.repeat_interleave(self.shape[-3] // key.shape[-3], dim=-3) value = value.repeat_interleave(self.shape[-3] // value.shape[-3], dim=-3) - qk = self.matmul(key.transpose(-2,-1), dtype=least_upper_dtype(self.dtype, key.dtype, dtypes.float32)) / math.sqrt(self.shape[-1]) + + if FUSE_ATTENTION: q, key, value = self.contiguous(), key.contiguous(), value.contiguous() + else: q = self + + qk = q.matmul(key.transpose(-2,-1), dtype=least_upper_dtype(q.dtype, key.dtype, dtypes.float32)) / math.sqrt(q.shape[-1]) # handle attention mask if is_causal: if attn_mask is not None: raise RuntimeError("cannot set attn_mask when is_causal=True") @@ -3938,7 +3961,8 @@ class Tensor(MathTrait): if attn_mask is not None: if attn_mask.dtype == dtypes.bool: attn_mask = attn_mask.where(0, -float("inf")) qk = qk + attn_mask - return qk.cast(self.dtype).softmax(-1).dropout(dropout_p) @ value + attn = qk.cast(self.dtype).softmax(-1).dropout(dropout_p) @ value + return attn.fuse() if FUSE_ATTENTION else attn def _do_reduction(self, reduction:ReductionStr="mean") -> Tensor: if reduction not in get_args(ReductionStr): raise ValueError(f"{reduction=} must be one of {get_args(ReductionStr)}") @@ -4206,11 +4230,6 @@ class Tensor(MathTrait): # ***** cast ops ***** - def llvm_bf16_cast(self, dtype:DTypeLike) -> Tensor: - # hack for devices that don't support bfloat16 - assert self.dtype == dtypes.bfloat16 - return self.to("LLVM").cast(dtype) - def cast(self, dtype:DTypeLike) -> Tensor: """ Casts `self` to the given `dtype`. diff --git a/tinygrad_repo/tinygrad/uop/__init__.py b/tinygrad_repo/tinygrad/uop/__init__.py index 2a453a7f..1cab5641 100644 --- a/tinygrad_repo/tinygrad/uop/__init__.py +++ b/tinygrad_repo/tinygrad/uop/__init__.py @@ -22,6 +22,7 @@ class Ops(FastEnum): # ops that adjust the behavior of the scheduler CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); FUSE = auto() # noqa: E702 + REALIZE = auto() # blocks in linearizer (only used there) BLOCK = auto(); BLOCKSTART = auto(); BLOCKEND = auto(); BLOCKFINAL = auto() # noqa: E702 diff --git a/tinygrad_repo/tinygrad/uop/decompositions.py b/tinygrad_repo/tinygrad/uop/decompositions.py index dd807f90..cc3e5cf0 100644 --- a/tinygrad_repo/tinygrad/uop/decompositions.py +++ b/tinygrad_repo/tinygrad/uop/decompositions.py @@ -293,7 +293,7 @@ def fast_idiv(device: str, x: UOp, d: int, dont_cast=False) -> UOp|None: if (ret:=fast_idiv(device, x//largest_factor_of_two_in_d, d//largest_factor_of_two_in_d, dont_cast=True)) is not None: return ret if dont_cast: return None # promo_lattice needs to return an unsigned type if the type is unsigned - if dtypes.is_int(next_dtype := promo_lattice[x.dtype][-1]) and is_dtype_supported(next_dtype, None if device=='' else device): + if dtypes.is_int(next_dtype := promo_lattice[x.dtype.scalar()][-1]) and is_dtype_supported(next_dtype, None if device=='' else device): if m*vmin >= dtypes.min(next_dtype) and m*vmax <= dtypes.max(next_dtype): return ((x.cast(next_dtype)*m) >> s).cast(x.dtype) if is_unsigned else ((x.cast(next_dtype)*m) >> s).cast(x.dtype) + (x<0).where(x.ufix(1), 0) return None @@ -333,6 +333,8 @@ def get_late_rewrite_patterns(ops:tuple[Ops, ...], force_transcendental=False): if Ops.SQRT not in ops: pat.append((UPat(Ops.SQRT, src=UPat.var("d")), lambda d: xpow(d, d.const_like(0.5)))) # rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1) if Ops.AND in ops: pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)] + if Ops.OR in ops: pat += [(UPat.var("x", dtypes.bool).logical_not()&UPat.var("y", dtypes.bool).logical_not(), + lambda x,y: (x | y).logical_not())] # rewrite MUL/IDIV to SHL+SHR: x*(2**y) -> shl(x,y) and x//(2**y) -> shr(x,y) if Ops.SHL in ops: pat += [(UPat.var("x", dtypes.ints)*UPat.cvar("c"), lambda c,x: x << v if (v:=powers_of_two.get(c.arg, 0)) else None)] if Ops.SHR in ops: @@ -341,7 +343,7 @@ def get_late_rewrite_patterns(ops:tuple[Ops, ...], force_transcendental=False): pat += [(UPat.var("x", dtypes.ints)//UPat.cvar("c"), lambda x,c: (x+(l.const_like(l.vmin) if (l:=(x<0)).vmin==l.vmax else l).where( c-1, 0)) >> v if (v:=powers_of_two.get(c.arg, 0)) else None)] # (x+(x<0).where(c-1, 0)) >> v if not DISABLE_FAST_IDIV: - pat += [(UPat.var("x", dtypes.ints)//UPat.cvar("d"), lambda ctx, x, d: fast_idiv(ctx, x, d.arg))] + pat += [(UPat.var("x", dtypes.ints)//UPat.cvar("d", vec=False), lambda ctx, x, d: fast_idiv(ctx, x, d.arg))] pat += [(UPat.var("x", dtypes.ints)%UPat.var("d"), lambda x, d: x-d*(x//d))] if Ops.NEG in ops: pat += [(UPat.var('x')*-1, lambda x: x.alu(Ops.NEG))] diff --git a/tinygrad_repo/tinygrad/uop/mathtraits.py b/tinygrad_repo/tinygrad/uop/mathtraits.py index 05bf89d8..0de976c9 100644 --- a/tinygrad_repo/tinygrad/uop/mathtraits.py +++ b/tinygrad_repo/tinygrad/uop/mathtraits.py @@ -167,3 +167,4 @@ class MathTrait: def log2(self): return self.alu(Ops.LOG2) def exp2(self): return self.alu(Ops.EXP2) def pow(self, x): return self.alu(Ops.POW, self.ufix(x)) + def __pow__(self, x): return self.pow(x) diff --git a/tinygrad_repo/tinygrad/uop/ops.py b/tinygrad_repo/tinygrad/uop/ops.py index 91c9e0c2..01b30424 100644 --- a/tinygrad_repo/tinygrad/uop/ops.py +++ b/tinygrad_repo/tinygrad/uop/ops.py @@ -1,19 +1,23 @@ from __future__ import annotations from typing import Any, Callable, cast, TYPE_CHECKING, Type, Sequence -import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect, weakref +import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect, weakref, collections from dataclasses import dataclass, field from enum import Enum, auto from tinygrad.uop import Ops, GroupOp from tinygrad.uop.mathtraits import MathTrait -from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType +from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType, least_upper_dtype, Invalid, InvalidType from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA -from tinygrad.helpers import PICKLE_BUFFERS, PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey +from tinygrad.helpers import PICKLE_BUFFERS, PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey, RANGEIFY, VIZ if TYPE_CHECKING: from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.device import Buffer, MultiBuffer class AxisType(Enum): - GLOBAL = auto(); LOCAL = auto(); LOOP = auto(); GROUP_REDUCE = auto(); REDUCE = auto(); UPCAST = auto(); UNROLL = auto() # noqa: E702 + def __repr__(self): return str(self) + GLOBAL = auto(); WARP = auto(); LOCAL = auto(); LOOP = auto(); GROUP_REDUCE = auto(); REDUCE = auto(); UPCAST = auto(); UNROLL = auto() # noqa: E702 + THREAD = auto() + +range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3} # https://en.wikipedia.org/wiki/Identity_element def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt) @@ -37,7 +41,7 @@ def smin(*lst): return _suop(argfix(*lst), UOp.minimum, min) def srender(x) -> str: return x.render() if isinstance(x, UOp) else str(x) def ssimplify(uop): return uop.ssimplify() if isinstance(uop, UOp) else uop -def sym_infer(uop: UOp|int, var_vals: dict[UOp, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop +def sym_infer(uop: UOp|int, var_vals: dict[str, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop # used for UOp and UPat def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->str: @@ -102,6 +106,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is Ops.REDUCE_AXIS else repr(self.arg) def tagstr(self): return f", tag={self.tag}" if self.tag is not None else "" + def f(self, op, **kwargs): return UOp(op, dtype=kwargs.pop("dtype", self.dtype), src=(self,), **kwargs) + @functools.cached_property def parents(self:UOp) -> dict[UOp, None]: ret = {s:None for s in self.src} @@ -135,6 +141,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def tuplize(self:UOp) -> tuple: return (self.op.value, self.arg, self.dtype,)+tuple([x.tuplize for x in self.src]) + @property + def ptrdtype(self) -> PtrDType: + if not isinstance(self.dtype, PtrDType): raise RuntimeError("ptrdtype called on UOp without PtrDType") + return self.dtype + # *** uop shape stuff *** @functools.cached_property @@ -142,18 +153,19 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if self.op is Ops.INDEX and self.src[0].op in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.BUFFER, Ops.BUFFERIZE, Ops.VECTORIZE, Ops.STORE}: return None + if self.op is Ops.BARRIER: return None if self.op in GroupOp.Block: return None from tinygrad.shape.shapetracker import ShapeTracker # VIEW and MovementOps define a new ShapeTracker from the arg if self.op is Ops.VIEW: return self.arg - if self.op is Ops.BUFFERIZE: return ShapeTracker.from_shape((prod(tuple([int(r.vmax+1) for r in self.src[1:]])),)) - #if self.op is Ops.BUFFERIZE: return ShapeTracker.from_shape(tuple([r.vmax+1 for r in self.src[1:]])) + if self.op is Ops.BUFFERIZE: return ShapeTracker.from_shape(tuple([int(r.vmax+1) for r in self.src[1:]])) # allow reshape from nothing if self.op is Ops.RESHAPE and self.src[0].st is None: return ShapeTracker.from_shape(self.arg) if self.op in GroupOp.Movement: return unwrap(self.src[0].st).mop(self.op, self.arg) # CONST with a DEVICE has a shape of () if self.op is Ops.CONST and len(self.src) and self.src[0].op is Ops.DEVICE: return ShapeTracker.from_shape(()) if self.op is Ops.STORE and isinstance(self.dtype, PtrDType): return ShapeTracker.from_shape((self.dtype.size,)) + if self.op is Ops.STORE and self.dtype is not dtypes.void: return self.src[0].src[0].st # BufferOps and ASSIGN flow ShapeTracker from a direct edge if self.op in {Ops.STORE, Ops.ASSIGN, Ops.LOAD}: return self.src[0].st if self.op in GroupOp.Buffer: return views[0] if (views:=[x.st for x in self.src if x.op is Ops.VIEW]) else None @@ -162,7 +174,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if self.op in {Ops.BUFFER, Ops.BUFFER_VIEW}: return ShapeTracker.from_shape((self.size,)) if self.op is Ops.KERNEL: return ShapeTracker.from_shape((self.arg.ast.size,)) if self.op in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}: - sz = cast(PtrDType, self.dtype).size + sz = self.ptrdtype.size return ShapeTracker.from_shape((sz,)) if sz > 0 else None # CONTIGUOUS with RANGE @@ -202,17 +214,12 @@ class UOp(MathTrait, metaclass=UOpMetaClass): @functools.cached_property def ranges(self) -> dict[UOp, None]: if self.op is Ops.RANGE: return {self:None} - if self.op in {Ops.BUFFERIZE, Ops.REDUCE}: - ret = self.src[0].ranges.copy() - for s in self.src[1:]: - if s in ret: del ret[s] - elif self.op in {Ops.STORE}: - ret = self.src[0].ranges.copy() - ret.update(self.src[1].ranges) - for s in self.src[2:]: + ret: dict[UOp, None] = {} + if self.op in range_start.keys(): + for s in self.src[:range_start[self.op]]: ret.update(s.ranges) + for s in UOp.sink(*self.src[range_start[self.op]:]).ranges: if s in ret: del ret[s] else: - ret = {} for s in self.src: ret.update(s.ranges) return ret @@ -251,7 +258,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass): ret = self.arg[1] if self.op is Ops.REDUCE_AXIS else self.arg[7] assert isinstance(ret, tuple) and all(isinstance(x, int) for x in ret), f"axis_arg trying to return {ret}" return ret - def sink(self, *srcs:UOp|None, **kwargs): return UOp(Ops.SINK, dtypes.void, (self,)+tuple([x for x in srcs if x is not None]), **kwargs) + def sink(*srcs:UOp|None, **kwargs): # pylint: disable=no-self-argument + return UOp(Ops.SINK, dtypes.void, tuple([x for x in srcs if x is not None]), **kwargs) def detach(self): return UOp(Ops.DETACH, self.dtype, (self,)) def index(self, *srcs:UOp|None, **kwargs): return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype), (self,)+tuple([x for x in srcs if x is not None]), **kwargs) @@ -287,20 +295,27 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if op in {Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ}: out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool return UOp(op, out_dtype, (self,)+src, **kwargs) @staticmethod - def const(dtype:DType, b:ConstLike, device:str|tuple[str, ...]|None=None, shape:tuple[sint, ...]|None=None): + def const(dtype:DType, b:ConstLike, device:str|tuple[str, ...]|None=None, shape:tuple[sint, ...]|None=None, src=None): if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same - ret = UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype)) - if shape is not None: - from tinygrad.shape.shapetracker import ShapeTracker - ret = ret.replace(src=(UOp(Ops.VIEW, dtypes.void, (), ShapeTracker.from_shape(shape, (0,)*len(shape))),)) - if device is not None: - if shape is not None: ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device).view(unwrap(ret.st)),)) - else: ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device),)) + ret = UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype), src=() if src is None else (src,)) + if RANGEIFY: + # VIEW on const is no longer supported in RANGEIFY + if device is not None: ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device),)) + if shape is not None: ret = ret.reshape((1,)*len(shape)).expand(shape) + else: + if shape is not None: + from tinygrad.shape.shapetracker import ShapeTracker + ret = ret.replace(src=(UOp(Ops.VIEW, dtypes.void, (), ShapeTracker.from_shape(shape, (0,)*len(shape))),)) + if device is not None: + if shape is not None: ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device).view(unwrap(ret.st)),)) + else: ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device),)) return ret @staticmethod - def range(dtype:DType, end:sint, idx:int, axistype:AxisType=AxisType.LOOP): - return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end),), arg=(idx, axistype)) + def range(end:sint, *arg): + if len(arg) == 0: raise RuntimeError("range needs an arg") + if len(arg) == 1: arg = arg+(AxisType.LOOP,) + return UOp(Ops.RANGE, dtype=dtypes.index, src=(sint_to_uop(end),), arg=arg) def r(self, op:Ops, axis:tuple[int, ...]): axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)])) if len(axis) == 0: return self @@ -312,14 +327,31 @@ class UOp(MathTrait, metaclass=UOpMetaClass): assert len(axis) == len(new_axis) ret = UOp(Ops.REDUCE_AXIS, self.dtype, (ret,), (op, new_axis)) return ret.reshape(tuple([x if i not in axis else 1 for i,x in enumerate(self.shape)])) + @staticmethod + def invalid(): return UOp(Ops.CONST, dtypes.index, src=(), arg=Invalid) + def get_idx(self) -> UOp: + assert self.dtype is dtypes.index, "Can only call get_idx on index dtype" + return self.src[1] if self.op is Ops.WHERE and self.src[2].arg is Invalid else self + def get_valid(self) -> UOp: + assert self.dtype is dtypes.index, "Can only call get_valid on index dtype" + return self.src[0] if self.op is Ops.WHERE and self.src[2].arg is Invalid else UOp.const(dtypes.bool, self.arg is not Invalid) def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs) def contiguous(self, *args, **kwargs): return UOp(Ops.CONTIGUOUS, dtype=self.dtype, src=(self,)+args, **kwargs) + def realize(self, *args, **kwargs): return UOp(Ops.REALIZE, dtype=self.dtype, src=(self,)+args, **kwargs) def contiguous_backward(self): return self.alu(Ops.CONTIGUOUS_BACKWARD) def bufferize(self, *args, **kwargs): return UOp(Ops.BUFFERIZE, dtype=self.dtype, src=(self,)+args, **kwargs) def fuse(self): return self.alu(Ops.FUSE) def allreduce(self, op, device:str|tuple[str, ...]|UOp): assert isinstance(self.device, tuple), f"allreduce must be on tuple {self.device} isn't" return UOp(Ops.ALLREDUCE, self.dtype, (self, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device), op) + def overflows(self, dtype:DType) -> bool: return self.vmin < dtype.min or dtype.max < self.vmax + + # *** ShapeTracker helpers *** + + def split_uop(self:UOp, sep:Ops): + if self.op is sep: + for s in self.src: yield from s.split_uop(sep) + else: yield self # *** from MultiLazyBuffer *** @@ -467,11 +499,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass): # *** uop Variable stuff *** @staticmethod - def variable(name:str, min_val:ConstType, max_val:ConstType, dtype:DType=dtypes.int) -> UOp: + def variable(name:str, min_val:ConstType, max_val:ConstType, dtype:DType=dtypes.index) -> UOp: assert not isinstance(min_val, UOp) and not isinstance(max_val, UOp), f"can't create Variable {name} with {min_val}/{max_val}" return UOp(Ops.DEFINE_VAR, dtype, arg=(name, min_val, max_val)) @property - def expr(self): + def expr(self) -> str: assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR" return self.arg[0] def bind(self, val:int|UOp): @@ -518,7 +550,23 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1] if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1 return None # generic None if we aren't sure - def pop_const(self) -> tuple[UOp, int]: return (self.src[0], self.src[1].arg) if self.op is Ops.ADD and self.src[1].op is Ops.CONST else (self, 0) + def pop_const(self, op=Ops.ADD) -> tuple[UOp, ConstType]: + return (self.src[0], self.src[1].arg) if self.op is op and self.src[1].op is Ops.CONST else (self, identity_element(op, self.dtype)) + @staticmethod + def gcd(*uops: UOp) -> UOp: + terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in uops]) + count = functools.reduce(operator.and_, [collections.Counter(term.split_uop(Ops.MUL)) for term in terms]) + return math.prod([*count.elements(), terms[0].const_like(math.gcd(*factors))]) # put the const at the top + def divide_exact(self, v:UOp) -> UOp|None: + if self is v: return self.const_like(1) + if self.op is Ops.ADD: return None if (s0:=self.src[0].divide_exact(v)) is None or (s1:=self.src[1].divide_exact(v)) is None else s0+s1 + if v.op is Ops.CONST: return self.divides(v.arg) + if self.op is Ops.MUL: + (fac, const), (div_fac, div_const) = self.pop_const(Ops.MUL), v.pop_const(Ops.MUL) + new_count = collections.Counter(fac.split_uop(Ops.MUL)) + new_count.subtract(div_fac.split_uop(Ops.MUL)) + if const%div_const==0 and all(v>=0 for v in new_count.values()): return math.prod([*new_count.elements(), self.const_like(const//div_const)]) + return None # generic None if we aren't sure @property def vmin(self) -> ConstType: return self._min_max[0] @property @@ -551,15 +599,14 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if self.op is Ops.WHERE and dtypes.is_int(self.dtype): return min(self.src[1].vmin, self.src[2].vmin), max(self.src[1].vmax, self.src[2].vmax) # NOTE: returned UOp is assumed to be CONST if self.op is Ops.DEFINE_VAR and self.arg: return self.arg[1], self.arg[2] - if self.op is Ops.RANGE: return 0, (self.src[0]-1).vmax + if self.op in (Ops.RANGE, Ops.SPECIAL): return 0, (self.src[0]-1).vmax if self.op is Ops.BIND: return self.src[0]._min_max # ignore the bound value if self.op in {Ops.UNROLL, Ops.VECTORIZE}: return min(x.vmin for x in self.src), max(x.vmax for x in self.src) - # TODO: Ops.SPECIAL is Ops.DEFINE_VAR - if self.op is Ops.SPECIAL: return 0, self.arg[1]-1 if isinstance(self.arg[1], int) else self.arg[1].vmax - if self.op is Ops.CONST: return self.arg, self.arg - if self.op is Ops.VCONST: return (min(self.arg), max(self.arg)) + if self.op is Ops.CONST and self.arg is not Invalid: return self.arg, self.arg + if self.op is Ops.VCONST and Invalid not in self.arg: return (min(self.arg), max(self.arg)) + if self.op is Ops.GEP: return self.src[0]._min_max # TODO: CAST to bool/unsigned is not monotone, still some case can be simplified - if self.op is Ops.CAST and self.dtype in (dtypes.floats+dtypes.sints): + if self.op is Ops.CAST and self.dtype in dtypes.floats+dtypes.sints+(dtypes.index,): return max(dtypes.min(self.dtype), self.src[0].vmin), min(self.src[0].vmax, dtypes.max(self.dtype)) return dtypes.min(self.dtype), dtypes.max(self.dtype) @@ -570,9 +617,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass): # TODO: sanitize varnames, or don't use naked eval while staying fast return eval("lambda "+','.join(varnames)+": "+sself.render(pm=renderer_infer)), varnames # pylint: disable=eval-used - def sym_infer(self, var_vals:dict[UOp, int]): + def sym_infer(self, var_vals:dict[str, int]): fxn, varnames = self._sym_fxn - return fxn(**{k.arg[0]:v for k,v in var_vals.items() if k.arg[0] in varnames}) + return fxn(**{k:v for k,v in var_vals.items() if k in varnames}) def render(self, simplify=True, pm:PatternMatcher|None=None) -> str: with Context(TRACK_MATCH_STATS=0): @@ -611,6 +658,7 @@ python_alu: dict[Ops, Callable] = { def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True): if dtype.count > 1: return tuple([exec_alu(op, dtype.scalar(), [x[i] if isinstance(x, tuple) else x for x in operands]) for i in range(dtype.count)]) + if dtype==dtypes.index and op in GroupOp.Binary and Invalid in operands: return Invalid alu = python_alu[op](*operands) return truncate.get(dtype, lambda x: x)(alu) if truncate_output else alu @@ -681,9 +729,10 @@ class UPat(MathTrait): def var(name:str|None=None, dtype:DType|tuple[DType, ...]|None=None): return UPat(dtype=dtype, name=name) @staticmethod @functools.cache - def cvar(name:str|None=None, dtype:DType|None=None, vec=True): return UPat((Ops.CONST,Ops.VCONST) if vec else Ops.CONST, dtype, name=name) + def cvar(name:str|None=None, dtype:DType|tuple[DType, ...]|None=None, vec=True): + return UPat((Ops.CONST,Ops.VCONST) if vec else Ops.CONST, dtype, name=name) @staticmethod - def const(dtype:DType|tuple[DType, ...]|None, b:ConstType): return UPat(Ops.CONST, dtype=dtype, arg=b) + def const(dtype:DType|tuple[DType, ...]|None, b:ConstType|InvalidType): return UPat(Ops.CONST, dtype=dtype, arg=b) # lil helper def f(self, op, **kwargs): return UPat(op, src=(self,), **kwargs) @@ -700,7 +749,8 @@ class UPat(MathTrait): def assign(self, x:UPat, **kwargs): return UPat(Ops.ASSIGN, self.dtype, (self,x), **kwargs) def reduce(self, *src:UPat, **kwargs): return UPat(Ops.REDUCE, self.dtype, src=(self,)+src, **kwargs) def fuse(self): return self.alu(Ops.FUSE) - def or_broadcasted(self, **kwargs): return UPat.any(self, UPat(Ops.VECTORIZE, self.dtype, src=self, **kwargs)) + def broadcast(self, **kwargs): return UPat(Ops.VECTORIZE, self.dtype, src=self, **kwargs) + def or_broadcasted(self, **kwargs): return UPat.any(self, self.broadcast(**kwargs)) def contiguous(self, *args, **kwargs): return UPat(Ops.CONTIGUOUS, dtype=self.dtype, src=(self,)+args, **kwargs) def const_like(self, b:ConstLike): return UPat.const(self.dtype, cast(ConstType, b)) @@ -802,7 +852,6 @@ def track_uop(u:UOp): # *** tracking pattern matcher *** -VIZ = ContextVar("VIZ", 0) TRACK_MATCH_STATS = ContextVar("TRACK_MATCH_STATS", 2 if VIZ else 0) match_stats:dict[UPat, list[int|float]] = dict() @@ -826,20 +875,22 @@ if getenv("CAPTURE_PROCESS_REPLAY"): def save_to_diskcache(): for k,v in replay_capture.items(): diskcache_put("process_replay", k, v, prepickled=True) +def add_trace_group(kt:TracingKey) -> None: + tracked_keys.append(kt) + tracked_ctxs.append([]) + def track_rewrites(name:Callable[..., str|TracingKey]|bool=True, replay:bool=False): def _decorator(func): def __wrapper(*args, **kwargs): fn = key = func.__name__ - if TRACK_MATCH_STATS >= 2: - tracked_keys.append(key:=TracingKey(n:=f"{fn} n{next(_name_cnt.setdefault(fn, itertools.count(1)))}", (n,), cat=fn)) - tracked_ctxs.append([]) + if TRACK_MATCH_STATS >= 2: add_trace_group(key:=TracingKey(n:=f"{fn} n{next(_name_cnt.setdefault(fn, itertools.count(1)))}", (n,))) with cpu_profile(key, "TINY") as e: ret = func(*args, **kwargs) if TRACK_MATCH_STATS >= 2 and callable(name): name_ret = name(*args, **kwargs, ret=ret) assert isinstance(name_ret, (TracingKey, str)), f"name function returned {type(name_ret)}" tracked_keys[-1] = k = TracingKey(n:=tracked_keys[-1].display_name.replace(fn, name_ret), (n,)) if isinstance(name_ret, str) else name_ret - e.name = TracingKey(k.display_name if isinstance(name_ret, str) else f"{fn} for {k.display_name}", k.keys, cat=fn) + e.name = TracingKey(k.display_name if isinstance(name_ret, str) else f"{fn} for {k.display_name}", k.keys) if getenv("CAPTURE_PROCESS_REPLAY") and replay: # find the unittest frame we're capturing in frm = sys._getframe(1) @@ -856,9 +907,10 @@ def track_rewrites(name:Callable[..., str|TracingKey]|bool=True, replay:bool=Fal active_rewrites:list[TrackedGraphRewrite] = [] def track_matches(func): def _track_func(*args, **kwargs): - if tracking:=(TRACK_MATCH_STATS >= 2 and tracked_ctxs): + if tracking:=(TRACK_MATCH_STATS >= 2): loc = ((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno) depth = len(active_rewrites) + if not tracked_ctxs: add_trace_group(TracingKey(f"default {func.__name__}")) tracked_ctxs[-1].append(ctx:=TrackedGraphRewrite(loc, track_uop(args[0]), [], kwargs.get("name", None), depth, kwargs.get("bottom_up", False))) active_rewrites.append(ctx) with cpu_profile(kwargs.get("name", ""), "TINY", display=tracking): @@ -901,8 +953,8 @@ if TRACK_MATCH_STATS or PROFILE: if TRACK_MATCH_STATS >= 2: with open(fn:=temp("rewrites.pkl", append_user=True), "wb") as f: print(f"rewrote {len(tracked_ctxs)} graphs and matched {sum(len(r.matches) for x in tracked_ctxs for r in x)} times, saved to {fn}") - pickle.dump((tracked_keys, tracked_ctxs, uop_fields), f) - if VIZ: launch_viz(VIZ, temp("rewrites.pkl", append_user=True)) + pickle.dump([(tracked_keys, tracked_ctxs, uop_fields)], f) + if VIZ: return launch_viz("VIZ", temp("rewrites.pkl", append_user=True)) if getenv("PRINT_MATCH_STATS", TRACK_MATCH_STATS.value): ret = [0,0,0.0,0.0] for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]+x[1][3]): @@ -912,11 +964,10 @@ if TRACK_MATCH_STATS or PROFILE: print(f"{ret[0]:6d} / {ret[1]:7d} -- {ret[3]*1000.:9.2f} / {(ret[2]+ret[3])*1000.:9.2f} ms -- TOTAL") print(f"{len(match_stats)} rules, {sum(v[0] > 0 for v in match_stats.values())} matched once") - def launch_viz(var:ContextVar, data:str): - os.environ[(env_str:=var.key)] = "0" + def launch_viz(env_str:str, data:str): + os.environ[env_str] = "0" os.environ[f"{env_str}_DATA"] = data - os.environ[f"{env_str}_VALUE"] = str(var.value) - if not int(os.getenv("VIZ", "0")) and not int(os.getenv("PROFILE", "0")): + if not int(os.getenv("VIZ", "0")) and not int(os.getenv("PROFILE", "0")) and not int(os.getenv("SQTT", "0")): args = ['--kernels', getenv("VIZ_DATA", "")] if getenv("VIZ_DATA", "") else [] args += ['--profile', getenv("PROFILE_DATA", "")] if getenv("PROFILE_DATA", "") else [] os.execv(sys.executable, [sys.executable] + [os.path.join(os.path.dirname(__file__), "../", "viz", "serve.py")] + args) @@ -924,6 +975,7 @@ if TRACK_MATCH_STATS or PROFILE: # *** simple graph rewrite engine *** class RewriteNotReady(Exception): pass +class BottomUpGate(Exception): pass class RewriteContext: def __init__(self, pm, bpm, ctx=None): self.pm: PatternMatcher|None = pm @@ -951,20 +1003,23 @@ class RewriteContext: if n in self.replace: continue # skip any nodes we have seen try: if stage == 0: - # if bottom up, we rewrite this node early. in both cases, we add its parents to the stack - if self.bpm is not None: - # apply rewrite rules until a fixed point is reached. may return `uop` itself if PatternMatcher doesn't match - test_n: UOp|None = n - seen = set() - while test_n is not None: - if test_n in seen: raise RuntimeError("infinite loop in fixed_point_rewrite") - seen.add(test_n) - new_n, test_n = test_n, self.cached_bpm_rewrite(test_n) - stack.append((n, 1, new_n)) - for x in reversed(new_n.src): stack.append((x, 0, x)) + try: + # if bottom up, we rewrite this node early. in both cases, we add its parents to the stack + if self.bpm is not None: + # apply rewrite rules until a fixed point is reached. may return `uop` itself if PatternMatcher doesn't match + test_n: UOp|None = n + seen = set() + while test_n is not None: + if test_n in seen: raise RuntimeError("infinite loop in fixed_point_rewrite") + seen.add(test_n) + new_n, test_n = test_n, self.cached_bpm_rewrite(test_n) + stack.append((n, 1, new_n)) + for x in reversed(new_n.src): stack.append((x, 0, x)) + # if the bpm matching raised a gate, we are done with this node and dont continue down the srcs + except BottomUpGate: self.replace[n] = new_n elif stage == 1: try: new_src = tuple([self.replace[x] for x in new_n.src]) - except KeyError: raise RewriteNotReady # pylint: disable=raise-missing-from + except KeyError: raise RewriteNotReady if new_src == new_n.src: # if top down, do the rewrite. if no rewrite or bottom up, we are done rewriting this node so we add it to the dict if self.pm is None or (new_src_n:=self.cached_pm_rewrite(new_n)) is None: @@ -979,7 +1034,7 @@ class RewriteContext: else: # in stage 2, we link the result of new_n to the result of n try: self.replace[n] = self.replace[new_n] - except KeyError: raise RewriteNotReady # pylint: disable=raise-missing-from + except KeyError: raise RewriteNotReady except RewriteNotReady: # retry this later stack.insert(0, (n, stage, new_n)) @@ -1002,7 +1057,30 @@ def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, na for k,v in input_map.items(): new_map[k] = new_map.get(v,v) return new_map -def sint_to_uop(x:sint, dtype:DType=dtypes.int) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x +def sint_to_uop(x:sint) -> UOp: return UOp.const(dtypes.index, x) if isinstance(x, int) else x.cast(dtypes.index) + +def select_dtype(u): return (dtypes.long if u.overflows(dtypes.int32) else dtypes.int).vec(u.dtype.count) +pm_lower_index_dtype = PatternMatcher([ + # There are no Unary ops at this point in symbolic, those are introduced later + (UPat(GroupOp.Binary, dtypes.index, name="u", src=(UPat.var("x"), UPat.var("y"))), lambda u,x,y: + x.cast(dt:=least_upper_dtype(select_dtype(u), x.dtype, y.dtype)).alu(u.op, y.cast(dt))), + # comparison ops might now have different dtypes in their sources + (UPat(GroupOp.Comparison, name="u", src=(UPat.var("x",dtypes.ints), UPat.var("y", dtypes.ints))), lambda u,x,y: + x.cast(dt:=least_upper_dtype(x.dtype, y.dtype)).alu(u.op, y.cast(dt)) if x.dtype!=y.dtype else None), + (UPat(Ops.WHERE, dtype=dtypes.index, src=(UPat.var("cond"), UPat.var("x"), UPat.var("y")), name="u"), lambda cond,u,x,y: + cond.where(x.cast(dt:=least_upper_dtype(x.dtype, y.dtype)), y.cast(dt))), + (UPat((Ops.CONST, Ops.VCONST), dtype=dtypes.index, name="u"), lambda u: u.replace(dtype=select_dtype(u))), + (UPat((Ops.RANGE,), dtype=dtypes.index, src=(UPat.var("end")), name="r"), lambda ctx,r,end: + r.replace(dtype=(dt:=select_dtype(r)), src=(end.cast(dt),))), + (UPat(Ops.CAST, dtype=dtypes.index, src=(UPat.var("x", dtypes.ints),), name="u"), lambda u,x: x), + (UPat(Ops.VECTORIZE, dtype=dtypes.index, name="u"), lambda u: u.replace( + dtype=(dt:=least_upper_dtype(*[x.dtype for x in u.src])).vec(u.dtype.count), src=tuple(x.cast(dt) for x in u.src))), + (UPat(Ops.VECTORIZE, dtype=dtypes.index, name="u"), lambda u: u.replace(dtype=(dt:=(dtypes.long if any(v.overflows(dtypes.int) for v in u.src) + else dtypes.long)).vec(u.dtype.count),src=tuple(x.cast(dt) for x in u.src))), + (UPat((Ops.SPECIAL,Ops.DEFINE_VAR), dtypes.index, name="u"), lambda u: u.replace(dtype=dtypes.int)), + (UPat((Ops.BIND), dtypes.index, name="u"), lambda u: u.replace(dtype=u.src[0].dtype)), +]) +def index_to_concrete_int(u:UOp): return graph_rewrite(u, pm_lower_index_dtype) _substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))]) @@ -1010,12 +1088,12 @@ _substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get syms = { Ops.ADD: "+", Ops.SUB: "-", Ops.IDIV: "//", Ops.MOD: "%", Ops.SHL: "<<", Ops.SHR: ">>", Ops.MUL: "*", Ops.CMPLT: "<", Ops.CMPNE: "!=", Ops.AND: "&", Ops.OR: "|", Ops.XOR: "^"} renderer = PatternMatcher([ - (UPat((Ops.DEFINE_VAR, Ops.SPECIAL), name="x"), lambda x: UOp(Ops.NOOP, arg=x.arg[0])), + (UPat((Ops.DEFINE_VAR,), name="x"), lambda x: UOp(Ops.NOOP, arg=x.arg[0])), + (UPat((Ops.SPECIAL), name="x"), lambda x: UOp(Ops.NOOP, arg=x.arg)), (UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg[0]}" if x.arg[0] >= 0 else f"ridxm{-x.arg[0]}")), (UPat((Ops.CONST, Ops.VCONST), name="x"), lambda x: UOp(Ops.NOOP, arg=str(x.arg))), (UPat(Ops.UNROLL, name="x"), lambda x: UOp(Ops.NOOP, arg=f"UNROLL({x.src[0].arg}, {x.arg})")), (UPat(Ops.CAST, name="x"), lambda x: UOp(Ops.NOOP, arg=f"({str(x.dtype)[7:]})({x.src[0].arg})")), - (UPat(Ops.LOAD), lambda: UOp(Ops.NOOP, arg="load")), (UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]), #(UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}[={x.src[1].arg}]")), (UPat(Ops.NEG, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"(-{x.src[0].arg})")), @@ -1023,7 +1101,8 @@ renderer = PatternMatcher([ (UPat(Ops.MAX, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"max({x.src[0].arg}, {x.src[1].arg})")), (UPat(Ops.MULACC, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}*{x.src[1].arg}+{x.src[2].arg})")), (UPat(Ops.WHERE, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[1].arg} if {x.src[0].arg} else {x.src[2].arg})")), - (UPat(GroupOp.ALU, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}{syms[x.op]}{x.src[1].arg})")), + (UPat(set(syms.keys()), src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}{syms[x.op]}{x.src[1].arg})")), + (UPat(Ops.VIEW, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.view({x.arg})")), ]) renderer_infer = PatternMatcher([ (UPat(Ops.MOD, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"cmod({x.src[0].arg}, {x.src[1].arg})")), @@ -1031,9 +1110,45 @@ renderer_infer = PatternMatcher([ *renderer.patterns ]) +sugar = { Ops.SINK: "sink", Ops.STORE: "store", Ops.LOAD: "load", Ops.SQRT: "sqrt", Ops.INDEX: "index", Ops.REDUCE: "reduce", + Ops.WHERE: "where", Ops.RECIP: "reciprocal", Ops.EXP2: "exp2", Ops.LOG2: "log2", Ops.SIN: "sin"} +pm_pyrender = PatternMatcher([ + (UPat(Ops.CONST, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=f"UOp.const({x.dtype}, {x.arg}, src={x.src[0].arg})")), + (UPat(Ops.CONST, name="x"), lambda x: UOp(Ops.NOOP, arg=f"UOp.const({x.dtype}, {x.arg})")), + (UPat(Ops.CAST, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.cast({x.dtype})")), + (UPat(Ops.BITCAST, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.bitcast({x.dtype})")), + (UPat({Ops.MAX, Ops.THREEFRY, Ops.CMPLT, Ops.CMPNE, Ops.POW}, src=UPat(Ops.NOOP), name="x"), + lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.alu({x.op}, {x.src[1].arg})")), + (UPat(Ops.RANGE, src=(UPat(Ops.NOOP),), name="x"), lambda x: + UOp(Ops.NOOP, arg=f"UOp.range({x.src[0].arg}, {str(x.arg[0])}, {str(x.arg[1])})")), + (UPat(set(sugar.keys()), src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, + arg=f"{x.src[0].arg}.{sugar[x.op]}({', '.join([y.arg for y in x.src[1:]] + ([f'arg={str(x.arg)}'] if x.arg is not None else []))})")), + (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.NOOP),), name="x"), + lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.f({x.op}, arg=({', '.join([str(y) for y in x.arg])}))")), + (UPat(Ops.VALID, src=(UPat(Ops.NOOP),), name="x"), + lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.f({x.op}, dtype=dtypes.bool)")), +]) + +def pyrender(ast:UOp) -> list[str]: + cmap = ast.get_children_map() + to_render = set() + for u in ast.toposort(): + if u.op is Ops.STORE: to_render.add(u.src[1]) + if len(cmap[u]) == 1 and u.op not in {Ops.DEFINE_GLOBAL, Ops.VIEW, Ops.LOAD} or u.op in {Ops.CONST}: continue + if u.op in {Ops.SINK, Ops.VIEW}: + for s in u.src: to_render.add(s) + to_render.add(u) + ret: list[str] = [] + rep: dict[UOp, UOp] = {} + for u in ast.toposort(): + if u not in to_render: continue + ret.append(f"c{len(ret)} = {u.substitute(rep).render(simplify=False, pm=pm_pyrender+renderer)}") + rep[u] = UOp(Ops.NOOP, arg=f"c{len(ret)-1}") + return ret[0:-1] + ["ast ="+ret[-1].split("=", 1)[1]] + # *** what was symbolic.py *** sint = int|UOp Variable = UOp -ConstLike = ConstType|Variable|tuple[ConstType, ...] +ConstLike = ConstType|InvalidType|Variable|tuple[ConstType|InvalidType, ...] diff --git a/tinygrad_repo/tinygrad/uop/spec.py b/tinygrad_repo/tinygrad/uop/spec.py index f578ab7a..48737fc9 100644 --- a/tinygrad_repo/tinygrad/uop/spec.py +++ b/tinygrad_repo/tinygrad/uop/spec.py @@ -1,15 +1,21 @@ from typing import cast, Callable -from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, python_alu, graph_rewrite -from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType, AddrSpace -from tinygrad.helpers import all_same, prod, DEBUG, ContextVar, Context +from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, python_alu, graph_rewrite, AxisType +from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType, AddrSpace, Invalid +from tinygrad.helpers import all_same, prod, DEBUG, ContextVar, Context, cpu_profile from tinygrad.shape.shapetracker import ShapeTracker try: import z3 + # older versions of z3 dont have some operators like & overloaded + if z3.get_version() < (4, 12, 4, 0): raise ImportError # IDIV is truncated division but z3 does euclidian division (floor if b>0 ceil otherwise); mod by power of two sometimes uses Ops.AND def z3_cdiv(a, b):return z3.If((a<0), z3.If(0= 0, z3.ToInt(a), -z3.ToInt(-a)))} def create_bounded(name:str, vmin, vmax, solver:z3.Solver) -> z3.ArithRef: s = z3.Int(name, ctx=solver.ctx) @@ -17,33 +23,35 @@ try: return s # ctx is (solver, load_number_dict) + # each uop gets rewritten to NOOP(arg=(solver, z3_object)), the arg has the solver first due to UOpMetaClass caching. z3 objects from different + # contexts can have the same hash but error on comparison z3_renderer = PatternMatcher([ - # Ops.SPECIAL can have symbolic arg but it wont be in the toposort beacuse its not a src, we need to add it manually - (UPat(Ops.SPECIAL, src=(), name="x"), lambda x: UOp(Ops.SPECIAL, arg=x.arg[0], src=(x.ufix(x.arg[1]),))), - (UPat(Ops.SPECIAL, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(x.arg, 0, x.src[0].arg-1, ctx[0]))), - (UPat(Ops.DEFINE_VAR, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(x.arg[0], x.arg[1], x.arg[2], ctx[0]))), - (UPat(Ops.RANGE, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(f"ridx{x.arg}", 0, x.src[0].arg-1, ctx[0]))), + (UPat(Ops.SPECIAL, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(x.arg, 0, x.src[0].arg[1]-1, ctx[0])))), + (UPat(Ops.DEFINE_VAR, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(x.arg[0], x.arg[1], x.arg[2], ctx[0])))), + (UPat(Ops.RANGE, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(f"ridx{x.arg}", 0, x.src[0].arg[1]-1, ctx[0])))), # float loads only become a variable when they get cast to int/bool (UPat(Ops.LOAD, dtypes.ints, name="x"), - lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(f"load{ctx[1].setdefault(x, len(ctx[1]))}", x.dtype.min, x.dtype.max, ctx[0]))), - (UPat(Ops.CONST, dtype=dtypes.ints+(dtypes.bool,), name="x"), - lambda x,ctx: UOp(Ops.NOOP, arg=(z3.BoolVal if dtypes.is_bool(x.dtype) else z3.IntVal)(x.arg, ctx=ctx[0].ctx))), + lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(f"load{ctx[1].setdefault(x, len(ctx[1]))}", x.dtype.min, x.dtype.max, ctx[0])))), + (UPat(Ops.CONST, dtype=dtypes.ints+(dtypes.bool,dtypes.index), name="x"), + lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],(z3.BoolVal if dtypes.is_bool(x.dtype) else z3.IntVal)(x.arg, ctx=ctx[0].ctx)))), # z3 can cast from bool to int automatically - (UPat(Ops.CAST, dtype=dtypes.ints, src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]), - (UPat(Ops.CAST, dtype=dtypes.bool, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=(x.src[0].arg!=0))), + (UPat(Ops.CAST, dtype=dtypes.ints+(dtypes.index,), src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]), + (UPat(Ops.CAST, dtype=dtypes.bool, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0], x.src[0].arg[1]!=0))), # if the source of the cast is not a noop it means that it is a float and so we create a new variable - (UPat(Ops.CAST, dtype=dtypes.ints, name="x"), lambda x,ctx: - UOp(Ops.NOOP, arg=create_bounded(f"cast{ctx[1].setdefault(x, len(ctx[1]))}", x.dtype.min, x.dtype.max, ctx[0]))), + (UPat(Ops.CAST, dtype=dtypes.ints+(dtypes.index,), name="x"), lambda x,ctx: + UOp(Ops.NOOP, arg=(ctx[0], create_bounded(f"cast{ctx[1].setdefault(x, len(ctx[1]))}", x.dtype.min, x.dtype.max, ctx[0])))), (UPat(Ops.CAST, dtype=dtypes.bool, name="x"), lambda x,ctx: - UOp(Ops.NOOP, arg=z3.Bool(f"cast{ctx[1].setdefault(x, len(ctx[1]))}",ctx=ctx[0].ctx))), - (UPat(Ops.XOR, src=UPat(Ops.NOOP), name="x"), - lambda x: UOp(Ops.NOOP, arg=z3.BV2Int(z3_alu[x.op](*(z3.Int2BV(s.arg, x.dtype.itemsize*8) for s in x.src))))), - (UPat(GroupOp.ALU, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=z3_alu[x.op](*(s.arg for s in x.src)))), + UOp(Ops.NOOP, arg=(ctx[0], z3.Bool(f"cast{ctx[1].setdefault(x, len(ctx[1]))}",ctx=ctx[0].ctx)))), + (UPat(GroupOp.ALU, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0], z3_alu[x.op](*(s.arg[1] for s in x.src))))), # A comparison between floats introduces a new bool variable (UPat(GroupOp.Comparison, src=UPat(dtype=dtypes.floats), name="x"), lambda x,ctx: - UOp(Ops.NOOP, arg=z3.Bool(f"float_cmp{ctx[1].setdefault(x, len(ctx[1]))}",ctx=ctx[0].ctx))), + UOp(Ops.NOOP, arg=(ctx[0], z3.Bool(f"float_cmp{ctx[1].setdefault(x, len(ctx[1]))}",ctx=ctx[0].ctx)))), ]) + def uops_to_z3(solver, *uops: UOp) -> 'list[z3.ExprRef]': + with Context(TRACK_MATCH_STATS=0): # cant pickle z3 objects + return [s.arg[1] for s in graph_rewrite(uops[0].sink(*uops[1:]), z3_renderer, ctx=(solver, {})).src] + z3_imported = True except (ImportError, AttributeError): z3_imported = False @@ -89,12 +97,15 @@ tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([ (UPat(Ops.VIEW, src=(UPat.var("x"),)), lambda x: x.base.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.ASSIGN, Ops.CONST, Ops.DEVICE}), # Tensor variable bindings - (UPat(Ops.BIND, dtypes.int, (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=dtypes.int)), arg=None), lambda: True), + (UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=(dtypes.int,dtypes.index,))), arg=None), lambda: True), # Tensor const has a device and an unmasked ShapeTracker of stride 0 # NOTE: variables in shape can cause multiple views in this ShapeTracker and other issues, see TestSymbolicJit.test_ones_sum - (UPat(Ops.CONST, src=(UPat(Ops.VIEW, name="st", src=(UPat(Ops.DEVICE),)),)), + # TODO: remove after rangeify is default + (UPat(Ops.CONST, src=(UPat.any(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="st"), + UPat(Ops.VIEW, src=(UPat(Ops.DEVICE), UPat(Ops.BIND)), name="st")),)), lambda st: len(st.st.views) == 1 and all(v.mask is None for v in st.st.views)), + (UPat(Ops.CONST, src=(UPat(Ops.DEVICE),)), lambda: True), # DETACH and CONTIGUOUS change how we interpret the source UOp # CONTIGUOUS ensures the source UOp realizes @@ -114,7 +125,8 @@ tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([ # ***** uop type spec ***** def validate_index(idx:UOp, gate:UOp=UOp.const(dtypes.bool, True)): - if IGNORE_OOB or isinstance(idx.dtype, ImageDType) or (sz := cast(PtrDType, idx.src[0].dtype).size) == -1: return True + # TODO: check for overflow + if IGNORE_OOB or isinstance(idx.dtype, ImageDType) or (sz := idx.src[0].ptrdtype.size) == -1: return True # We can use UOp min/max to do a faster check, but it can give false positive since its not an exact bound and doesn't consider the mask if 0<=idx.src[1].vmin and idx.src[1].vmax= 4.12.4 is required for bounds checking, try IGNORE_OOB=0 or \"pip install 'z3-solver>=4.12.4\"") solver = z3.Solver(ctx=z3.Context()) - z3_sink = graph_rewrite(idx.src[1].sink(mask), z3_renderer, ctx=(solver, {})) - z3_idx = z3_sink.src[0].arg - solver.add(z3_sink.src[1].arg) - if solver.check((z3_idx<0)|(sz<=z3_idx)) == z3.sat: - print(f"idx={idx.src[1].render(simplify=False)}") - print(f"mask & gate={mask.render(simplify=False)}") - print(f"# OUT OF BOUNDS ACCESS: at {solver.model()} INDEX not in 0 - {sz}\nconstraints = {solver}") - return False + z3_idx, z3_mask = uops_to_z3(solver, idx.src[1], mask) + solver.add(z3_mask) + with cpu_profile("validate index with z3", "TINY"): + if solver.check((z3_idx<0)|(sz<=z3_idx)) == z3.sat: + print(f"idx={idx.src[1].render(simplify=False)}") + print(f"mask & gate={mask.render(simplify=False)}") + print(f"# OUT OF BOUNDS ACCESS: at {solver.model()} INDEX not in 0 - {sz}\nconstraints = {solver}") + return False return True def validate_store(idx:UOp, val:UOp, gate:UOp=UOp.const(dtypes.bool, True)): @@ -151,15 +163,16 @@ spec = PatternMatcher([ (UPat(Ops.DEFINE_REG, src=()), lambda: True), (UPat(Ops.DEFINE_VAR, name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)), - (UPat(Ops.RANGE, src=(UPat.var("x"),), name="rng"), lambda rng,x: rng.dtype == x.dtype and isinstance(rng.arg, tuple)), - (UPat(Ops.SPECIAL, src=()), lambda: True), + (UPat(Ops.RANGE, src=(UPat.var("x"),), name="rng"), lambda rng,x: rng.dtype == x.dtype and isinstance(rng.arg, tuple) and len(rng.arg) == 2 and \ + isinstance(rng.arg[0], int) and isinstance(rng.arg[1], AxisType)), + (UPat(Ops.SPECIAL, src=(UPat.var("x"),), name="s"), lambda s,x: s.dtype == x.dtype == dtypes.int32 and isinstance(s.arg, str)), (UPat(Ops.VIEW, dtypes.void, src=(), name="x"), lambda x: isinstance(x.arg, ShapeTracker)), (UPat(Ops.VIEW, src=(UPat.var("src"),), name="x"), lambda x,src: isinstance(x.arg, ShapeTracker) and src.op is not Ops.STORE and x.dtype.base == src.dtype.base), (UPat(Ops.VALID, dtypes.bool, (UPat(Ops.VIEW),)), lambda: True), - (UPat(Ops.CONST, name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))), + (UPat(Ops.CONST, src=(), name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))), # early LOAD has a (UPat(Ops.LOAD, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.Defines),)),)), lambda: True), @@ -170,6 +183,11 @@ spec = PatternMatcher([ # **** new style load/store **** + # make sure all index dtypes have been lowered + (UPat(GroupOp.All, dtype=dtypes.index), lambda: False), + (UPat(Ops.CONST, arg=Invalid), lambda: False), + (UPat(Ops.VCONST, name="x"), lambda x: all(v is not Invalid for v in x.src)), + # INDEX is used in new style load/store # INDEX takes a (UPat(Ops.INDEX, src=(UPat(GroupOp.Defines), UPat())), lambda: True), diff --git a/tinygrad_repo/tinygrad/uop/symbolic.py b/tinygrad_repo/tinygrad/uop/symbolic.py index f20f6f6e..bd35de5b 100644 --- a/tinygrad_repo/tinygrad/uop/symbolic.py +++ b/tinygrad_repo/tinygrad/uop/symbolic.py @@ -1,10 +1,10 @@ # all of symbolic lives here now -from typing import Any, cast +from typing import cast import math, operator, struct, functools from collections import defaultdict from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu -from tinygrad.dtype import ConstType, dtypes, PtrDType, AddrSpace, can_safe_cast -from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, cdiv, cmod, CORRECT_DIVMOD_FOLDING +from tinygrad.dtype import ConstType, dtypes, PtrDType, AddrSpace, can_safe_cast, Invalid +from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, cdiv, cmod, CORRECT_DIVMOD_FOLDING, unwrap from tinygrad.uop.decompositions import xpow # ******** phase 1 of symbolic used to live in ops, it's the most generic folding rules ******** @@ -19,34 +19,62 @@ def simplify_pow(x:UOp, c:UOp) -> UOp|None: def fold_bitcast(root:UOp, c:UOp) -> UOp|None: if (from_fmt:=c.dtype.scalar().fmt) is None or (to_fmt:=root.dtype.scalar().fmt) is None: return None if c.dtype.itemsize != root.dtype.itemsize: return None - def convert(v:Any): return struct.unpack(to_fmt, struct.pack(from_fmt, v))[0] + def convert(v:ConstType): return struct.unpack(to_fmt, struct.pack(from_fmt, v))[0] return root.const_like(convert(c.arg) if root.dtype.count == 1 else tuple(map(convert, c.arg))) -symbolic_simple = PatternMatcher([ +invalid_pat = UPat.const(dtypes.index, Invalid).named("i") +invalid_gate = UPat.var("cond").where(UPat.var("x",dtype=dtypes.index), invalid_pat) + +propagate_invalid = PatternMatcher([ + # this needs to be before symbolic so that 0*something_that_might_be_invalid doesnt become 0 + # propagate invalid, push it past children + *((invalid_gate.alu(op, UPat.var("y")).named("alu"), lambda cond,x,y,alu,i: cond.where(x.alu(alu.op,y), i)) + for op in GroupOp.Binary-GroupOp.Comparison), + *((invalid_gate.alu(op, UPat.var("y")).named("alu"), lambda cond,x,y,alu,i: x.alu(alu.op,y)) for op in GroupOp.Comparison), + # invalid + y -> y same for other ops + *((invalid_pat.alu(op, UPat(dtype=dtypes.index)).named("alu"), lambda alu,i: i) for op in GroupOp.Binary-GroupOp.Comparison), + # i < y -> a_bool_value_that_will_never_be_used: we choose a random bool const + *((invalid_pat.alu(op, UPat(dtype=dtypes.index)), lambda i: UOp.const(dtypes.bool, True)) for op in GroupOp.Comparison), + # a.where(b.where(c, d), d) -> (a & b).where(c, d) + (UPat.var("a").where(UPat.var("b").where(UPat.var("c"), UPat.var("d")), UPat.var("d")), lambda a,b,c,d: (a&b).where(c,d)), + # order of gate&!cond matters!, and-clauses are only simplified left to right and we need to gate to be used to fold cond + (UPat.var("gate").where(invalid_gate, UPat.var("y")), lambda gate,cond,x,y,i: ((gate&cond.logical_not()).logical_not()).where(gate.where(x,y), i)), + # unswap the branches for the rule above + (UPat.var("gate").where(UPat.var("y"), invalid_gate).named("where"), lambda gate,cond,x,y,i: gate.logical_not().where(cond.where(x,i), y)) +]) + +symbolic_simple = propagate_invalid + PatternMatcher([ # ** self folding ** (UPat.var("x") + 0, lambda x: x), # x+0 -> x (UPat.var("x") * 1, lambda x: x), # x*1 -> x - (UPat.var("x", dtype=dtypes.ints) ^ 0, lambda x: x), # x^0 -> x + (UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.index)) ^ 0, lambda x: x), # x^0 -> x (UPat.var("x") // UPat.var("x"), lambda x: x.const_like(1)), # x//x -> 1 (UPat.var("x") // 1, lambda x: x), # x//1 -> x (UPat.var("x") // -1, lambda x: -x), # x//-1 -> -x (UPat.var("x") / UPat.var("x"), lambda x: x.const_like(1)), # x/x -> 1 ((UPat.var("x") * UPat.var("x2")) / UPat.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x ((UPat.var() % UPat.var("y")).named("base") % UPat.var("y"), lambda base,y: base), # (x%y)%y = -> x%y (rewritten with base for speed) + # 4 variations of (x%c)+(x//c)*c = x TODO: add sorting to remove some variations (UPat.var("x")%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"), lambda x,c: x), # (x%c)+(x//c)*c = x ((UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3")+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"), lambda x,c1,c2,c3: x*c2 if c1.arg*c2.arg==c3.arg else None), # (x%c1)*c2+(x//c1)*c3 = x*c2 if c1*c2==c3 + ((UPat.var("y")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"))+UPat.var("x")%UPat.cvar("c"), lambda y,x,c: y+x), + ((UPat.var("y")+UPat.var("x")%UPat.cvar("c"))+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"), lambda y,x,c: y+x), + ((UPat.var("y")+(UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3"))+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"), + lambda y,x,c1,c2,c3: y+x*c2 if c1.arg*c2.arg==c3.arg else None), + ((UPat.var("y")+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"))+(UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3"), + lambda y,x,c1,c2,c3: y+x*c2 if c1.arg*c2.arg==c3.arg else None), (UPat.var("x", dtype=dtypes.bool) & UPat.cvar("c", vec=False), lambda x,c: x if c.arg else c), (UPat.var("x", dtype=dtypes.bool) | UPat.cvar("c", vec=False), lambda x,c: c if c.arg else x), (UPat(GroupOp.Idempotent, src=(UPat.var("x"), UPat.var("x"))), lambda x: x), (UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x), (UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, True), UPat.const(dtypes.bool, False)), lambda x: x), (UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, False), UPat.const(dtypes.bool, True)), lambda x: x.logical_not()), - (UPat.var("x", dtype=dtypes.ints+(dtypes.bool,)).trunc(), lambda x: x), + (UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.index)).trunc(), lambda x: x), # ** zero folding ** (UPat.var("x") < UPat.var("x"), lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x < x -> False (UPat.var("x") % UPat.var("x"), lambda x: x.const_like(0)), # x%x -> 0 - (UPat.var("x", dtype=dtypes.ints) != UPat.var("x", dtype=dtypes.ints), + (UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.index)) != UPat.var("x"), lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints) # x*0 -> 0 or 0*x -> 0 # if x is nan or inf it should render the nan value. @@ -68,7 +96,7 @@ symbolic_simple = PatternMatcher([ (UPat((Ops.CAST, Ops.BITCAST), name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None), (UPat(Ops.BITCAST, name="root", src=(UPat.cvar("c"),)), fold_bitcast), # b.cast(a).cast(b) -> b if a preserves all values in b - (UPat.var('x').cast().named('a').cast().named('b'), lambda x,a,b: x if x.dtype == b.dtype and can_safe_cast(b.dtype, a.dtype) else None), + (UPat.var('x').cast(name="a").cast(name="b"), lambda x,a,b: x if x.dtype == b.dtype and can_safe_cast(b.dtype, a.dtype) else None), # ** pow ** (UPat.var("x").alu(Ops.POW, UPat.cvar("c", vec=False)), simplify_pow), # positive const ** x @@ -90,39 +118,8 @@ symbolic_simple = PatternMatcher([ # ******** phase 2 builds on phase 1, it includes the old "symbolic", rules that match deeper ******** -def split_uop(x:UOp, sep:Ops): - if x.op is sep: - for s in x.src: yield from split_uop(s, sep) - else: yield x - -def fold_unrolled_divs(divs:UOp, denominator: int, fac=1) -> UOp|None: - # div pattern in unrolled arange - # example: (x//4+(x+1)//4+(x+2)//4+(x+3)//4 -> x - seen_const, ans = [], None - for u in split_uop(divs, Ops.ADD): - if fac!=1: - if u.op is not Ops.MUL or u.src[1].op is not Ops.CONST or u.src[1].arg != fac: return None - u = u.src[0] - if not (u.op is Ops.IDIV and u.src[1].op is Ops.CONST): return None - if denominator != u.src[1].arg: return None - if (s0:=u.src[0]).vmin < 0: return None - # assumed CONST is the last of an ADD - if s0.op is Ops.ADD and s0.src[1].op is Ops.CONST and s0.src[1].op is Ops.CONST: - seen_const.append(s0.src[1].arg) - s0 = s0.src[0] - else: seen_const.append(0) - if ans is None: ans = s0 - if ans is not s0: return None - if ans is None: return None - # the first (denominator-len(seen_const)) terms may have been folded to 0 already - for i in range(denominator-len(seen_const)): - if ans is not None and 0 <= ans.vmin and ans.vmax + i < denominator: seen_const.append(i) - if sorted(seen_const)==list(range(denominator)): - return fac*ans - return None - def lt_folding(x:UOp, c:int) -> UOp|None: - p, np = partition(split_uop(x, Ops.ADD), lambda u: u.const_factor() == 1) + p, np = partition(x.split_uop(Ops.ADD), lambda u: u.const_factor() == 1) if np and (d:=math.gcd(*[u.const_factor() for u in np], c)) > 1 and 0 <= sum(u.vmin for u in p) and sum(u.vmax for u in p) < d: return cast(UOp, functools.reduce(operator.add, np).divides(d))<(c//d) return None @@ -131,7 +128,7 @@ def canonicalize_simplex(X:UOp) -> UOp|None: # (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints. # returns x0 + x1 + ... in such case, or None if not changed, ret = False, [] - for u in split_uop(X, Ops.ADD): + for u in X.split_uop(Ops.ADD): # assumed the const is the last src of MUL if u.op is Ops.MUL and u.src[1].op is Ops.CONST and u.src[1].arg > 0: changed = True @@ -155,7 +152,7 @@ def remove_nested_mod(m: UOp, x: UOp, y: UOp) -> UOp|None: if ((c := y.arg) < 0) or x.vmin<0: return None new_xs = [] something_changed = False - for u in split_uop(x, Ops.ADD): + for u in x.split_uop(Ops.ADD): if u.op is Ops.MOD: if u.src[1].divides(c) is not None: something_changed = True @@ -167,9 +164,9 @@ def remove_nested_mod(m: UOp, x: UOp, y: UOp) -> UOp|None: def fold_binary_numerator(d: UOp, x: UOp, y: UOp) -> UOp|None: # we can fold if the expression has only one non-constant term and this term can only take on two values - if ((c := y.arg) < 0) or (x.dtype.count > 1): return None + if ((c := y.arg) < 0): return None x,const = x.pop_const() - terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in split_uop(x, Ops.ADD)]) + terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in x.split_uop(Ops.ADD)]) if len(terms)==1 and (v:=terms[0]).vmax-v.vmin == 1: y1 = cmod(factors[0]*v.vmin+const, c) if d.op is Ops.MOD else cdiv(factors[0]*v.vmin+const, c) # type: ignore y2 = cmod(factors[0]*v.vmax+const, c) if d.op is Ops.MOD else cdiv(factors[0]*v.vmax+const, c) # type: ignore @@ -178,9 +175,9 @@ def fold_binary_numerator(d: UOp, x: UOp, y: UOp) -> UOp|None: def fold_divmod_congruence(d: UOp, x: UOp, y: UOp) -> UOp|None: # within a mod we can freely subtract multiples of c, we use this to see if a is congruent to an expression whose vmin/vmax are between 0 and c - if (x.vmin<0 and CORRECT_DIVMOD_FOLDING) or ((c := y.arg) < 0) or (x.dtype.count > 1): return None + if (x.vmin<0 and CORRECT_DIVMOD_FOLDING) or ((c := y.arg) < 0): return None x,const = x.pop_const() - terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in split_uop(x, Ops.ADD)]) + terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in x.split_uop(Ops.ADD)]) # a//c = (a-a%c)/c, if we can fold a%c, we can fold a//c rems = [min((r:=f%c), r-c, key=abs) for f in factors] if (rem:=sum(r*v for r,v in zip(rems,terms))+const%c).vmin//c!=rem.vmax//c: return None @@ -189,15 +186,29 @@ def fold_divmod_congruence(d: UOp, x: UOp, y: UOp) -> UOp|None: def divide_by_gcd(d: UOp, x: UOp, y: UOp) -> UOp|None: # x//y -> (x//gcd)//(y//gcd) or x%y -> gcd*(x//gcd)%(y//gcd) - terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in split_uop(x, Ops.ADD)]) - if (gcd := math.gcd(y.arg, *factors)) == 1: return None - ret = sum(f//gcd * v for f,v in zip(factors, terms)).alu(d.op, y.const_like(y.arg//gcd)) + gcd = UOp.gcd(*x.split_uop(Ops.ADD), y).simplify() + if gcd.op is Ops.CONST and gcd.arg==1: return None + ret = unwrap(x.divide_exact(gcd)).alu(d.op, unwrap(y.divide_exact(gcd))) return ret*gcd if d.op is Ops.MOD else ret +def gcd_with_remainder(d: UOp, x: UOp, y: UOp): + # (gcd*x+r)//(gcd*d) -> (x+(r%d)//gcd)//d + r//(gcd*d) + # (gcd*x+r)%(gcd*d) -> gcd*(x+(r%d)//gcd)%d + r%gcd + # These only work for floordiv (and the corresponding remainder)! Thats why we check the sign of x,y and new_x + if ((c := y.arg) < 0) or x.vmin<0: return None + x_no_const, const = x.pop_const() + gcd = UOp.gcd(*x_no_const.split_uop(Ops.ADD), y).simplify() + assert gcd.op is Ops.CONST + if gcd.arg==1: return None + new_x = unwrap(x_no_const.divide_exact(gcd)).simplify() + (const%c)//gcd + if new_x.vmin<0: return None + ret = new_x.alu(d.op, x.ufix(c//gcd.arg)) + return ret*gcd + const%gcd.arg if d.op is Ops.MOD else ret+const//c + def nest_div_by_smallest_factor(d: UOp, x: UOp, y: UOp) -> UOp|None: # we try and nest the div and see if it allows the numerator to be simplified - if ((c := y.arg) < 0) or (x.dtype.count > 1): return None - factors = [u.const_factor() for u in split_uop(x.pop_const()[0], Ops.ADD)] + if ((c := y.arg) < 0): return None + factors = [u.const_factor() for u in x.pop_const()[0].split_uop(Ops.ADD)] # div is the smallest factor of the denominator (greater than 1) out of all "factors" # TODO: there are better ways to pick `div`, this sometimes adds extra divisions # TODO: add same optimization for mod @@ -205,27 +216,22 @@ def nest_div_by_smallest_factor(d: UOp, x: UOp, y: UOp) -> UOp|None: if (1 < div < c) and (newxs:=(newx:=(x//div)).simplify()) is not newx and x.vmin>=0 and newx.vmin>=0: return newxs//(c//div) return None -def simplify_remainder(d: UOp, x: UOp, y: UOp) -> UOp|None: - # we try and take out the quotient and see if it allows the numerator to be simplified - if ((c := y.arg) < 0) or (x.dtype.count > 1): return None - x_no_const,const = x.pop_const() - terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in split_uop(x_no_const, Ops.ADD)]) - quotients, remainders = zip(*[divmod(f, c) for f in factors]) - gcd = math.gcd(c, *remainders) # gcd without const! - if const%c==const and gcd==1 and not any(r==0 or (r!=f and d.op is Ops.MOD) for r,f in zip(remainders, factors)): return None - - quo, rem = x.const_like(const//c), x.const_like((const%c)//gcd) - for q,r,f,v in zip(quotients, remainders, factors, terms): - if d.op is Ops.IDIV and r!=0: - rem += f//gcd * v - else: - rem += r//gcd * v - quo += q * v - - # if numerator before/after is negative, and it has remainder, don't simplify because C divmod is different from python divmod. - if (x.vmin < 0 or rem.vmin < 0) and remainders: return None - if d.op is Ops.MOD: return gcd*(rem % (c//gcd)) + const%gcd - return rem//(c//gcd)+quo +def factor_remainder(d: UOp, x: UOp, y: UOp) -> UOp|None: + # (d*x+y)//d -> x+y//d or (d*x+y)%d + # for mod we go further and take the remainder of all factors to reduce their size + # These only work for floordiv (and the corresponding remainder)! Thats why we check the sign of x,y and new_x + if y.vmin<0 or x.vmin<0: return None + quo, rem = [], [] + for u in x.split_uop(Ops.ADD): + if (q:=u.divide_exact(y)) is not None: quo.append(q) + # if this is mod and y is a const, we can make the remainder factor sm + elif d.op is Ops.MOD and y.op is Ops.CONST and (c:=u.const_factor())%y.arg!=c: + rem.append(u.divides(c)*(c%y.arg)) + quo.append(u.const_like(0)) # we append this so we can check if something changed + else: rem.append(u) + new_x = sum(rem)+x.const_like(0) + if len(quo)==0 or new_x.vmin<0: return None + return new_x%y if d.op is Ops.MOD else new_x//y+sum(quo) def gep_through_wmma(gep:UOp, wmma:UOp): out_sz = prod(x[1] for x in wmma.arg[6][-1]) @@ -266,14 +272,16 @@ gep_pushing = PatternMatcher([ ]) commutative = PatternMatcher([ - # ** COMMUTATIVE flipping (only for ints) ** + # ** COMMUTATIVE flipping (only for index) ** # NOTE: this can break merging vector math by only flipping some of them - (UPat(GroupOp.Commutative, dtype=dtypes.int, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None), + (UPat(GroupOp.Commutative, dtype=dtypes.index, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None), ]) symbolic = symbolic_simple+commutative+PatternMatcher([ # ** boolean algebra ** (UPat.var("x") | (UPat.var("x") & UPat.var()), lambda x: x), # x|(x&y) -> x + # TODO: make a more general or folder like simplify_valid + (UPat.var("x", dtype=dtypes.bool) | UPat.var("x").logical_not(), lambda x: x.const_like(True)), # x|!x -> True # ** combine terms ** (UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1) ((UPat.var("y") + UPat.var("x") * UPat.cvar("c0")) + UPat.var("x") * UPat.cvar("c1"), lambda x,y,c0,c1: y+x*(c0+c1)), @@ -287,10 +295,14 @@ symbolic = symbolic_simple+commutative+PatternMatcher([ # a conditional with the same results either way is a noop, also fold const conditionals (UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val), (UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1), - (UPat.var("cond", dtype=dtypes.bool).logical_not().where(UPat.var("t"), UPat.var("f")), lambda cond, t, f: cond.where(f,t)), + (UPat.var("cond", dtype=dtypes.bool).logical_not().where(UPat.var("t"), UPat.var("f")), lambda cond, t, f: cond.where(f,t) + if f.arg is not Invalid else None), # alu of two where with same conds can combine, only do if true branch or false branch is const (UPat(GroupOp.Binary, name="alu", src=(UPat.var("c").where(UPat.var("t"), UPat.var("f")), UPat.var("c").where(UPat.var("tt"), UPat.var("ff")))), \ lambda alu,c,t,tt,f,ff: c.where(t.alu(alu.op, tt), f.alu(alu.op, ff)) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None), + # if its a plus we add the associative variation too + ((UPat.var("y")+UPat.var("c").where(UPat.var("t"), UPat.var("f"))) + UPat.var("c").where(UPat.var("tt"), UPat.var("ff")), \ + lambda y,c,t,tt,f,ff: y+c.where(t+tt, f+ff) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None), # ALU/variable min==max -> CONST (slow!) (UPat(GroupOp.ALU|{Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None), # max folding @@ -305,53 +317,65 @@ symbolic = symbolic_simple+commutative+PatternMatcher([ ((UPat.var("x") // UPat.cvar("c1")) // UPat.cvar("c2"), lambda x,c1,c2: x//(c1*c2)), # (x//c1)//c2 -> x//(c1*c2) # ** lt ** # c0*x 0 and c1.arg > 0 else None), # c0*x 0 else x<(c.arg*d.arg-(d.arg-1))) if d.arg > 0 else None), # ** move add/mul consts to end (NOTE: this is still happening before constant folding) ** ((UPat.var("x") + UPat.cvar("c1")) + UPat.var("y"), lambda x,c1,y: (x+y)+c1), ((UPat.var("x") * UPat.cvar("c1")) * UPat.var("y"), lambda x,c1,y: (x*y)*c1), # *** rules from symbolic *** - # unrolled arange div folding - ((UPat() + UPat()//UPat.cvar("d", vec=False)).named("divs"), lambda divs,d: fold_unrolled_divs(divs, d.arg)), - ((UPat() + (UPat()//UPat.cvar("d", vec=False))*UPat.cvar("c")).named("divs"), lambda divs,d,c: fold_unrolled_divs(divs, d.arg, c.arg)), # generic lt folding - (UPat.var("x", dtypes.sints) 0 # not x < 1 -> X > 0 - ((UPat.var("x", dtypes.ints)<1).ne(True), lambda x: (newx<1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None), + ((UPat.var("x", dtypes.index)<1).ne(True), lambda x: (newx<1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None), # ** div ** # div folding ((UPat.var("x")//UPat.cvar("c") + UPat.cvar("a"))//UPat.cvar("d"), lambda x,c,a,d: (x+a*c)//(c*d) if c.vmin>0 and d.vmin>0 and ((x.vmin>=0 and a.vmin>=0) or (x.vmax<=0 and a.vmax<=0)) else None), # (x//c+a)//d -> (x+a*c)//(c*d) - (UPat((Ops.IDIV, Ops.MOD), dtypes.sints, name="d", src=(UPat.var("x"), UPat.var("y"))), cancel_divmod), - (UPat((Ops.IDIV, Ops.MOD), dtypes.sints, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), fold_binary_numerator), - (UPat((Ops.IDIV, Ops.MOD), dtypes.sints, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), fold_divmod_congruence), - (UPat((Ops.IDIV, Ops.MOD), dtypes.sints, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), divide_by_gcd), - (UPat(Ops.MOD, dtypes.sints, name="m", src=(UPat.var("x"), UPat.cvar("y", vec=False))), remove_nested_mod), - (UPat((Ops.IDIV), dtypes.sints, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), nest_div_by_smallest_factor), - (UPat((Ops.IDIV, Ops.MOD), dtypes.sints, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), simplify_remainder), - (UPat.var("x") // UPat.var("d"), lambda x,d: -(x//(-d)) if d.vmax < 0 else None), - (UPat.var("x") // UPat.var("d"), lambda x,d: -((-x)//d) if x.vmax <=0 else None), - ((UPat.var("x", dtypes.sints)+UPat.cvar("c", vec=False)).named("n")//UPat.cvar("d", vec=False), + # a range mod its own upper bound is just the range + (UPat(Ops.RANGE, src=UPat.var("end"), name="r")%UPat.var("end"), lambda r,end: r), + (UPat(Ops.RANGE, src=UPat.var("end"), name="r")//UPat.var("end"), lambda r,end: r.const_like(0)), + (UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.var("y"))), cancel_divmod), + (UPat.var("x", dtypes.index) // UPat.var("d"), lambda x,d: -(x//(-d)) if d.vmax < 0 else None), + (UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), fold_binary_numerator), + (UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), fold_divmod_congruence), + (UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.var("y"))), divide_by_gcd), + (UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), gcd_with_remainder), + (UPat(Ops.MOD, dtypes.index, name="m", src=(UPat.var("x"), UPat.cvar("y", vec=False))), remove_nested_mod), + (UPat((Ops.IDIV), dtypes.index, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), nest_div_by_smallest_factor), + (UPat((Ops.IDIV, Ops.MOD), dtypes.index, name="d", src=(UPat.var("x"), UPat.var("y"))), factor_remainder), + (UPat.var("x", dtypes.index) // UPat.var("d"), lambda x,d: -((-x)//d) if x.vmax<=0 else None), + ((UPat.var("x", dtypes.index)+UPat.cvar("c", vec=False)).named("n")//UPat.cvar("d", vec=False), + lambda x,c,n,d: ((x+c.arg%d.arg)//d + c.arg//d.arg) if c.arg%d.arg!=c.arg and x.vmin>=0 and n.vmin>=0 and d.arg>0 else None), + ((UPat.var("x", dtypes.index)+UPat.cvar("c", vec=False)).named("n")//UPat.cvar("d", vec=False), lambda x,c,n,d: (-(-(c.arg%d.arg + x - (d.arg-1))//d) + c.arg//d.arg) if x.vmax<=0 and n.vmin>=0 and d.arg>0 else None), # ** mod ** # mod folding - (UPat.var("x") % UPat.var("d"), lambda x,d: -((-x)%d) if x.vmax <= 0 else None), - (UPat.var("x") % UPat.var("d"), lambda x,d: (x%(-d)) if d.vmax < 0 else None), + (UPat.var("x", dtypes.index) % UPat.var("d"), lambda x,d: -((-x)%d) if x.vmax <= 0 else None), + (UPat.var("x", dtypes.index) % UPat.var("d"), lambda x,d: (x%(-d)) if d.vmax < 0 else None), + # cast/long folding + # if the intermediate cast doesnt narrow we can do it in one cast + (UPat.var('x').cast(name="a").cast(name="b"), lambda x,a,b: x.cast(b.dtype) if can_safe_cast(x.dtype, a.dtype) else None), + (UPat.var('x', dtypes.ints+(dtypes.index,)).cast(dtypes.ints+(dtypes.index,), name="a").cast(name="b"), + lambda x,a,b: x.cast(b.dtype) if a.dtype.min<=x.vmin and x.vmax<=a.dtype.max else None), + # try to do math in int instead of long + (UPat(GroupOp.Binary, src=(UPat.var("x", dtypes.long), UPat.var("y", dtypes.long)), name="u"), lambda u,x,y: + x.cast(dtypes.int).alu(u.op, y.cast(dtypes.int)).cast(u.dtype) if not any(v.overflows(dtypes.int) for v in (u,x,y)) else None), + ((UPat.var("x", dtypes.index) + UPat.cvar("c")).cast(dtypes.sints, name="cast"), lambda x,c,cast:x.cast(cast.dtype)+c.cast(cast.dtype)), ])+gep_pushing symbolic_flat = symbolic+PatternMatcher([ # ** combine terms (opinionated) ** (-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y # (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue - ((UPat.var("x", dtypes.ints) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c), + ((UPat.var("x", dtypes.index) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c), ]) # ******** we take a small aside to "simplify_valid" to rewrite valids ******** @@ -362,9 +386,9 @@ def parse_valid(valid:UOp) -> tuple[UOp, bool, int]: # (X < c).ne(True) -> X >= c if valid.op is Ops.CMPNE and valid.src[1].op is Ops.CONST and valid.src[1].arg == 1 and \ - (s0:=valid.src[0]).op is Ops.CMPLT and s0.src[1].op is Ops.CONST: return s0.src[0], False, s0.src[1].arg + (s0:=valid.src[0]).op is Ops.CMPLT and dtypes.is_int(s0.src[0].dtype): return s0.src[0], False, int(s0.src[1].vmin) # X < c -> X <= c-1 - if valid.op is Ops.CMPLT and valid.src[1].op is Ops.CONST and dtypes.is_int(valid.src[0].dtype): return valid.src[0], True, valid.src[1].arg-1 + if valid.op is Ops.CMPLT and dtypes.is_int(valid.src[0].dtype): return valid.src[0], True, int((valid.src[1]).vmax)-1 raise ValueError(f"not able to parse {valid=}") def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None: @@ -372,9 +396,9 @@ def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None: # first, parse valid into {expr: (lower_bound, upper_bound)} bounds:defaultdict[UOp, list[ConstType|None]] = defaultdict(lambda: [None, None]) - for stmt in split_uop(valid, Ops.AND): + for stmt in valid.split_uop(Ops.AND): try: expr, is_upper, c = parse_valid(stmt) - except ValueError: return uop # give up if we cannot parse the valid + except ValueError: continue # give up if we cannot parse the valid bounds[expr][int(is_upper)] = c # don't simplify any other gates, can lead to OOB, we substitute them back later @@ -383,6 +407,7 @@ def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None: # simplify uop given that valid is True for expr,v in bounds.items(): v0, v1 = (expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1]) + expr = expr.substitute(load_subs) # make sure expr appears in same form in the uop # some expr has lower bound > upper bound -> valid is an empty set and we return None if v0 > v1: return None # whole node became a const @@ -391,9 +416,9 @@ def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None: continue # every candidate is a set of constrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop candidates = [] - if expr.op is Ops.ADD and v0 == 1 and all(u.op in GroupOp.Irreducible for u in split_uop(expr, Ops.ADD)): + if expr.op is Ops.ADD and v0 == 1 and all(u.op in GroupOp.Irreducible for u in expr.split_uop(Ops.ADD)): # if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output - candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, Ops.ADD)]) + candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in expr.split_uop(Ops.ADD)]) # try checking the whole clause if expr in uop.toposort(): candidates.append([(expr, UOp.variable("fake", v0, v1, expr.dtype))]) @@ -417,7 +442,7 @@ def _valid_priority(v: UOp, valids:list[UOp]): def simplify_valid(valid:UOp) -> UOp|None: ret:list[UOp] = [] something_changed = False - valids = list(split_uop(valid, Ops.AND)) + valids = list(valid.split_uop(Ops.AND)) for stmt in sorted(valids, key=lambda v: _valid_priority(v, valids)): # TODO: root cause this and test_simplify_valid_from_div if stmt.op is Ops.CAST: return None @@ -431,7 +456,7 @@ def reduce_mul_chain(r:UOp): if r.arg not in {Ops.ADD, Ops.MAX}: return None if r.dtype != r.src[0].dtype: return None inside, outside = [], [] - for m in split_uop(r.src[0], Ops.MUL): + for m in r.src[0].split_uop(Ops.MUL): m_parents = m.toposort() if all(r not in m_parents for r in r.src[1:]) and (r.arg != Ops.MAX or m.vmin >= 0): outside.append(m) else: inside.append(m) @@ -442,6 +467,10 @@ def reduce_mul_chain(r:UOp): REMOVE_FROM_SINK = {Ops.SINK, Ops.UNROLL, Ops.PTRCAT, Ops.CAT, Ops.NOOP} REMOVE_FROM_BARRIER = {Ops.VECTORIZE, Ops.SINK, Ops.CAT, Ops.PTRCAT, Ops.NOOP} sym = symbolic_flat+PatternMatcher([ + # simplify valid + (UPat(Ops.AND, name="valid"), simplify_valid), + (UPat.var("cond").where(UPat.var("x", dtype=dtypes.index), invalid_pat), lambda cond,x,i: cond.where(newx, i) if + (newx:=uop_given_valid(cond, x)) is not x else None), # LOAD/STORE -> NOOP (UPat.var('x').store(UPat.var('x').load(), allow_any_len=True), lambda x: None if x.dtype.addrspace != AddrSpace.REG else x.src[0].src[0]), (UPat(Ops.LOAD, src=(UPat.cvar('c'))), lambda c: c), @@ -465,21 +494,20 @@ sym = symbolic_flat+PatternMatcher([ # ** where ** # push cast to branches (UPat.var("s").where(UPat.var("a"), UPat.var("b")).cast().named("cast"), lambda s,a,b,cast: s.where(a.cast(cast.dtype), b.cast(cast.dtype))), - # a.where(b.where(c, d), d) -> (a & b).where(c, d) - (UPat.var("a").where(UPat.var("b").where(UPat.var("c"), UPat.var("d")), UPat.var("d")), lambda a,b,c,d: (a&b).where(c,d)), # ** pow ** ((UPat(Ops.POW, name="p"), lambda p: xpow(*p.src))), - # index true is index without op - (UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"), UPat.const(dtypes.bool, True))), lambda b, idx: b.index(idx)), # ** load/store folding ** (UPat.store(UPat(Ops.INDEX, name="index"), UPat.load(UPat(Ops.INDEX, name="index"))), lambda index: UOp(Ops.NOOP)), (UPat.store(UPat(Ops.INDEX, name="index"), UPat.var("gate").where(UPat.var("alt"), UPat.load(UPat(Ops.INDEX, name="index"))), allow_any_len=True, name="store"), - lambda index, gate, alt, store: UOp.store(index.src[0].index(index.src[1], gate), alt, *store.src[2:])), + lambda index, gate, alt, store: UOp.store(index.src[0].index(gate.where(index.src[1], UOp.invalid())), alt, *store.src[2:])), # fold gated LOAD/STORE - (UPat().index(UPat(), UPat.const(dtypes.bool, True)).named("idx"), lambda idx: idx.replace(src=idx.src[0:2])), # remove True - (UPat((Ops.LOAD, Ops.STORE), src=(UPat().index(UPat(), UPat.const(dtypes.bool, False)).or_casted(),), allow_any_len=True, name="x"), - lambda x: UOp(Ops.NOOP) if x.op is Ops.STORE else x.const_like(0)), # NULL pointer store does nothing. NULL pointer load produces 0 + (UPat((Ops.LOAD, Ops.STORE), src=(UPat().index(UPat.const(dtypes.index, Invalid)).or_casted(),), allow_any_len=True, name="x"), + lambda x: UOp(Ops.NOOP) if x.op is Ops.STORE else x.const_like(0)), # invalid store does nothing. invalid load produces 0 + (UPat.var("c").where(UPat(Ops.LOAD, src=(UPat().index(UPat.var("idx"), UPat.var("c")).or_casted(),), allow_any_len=True, name="l"), UPat.var("a")), + lambda c,idx,l,a: l.replace(src=(l.src[0], a)+l.src[1:])), + (UPat.var("c").where(UPat.var("a"), UPat(Ops.LOAD, src=(UPat().index(UPat.var("idx"), UPat.var("c").logical_not()).or_casted(),), + allow_any_len=True, name="l")), lambda c,idx,l,a: l.replace(src=(l.src[0], a)+l.src[1:])), # remove VECTORIZE from SINK/BARRIER. TODO: SINK/BARRIER are really the same thing at GLOBAL/LOCAL levels (UPat(Ops.BARRIER, name="root"), lambda root: UOp(Ops.BARRIER, root.dtype, tuple(flatten(x.src if x.op in REMOVE_FROM_BARRIER else (x,) for x in root.src)), root.arg) diff --git a/tinygrad_repo/tinygrad/viz/README b/tinygrad_repo/tinygrad/viz/README index ce46d461..bdd038e4 100644 --- a/tinygrad_repo/tinygrad/viz/README +++ b/tinygrad_repo/tinygrad/viz/README @@ -6,19 +6,18 @@ most uses of DEBUG >= 3 tiny-tools and a viewer for: -SAVE_SCHEDULE=1 TRACK_MATCH_STATS=2 -PROFILE=1 +ProfileEvents to use: -1. Run tinygrad with VIZ=1 and/or PROFILE=1 (this saves the pkls and launches the server (new process please!)) +1. Run tinygrad with VIZ=1 (this saves the pkls and launches the server (new process please!)) 2. That's it! This should be able to: 1. See all schedules (VIZ=1) 2. See all graphs and how they were rewritten (VIZ=1) 3. See generated code (VIZ=1) -4. See profile (PROFILE=1) +4. See profile (click on 'profiler') bunch of dev rules: * everything must be responsive to keyboard smashing! lag should never happen diff --git a/tinygrad_repo/tinygrad/viz/index.html b/tinygrad_repo/tinygrad/viz/index.html index dd1368d9..0310c642 100644 --- a/tinygrad_repo/tinygrad/viz/index.html +++ b/tinygrad_repo/tinygrad/viz/index.html @@ -75,9 +75,14 @@ g.tag circle { fill: #FFD700; stroke: #B8860B; + } + g.port circle { + fill: #b3dcc2; + } + g.tag circle, #edge-labels circle { stroke-width: 0.8; } - g.tag text { + g.tag text, #edge-labels text { text-anchor: middle; font-size: 6px; fill: #08090e; @@ -85,11 +90,33 @@ .label :is(text, p) { font-weight: 350; } + g.node rect { + stroke-width: 1.4; + stroke: #4a4b57; + } + g.overlay rect { + fill: rgba(26, 27, 38, 0.5); + } .edgePath { stroke: #4a4b57; fill: none; stroke-width: 1.4px; } + g.node.highlight rect, .edgePath.highlight, g.port circle { + stroke: #89C9A2; + } + g.highlight.child rect, .edgePath.highlight.child { + stroke: #C888B0; + } + #edge-labels g.port.highlight { + display: block + } + #edge-labels g.port { + display: none + } + #arrowhead { + fill: #4a4b57; + } .main-container { display: flex; width: 100%; @@ -107,17 +134,6 @@ .metadata > * + *, .rewrite-container > * + *, .ctx-list > * + * { margin-top: 12px; } - .stats-list > * + * { - margin-top: 8px; - } - .stats-list > p > * + * { - margin-top: 12px; - } - .stats-list { - width: 100%; - max-height: 240px; - overflow: auto; - } .ctx-list > ul > * + * { margin-top: 4px; } @@ -126,7 +142,7 @@ inset: 0; z-index: 1; } - .profiler { + .profiler, .disasm { flex: 1 1 auto; min-width: 0; width: 100%; @@ -220,15 +236,15 @@ z-index: 4; background-color: #1e2029; padding: 4px 8px; + max-width: 164px; border-radius: 4px; pointer-events: none; display: none; font-size: 10px; - white-space: pre; } #device-list > div { min-height: 32px; - max-width: 132px; + width: 132px; overflow-x: auto; overflow-y: hidden; white-space: nowrap; @@ -237,6 +253,9 @@ #device-list > div:hover { background-color: rgba(20, 23, 35, 0.3); } + #device-list { + height: fit-content; + } .raw-text { padding: 0 8px; width: 100%; @@ -322,6 +341,7 @@
+
@@ -331,14 +351,14 @@ - +
-
+
diff --git a/tinygrad_repo/tinygrad/viz/js/index.js b/tinygrad_repo/tinygrad/viz/js/index.js index a6f9eec2..ab00826c 100644 --- a/tinygrad_repo/tinygrad/viz/js/index.js +++ b/tinygrad_repo/tinygrad/viz/js/index.js @@ -4,6 +4,15 @@ const displayGraph = (cls) => { for (const e of document.getElementsByClassName("view")) e.style.display = e.classList.contains(cls) ? "flex" : "none"; } +const darkenHex = (h, p = 0) => + `#${( + c = parseInt(h.slice(1), 16), + f = 1 - p / 100, + ((c >> 16 & 255) * f | 0) << 16 | + ((c >> 8 & 255) * f | 0) << 8 | + ((c & 255) * f | 0) + ).toString(16).padStart(6, '0')}`; + const ANSI_COLORS = ["#b3b3b3", "#ff6666", "#66b366", "#ffff66", "#6666ff", "#ff66ff", "#66ffff", "#ffffff"]; const parseColors = (name, defaultColor="#ffffff") => Array.from(name.matchAll(/(?:\u001b\[(\d+)m([\s\S]*?)\u001b\[0m)|([^\u001b]+)/g), ([_, code, colored_st, st]) => ({ st: colored_st ?? st, color: code != null ? ANSI_COLORS[(parseInt(code)-30+60)%60] : defaultColor })); @@ -11,13 +20,14 @@ const parseColors = (name, defaultColor="#ffffff") => Array.from(name.matchAll(/ const rect = (s) => (typeof s === "string" ? document.querySelector(s) : s).getBoundingClientRect(); let timeout = null; -const updateProgress = ({ show=true }) => { +const updateProgress = ({ start }) => { clearTimeout(timeout); const msg = document.getElementById("progress-message"); - if (show) { + msg.style.display = "none"; + if (start) { msg.innerText = "Rendering new graph..."; timeout = setTimeout(() => { msg.style.display = "block"; }, 2000); - } else msg.style.display = "none"; + } } // ** UOp graph @@ -38,29 +48,39 @@ function addTags(root) { } let [workerUrl, worker] = [null, null]; -async function renderDag(graph, additions, recenter=false) { +async function initWorker() { + const resp = await Promise.all(["/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js","/js/worker.js"].map(u => fetch(u))); + workerUrl = URL.createObjectURL(new Blob([(await Promise.all(resp.map((r) => r.text()))).join("\n")], { type: "application/javascript" })); +} + +function renderDag(graph, additions, recenter) { // start calculating the new layout (non-blocking) - updateProgress({ show:true }); - if (worker == null) { - const resp = await Promise.all(["/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js","/js/worker.js"].map(u => fetch(u))); - workerUrl = URL.createObjectURL(new Blob([(await Promise.all(resp.map((r) => r.text()))).join("\n")], { type: "application/javascript" })); - worker = new Worker(workerUrl); - } else { - worker.terminate(); - worker = new Worker(workerUrl); - } - worker.postMessage({graph, additions, ctxs}); + updateProgress({ start:true }); + if (worker != null) worker.terminate(); + worker = new Worker(workerUrl); + worker.postMessage({graph, additions}); worker.onmessage = (e) => { displayGraph("graph"); - updateProgress({ show:false }); + updateProgress({ start:false }); const g = dagre.graphlib.json.read(e.data); // draw nodes const STROKE_WIDTH = 1.4; - const nodes = d3.select("#nodes").selectAll("g").data(g.nodes().map(id => g.node(id)), d => d).join("g") - .attr("transform", d => `translate(${d.x},${d.y})`).classed("clickable", d => d.ref != null) - .on("click", (_,d) => setCtxWithHistory(d.ref)); + d3.select("#graph-svg").on("click", () => d3.selectAll(".highlight").classed("highlight", false)); + const nodes = d3.select("#nodes").selectAll("g").data(g.nodes().map(id => g.node(id)), d => d).join("g").attr("class", d => d.className ?? "node") + .attr("transform", d => `translate(${d.x},${d.y})`).classed("clickable", d => d.ref != null).on("click", (e,d) => { + if (d.ref != null) return setCtxWithHistory(d.ref); + const parents = g.predecessors(d.id); + const children = g.successors(d.id); + if (parents == null && children == null) return; + const src = [...parents, ...children, d.id]; + nodes.classed("highlight", n => src.includes(n.id)).classed("child", n => children.includes(n.id)); + const matchEdge = (v, w) => (v===d.id && children.includes(w)) ? "highlight child " : (parents.includes(v) && w===d.id) ? "highlight " : ""; + d3.select("#edges").selectAll("path.edgePath").attr("class", e => matchEdge(e.v, e.w)+"edgePath"); + d3.select("#edge-labels").selectAll("g.port").attr("class", (_, i, n) => matchEdge(...n[i].id.split("-"))+"port"); + e.stopPropagation(); + }); nodes.selectAll("rect").data(d => [d]).join("rect").attr("width", d => d.width).attr("height", d => d.height).attr("fill", d => d.color) - .attr("x", d => -d.width/2).attr("y", d => -d.height/2).attr("style", d => d.style ?? `stroke:#4a4b57; stroke-width:${STROKE_WIDTH}px;`); + .attr("x", d => -d.width/2).attr("y", d => -d.height/2); nodes.selectAll("g.label").data(d => [d]).join("g").attr("class", "label").attr("transform", d => { const x = (d.width-d.padding*2)/2; const y = (d.height-d.padding*2)/2+STROKE_WIDTH; @@ -75,19 +95,19 @@ async function renderDag(graph, additions, recenter=false) { } return [ret]; }).join("text").selectAll("tspan").data(d => d).join("tspan").attr("x", "0").attr("dy", 14).selectAll("tspan").data(d => d).join("tspan") - .attr("fill", d => d.color).text(d => d.st).attr("xml:space", "preserve"); + .attr("fill", d => darkenHex(d.color, 25)).text(d => d.st).attr("xml:space", "preserve"); addTags(nodes.selectAll("g.tag").data(d => d.tag != null ? [d] : []).join("g").attr("class", "tag") .attr("transform", d => `translate(${-d.width/2+8}, ${-d.height/2+8})`).datum(e => e.tag)); // draw edges - const line = d3.line().x(d => d.x).y(d => d.y).curve(d3.curveBasis); - d3.select("#edges").selectAll("path.edgePath").data(g.edges()).join("path").attr("class", "edgePath").attr("d", (e) => { + const line = d3.line().x(d => d.x).y(d => d.y).curve(d3.curveBasis), edges = g.edges(); + d3.select("#edges").selectAll("path.edgePath").data(edges).join("path").attr("class", "edgePath").attr("d", (e) => { const edge = g.edge(e); const points = edge.points.slice(1, edge.points.length-1); points.unshift(intersectRect(g.node(e.v), points[0])); points.push(intersectRect(g.node(e.w), points[points.length-1])); return line(points); }).attr("marker-end", "url(#arrowhead)"); - addTags(d3.select("#edge-labels").selectAll("g").data(g.edges().filter(e => g.edge(e).label != null)).join("g").attr("transform", (e) => { + addTags(d3.select("#edge-labels").selectAll("g").data(edges).join("g").attr("transform", (e) => { // get a point near the end const [p1, p2] = g.edge(e).points.slice(-2); const dx = p2.x-p1.x; @@ -101,10 +121,9 @@ async function renderDag(graph, additions, recenter=false) { const x = p2.x - ux * offset; const y = p2.y - uy * offset; return `translate(${x}, ${y})` - }).attr("class", "tag").datum(e => g.edge(e).label)); + }).attr("class", e => g.edge(e).label.type).attr("id", e => `${e.v}-${e.w}`).datum(e => g.edge(e).label.text)); if (recenter) document.getElementById("zoom-to-fit-btn").click(); }; - } // ** profiler graph @@ -136,11 +155,11 @@ const rescaleTrack = (source, tid, k) => { return change; } -const drawLine = (ctx, x, y) => { +const drawLine = (ctx, x, y, opts) => { ctx.beginPath(); ctx.moveTo(x[0], y[0]); ctx.lineTo(x[1], y[1]); - ctx.fillStyle = ctx.strokeStyle = "#f0f0f5"; + ctx.fillStyle = ctx.strokeStyle = opts?.color || "#f0f0f5"; ctx.stroke(); } @@ -149,7 +168,7 @@ async function renderProfiler() { displayGraph("profiler"); d3.select(".metadata").html(""); // layout once! - if (data != null) return; + if (data != null) return updateProgress({ start:false }); const profiler = d3.select(".profiler").html(""); const buf = await (await fetch("/get_profile")).arrayBuffer(); const view = new DataView(buf); @@ -159,9 +178,9 @@ async function renderProfiler() { const u64 = () => { const ret = new Number(view.getBigUint64(offset, true)); offset += 8; return ret; } const f32 = () => { const ret = view.getFloat32(offset, true); offset += 4; return ret; } const optional = (i) => i === 0 ? null : i-1; - const dur = u32(), peak = u64(), indexLen = u32(), layoutsLen = u32(); + const dur = u32(), tracePeak = u64(), indexLen = u32(), layoutsLen = u32(); const textDecoder = new TextDecoder("utf-8"); - const { strings, dtypeSize } = JSON.parse(textDecoder.decode(new Uint8Array(buf, offset, indexLen))); offset += indexLen; + const { strings, dtypeSize, markers } = JSON.parse(textDecoder.decode(new Uint8Array(buf, offset, indexLen))); offset += indexLen; // place devices on the y axis and set vertical positions const [tickSize, padding] = [10, 8]; const deviceList = profiler.append("div").attr("id", "device-list").style("padding-top", tickSize+padding+"px"); @@ -170,26 +189,26 @@ async function renderProfiler() { canvas.addEventListener("wheel", e => (e.stopPropagation(), e.preventDefault()), { passive:false }); const ctx = canvas.getContext("2d"); const canvasTop = rect(canvas).top; - // color by key (name/category/device) + // color by key (name/device) const colorMap = new Map(); data = {tracks:new Map(), axes:{}}; - const heightScale = d3.scaleLinear().domain([0, peak]).range([4,maxheight=100]); + const heightScale = d3.scaleLinear().domain([0, tracePeak]).range([4,maxheight=100]); for (let i=0; i e.st >= levelEt); const et = e.st+Math.trunc(e.dur); @@ -197,9 +216,10 @@ async function renderProfiler() { depth = levels.length; levels.push(et); } else levels[depth] = et; - if (depth === 0) colorKey = e.cat ?? e.name; - if (!colorMap.has(colorKey)) colorMap.set(colorKey, cycleColors(colorScheme[k] ?? colorScheme.DEFAULT, colorMap.size)); - const fillColor = d3.color(colorMap.get(colorKey)).brighter(depth).toString(); + if (depth === 0) colorKey = e.name.split(" ")[0]; + if (!colorMap.has(colorKey)) colorMap.set(colorKey, d3.rgb(cycleColors(colorScheme[k.split(":")[0]] ?? colorScheme.DEFAULT, colorMap.size))); + const base = colorMap.get(colorKey), s = Math.min(Math.pow(1/0.7, depth), 240 / Math.max(base.r, base.g, base.b)); + const fillColor = d3.rgb(base.r*s, base.g*s, base.b*s).toString(); const label = parseColors(e.name).map(({ color, st }) => ({ color, st, width:ctx.measureText(st).width })); if (e.ref != null) ref = {ctx:e.ref, step:0}; else if (ref != null) { @@ -207,15 +227,14 @@ async function renderProfiler() { const stepIdx = ctxs[ref.ctx+1].steps.findIndex((s, i) => i >= start && s.name == e.name); ref = stepIdx === -1 ? null : {ctx:ref.ctx, step:stepIdx}; } - const arg = { tooltipText:formatTime(e.dur)+(e.info != null ? "\n"+e.info : ""), ...ref }; + const htmlLabel = label.map(({color, st}) => `${st}`).join(''); + const arg = { tooltipText:htmlLabel+"\n"+formatTime(e.dur)+(e.info != null ? "\n"+e.info : ""), ...ref }; // offset y by depth shapes.push({x:e.st, y:levelHeight*depth, width:e.dur, height:levelHeight, arg, label, fillColor }); } div.style("height", levelHeight*levels.length+padding+"px").style("pointerEvents", "none"); } else { const peak = u64(); - const height = heightScale(peak); - const yscale = d3.scaleLinear().domain([0, peak]).range([height, 0]); let x = 0, y = 0; const buf_shapes = new Map(), temp = new Map(); const timestamps = []; @@ -246,12 +265,15 @@ async function renderProfiler() { v.y.push(v.y.at(-1)); } timestamps.push(dur); - for (const [_, {dtype, sz, nbytes, y, x:steps}] of buf_shapes) { + const height = heightScale(peak); + const yscale = d3.scaleLinear().domain([0, peak]).range([height, 0]); + for (const [num, {dtype, sz, nbytes, y, x:steps}] of buf_shapes) { const x = steps.map(s => timestamps[s]); - const arg = {tooltipText:`${dtype} len:${formatUnit(sz)}\n${formatUnit(nbytes, "B")}`}; + const dur = x.at(-1)-x[0]; + const arg = {tooltipText:`${dtype} len:${formatUnit(sz)}\n${formatUnit(nbytes, "B")}\nnum:${num}\nalive for ${formatTime(dur)}`}; shapes.push({ x, y0:y.map(yscale), y1:y.map(y0 => yscale(y0+nbytes)), arg, fillColor:cycleColors(colorScheme.BUFFER, shapes.length) }); } - data.tracks.set(k, { shapes, offsetY, height, peak, scaleFactor:maxheight*4/height }); + data.tracks.set(k, { shapes, visible, offsetY, height, peak, scaleFactor:maxheight*4/height }); div.style("height", height+padding+"px").style("cursor", "pointer").on("click", (e) => { const newFocus = e.currentTarget.id === focusedDevice ? null : e.currentTarget.id; let offset = 0; @@ -266,60 +288,50 @@ async function renderProfiler() { }); } } - updateProgress({ "show":false }); + updateProgress({ start:false }); // draw events on a timeline const dpr = window.devicePixelRatio || 1; const ellipsisWidth = ctx.measureText("...").width; - const rectLst = []; function render(transform) { zoomLevel = transform; - rectLst.length = 0; - ctx.save(); ctx.clearRect(0, 0, canvas.clientWidth, canvas.clientHeight); // rescale to match current zoom const xscale = d3.scaleLinear().domain([0, dur]).range([0, canvas.clientWidth]); - xscale.domain(xscale.range().map(zoomLevel.invertX, zoomLevel).map(xscale.invert, xscale)); - const zoomDomain = transform != null ? xscale.domain() : null; - let yscale = null; - if (data.axes.y != null) { - yscale = d3.scaleLinear().domain(data.axes.y.domain).range(data.axes.y.range); - } + const visibleX = xscale.range().map(zoomLevel.invertX, zoomLevel).map(xscale.invert, xscale); + const st = visibleX[0], et = visibleX[1]; + xscale.domain(visibleX); // draw shapes - for (const [_, { offsetY, shapes }] of data.tracks) { + for (const [_, { offsetY, shapes, visible }] of data.tracks) { + visible.length = 0; for (const e of shapes) { - const [start, end] = e.width != null ? [e.x, e.x+e.width] : [e.x[0], e.x[e.x.length-1]]; - if (zoomDomain != null && (start>zoomDomain[1]|| endet || e.x.at(-1)=0; i--) ctx.lineTo(x[i], offsetY+e.y1[i]); ctx.closePath(); - ctx.fill(); - // NOTE: y coordinates are in reverse order - for (let i = 0; i < x.length - 1; i++) { - let tooltipText = e.arg.tooltipText; - if (yscale != null && ((yaxisVal=yscale.invert(offsetY+e.y1[i]))>0)) { - tooltipText += `\nTotal: ${formatUnit(yaxisVal, data.axes.y.fmt)}`; - } - rectLst.push({ x0:x[i], x1:x[i+1], y0:offsetY+e.y1[i], y1:offsetY+e.y0[i], arg:{...e.arg, tooltipText} }); - } + ctx.fillStyle = e.fillColor; ctx.fill(); continue; } // contiguous rect - const x = xscale(start); - const width = xscale(end)-x; - ctx.fillRect(x, offsetY+e.y, width, e.height); - rectLst.push({ y0:offsetY+e.y, y1:offsetY+e.y+e.height, x0:x, x1:x+width, arg:e.arg }); + if (e.x>et || e.x+e.width width) { if (labelWidth !== 0) ctx.fillText("...", labelX, labelY); @@ -340,27 +352,31 @@ async function renderProfiler() { drawLine(ctx, [x, x], [0, tickSize]) // tick label ctx.textBaseline = "top"; - ctx.textAlign = "left"; ctx.fillText(formatTime(tick, dur), x+ctx.lineWidth+2, tickSize); } - if (yscale != null) { - drawLine(ctx, [0, 0], yscale.range()); + if (data.axes.y != null) { + drawLine(ctx, [0, 0], data.axes.y.range); + const yscale = d3.scaleLinear().domain(data.axes.y.domain).range(data.axes.y.range); for (const tick of yscale.ticks()) { const y = yscale(tick); drawLine(ctx, [0, tickSize], [y, y]); - ctx.textAlign = "left"; ctx.textBaseline = "middle"; ctx.fillText(formatUnit(tick, data.axes.y.fmt), tickSize+2, y); } } - ctx.restore(); + // draw markers + for (const m of markers) { + const x = xscale(m.ts); + drawLine(ctx, [x, x], [0, canvas.clientHeight], { color:m.color }); + ctx.fillText(m.name, x+2, 1); + } } function resize() { const profiler = document.querySelector(".profiler"); - // NOTE: use clientWidth to account for the scrollbar - let [width, height] = [profiler.clientWidth, profiler.scrollHeight]; - width -= rect("#device-list").width+padding; + const sideRect = rect("#device-list"); + const width = profiler.clientWidth-(sideRect.width+padding), height = Math.round(sideRect.height); + if (canvas.width === width*dpr && canvas.height === height*dpr) return; canvas.width = width*dpr; canvas.height = height*dpr; canvas.style.height = `${height}px`; @@ -369,19 +385,23 @@ async function renderProfiler() { d3.select(canvas).call(canvasZoom.transform, zoomLevel); } - canvasZoom = d3.zoom().filter(e => (!e.ctrlKey || e.type === 'wheel' || e.type === 'mousedown') && !e.button) - .scaleExtent([1, Infinity]).translateExtent([[0,0], [Infinity,0]]).on("zoom", e => render(e.transform)); + canvasZoom = d3.zoom().filter(vizZoomFilter).scaleExtent([1, Infinity]).translateExtent([[0,0], [Infinity,0]]).on("zoom", e => render(e.transform)); d3.select(canvas).call(canvasZoom); document.addEventListener("contextmenu", e => e.ctrlKey && e.preventDefault()); - resize(); - window.addEventListener("resize", resize); + new ResizeObserver(([e]) => e.contentRect.width > 0 && resize()).observe(profiler.node()); function findRectAtPosition(x, y) { + let tid = null; + for (const k of data.tracks.keys()) { + const r = rect(document.getElementById(k)); + if (y >= r.y && y <= r.y+r.height) { tid = k; break; } + } + if (tid == null) return; const { top, left, width, height } = rect(canvas); const X = ((x-left) * (canvas.width/width))/dpr; const Y = ((y-top) * (canvas.height/height))/dpr; - for (const r of rectLst) { + for (const r of data.tracks.get(tid).visible) { if (Y>=r.y0 && Y<=r.y1 && X>=r.x0 && X<=r.x1) return r.arg; } } @@ -399,7 +419,7 @@ async function renderProfiler() { tooltip.style.display = "block"; tooltip.style.left = (e.pageX+10)+"px"; tooltip.style.top = (e.pageY)+"px"; - tooltip.innerText = foundRect.tooltipText; + tooltip.innerHTML = foundRect.tooltipText; } else tooltip.style.display = "none"; }); canvas.addEventListener("mouseleave", () => document.getElementById("tooltip").style.display = "none"); @@ -407,7 +427,8 @@ async function renderProfiler() { // ** zoom and recentering -const svgZoom = d3.zoom().on("zoom", (e) => d3.select("#render").attr("transform", e.transform)); +const vizZoomFilter = e => (!e.ctrlKey || e.type === 'wheel' || e.type === 'mousedown') && !e.button && e.type !== 'dblclick'; +const svgZoom = d3.zoom().filter(vizZoomFilter).on("zoom", (e) => d3.select("#render").attr("transform", e.transform)); d3.select("#graph-svg").call(svgZoom); // zoom to fit into view @@ -511,7 +532,6 @@ function setState(ns) { // set a new context and keep the old one in browser history function setCtxWithHistory(newCtx, step=0) { - if (newCtx == null) return; // NOTE: browser does a structured clone, passing a mutable object is safe. history.replaceState(state, ""); history.pushState(state, ""); @@ -563,13 +583,14 @@ async function main() { else if (e.readyState === EventSource.OPEN) activeSrc = e; } if (ctx.name === "Profiler") return renderProfiler(); + if (workerUrl == null) await initWorker(); if (ckey in cache) { ret = cache[ckey]; } // ** Disassembly view if (ckey.startsWith("/disasm")) { if (!(ckey in cache)) cache[ckey] = ret = await (await fetch(ckey)).json(); - displayGraph("profiler"); + displayGraph("disasm"); const root = document.createElement("div"); root.className = "raw-text"; const metadata = document.querySelector(".metadata"); @@ -611,7 +632,7 @@ async function main() { appendTd(tr, s.value); } } else root.appendChild(codeBlock(ret.src, "x86asm")); - return document.querySelector(".profiler").replaceChildren(root); + return document.querySelector(".disasm").replaceChildren(root); } // ** UOp view (default) // if we don't have a complete cache yet we start streaming rewrites in this step @@ -632,33 +653,11 @@ async function main() { }; } if (ret.length === 0) return; - renderDag(ret[currentRewrite].graph, ret[currentRewrite].changed_nodes || [], recenter=currentRewrite === 0); + renderDag(ret[currentRewrite].graph, ret[currentRewrite].changed_nodes ?? [], currentRewrite === 0); // ** right sidebar code blocks const metadata = document.querySelector(".metadata"); const [code, lang] = ctx.fmt != null ? [ctx.fmt, "cpp"] : [ret[currentRewrite].uop, "python"]; metadata.replaceChildren(codeBlock(step.code_line, "python", { loc:step.loc, wrap:true }), codeBlock(code, lang, { wrap:false })); - if (ctx.runtime_stats != null) { - const div = metadata.appendChild(document.createElement("div")); - div.className = "stats-list"; - for (const [i, s] of ctx.runtime_stats.entries()) { - const p = div.appendChild(document.createElement("p")); - if (ctx.runtime_stats.length > 1) p.innerText = `Run ${i+1}/${ctx.runtime_stats.length}`; - const table = div.appendChild(document.createElement("table")); - const tbody = table.appendChild(document.createElement("tbody")); - for (const { name, value, unit, subunits } of s.data) { - const mainRow = appendRow(tbody, name, value, unit, "main-row"); - if (!subunits?.length) continue; - const subunitRow = tbody.appendChild(document.createElement("tr")); - subunitRow.style.display = "none"; - mainRow.onclick = () => subunitRow.style.display = subunitRow.style.display === "none" ? "table-row" : "none"; - mainRow.style.cursor = "pointer"; - const td = subunitRow.appendChild(document.createElement("td")); - td.colSpan = 2; - const table = td.appendChild(document.createElement("table")); - for (const u of subunits) appendRow(table, u.name, u.value, unit, "sub-row"); - } - } - } // ** rewrite steps if (step.match_count >= 1) { const rewriteList = metadata.appendChild(document.createElement("div")); @@ -723,7 +722,7 @@ appendResizer(document.querySelector(".metadata-parent"), { minWidth: 20, maxWid // **** keyboard shortcuts -document.addEventListener("keydown", async function(event) { +document.addEventListener("keydown", (event) => { const { currentCtx, currentStep, currentRewrite, expandSteps } = state; // up and down change the step or context from the list const changeStep = expandSteps && ctxs[currentCtx].steps?.length; diff --git a/tinygrad_repo/tinygrad/viz/js/worker.js b/tinygrad_repo/tinygrad/viz/js/worker.js index 3f78445f..9fbcaa06 100644 --- a/tinygrad_repo/tinygrad/viz/js/worker.js +++ b/tinygrad_repo/tinygrad/viz/js/worker.js @@ -5,10 +5,10 @@ const ctx = canvas.getContext("2d"); ctx.font = `${LINE_HEIGHT}px sans-serif`; onmessage = (e) => { - const { graph, additions, ctxs } = e.data; + const { graph, additions } = e.data; const g = new dagre.graphlib.Graph({ compound: true }); g.setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; }); - if (additions.length !== 0) g.setNode("addition", {label:"", style:"fill: rgba(26, 27, 38, 0.5);", padding:0}); + if (additions.length !== 0) g.setNode("addition", {label:"", className:"overlay", padding:0}); for (let [k, {label, src, ref, ...rest }] of Object.entries(graph)) { // adjust node dims by label size (excluding escape codes) + add padding let [width, height] = [0, 0]; @@ -16,11 +16,11 @@ onmessage = (e) => { width = Math.max(width, ctx.measureText(line).width); height += LINE_HEIGHT; } - g.setNode(k, {width:width+NODE_PADDING*2, height:height+NODE_PADDING*2, padding:NODE_PADDING, label, ref, ...rest}); + g.setNode(k, {width:width+NODE_PADDING*2, height:height+NODE_PADDING*2, padding:NODE_PADDING, label, ref, id:k, ...rest}); // add edges const edgeCounts = {} - for (const s of src) edgeCounts[s] = (edgeCounts[s] || 0)+1; - for (const s of src) g.setEdge(s, k, { label: edgeCounts[s] > 1 ? edgeCounts[s] : null }); + for (const [_, s] of src) edgeCounts[s] = (edgeCounts[s] || 0)+1; + for (const [port, s] of src) g.setEdge(s, k, { label: edgeCounts[s] > 1 ? {type:"tag", text:edgeCounts[s]} : {type:"port", text:port}}); if (additions.includes(parseInt(k))) g.setParent(k, "addition"); } dagre.layout(g); diff --git a/tinygrad_repo/tinygrad/viz/serve.py b/tinygrad_repo/tinygrad/viz/serve.py index 48e6efb1..e58839bb 100755 --- a/tinygrad_repo/tinygrad/viz/serve.py +++ b/tinygrad_repo/tinygrad/viz/serve.py @@ -1,16 +1,17 @@ #!/usr/bin/env python3 import multiprocessing, pickle, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, socketserver, functools, codecs, io, struct -import subprocess, ctypes +import subprocess, ctypes, pathlib from contextlib import redirect_stdout from decimal import Decimal from http.server import BaseHTTPRequestHandler from urllib.parse import parse_qs, urlparse from typing import Any, TypedDict, Generator -from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA, ProfileEvent, ProfileRangeEvent, TracingKey, ProfilePointEvent +from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA, ProfileEvent, ProfileRangeEvent, TracingKey, ProfilePointEvent, temp from tinygrad.uop.ops import TrackedGraphRewrite, UOp, Ops, printable, GroupOp, srender, sint, sym_infer from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, Device from tinygrad.renderer import ProgramSpec from tinygrad.dtype import dtypes +from tinygrad.codegen.opt import axis_colors uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B", Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_REG: "#f0ffe0", Ops.REDUCE_AXIS: "#FF6B6B", @@ -18,7 +19,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55", **{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BUFFER_VIEW: "#E5EAFF", Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.FUSE: "#FFa500", - Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D", + Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D", Ops.REALIZE: "#C1C14D", Ops.CHILDREN: "#80ffc0", Ops.CHILD: "#80fff0", Ops.BUFFERIZE: "#FF991C", Ops.REWRITE_ERROR: "#ff2e2e"} # VIZ API @@ -26,19 +27,20 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", # ** Metadata for a track_rewrites scope ref_map:dict[Any, int] = {} -def get_metadata(keys:list[TracingKey], contexts:list[list[TrackedGraphRewrite]]) -> list[dict]: +traces:dict[int, tuple] = {} +def get_metadata(trace_bufs:list[tuple]) -> list[dict]: ret = [] - for i,(k,v) in enumerate(zip(keys, contexts)): - steps = [{"name":s.name, "loc":s.loc, "depth":s.depth, "match_count":len(s.matches), "code_line":printable(s.loc), - "query":f"/ctxs?ctx={i}&idx={j}"} for j,s in enumerate(v)] - ret.append(r:={"name":k.display_name, "steps":steps}) - # use the first key to get runtime profiling data about this context - if getenv("PROFILE_VALUE") >= 2 and k.keys: r["runtime_stats"] = get_runtime_stats(k.keys[0]) - # program spec metadata - if isinstance(k.ret, ProgramSpec): - steps.append({"name":"View Disassembly", "query":f"/disasm?ctx={i}"}) - r["fmt"] = k.ret.src - for key in k.keys: ref_map[key] = i + for keys,contexts,uop_fields in trace_bufs: + for k,v in zip(keys, contexts): + traces[i:=len(traces)] = (k, v, uop_fields) + steps = [{"name":s.name, "loc":s.loc, "depth":s.depth, "match_count":len(s.matches), "code_line":printable(s.loc), + "query":f"/ctxs?ctx={i}&idx={j}"} for j,s in enumerate(v)] + ret.append(r:={"name":k.display_name, "steps":steps}) + # program spec metadata + if isinstance(k.ret, ProgramSpec): + steps.append({"name":"View Disassembly", "query":f"/disasm?ctx={i}"}) + r["fmt"] = k.ret.src + for key in k.keys: ref_map[key] = i return ret # ** Complete rewrite details for a graph_rewrite call @@ -69,37 +71,38 @@ def uop_to_json(x:UOp) -> dict[int, dict]: if u.op is Ops.VIEW: argst = ("\n".join([f"{shape_to_str(v.shape)} / {shape_to_str(v.strides)}"+("" if v.offset == 0 else f" / {srender(v.offset)}")+ (f"\nMASK {mask_to_str(v.mask)}" if v.mask is not None else "") for v in unwrap(u.st).views])) + if u.op in GroupOp.Movement: argst = (mask_to_str if u.op in {Ops.SHRINK, Ops.PAD} else shape_to_str)(u.arg) label = f"{str(u.op).split('.')[1]}{(chr(10)+word_wrap(argst.replace(':', ''))) if u.arg is not None else ''}" if u.dtype != dtypes.void: label += f"\n{u.dtype}" for idx,x in enumerate(u.src): if x in excluded: - arg = f"{x.arg:g}" if x.op is Ops.CONST and dtypes.is_float(u.dtype) else f"{x.arg}" + arg = f"{x.arg:g}" if x.op is Ops.CONST and dtypes.is_float(x.dtype) else f"{x.arg}" label += f"\n{x.op.name}{idx} {arg}" + (f" {x.src[0].op}" if len(x.src) else "") try: if u.op not in {Ops.VIEW, Ops.BUFFER, Ops.KERNEL, Ops.ASSIGN, Ops.COPY, Ops.SINK, *GroupOp.Buffer} and u.st is not None: label += f"\n{shape_to_str(u.shape)}" elif len(rngs:=u.ranges): - label += f"\n{str(sorted([x.arg[0] for x in rngs]))}" + label += f"\n({','.join([colored(str(x.arg[0]), axis_colors[x.arg[-1]]) for x in sorted(rngs, key=lambda x: x.arg[0:-1])])})" except Exception: label += "\n" if (ref:=ref_map.get(u.arg.ast) if u.op is Ops.KERNEL else None) is not None: label += f"\ncodegen@{ctxs[ref]['name']}" # NOTE: kernel already has metadata in arg if TRACEMETA >= 2 and u.metadata is not None and u.op is not Ops.KERNEL: label += "\n"+repr(u.metadata) - graph[id(u)] = {"label":label, "src":[id(x) for x in u.src if x not in excluded], "color":uops_colors.get(u.op, "#ffffff"), - "ref":ref, "tag":u.tag} + graph[id(u)] = {"label":label, "src":[(i,id(x)) for i,x in enumerate(u.src) if x not in excluded], "color":uops_colors.get(u.op, "#ffffff"), + "ref":ref, "tag":repr(u.tag) if u.tag is not None else None} return graph @functools.cache -def _reconstruct(a:int): - op, dtype, src, arg, *rest = contexts[2][a] - arg = type(arg)(_reconstruct(arg.ast), arg.metadata) if op is Ops.KERNEL else arg - return UOp(op, dtype, tuple(_reconstruct(s) for s in src), arg, *rest) +def _reconstruct(a:int, i:int): + op, dtype, src, arg, *rest = traces[i][2][a] + arg = type(arg)(_reconstruct(arg.ast, i), arg.metadata) if op is Ops.KERNEL else arg + return UOp(op, dtype, tuple(_reconstruct(s, i) for s in src), arg, *rest) -def get_details(ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None, None]: - yield {"graph":uop_to_json(next_sink:=_reconstruct(ctx.sink)), "uop":str(next_sink), "changed_nodes":None, "diff":None, "upat":None} +def get_details(ctx:TrackedGraphRewrite, i:int=0) -> Generator[GraphRewriteDetails, None, None]: + yield {"graph":uop_to_json(next_sink:=_reconstruct(ctx.sink, i)), "uop":str(next_sink), "changed_nodes":None, "diff":None, "upat":None} replaces: dict[UOp, UOp] = {} for u0_num,u1_num,upat_loc in tqdm(ctx.matches): - replaces[u0:=_reconstruct(u0_num)] = u1 = _reconstruct(u1_num) + replaces[u0:=_reconstruct(u0_num, i)] = u1 = _reconstruct(u1_num, i) try: new_sink = next_sink.substitute(replaces) except RuntimeError as e: new_sink = UOp(Ops.NOOP, arg=str(e)) yield {"graph":(sink_json:=uop_to_json(new_sink)), "uop":str(new_sink), "changed_nodes":[id(x) for x in u1.toposort() if id(x) in sink_json], @@ -135,41 +138,39 @@ def flatten_events(profile:list[ProfileEvent]) -> Generator[tuple[Decimal, Decim def timeline_layout(dev_events:list[tuple[int, int, float, DevEvent]], start_ts:int, scache:dict[str, int]) -> bytes|None: events:list[bytes] = [] exec_points:dict[str, dict] = {} - category_enum:dict[str, int] = {} for st,et,dur,e in dev_events: if isinstance(e, ProfilePointEvent) and e.name == "exec": exec_points[e.key] = e.arg if dur == 0: continue - name, cat, info = e.name, None, None + name, info = e.name, None if (ref:=ref_map.get(name)) is not None: name = ctxs[ref]["name"] - if isinstance(p:=contexts[0][ref].ret, ProgramSpec) and (ei:=exec_points.get(p.name)) is not None: + if isinstance(p:=traces[ref][0].ret, ProgramSpec) and (ei:=exec_points.get(p.name)) is not None: info = f"{sym_infer(p.estimates.ops, ei['var_vals'])/(t:=dur*1e3):.2f} GFLOPS {sym_infer(p.estimates.mem, ei['var_vals'])/t:4.1f}"+ \ f"|{sym_infer(p.estimates.lds,ei['var_vals'])/t:.1f} GB/s\n{ei['metadata']}" elif isinstance(e.name, TracingKey): - name, cat = e.name.display_name, e.name.cat + name = e.name.display_name ref = next((v for k in e.name.keys if (v:=ref_map.get(k)) is not None), None) - events.append(struct.pack(" bytes|None: peak, mem = 0, 0 temp:dict[int, int] = {} - bufs:list[bytes] = [] - for st,_,_,e in events: + events:list[bytes] = [] + for st,_,_,e in dev_events: if not isinstance(e, ProfilePointEvent): continue if e.name == "alloc": - bufs.append(struct.pack(" peak: peak = mem if e.name == "free": - bufs.append(struct.pack(" bytes|None: # start by getting the time diffs @@ -177,12 +178,14 @@ def get_profile(profile:list[ProfileEvent]) -> bytes|None: if isinstance(ev,ProfileDeviceEvent): device_ts_diffs[ev.device] = (ev.comp_tdiff, ev.copy_tdiff if ev.copy_tdiff is not None else ev.comp_tdiff) # map events per device dev_events:dict[str, list[tuple[int, int, float, DevEvent]]] = {} + markers:list[ProfilePointEvent] = [] start_ts:int|None = None end_ts:int|None = None for ts,en,e in flatten_events(profile): dev_events.setdefault(e.device,[]).append((st:=int(ts), et:=int(en), float(en-ts), e)) if start_ts is None or st < start_ts: start_ts = st if end_ts is None or et > end_ts: end_ts = et + if isinstance(e, ProfilePointEvent) and e.name == "marker": markers.append(e) if start_ts is None: return None # return layout of per device events layout:dict[str, bytes|None] = {} @@ -194,16 +197,9 @@ def get_profile(profile:list[ProfileEvent]) -> bytes|None: layout[k] = timeline_layout(v, start_ts, scache) layout[f"{k} Memory"] = mem_layout(v, start_ts, unwrap(end_ts), peaks, dtype_size, scache) ret = [b"".join([struct.pack(" list[dict]: - ret:list[dict] = [] - for e in profile: - if isinstance(e, ProfileRangeEvent) and e.en is not None and e.name == key: - ret.append({"device":e.device, "data":[{"name":"Duration", "value":float(e.en-e.st), "unit":"us"}]}) - return ret - # ** Assembly analyzers def get_llvm_mca(asm:str, mtriple:str, mcpu:str) -> dict: @@ -227,11 +223,11 @@ def get_llvm_mca(asm:str, mtriple:str, mcpu:str) -> dict: return {"rows":rows, "cols":["Opcode", "Latency", {"title":"HW Resources", "labels":resource_labels}], "summary":summary} def get_disassembly(ctx:list[str]): - if not isinstance(prg:=contexts[0][int(ctx[0])].ret, ProgramSpec): return + if not isinstance(prg:=traces[int(ctx[0])][0].ret, ProgramSpec): return lib = (compiler:=Device[prg.device].compiler).compile(prg.src) with redirect_stdout(buf:=io.StringIO()): compiler.disassemble(lib) disasm_str = buf.getvalue() - from tinygrad.runtime.ops_llvm import llvm, LLVMCompiler + from tinygrad.runtime.support.compiler_cpu import llvm, LLVMCompiler if isinstance(compiler, LLVMCompiler): mtriple = ctypes.string_at(llvm.LLVMGetTargetMachineTriple(tm:=compiler.target_machine)).decode() mcpu = ctypes.string_at(llvm.LLVMGetTargetMachineCPU(tm)).decode() @@ -255,9 +251,9 @@ class Handler(BaseHTTPRequestHandler): except FileNotFoundError: status_code = 404 elif (query:=parse_qs(url.query)): if url.path == "/disasm": ret, content_type = get_disassembly(**query), "application/json" - else: return self.stream_json(get_details(contexts[1][int(query["ctx"][0])][int(query["idx"][0])])) + else: return self.stream_json(get_details(traces[i:=int(query["ctx"][0])][1][int(query["idx"][0])], i)) elif url.path == "/ctxs": ret, content_type = json.dumps(ctxs).encode(), "application/json" - elif url.path == "/get_profile" and profile_ret is not None: ret, content_type = profile_ret, "application/octet-stream" + elif url.path == "/get_profile" and profile_ret: ret, content_type = profile_ret, "application/octet-stream" else: status_code = 404 # send response @@ -290,17 +286,17 @@ def reloader(): os.execv(sys.executable, [sys.executable] + sys.argv) time.sleep(0.1) -def load_pickle(path:str): - if path is None or not os.path.exists(path): return None - with open(path, "rb") as f: return pickle.load(f) +def load_pickle(path:pathlib.Path|None) -> list: + if path is None or not path.exists(): return [] + with path.open("rb") as f: return pickle.load(f) # NOTE: using HTTPServer forces a potentially slow socket.getfqdn class TCPServerWithReuse(socketserver.TCPServer): allow_reuse_address = True if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--kernels', type=str, help='Path to kernels', default=None) - parser.add_argument('--profile', type=str, help='Path profile', default=None) + parser.add_argument('--kernels', type=pathlib.Path, help='Path to kernels', default=pathlib.Path(temp("rewrites.pkl", append_user=True))) + parser.add_argument('--profile', type=pathlib.Path, help='Path to profile', default=pathlib.Path(temp("profile.pkl", append_user=True))) args = parser.parse_args() with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: @@ -311,19 +307,16 @@ if __name__ == "__main__": st = time.perf_counter() print("*** viz is starting") - contexts, profile = load_pickle(args.kernels), load_pickle(args.profile) + ctxs = get_metadata(load_pickle(args.kernels)) - # NOTE: this context is a tuple of list[keys] and list[values] - ctxs = get_metadata(*contexts[:2]) if contexts is not None else [] - - profile_ret = get_profile(profile) if profile is not None else None + profile_ret = get_profile(load_pickle(args.profile)) server = TCPServerWithReuse(('', PORT), Handler) reloader_thread = threading.Thread(target=reloader) reloader_thread.start() print(f"*** started viz on {HOST}:{PORT}") print(colored(f"*** ready in {(time.perf_counter()-st)*1e3:4.2f}ms", "green"), flush=True) - if len(getenv("BROWSER", "")) > 0: webbrowser.open(f"{HOST}:{PORT}{'/profiler' if contexts is None else ''}") + if len(getenv("BROWSER", "")) > 0: webbrowser.open(f"{HOST}:{PORT}") try: server.serve_forever() except KeyboardInterrupt: print("*** viz is shutting down...")