mirror of https://github.com/commaai/rednose.git
83 lines
2.3 KiB
Python
83 lines
2.3 KiB
Python
import os
|
|
import numpy as np
|
|
import unittest
|
|
|
|
from kinematic_kf import KinematicKalman, ObservationKind, States # pylint: disable=import-error
|
|
|
|
GENERATED_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), 'generated'))
|
|
|
|
class TestKinematic(unittest.TestCase):
|
|
def test_kinematic_kf(self):
|
|
np.random.seed(0)
|
|
|
|
kf = KinematicKalman(GENERATED_DIR)
|
|
|
|
# Simple simulation
|
|
dt = 0.01
|
|
ts = np.arange(0, 5, step=dt)
|
|
vs = np.sin(ts * 5)
|
|
|
|
x = 0.0
|
|
xs = []
|
|
|
|
xs_meas = []
|
|
|
|
xs_kf = []
|
|
vs_kf = []
|
|
|
|
xs_kf_std = []
|
|
vs_kf_std = []
|
|
|
|
for t, v in zip(ts, vs):
|
|
xs.append(x)
|
|
|
|
# Update kf
|
|
meas = np.random.normal(x, 0.1)
|
|
xs_meas.append(meas)
|
|
kf.predict_and_observe(t, ObservationKind.POSITION, [meas])
|
|
|
|
# Retrieve kf values
|
|
state = kf.x
|
|
xs_kf.append(float(state[States.POSITION]))
|
|
vs_kf.append(float(state[States.VELOCITY]))
|
|
std = np.sqrt(kf.P)
|
|
xs_kf_std.append(float(std[States.POSITION, States.POSITION]))
|
|
vs_kf_std.append(float(std[States.VELOCITY, States.VELOCITY]))
|
|
|
|
# Update simulation
|
|
x += v * dt
|
|
|
|
xs, xs_meas, xs_kf, vs_kf, xs_kf_std, vs_kf_std = [np.asarray(a) for a in (xs, xs_meas, xs_kf, vs_kf, xs_kf_std, vs_kf_std)]
|
|
|
|
self.assertAlmostEqual(xs_kf[-1], -0.010866289677966417)
|
|
self.assertAlmostEqual(xs_kf_std[-1], 0.04477103863330089)
|
|
self.assertAlmostEqual(vs_kf[-1], -0.8553720537261753)
|
|
self.assertAlmostEqual(vs_kf_std[-1], 0.6695762270974388)
|
|
|
|
if "PLOT" in os.environ:
|
|
import matplotlib.pyplot as plt # pylint: disable=import-error
|
|
plt.figure()
|
|
plt.subplot(2, 1, 1)
|
|
plt.plot(ts, xs, 'k', label='Simulation')
|
|
plt.plot(ts, xs_meas, 'k.', label='Measurements')
|
|
plt.plot(ts, xs_kf, label='KF')
|
|
ax = plt.gca()
|
|
ax.fill_between(ts, xs_kf - xs_kf_std, xs_kf + xs_kf_std, alpha=.2, color='C0')
|
|
|
|
plt.xlabel("Time [s]")
|
|
plt.ylabel("Position [m]")
|
|
plt.legend()
|
|
|
|
plt.subplot(2, 1, 2)
|
|
plt.plot(ts, vs, 'k', label='Simulation')
|
|
plt.plot(ts, vs_kf, label='KF')
|
|
|
|
ax = plt.gca()
|
|
ax.fill_between(ts, vs_kf - vs_kf_std, vs_kf + vs_kf_std, alpha=.2, color='C0')
|
|
|
|
plt.xlabel("Time [s]")
|
|
plt.ylabel("Velocity [m/s]")
|
|
plt.legend()
|
|
|
|
plt.show()
|