mirror of
https://github.com/dragonpilot/dragonpilot.git
synced 2026-02-24 06:03:54 +08:00
* 1456d261-d232-4654-8885-4d9fde883894/440 6b7d7cec-ead8-40f3-86cc-86d52c9b03fe/300 * compute only 9 tokens: 1456d261-d232-4654-8885-4d9fde883894/440 6b7d7cec-ead8-40f3-86cc-86d52c9b03fe/300 * tinygrad: cleanup gather * 1456d261-d232-4654-8885-4d9fde883894/440 6b7d7cec-ead8-40f3-86cc-86d52c9b03fe/700 * empty commit for tests * bump tinygrad * dont use tinygrad matmul for now * bump tinygrad * 1456d261-d232-4654-8885-4d9fde883894/440 e63ab895-2222-4abd-a9a5-af86bb70e260/700 * float16 1456d261-d232-4654-8885-4d9fde883894/440 e63ab895-2222-4abd-a9a5-af86bb70e260/700 * increase steer rate cost * Revert "increase steer rate cost" This reverts commit 74ce9ab9be7ef17ecfec931f96851b12f37f2336. * fork tinygrad * empty commit for tests * basics * Kinda works * new lat * new tuning * Move LATMPCN so scons compiles * Update long weights * Add tinygrad optim * Update model ref * update weights * Update ref * Try * Error message for field ignore * update model regf * ref commit * Fix onnx test Co-authored-by: Yassine Yousfi <yyousfi1@binghamton.edu>
89 lines
2.8 KiB
Python
89 lines
2.8 KiB
Python
import unittest
|
|
import numpy as np
|
|
from selfdrive.controls.lib.lateral_mpc_lib.lat_mpc import LateralMpc
|
|
from selfdrive.controls.lib.drive_helpers import CAR_ROTATION_RADIUS
|
|
from selfdrive.controls.lib.lateral_mpc_lib.lat_mpc import N as LAT_MPC_N
|
|
|
|
|
|
def run_mpc(lat_mpc=None, v_ref=30., x_init=0., y_init=0., psi_init=0., curvature_init=0.,
|
|
lane_width=3.6, poly_shift=0.):
|
|
|
|
if lat_mpc is None:
|
|
lat_mpc = LateralMpc()
|
|
lat_mpc.set_weights(1., 1., 0.0, 1.)
|
|
|
|
y_pts = poly_shift * np.ones(LAT_MPC_N + 1)
|
|
heading_pts = np.zeros(LAT_MPC_N + 1)
|
|
curv_rate_pts = np.zeros(LAT_MPC_N + 1)
|
|
|
|
x0 = np.array([x_init, y_init, psi_init, curvature_init])
|
|
p = np.array([v_ref, CAR_ROTATION_RADIUS])
|
|
|
|
# converge in no more than 10 iterations
|
|
for _ in range(10):
|
|
lat_mpc.run(x0, p,
|
|
y_pts, heading_pts, curv_rate_pts)
|
|
return lat_mpc.x_sol
|
|
|
|
|
|
class TestLateralMpc(unittest.TestCase):
|
|
|
|
def _assert_null(self, sol, curvature=1e-6):
|
|
for i in range(len(sol)):
|
|
self.assertAlmostEqual(sol[0,i,1], 0., delta=curvature)
|
|
self.assertAlmostEqual(sol[0,i,2], 0., delta=curvature)
|
|
self.assertAlmostEqual(sol[0,i,3], 0., delta=curvature)
|
|
|
|
def _assert_simmetry(self, sol, curvature=1e-6):
|
|
for i in range(len(sol)):
|
|
self.assertAlmostEqual(sol[0,i,1], -sol[1,i,1], delta=curvature)
|
|
self.assertAlmostEqual(sol[0,i,2], -sol[1,i,2], delta=curvature)
|
|
self.assertAlmostEqual(sol[0,i,3], -sol[1,i,3], delta=curvature)
|
|
self.assertAlmostEqual(sol[0,i,0], sol[1,i,0], delta=curvature)
|
|
|
|
def test_straight(self):
|
|
sol = run_mpc()
|
|
self._assert_null(np.array([sol]))
|
|
|
|
def test_y_symmetry(self):
|
|
sol = []
|
|
for y_init in [-0.5, 0.5]:
|
|
sol.append(run_mpc(y_init=y_init))
|
|
self._assert_simmetry(np.array(sol))
|
|
|
|
def test_poly_symmetry(self):
|
|
sol = []
|
|
for poly_shift in [-1., 1.]:
|
|
sol.append(run_mpc(poly_shift=poly_shift))
|
|
self._assert_simmetry(np.array(sol))
|
|
|
|
def test_curvature_symmetry(self):
|
|
sol = []
|
|
for curvature_init in [-0.1, 0.1]:
|
|
sol.append(run_mpc(curvature_init=curvature_init))
|
|
self._assert_simmetry(np.array(sol))
|
|
|
|
def test_psi_symmetry(self):
|
|
sol = []
|
|
for psi_init in [-0.1, 0.1]:
|
|
sol.append(run_mpc(psi_init=psi_init))
|
|
self._assert_simmetry(np.array(sol))
|
|
|
|
def test_no_overshoot(self):
|
|
y_init = 1.
|
|
sol = run_mpc(y_init=y_init)
|
|
for y in list(sol[:,1]):
|
|
self.assertGreaterEqual(y_init, abs(y))
|
|
|
|
def test_switch_convergence(self):
|
|
lat_mpc = LateralMpc()
|
|
sol = run_mpc(lat_mpc=lat_mpc, poly_shift=30.0, v_ref=7.0)
|
|
right_psi_deg = np.degrees(sol[:,2])
|
|
sol = run_mpc(lat_mpc=lat_mpc, poly_shift=-30.0, v_ref=7.0)
|
|
left_psi_deg = np.degrees(sol[:,2])
|
|
np.testing.assert_almost_equal(right_psi_deg, -left_psi_deg, decimal=3)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|