mirror of
https://github.com/sunnypilot/sunnypilot.git
synced 2026-02-18 13:04:01 +08:00
SCC-V: Use p97 for predicted lateral accel (#1635)
* SCC-V: Use p97 for predicted lateral accel * SCC-V: add param tests for p97 max_pred_lat_acc * fix * typo --------- Co-authored-by: Jason Wen <haibin.wen3@gmail.com>
This commit is contained in:
@@ -5,6 +5,7 @@ This file is part of sunnypilot and is licensed under the MIT License.
|
||||
See the LICENSE.md file in the root directory for more details.
|
||||
"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import cereal.messaging as messaging
|
||||
from cereal import custom, log
|
||||
@@ -12,11 +13,52 @@ from openpilot.common.params import Params
|
||||
from openpilot.common.realtime import DT_MDL
|
||||
from openpilot.selfdrive.car.cruise import V_CRUISE_UNSET
|
||||
from openpilot.selfdrive.modeld.constants import ModelConstants
|
||||
from openpilot.sunnypilot.selfdrive.controls.lib.smart_cruise_control.vision_controller import SmartCruiseControlVision
|
||||
from openpilot.sunnypilot.selfdrive.controls.lib.smart_cruise_control import MIN_V
|
||||
from openpilot.sunnypilot.selfdrive.controls.lib.smart_cruise_control.vision_controller import SmartCruiseControlVision, _ENTERING_PRED_LAT_ACC_TH
|
||||
|
||||
VisionState = custom.LongitudinalPlanSP.SmartCruiseControl.VisionState
|
||||
|
||||
|
||||
def _th_above_f32(th: float) -> float:
|
||||
"""
|
||||
Return the next representable float32 *above* `th`.
|
||||
This avoids flaky comparisons around thresholds due to float32 rounding.
|
||||
"""
|
||||
th32 = np.float32(th)
|
||||
above32 = np.nextafter(th32, np.float32(np.inf), dtype=np.float32)
|
||||
return float(above32)
|
||||
|
||||
|
||||
def _build_single_spike_filtered(n: int, base: float = 1.0) -> np.ndarray:
|
||||
"""
|
||||
Create an array where max() is >= threshold but p97 is < threshold.
|
||||
This demonstrates the behavior difference vs np.amax().
|
||||
|
||||
Note: We intentionally construct using float32-representable values to match
|
||||
the data path through cereal/capnp.
|
||||
"""
|
||||
th = float(_ENTERING_PRED_LAT_ACC_TH)
|
||||
th32 = float(np.float32(th))
|
||||
|
||||
# numpy percentile default is linear interpolation: idx=(n-1)*p/100
|
||||
idx = (n - 1) * 0.97
|
||||
w = float(idx - np.floor(idx))
|
||||
|
||||
base32 = float(np.float32(base))
|
||||
|
||||
# Choose spike so that p97 = base + w*(spike-base) < th
|
||||
# -> spike < base + (th-base)/w. Use a margin (0.9) and ensure spike >= th.
|
||||
if w == 0.0:
|
||||
spike = th32 + 1.0
|
||||
else:
|
||||
spike = base32 + (th32 - base32) / w * 0.9
|
||||
spike = max(spike, th32 + 0.01)
|
||||
|
||||
arr = np.full(n, base32, dtype=np.float32)
|
||||
arr[-1] = np.float32(spike)
|
||||
return arr
|
||||
|
||||
|
||||
def generate_modelV2():
|
||||
model = messaging.new_message('modelV2')
|
||||
position = log.XYZTData.new_message()
|
||||
@@ -101,4 +143,72 @@ class TestSmartCruiseControlVision:
|
||||
self.scc_v.update(self.sm, True, False, 0., 0., 0.)
|
||||
assert self.scc_v.state == VisionState.enabled
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case, should_enter",
|
||||
[
|
||||
("p97_just_above_threshold", True),
|
||||
("single_spike_filtered", False),
|
||||
("persistent_high_values", True),
|
||||
],
|
||||
ids=[
|
||||
"p97>threshold_enters",
|
||||
"single_spike_max_large_but_p97_below_threshold",
|
||||
"high_values_persist_trigger_entering",
|
||||
],
|
||||
)
|
||||
def test_max_pred_lat_acc_uses_p97_and_threshold(self, case, should_enter):
|
||||
n = len(ModelConstants.T_IDXS)
|
||||
th = float(_ENTERING_PRED_LAT_ACC_TH)
|
||||
|
||||
if case == "p97_just_above_threshold":
|
||||
# Use the next representable float32 above threshold to avoid float32 rounding flakiness.
|
||||
val = _th_above_f32(th)
|
||||
pred_lat_accels = np.full(n, np.float32(val), dtype=np.float32)
|
||||
|
||||
elif case == "single_spike_filtered":
|
||||
pred_lat_accels = _build_single_spike_filtered(n, base=1.0)
|
||||
|
||||
elif case == "persistent_high_values":
|
||||
# Make enough "high" samples so p97 is driven by the persistent trend, not a single outlier.
|
||||
high_count = max(2, int(np.ceil(n * 0.03)) + 1)
|
||||
pred_lat_accels = np.full(n, np.float32(1.0), dtype=np.float32)
|
||||
pred_lat_accels[-high_count:] = np.float32(2.0)
|
||||
pred_lat_accels[-1] = np.float32(8.0) # keep one big outlier too
|
||||
|
||||
else:
|
||||
raise AssertionError(f"Unknown case: {case}")
|
||||
|
||||
# Override model predictions so:
|
||||
# predicted_lat_accels = abs(orientationRate.z) * velocity.x == pred_lat_accels
|
||||
mdl = generate_modelV2()
|
||||
mdl.modelV2.velocity.x = [1.0 for _ in range(n)]
|
||||
mdl.modelV2.orientationRate.z = [float(x) for x in pred_lat_accels]
|
||||
self.sm["modelV2"] = mdl.modelV2
|
||||
|
||||
v_ego = float(MIN_V + 5.0)
|
||||
|
||||
# 1st update: disabled -> enabled
|
||||
self.scc_v.update(self.sm, True, False, v_ego, 0.0, 0.0)
|
||||
# 2nd update: evaluate entering condition from enabled state
|
||||
self.scc_v.update(self.sm, True, False, v_ego, 0.0, 0.0)
|
||||
|
||||
# Controller does percentile on numpy float64 arrays (values already quantized by capnp),
|
||||
# so compute expected in float64 to match behavior and avoid interpolation/rounding deltas.
|
||||
expected_p97 = float(np.percentile(pred_lat_accels.astype(np.float64), 97))
|
||||
|
||||
# allow tiny numeric differences due to float conversions/interpolation
|
||||
assert np.isclose(self.scc_v.max_pred_lat_acc, expected_p97, rtol=1e-6, atol=1e-5)
|
||||
|
||||
if should_enter:
|
||||
# We assert entering primarily by state (this is the actual intended behavior).
|
||||
assert self.scc_v.state == VisionState.entering
|
||||
# Optional sanity: should be >= threshold with some margin (since we used nextafter above threshold).
|
||||
assert self.scc_v.max_pred_lat_acc > th
|
||||
|
||||
else:
|
||||
# Difference vs np.amax(): max can be above threshold, but p97 stays below it.
|
||||
assert float(np.max(pred_lat_accels)) >= th
|
||||
assert self.scc_v.max_pred_lat_acc < th
|
||||
assert self.scc_v.state == VisionState.enabled
|
||||
|
||||
# TODO-SP: mock modelV2 data to test other states
|
||||
|
||||
@@ -90,7 +90,7 @@ class SmartCruiseControlVision:
|
||||
|
||||
# get the maximum lat accel from the model
|
||||
predicted_lat_accels = rate_plan * vel_plan
|
||||
self.max_pred_lat_acc = np.amax(predicted_lat_accels)
|
||||
self.max_pred_lat_acc = np.percentile(predicted_lat_accels, 97)
|
||||
|
||||
# get the maximum curve based on the current velocity
|
||||
v_ego = max(self.v_ego, 0.1) # ensure a value greater than 0 for calculations
|
||||
|
||||
Reference in New Issue
Block a user