From b5d2026fc750a2c474dfb6ade5e93828abe5a6c4 Mon Sep 17 00:00:00 2001 From: Jason Wen Date: Thu, 2 Apr 2026 21:14:39 -0400 Subject: [PATCH] sunnypilot models: support for on-policy models --- cereal/custom.capnp | 1 + common/model.h | 2 +- release/ci/model_generator.py | 2 +- .../ui/sunnypilot/layouts/settings/models.py | 8 +++- sunnypilot/modeld_v2/install_models_pc.py | 2 +- sunnypilot/models/default_model.py | 8 ++-- sunnypilot/models/fetcher.py | 2 +- sunnypilot/models/runners/helpers.py | 5 ++- .../models/runners/tinygrad/model_types.py | 16 +++++++ .../runners/tinygrad/tinygrad_runner.py | 44 ++++++++++++++----- sunnypilot/models/tests/model_hash | 2 +- sunnypilot/models/tests/test_default_model.py | 7 +-- 12 files changed, 74 insertions(+), 25 deletions(-) diff --git a/cereal/custom.capnp b/cereal/custom.capnp index 53986262ec..fe3ed9196f 100644 --- a/cereal/custom.capnp +++ b/cereal/custom.capnp @@ -154,6 +154,7 @@ struct ModelManagerSP @0xaedffd8f31e7b55d { vision @2; policy @3; offPolicy @4; + onPolicy @5; } } diff --git a/common/model.h b/common/model.h index 444727be93..b482499c74 100644 --- a/common/model.h +++ b/common/model.h @@ -1 +1 @@ -#define DEFAULT_MODEL "CD210 (Default)" +#define DEFAULT_MODEL "OP Model (Default)" diff --git a/release/ci/model_generator.py b/release/ci/model_generator.py index da6b933030..feeb80095b 100755 --- a/release/ci/model_generator.py +++ b/release/ci/model_generator.py @@ -104,7 +104,7 @@ def generate_metadata(model_path: Path, output_dir: Path, short_name: str): metadata_file = metadata_file.rename(output_path / f"{base}_{short_name.lower()}_metadata.pkl") # Build the metadata structure - model_type = "offPolicy" if "off_policy" in base else base.split("_")[-1] + model_type = "offPolicy" if "off_policy" in base else "onPolicy" if "on_policy" in base else base.split("_")[-1] model_metadata = { "type": model_type, diff --git a/selfdrive/ui/sunnypilot/layouts/settings/models.py b/selfdrive/ui/sunnypilot/layouts/settings/models.py index bf6af56941..b34604af0c 100644 --- a/selfdrive/ui/sunnypilot/layouts/settings/models.py +++ b/selfdrive/ui/sunnypilot/layouts/settings/models.py @@ -58,6 +58,8 @@ class ModelsLayout(Widget): self.supercombo_label = progress_item(tr("Driving Model")) self.vision_label = progress_item(tr("Vision Model")) self.policy_label = progress_item(tr("Policy Model")) + self.off_policy_label = progress_item(tr("Off-Policy Model")) + self.on_policy_label = progress_item(tr("On-Policy Model")) self.refresh_item = button_item(tr("Refresh Model List"), tr("REFRESH"), "", lambda: (ui_state.params.put("ModelManager_LastSyncTime", 0), @@ -91,7 +93,7 @@ class ModelsLayout(Widget): self.lagd_toggle = toggle_item_sp(tr("Live Learning Steer Delay"), "", param="LagdToggle") self.items = [self.current_model_item, self.cancel_download_item, self.supercombo_label, self.vision_label, - self.policy_label, self.refresh_item, self.clear_cache_item, self.lane_turn_desire_toggle, + self.policy_label, self.off_policy_label, self.on_policy_label, self.refresh_item, self.clear_cache_item, self.lane_turn_desire_toggle, self.lane_turn_value_control, self.lagd_toggle, self.delay_control] def _update_lagd_description(self, lagd_toggle: bool): @@ -129,7 +131,9 @@ class ModelsLayout(Widget): def _handle_bundle_download_progress(self): labels = {custom.ModelManagerSP.Model.Type.supercombo: self.supercombo_label, custom.ModelManagerSP.Model.Type.vision: self.vision_label, - custom.ModelManagerSP.Model.Type.policy: self.policy_label} + custom.ModelManagerSP.Model.Type.policy: self.policy_label, + custom.ModelManagerSP.Model.Type.offPolicy: self.off_policy_label, + custom.ModelManagerSP.Model.Type.onPolicy: self.on_policy_label} for label in labels.values(): label.set_visible(False) self.cancel_download_item.set_visible(False) diff --git a/sunnypilot/modeld_v2/install_models_pc.py b/sunnypilot/modeld_v2/install_models_pc.py index d203de3487..1bba001abd 100755 --- a/sunnypilot/modeld_v2/install_models_pc.py +++ b/sunnypilot/modeld_v2/install_models_pc.py @@ -30,7 +30,7 @@ def generate_metadata_pkl(model_path, output_path): def install_models(model_dir): model_dir = Path(model_dir) - models = ["driving_off_policy", "driving_policy", "driving_vision"] + models = ["driving_off_policy", "driving_on_policy", "driving_vision"] found_models = [] for model in models: diff --git a/sunnypilot/models/default_model.py b/sunnypilot/models/default_model.py index 0260a3c3bc..d540efbffd 100755 --- a/sunnypilot/models/default_model.py +++ b/sunnypilot/models/default_model.py @@ -8,14 +8,16 @@ from openpilot.sunnypilot import get_file_hash DEFAULT_MODEL_NAME_PATH = os.path.join(BASEDIR, "common", "model.h") MODEL_HASH_PATH = os.path.join(BASEDIR, "sunnypilot", "models", "tests", "model_hash") VISION_ONNX_PATH = os.path.join(BASEDIR, "selfdrive", "modeld", "models", "driving_vision.onnx") -POLICY_ONNX_PATH = os.path.join(BASEDIR, "selfdrive", "modeld", "models", "driving_policy.onnx") +OFF_POLICY_ONNX_PATH = os.path.join(BASEDIR, "selfdrive", "modeld", "models", "driving_off_policy.onnx") +ON_POLICY_ONNX_PATH = os.path.join(BASEDIR, "selfdrive", "modeld", "models", "driving_on_policy.onnx") def update_model_hash(): vision_hash = get_file_hash(VISION_ONNX_PATH) - policy_hash = get_file_hash(POLICY_ONNX_PATH) + off_policy_hash = get_file_hash(OFF_POLICY_ONNX_PATH) + on_policy_hash = get_file_hash(ON_POLICY_ONNX_PATH) - combined_hash = hashlib.sha256((vision_hash + policy_hash).encode()).hexdigest() + combined_hash = hashlib.sha256((vision_hash + off_policy_hash + on_policy_hash).encode()).hexdigest() with open(MODEL_HASH_PATH, "w") as f: f.write(combined_hash) diff --git a/sunnypilot/models/fetcher.py b/sunnypilot/models/fetcher.py index 452c59e06b..0b6853da8a 100644 --- a/sunnypilot/models/fetcher.py +++ b/sunnypilot/models/fetcher.py @@ -116,7 +116,7 @@ class ModelCache: class ModelFetcher: """Handles fetching and caching of model data from remote source""" - MODEL_URL = "https://raw.githubusercontent.com/sunnypilot/sunnypilot-models/refs/heads/gh-pages/docs/driving_models_v15.json" + MODEL_URL = "https://raw.githubusercontent.com/sunnypilot/sunnypilot-models/refs/heads/gh-pages/docs/driving_models_v16.json" def __init__(self, params: Params): self.params = params diff --git a/sunnypilot/models/runners/helpers.py b/sunnypilot/models/runners/helpers.py index 8f9d8fc2f5..b34a62132b 100644 --- a/sunnypilot/models/runners/helpers.py +++ b/sunnypilot/models/runners/helpers.py @@ -16,8 +16,9 @@ def get_model_runner() -> ModelRunner: bundle = get_active_bundle() if bundle and bundle.models: model_types = {m.type.raw for m in bundle.models} - # Check if the bundle uses separate vision and policy models - if ModelType.vision in model_types or ModelType.policy in model_types: + # Check if the bundle uses separate vision and policy models (legacy or new split format) + split_types = {ModelType.vision, ModelType.policy, ModelType.offPolicy, ModelType.onPolicy} + if model_types & split_types: return TinygradSplitRunner() # Otherwise, assume a single model (likely supercombo) if bundle.models: diff --git a/sunnypilot/models/runners/tinygrad/model_types.py b/sunnypilot/models/runners/tinygrad/model_types.py index 11f0965828..015adc035f 100644 --- a/sunnypilot/models/runners/tinygrad/model_types.py +++ b/sunnypilot/models/runners/tinygrad/model_types.py @@ -29,6 +29,22 @@ class OffPolicyTinygrad(ModularRunner, ABC): return result +class OnPolicyTinygrad(ModularRunner, ABC): + """ + A TinygradRunner specialized for on-policy models. + + Uses a SplitParser to handle outputs specific to the on-policy part of a split model setup. + """ + def __init__(self): + self._on_policy_parser = SplitParser() + self.parser_method_dict[ModelType.onPolicy] = self._parse_on_policy_outputs + + def _parse_on_policy_outputs(self, model_outputs: np.ndarray) -> NumpyDict: + """Parses on-policy model outputs using SplitParser.""" + result: NumpyDict = self._on_policy_parser.parse_policy_outputs(self._slice_outputs(model_outputs)) + return result + + class PolicyTinygrad(ModularRunner, ABC): """ A TinygradRunner specialized for policy-only models. diff --git a/sunnypilot/models/runners/tinygrad/tinygrad_runner.py b/sunnypilot/models/runners/tinygrad/tinygrad_runner.py index 9033c892eb..4e17bd5ead 100644 --- a/sunnypilot/models/runners/tinygrad/tinygrad_runner.py +++ b/sunnypilot/models/runners/tinygrad/tinygrad_runner.py @@ -3,14 +3,14 @@ import pickle import numpy as np from openpilot.sunnypilot.models.runners.constants import NumpyDict, ModelType, ShapeDict, CUSTOM_MODEL_PATH, SliceDict from openpilot.sunnypilot.models.runners.model_runner import ModelRunner -from openpilot.sunnypilot.models.runners.tinygrad.model_types import PolicyTinygrad, VisionTinygrad, SupercomboTinygrad, OffPolicyTinygrad +from openpilot.sunnypilot.models.runners.tinygrad.model_types import PolicyTinygrad, VisionTinygrad, SupercomboTinygrad, OffPolicyTinygrad, OnPolicyTinygrad from openpilot.sunnypilot.models.split_model_constants import SplitModelConstants from openpilot.sunnypilot.modeld_v2.constants import ModelConstants from tinygrad.tensor import Tensor -class TinygradRunner(ModelRunner, SupercomboTinygrad, PolicyTinygrad, VisionTinygrad, OffPolicyTinygrad): +class TinygradRunner(ModelRunner, SupercomboTinygrad, PolicyTinygrad, VisionTinygrad, OffPolicyTinygrad, OnPolicyTinygrad): """ A ModelRunner implementation for executing Tinygrad models. @@ -26,6 +26,7 @@ class TinygradRunner(ModelRunner, SupercomboTinygrad, PolicyTinygrad, VisionTiny PolicyTinygrad.__init__(self) VisionTinygrad.__init__(self) OffPolicyTinygrad.__init__(self) + OnPolicyTinygrad.__init__(self) self._constants = ModelConstants self._model_data = self.models.get(model_type) if not self._model_data or not self._model_data.model: @@ -98,20 +99,30 @@ class TinygradSplitRunner(ModelRunner): super().__init__() self.is_20hz_3d = True self.vision_runner = TinygradRunner(ModelType.vision) - self.policy_runner = TinygradRunner(ModelType.policy) + self.policy_runner = TinygradRunner(ModelType.policy) if self.models.get(ModelType.policy) else None self.off_policy_runner = TinygradRunner(ModelType.offPolicy) if self.models.get(ModelType.offPolicy) else None + self.on_policy_runner = TinygradRunner(ModelType.onPolicy) if self.models.get(ModelType.onPolicy) else None self._constants = SplitModelConstants def _run_model(self) -> NumpyDict: """Runs both vision and policy models and merges their parsed outputs.""" - policy_output = self.policy_runner.run_model() vision_output = self.vision_runner.run_model() - outputs = {**policy_output, **vision_output} + outputs = {**vision_output} + + if self.policy_runner: + policy_output = self.policy_runner.run_model() + outputs.update(policy_output) if self.off_policy_runner: off_policy_output = self.off_policy_runner.run_model() + if self.on_policy_runner: + off_policy_output.pop('plan', None) outputs.update(off_policy_output) + if self.on_policy_runner: + on_policy_output = self.on_policy_runner.run_model() + outputs.update(on_policy_output) + if 'planplus' in outputs and 'plan' in outputs: outputs['plan'] = outputs['plan'] + outputs['planplus'] @@ -125,31 +136,44 @@ class TinygradSplitRunner(ModelRunner): @property def input_shapes(self) -> ShapeDict: """Returns the combined input shapes from both vision and policy models.""" - shapes = {**self.policy_runner.input_shapes, **self.vision_runner.input_shapes} + shapes = {**self.vision_runner.input_shapes} + if self.policy_runner: + shapes.update(self.policy_runner.input_shapes) if self.off_policy_runner: shapes.update(self.off_policy_runner.input_shapes) + if self.on_policy_runner: + shapes.update(self.on_policy_runner.input_shapes) return shapes @property def output_slices(self) -> SliceDict: """Returns the combined output slices from both vision and policy models.""" - slices = {**self.policy_runner.output_slices, **self.vision_runner.output_slices} + slices = {**self.vision_runner.output_slices} + if self.policy_runner: + slices.update(self.policy_runner.output_slices) if self.off_policy_runner: slices.update(self.off_policy_runner.output_slices) + if self.on_policy_runner: + slices.update(self.on_policy_runner.output_slices) return slices def prepare_inputs(self, numpy_inputs: NumpyDict) -> dict: """Prepares inputs for both vision and policy models.""" - # Policy inputs only depend on numpy_inputs - self.policy_runner.prepare_policy_inputs(numpy_inputs) + if self.policy_runner: + self.policy_runner.prepare_policy_inputs(numpy_inputs) for key in self.vision_input_names: if key in self.inputs: self.vision_runner.inputs[key] = self.inputs[key].cast(self.vision_runner.input_to_dtype[key]) - inputs = {**self.policy_runner.inputs, **self.vision_runner.inputs} + inputs = {**self.vision_runner.inputs} + if self.policy_runner: + inputs.update(self.policy_runner.inputs) if self.off_policy_runner: self.off_policy_runner.prepare_policy_inputs(numpy_inputs) inputs.update(self.off_policy_runner.inputs) + if self.on_policy_runner: + self.on_policy_runner.prepare_policy_inputs(numpy_inputs) + inputs.update(self.on_policy_runner.inputs) return inputs diff --git a/sunnypilot/models/tests/model_hash b/sunnypilot/models/tests/model_hash index f363f8309a..9a9ea5a968 100644 --- a/sunnypilot/models/tests/model_hash +++ b/sunnypilot/models/tests/model_hash @@ -1 +1 @@ -32f57bdc91f910df1f48ddae7c59aaf6e751f9df6756da481a210577dbce8bcf \ No newline at end of file +adfcb5ccac9cfaf291af6091d12e71be3f543c7694fc29d80caa561dc32194d7 diff --git a/sunnypilot/models/tests/test_default_model.py b/sunnypilot/models/tests/test_default_model.py index 7c2fde70a8..abe685c36a 100644 --- a/sunnypilot/models/tests/test_default_model.py +++ b/sunnypilot/models/tests/test_default_model.py @@ -6,16 +6,17 @@ See the LICENSE.md file in the root directory for more details. """ from openpilot.sunnypilot import get_file_hash -from openpilot.sunnypilot.models.default_model import MODEL_HASH_PATH, VISION_ONNX_PATH, POLICY_ONNX_PATH +from openpilot.sunnypilot.models.default_model import MODEL_HASH_PATH, VISION_ONNX_PATH, OFF_POLICY_ONNX_PATH, ON_POLICY_ONNX_PATH import hashlib class TestDefaultModel: def test_compare_onnx_hashes(self): vision_hash = get_file_hash(VISION_ONNX_PATH) - policy_hash = get_file_hash(POLICY_ONNX_PATH) + off_policy_hash = get_file_hash(OFF_POLICY_ONNX_PATH) + on_policy_hash = get_file_hash(ON_POLICY_ONNX_PATH) - combined_hash = hashlib.sha256((vision_hash + policy_hash).encode()).hexdigest() + combined_hash = hashlib.sha256((vision_hash + off_policy_hash + on_policy_hash).encode()).hexdigest() with open(MODEL_HASH_PATH) as f: current_hash = f.read().strip()