mirror of
https://github.com/sunnypilot/sunnypilot.git
synced 2026-04-06 23:53:58 +08:00
sunnypilot models: support for on-policy models
This commit is contained in:
@@ -154,6 +154,7 @@ struct ModelManagerSP @0xaedffd8f31e7b55d {
|
|||||||
vision @2;
|
vision @2;
|
||||||
policy @3;
|
policy @3;
|
||||||
offPolicy @4;
|
offPolicy @4;
|
||||||
|
onPolicy @5;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
#define DEFAULT_MODEL "CD210 (Default)"
|
#define DEFAULT_MODEL "OP Model (Default)"
|
||||||
|
|||||||
@@ -104,7 +104,7 @@ def generate_metadata(model_path: Path, output_dir: Path, short_name: str):
|
|||||||
metadata_file = metadata_file.rename(output_path / f"{base}_{short_name.lower()}_metadata.pkl")
|
metadata_file = metadata_file.rename(output_path / f"{base}_{short_name.lower()}_metadata.pkl")
|
||||||
|
|
||||||
# Build the metadata structure
|
# Build the metadata structure
|
||||||
model_type = "offPolicy" if "off_policy" in base else base.split("_")[-1]
|
model_type = "offPolicy" if "off_policy" in base else "onPolicy" if "on_policy" in base else base.split("_")[-1]
|
||||||
|
|
||||||
model_metadata = {
|
model_metadata = {
|
||||||
"type": model_type,
|
"type": model_type,
|
||||||
|
|||||||
@@ -58,6 +58,8 @@ class ModelsLayout(Widget):
|
|||||||
self.supercombo_label = progress_item(tr("Driving Model"))
|
self.supercombo_label = progress_item(tr("Driving Model"))
|
||||||
self.vision_label = progress_item(tr("Vision Model"))
|
self.vision_label = progress_item(tr("Vision Model"))
|
||||||
self.policy_label = progress_item(tr("Policy Model"))
|
self.policy_label = progress_item(tr("Policy Model"))
|
||||||
|
self.off_policy_label = progress_item(tr("Off-Policy Model"))
|
||||||
|
self.on_policy_label = progress_item(tr("On-Policy Model"))
|
||||||
|
|
||||||
self.refresh_item = button_item(tr("Refresh Model List"), tr("REFRESH"), "",
|
self.refresh_item = button_item(tr("Refresh Model List"), tr("REFRESH"), "",
|
||||||
lambda: (ui_state.params.put("ModelManager_LastSyncTime", 0),
|
lambda: (ui_state.params.put("ModelManager_LastSyncTime", 0),
|
||||||
@@ -91,7 +93,7 @@ class ModelsLayout(Widget):
|
|||||||
self.lagd_toggle = toggle_item_sp(tr("Live Learning Steer Delay"), "", param="LagdToggle")
|
self.lagd_toggle = toggle_item_sp(tr("Live Learning Steer Delay"), "", param="LagdToggle")
|
||||||
|
|
||||||
self.items = [self.current_model_item, self.cancel_download_item, self.supercombo_label, self.vision_label,
|
self.items = [self.current_model_item, self.cancel_download_item, self.supercombo_label, self.vision_label,
|
||||||
self.policy_label, self.refresh_item, self.clear_cache_item, self.lane_turn_desire_toggle,
|
self.policy_label, self.off_policy_label, self.on_policy_label, self.refresh_item, self.clear_cache_item, self.lane_turn_desire_toggle,
|
||||||
self.lane_turn_value_control, self.lagd_toggle, self.delay_control]
|
self.lane_turn_value_control, self.lagd_toggle, self.delay_control]
|
||||||
|
|
||||||
def _update_lagd_description(self, lagd_toggle: bool):
|
def _update_lagd_description(self, lagd_toggle: bool):
|
||||||
@@ -129,7 +131,9 @@ class ModelsLayout(Widget):
|
|||||||
def _handle_bundle_download_progress(self):
|
def _handle_bundle_download_progress(self):
|
||||||
labels = {custom.ModelManagerSP.Model.Type.supercombo: self.supercombo_label,
|
labels = {custom.ModelManagerSP.Model.Type.supercombo: self.supercombo_label,
|
||||||
custom.ModelManagerSP.Model.Type.vision: self.vision_label,
|
custom.ModelManagerSP.Model.Type.vision: self.vision_label,
|
||||||
custom.ModelManagerSP.Model.Type.policy: self.policy_label}
|
custom.ModelManagerSP.Model.Type.policy: self.policy_label,
|
||||||
|
custom.ModelManagerSP.Model.Type.offPolicy: self.off_policy_label,
|
||||||
|
custom.ModelManagerSP.Model.Type.onPolicy: self.on_policy_label}
|
||||||
for label in labels.values():
|
for label in labels.values():
|
||||||
label.set_visible(False)
|
label.set_visible(False)
|
||||||
self.cancel_download_item.set_visible(False)
|
self.cancel_download_item.set_visible(False)
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ def generate_metadata_pkl(model_path, output_path):
|
|||||||
|
|
||||||
def install_models(model_dir):
|
def install_models(model_dir):
|
||||||
model_dir = Path(model_dir)
|
model_dir = Path(model_dir)
|
||||||
models = ["driving_off_policy", "driving_policy", "driving_vision"]
|
models = ["driving_off_policy", "driving_on_policy", "driving_vision"]
|
||||||
found_models = []
|
found_models = []
|
||||||
|
|
||||||
for model in models:
|
for model in models:
|
||||||
|
|||||||
@@ -8,14 +8,16 @@ from openpilot.sunnypilot import get_file_hash
|
|||||||
DEFAULT_MODEL_NAME_PATH = os.path.join(BASEDIR, "common", "model.h")
|
DEFAULT_MODEL_NAME_PATH = os.path.join(BASEDIR, "common", "model.h")
|
||||||
MODEL_HASH_PATH = os.path.join(BASEDIR, "sunnypilot", "models", "tests", "model_hash")
|
MODEL_HASH_PATH = os.path.join(BASEDIR, "sunnypilot", "models", "tests", "model_hash")
|
||||||
VISION_ONNX_PATH = os.path.join(BASEDIR, "selfdrive", "modeld", "models", "driving_vision.onnx")
|
VISION_ONNX_PATH = os.path.join(BASEDIR, "selfdrive", "modeld", "models", "driving_vision.onnx")
|
||||||
POLICY_ONNX_PATH = os.path.join(BASEDIR, "selfdrive", "modeld", "models", "driving_policy.onnx")
|
OFF_POLICY_ONNX_PATH = os.path.join(BASEDIR, "selfdrive", "modeld", "models", "driving_off_policy.onnx")
|
||||||
|
ON_POLICY_ONNX_PATH = os.path.join(BASEDIR, "selfdrive", "modeld", "models", "driving_on_policy.onnx")
|
||||||
|
|
||||||
|
|
||||||
def update_model_hash():
|
def update_model_hash():
|
||||||
vision_hash = get_file_hash(VISION_ONNX_PATH)
|
vision_hash = get_file_hash(VISION_ONNX_PATH)
|
||||||
policy_hash = get_file_hash(POLICY_ONNX_PATH)
|
off_policy_hash = get_file_hash(OFF_POLICY_ONNX_PATH)
|
||||||
|
on_policy_hash = get_file_hash(ON_POLICY_ONNX_PATH)
|
||||||
|
|
||||||
combined_hash = hashlib.sha256((vision_hash + policy_hash).encode()).hexdigest()
|
combined_hash = hashlib.sha256((vision_hash + off_policy_hash + on_policy_hash).encode()).hexdigest()
|
||||||
|
|
||||||
with open(MODEL_HASH_PATH, "w") as f:
|
with open(MODEL_HASH_PATH, "w") as f:
|
||||||
f.write(combined_hash)
|
f.write(combined_hash)
|
||||||
|
|||||||
@@ -116,7 +116,7 @@ class ModelCache:
|
|||||||
|
|
||||||
class ModelFetcher:
|
class ModelFetcher:
|
||||||
"""Handles fetching and caching of model data from remote source"""
|
"""Handles fetching and caching of model data from remote source"""
|
||||||
MODEL_URL = "https://raw.githubusercontent.com/sunnypilot/sunnypilot-models/refs/heads/gh-pages/docs/driving_models_v15.json"
|
MODEL_URL = "https://raw.githubusercontent.com/sunnypilot/sunnypilot-models/refs/heads/gh-pages/docs/driving_models_v16.json"
|
||||||
|
|
||||||
def __init__(self, params: Params):
|
def __init__(self, params: Params):
|
||||||
self.params = params
|
self.params = params
|
||||||
|
|||||||
@@ -16,8 +16,9 @@ def get_model_runner() -> ModelRunner:
|
|||||||
bundle = get_active_bundle()
|
bundle = get_active_bundle()
|
||||||
if bundle and bundle.models:
|
if bundle and bundle.models:
|
||||||
model_types = {m.type.raw for m in bundle.models}
|
model_types = {m.type.raw for m in bundle.models}
|
||||||
# Check if the bundle uses separate vision and policy models
|
# Check if the bundle uses separate vision and policy models (legacy or new split format)
|
||||||
if ModelType.vision in model_types or ModelType.policy in model_types:
|
split_types = {ModelType.vision, ModelType.policy, ModelType.offPolicy, ModelType.onPolicy}
|
||||||
|
if model_types & split_types:
|
||||||
return TinygradSplitRunner()
|
return TinygradSplitRunner()
|
||||||
# Otherwise, assume a single model (likely supercombo)
|
# Otherwise, assume a single model (likely supercombo)
|
||||||
if bundle.models:
|
if bundle.models:
|
||||||
|
|||||||
@@ -29,6 +29,22 @@ class OffPolicyTinygrad(ModularRunner, ABC):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class OnPolicyTinygrad(ModularRunner, ABC):
|
||||||
|
"""
|
||||||
|
A TinygradRunner specialized for on-policy models.
|
||||||
|
|
||||||
|
Uses a SplitParser to handle outputs specific to the on-policy part of a split model setup.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
self._on_policy_parser = SplitParser()
|
||||||
|
self.parser_method_dict[ModelType.onPolicy] = self._parse_on_policy_outputs
|
||||||
|
|
||||||
|
def _parse_on_policy_outputs(self, model_outputs: np.ndarray) -> NumpyDict:
|
||||||
|
"""Parses on-policy model outputs using SplitParser."""
|
||||||
|
result: NumpyDict = self._on_policy_parser.parse_policy_outputs(self._slice_outputs(model_outputs))
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class PolicyTinygrad(ModularRunner, ABC):
|
class PolicyTinygrad(ModularRunner, ABC):
|
||||||
"""
|
"""
|
||||||
A TinygradRunner specialized for policy-only models.
|
A TinygradRunner specialized for policy-only models.
|
||||||
|
|||||||
@@ -3,14 +3,14 @@ import pickle
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from openpilot.sunnypilot.models.runners.constants import NumpyDict, ModelType, ShapeDict, CUSTOM_MODEL_PATH, SliceDict
|
from openpilot.sunnypilot.models.runners.constants import NumpyDict, ModelType, ShapeDict, CUSTOM_MODEL_PATH, SliceDict
|
||||||
from openpilot.sunnypilot.models.runners.model_runner import ModelRunner
|
from openpilot.sunnypilot.models.runners.model_runner import ModelRunner
|
||||||
from openpilot.sunnypilot.models.runners.tinygrad.model_types import PolicyTinygrad, VisionTinygrad, SupercomboTinygrad, OffPolicyTinygrad
|
from openpilot.sunnypilot.models.runners.tinygrad.model_types import PolicyTinygrad, VisionTinygrad, SupercomboTinygrad, OffPolicyTinygrad, OnPolicyTinygrad
|
||||||
from openpilot.sunnypilot.models.split_model_constants import SplitModelConstants
|
from openpilot.sunnypilot.models.split_model_constants import SplitModelConstants
|
||||||
from openpilot.sunnypilot.modeld_v2.constants import ModelConstants
|
from openpilot.sunnypilot.modeld_v2.constants import ModelConstants
|
||||||
|
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
|
|
||||||
|
|
||||||
class TinygradRunner(ModelRunner, SupercomboTinygrad, PolicyTinygrad, VisionTinygrad, OffPolicyTinygrad):
|
class TinygradRunner(ModelRunner, SupercomboTinygrad, PolicyTinygrad, VisionTinygrad, OffPolicyTinygrad, OnPolicyTinygrad):
|
||||||
"""
|
"""
|
||||||
A ModelRunner implementation for executing Tinygrad models.
|
A ModelRunner implementation for executing Tinygrad models.
|
||||||
|
|
||||||
@@ -26,6 +26,7 @@ class TinygradRunner(ModelRunner, SupercomboTinygrad, PolicyTinygrad, VisionTiny
|
|||||||
PolicyTinygrad.__init__(self)
|
PolicyTinygrad.__init__(self)
|
||||||
VisionTinygrad.__init__(self)
|
VisionTinygrad.__init__(self)
|
||||||
OffPolicyTinygrad.__init__(self)
|
OffPolicyTinygrad.__init__(self)
|
||||||
|
OnPolicyTinygrad.__init__(self)
|
||||||
self._constants = ModelConstants
|
self._constants = ModelConstants
|
||||||
self._model_data = self.models.get(model_type)
|
self._model_data = self.models.get(model_type)
|
||||||
if not self._model_data or not self._model_data.model:
|
if not self._model_data or not self._model_data.model:
|
||||||
@@ -98,20 +99,30 @@ class TinygradSplitRunner(ModelRunner):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.is_20hz_3d = True
|
self.is_20hz_3d = True
|
||||||
self.vision_runner = TinygradRunner(ModelType.vision)
|
self.vision_runner = TinygradRunner(ModelType.vision)
|
||||||
self.policy_runner = TinygradRunner(ModelType.policy)
|
self.policy_runner = TinygradRunner(ModelType.policy) if self.models.get(ModelType.policy) else None
|
||||||
self.off_policy_runner = TinygradRunner(ModelType.offPolicy) if self.models.get(ModelType.offPolicy) else None
|
self.off_policy_runner = TinygradRunner(ModelType.offPolicy) if self.models.get(ModelType.offPolicy) else None
|
||||||
|
self.on_policy_runner = TinygradRunner(ModelType.onPolicy) if self.models.get(ModelType.onPolicy) else None
|
||||||
self._constants = SplitModelConstants
|
self._constants = SplitModelConstants
|
||||||
|
|
||||||
def _run_model(self) -> NumpyDict:
|
def _run_model(self) -> NumpyDict:
|
||||||
"""Runs both vision and policy models and merges their parsed outputs."""
|
"""Runs both vision and policy models and merges their parsed outputs."""
|
||||||
policy_output = self.policy_runner.run_model()
|
|
||||||
vision_output = self.vision_runner.run_model()
|
vision_output = self.vision_runner.run_model()
|
||||||
outputs = {**policy_output, **vision_output}
|
outputs = {**vision_output}
|
||||||
|
|
||||||
|
if self.policy_runner:
|
||||||
|
policy_output = self.policy_runner.run_model()
|
||||||
|
outputs.update(policy_output)
|
||||||
|
|
||||||
if self.off_policy_runner:
|
if self.off_policy_runner:
|
||||||
off_policy_output = self.off_policy_runner.run_model()
|
off_policy_output = self.off_policy_runner.run_model()
|
||||||
|
if self.on_policy_runner:
|
||||||
|
off_policy_output.pop('plan', None)
|
||||||
outputs.update(off_policy_output)
|
outputs.update(off_policy_output)
|
||||||
|
|
||||||
|
if self.on_policy_runner:
|
||||||
|
on_policy_output = self.on_policy_runner.run_model()
|
||||||
|
outputs.update(on_policy_output)
|
||||||
|
|
||||||
if 'planplus' in outputs and 'plan' in outputs:
|
if 'planplus' in outputs and 'plan' in outputs:
|
||||||
outputs['plan'] = outputs['plan'] + outputs['planplus']
|
outputs['plan'] = outputs['plan'] + outputs['planplus']
|
||||||
|
|
||||||
@@ -125,31 +136,44 @@ class TinygradSplitRunner(ModelRunner):
|
|||||||
@property
|
@property
|
||||||
def input_shapes(self) -> ShapeDict:
|
def input_shapes(self) -> ShapeDict:
|
||||||
"""Returns the combined input shapes from both vision and policy models."""
|
"""Returns the combined input shapes from both vision and policy models."""
|
||||||
shapes = {**self.policy_runner.input_shapes, **self.vision_runner.input_shapes}
|
shapes = {**self.vision_runner.input_shapes}
|
||||||
|
if self.policy_runner:
|
||||||
|
shapes.update(self.policy_runner.input_shapes)
|
||||||
if self.off_policy_runner:
|
if self.off_policy_runner:
|
||||||
shapes.update(self.off_policy_runner.input_shapes)
|
shapes.update(self.off_policy_runner.input_shapes)
|
||||||
|
if self.on_policy_runner:
|
||||||
|
shapes.update(self.on_policy_runner.input_shapes)
|
||||||
return shapes
|
return shapes
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_slices(self) -> SliceDict:
|
def output_slices(self) -> SliceDict:
|
||||||
"""Returns the combined output slices from both vision and policy models."""
|
"""Returns the combined output slices from both vision and policy models."""
|
||||||
slices = {**self.policy_runner.output_slices, **self.vision_runner.output_slices}
|
slices = {**self.vision_runner.output_slices}
|
||||||
|
if self.policy_runner:
|
||||||
|
slices.update(self.policy_runner.output_slices)
|
||||||
if self.off_policy_runner:
|
if self.off_policy_runner:
|
||||||
slices.update(self.off_policy_runner.output_slices)
|
slices.update(self.off_policy_runner.output_slices)
|
||||||
|
if self.on_policy_runner:
|
||||||
|
slices.update(self.on_policy_runner.output_slices)
|
||||||
return slices
|
return slices
|
||||||
|
|
||||||
def prepare_inputs(self, numpy_inputs: NumpyDict) -> dict:
|
def prepare_inputs(self, numpy_inputs: NumpyDict) -> dict:
|
||||||
"""Prepares inputs for both vision and policy models."""
|
"""Prepares inputs for both vision and policy models."""
|
||||||
# Policy inputs only depend on numpy_inputs
|
if self.policy_runner:
|
||||||
self.policy_runner.prepare_policy_inputs(numpy_inputs)
|
self.policy_runner.prepare_policy_inputs(numpy_inputs)
|
||||||
|
|
||||||
for key in self.vision_input_names:
|
for key in self.vision_input_names:
|
||||||
if key in self.inputs:
|
if key in self.inputs:
|
||||||
self.vision_runner.inputs[key] = self.inputs[key].cast(self.vision_runner.input_to_dtype[key])
|
self.vision_runner.inputs[key] = self.inputs[key].cast(self.vision_runner.input_to_dtype[key])
|
||||||
|
|
||||||
inputs = {**self.policy_runner.inputs, **self.vision_runner.inputs}
|
inputs = {**self.vision_runner.inputs}
|
||||||
|
if self.policy_runner:
|
||||||
|
inputs.update(self.policy_runner.inputs)
|
||||||
|
|
||||||
if self.off_policy_runner:
|
if self.off_policy_runner:
|
||||||
self.off_policy_runner.prepare_policy_inputs(numpy_inputs)
|
self.off_policy_runner.prepare_policy_inputs(numpy_inputs)
|
||||||
inputs.update(self.off_policy_runner.inputs)
|
inputs.update(self.off_policy_runner.inputs)
|
||||||
|
if self.on_policy_runner:
|
||||||
|
self.on_policy_runner.prepare_policy_inputs(numpy_inputs)
|
||||||
|
inputs.update(self.on_policy_runner.inputs)
|
||||||
return inputs
|
return inputs
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
32f57bdc91f910df1f48ddae7c59aaf6e751f9df6756da481a210577dbce8bcf
|
adfcb5ccac9cfaf291af6091d12e71be3f543c7694fc29d80caa561dc32194d7
|
||||||
|
|||||||
@@ -6,16 +6,17 @@ See the LICENSE.md file in the root directory for more details.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from openpilot.sunnypilot import get_file_hash
|
from openpilot.sunnypilot import get_file_hash
|
||||||
from openpilot.sunnypilot.models.default_model import MODEL_HASH_PATH, VISION_ONNX_PATH, POLICY_ONNX_PATH
|
from openpilot.sunnypilot.models.default_model import MODEL_HASH_PATH, VISION_ONNX_PATH, OFF_POLICY_ONNX_PATH, ON_POLICY_ONNX_PATH
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
|
|
||||||
class TestDefaultModel:
|
class TestDefaultModel:
|
||||||
def test_compare_onnx_hashes(self):
|
def test_compare_onnx_hashes(self):
|
||||||
vision_hash = get_file_hash(VISION_ONNX_PATH)
|
vision_hash = get_file_hash(VISION_ONNX_PATH)
|
||||||
policy_hash = get_file_hash(POLICY_ONNX_PATH)
|
off_policy_hash = get_file_hash(OFF_POLICY_ONNX_PATH)
|
||||||
|
on_policy_hash = get_file_hash(ON_POLICY_ONNX_PATH)
|
||||||
|
|
||||||
combined_hash = hashlib.sha256((vision_hash + policy_hash).encode()).hexdigest()
|
combined_hash = hashlib.sha256((vision_hash + off_policy_hash + on_policy_hash).encode()).hexdigest()
|
||||||
|
|
||||||
with open(MODEL_HASH_PATH) as f:
|
with open(MODEL_HASH_PATH) as f:
|
||||||
current_hash = f.read().strip()
|
current_hash = f.read().strip()
|
||||||
|
|||||||
Reference in New Issue
Block a user