mirror of https://github.com/commaai/openpilot.git
Clean up model_publish args to simplify cython bindings (#29203)
* Clean up model_publish args to simplify cython bindings * pass by reference * Move FCW and model confidence queues into PublishState
This commit is contained in:
parent
feaad4ce42
commit
663fc0d8fe
|
@ -63,6 +63,7 @@ void run_model(ModelState &model, VisionIpcClient &vipc_client_main, VisionIpcCl
|
|||
SubMaster sm({"lateralPlan", "roadCameraState", "liveCalibration", "driverMonitoringState", "navModel"});
|
||||
|
||||
Params params;
|
||||
PublishState ps = {};
|
||||
|
||||
// setup filter to track dropped frames
|
||||
FirstOrderFilter frame_dropped_filter(0., 10., 1. / MODEL_FREQ);
|
||||
|
@ -179,7 +180,7 @@ void run_model(ModelState &model, VisionIpcClient &vipc_client_main, VisionIpcCl
|
|||
float model_execution_time = (mt2 - mt1) / 1000.0;
|
||||
|
||||
if (model_output != nullptr) {
|
||||
model_publish(&model, pm, meta_main.frame_id, meta_extra.frame_id, frame_id, frame_drop_ratio, *model_output, meta_main.timestamp_eof, timestamp_llk, model_execution_time,
|
||||
model_publish(pm, meta_main.frame_id, meta_extra.frame_id, frame_id, frame_drop_ratio, *model_output, model, ps, meta_main.timestamp_eof, timestamp_llk, model_execution_time,
|
||||
nav_enabled, live_calib_seen);
|
||||
posenet_publish(pm, meta_main.frame_id, vipc_dropped_frames, *model_output, meta_main.timestamp_eof, live_calib_seen);
|
||||
}
|
||||
|
|
|
@ -13,12 +13,6 @@
|
|||
#include "common/timing.h"
|
||||
#include "common/swaglog.h"
|
||||
|
||||
constexpr float FCW_THRESHOLD_5MS2_HIGH = 0.15;
|
||||
constexpr float FCW_THRESHOLD_5MS2_LOW = 0.05;
|
||||
constexpr float FCW_THRESHOLD_3MS2 = 0.7;
|
||||
|
||||
std::array<float, 5> prev_brake_5ms2_probs = {0,0,0,0,0};
|
||||
std::array<float, 3> prev_brake_3ms2_probs = {0,0,0};
|
||||
|
||||
// #define DUMP_YUV
|
||||
|
||||
|
@ -152,7 +146,7 @@ void fill_lead(cereal::ModelDataV2::LeadDataV3::Builder lead, const ModelOutputL
|
|||
lead.setAStd(to_kj_array_ptr(lead_a_std));
|
||||
}
|
||||
|
||||
void fill_meta(cereal::ModelDataV2::MetaData::Builder meta, const ModelOutputMeta &meta_data) {
|
||||
void fill_meta(cereal::ModelDataV2::MetaData::Builder meta, const ModelOutputMeta &meta_data, PublishState &ps) {
|
||||
std::array<float, DESIRE_LEN> desire_state_softmax;
|
||||
softmax(meta_data.desire_state_prob.array.data(), desire_state_softmax.data(), DESIRE_LEN);
|
||||
|
||||
|
@ -174,18 +168,18 @@ void fill_meta(cereal::ModelDataV2::MetaData::Builder meta, const ModelOutputMet
|
|||
//gas_pressed_sigmoid[i] = sigmoid(meta_data.disengage_prob[i].gas_pressed);
|
||||
}
|
||||
|
||||
std::memmove(prev_brake_5ms2_probs.data(), &prev_brake_5ms2_probs[1], 4*sizeof(float));
|
||||
std::memmove(prev_brake_3ms2_probs.data(), &prev_brake_3ms2_probs[1], 2*sizeof(float));
|
||||
prev_brake_5ms2_probs[4] = brake_5ms2_sigmoid[0];
|
||||
prev_brake_3ms2_probs[2] = brake_3ms2_sigmoid[0];
|
||||
std::memmove(ps.prev_brake_5ms2_probs.data(), &ps.prev_brake_5ms2_probs[1], 4*sizeof(float));
|
||||
std::memmove(ps.prev_brake_3ms2_probs.data(), &ps.prev_brake_3ms2_probs[1], 2*sizeof(float));
|
||||
ps.prev_brake_5ms2_probs[4] = brake_5ms2_sigmoid[0];
|
||||
ps.prev_brake_3ms2_probs[2] = brake_3ms2_sigmoid[0];
|
||||
|
||||
bool above_fcw_threshold = true;
|
||||
for (int i=0; i<prev_brake_5ms2_probs.size(); i++) {
|
||||
for (int i=0; i<ps.prev_brake_5ms2_probs.size(); i++) {
|
||||
float threshold = i < 2 ? FCW_THRESHOLD_5MS2_LOW : FCW_THRESHOLD_5MS2_HIGH;
|
||||
above_fcw_threshold = above_fcw_threshold && prev_brake_5ms2_probs[i] > threshold;
|
||||
above_fcw_threshold = above_fcw_threshold && ps.prev_brake_5ms2_probs[i] > threshold;
|
||||
}
|
||||
for (int i=0; i<prev_brake_3ms2_probs.size(); i++) {
|
||||
above_fcw_threshold = above_fcw_threshold && prev_brake_3ms2_probs[i] > FCW_THRESHOLD_3MS2;
|
||||
for (int i=0; i<ps.prev_brake_3ms2_probs.size(); i++) {
|
||||
above_fcw_threshold = above_fcw_threshold && ps.prev_brake_3ms2_probs[i] > FCW_THRESHOLD_3MS2;
|
||||
}
|
||||
|
||||
auto disengage = meta.initDisengagePredictions();
|
||||
|
@ -203,7 +197,7 @@ void fill_meta(cereal::ModelDataV2::MetaData::Builder meta, const ModelOutputMet
|
|||
meta.setHardBrakePredicted(above_fcw_threshold);
|
||||
}
|
||||
|
||||
void fill_confidence(ModelState* s, cereal::ModelDataV2::Builder &framed) {
|
||||
void fill_confidence(cereal::ModelDataV2::Builder &framed, PublishState &ps) {
|
||||
if (framed.getFrameId() % (2*MODEL_FREQ) == 0) {
|
||||
// update every 2s to match predictions interval
|
||||
auto dbps = framed.getMeta().getDisengagePredictions().getBrakeDisengageProbs();
|
||||
|
@ -223,13 +217,13 @@ void fill_confidence(ModelState* s, cereal::ModelDataV2::Builder &framed) {
|
|||
}
|
||||
|
||||
// rolling buf for 2, 4, 6, 8, 10s
|
||||
std::memmove(&s->disengage_buffer[0], &s->disengage_buffer[DISENGAGE_LEN], sizeof(float) * DISENGAGE_LEN * (DISENGAGE_LEN-1));
|
||||
std::memcpy(&s->disengage_buffer[DISENGAGE_LEN * (DISENGAGE_LEN-1)], &dp_ind[0], sizeof(float) * DISENGAGE_LEN);
|
||||
std::memmove(&ps.disengage_buffer[0], &ps.disengage_buffer[DISENGAGE_LEN], sizeof(float) * DISENGAGE_LEN * (DISENGAGE_LEN-1));
|
||||
std::memcpy(&ps.disengage_buffer[DISENGAGE_LEN * (DISENGAGE_LEN-1)], &dp_ind[0], sizeof(float) * DISENGAGE_LEN);
|
||||
}
|
||||
|
||||
float score = 0;
|
||||
for (int i = 0; i < DISENGAGE_LEN; i++) {
|
||||
score += s->disengage_buffer[i*DISENGAGE_LEN+DISENGAGE_LEN-1-i] / DISENGAGE_LEN;
|
||||
score += ps.disengage_buffer[i*DISENGAGE_LEN+DISENGAGE_LEN-1-i] / DISENGAGE_LEN;
|
||||
}
|
||||
|
||||
if (score < RYG_GREEN) {
|
||||
|
@ -355,7 +349,7 @@ void fill_road_edges(cereal::ModelDataV2::Builder &framed, const std::array<floa
|
|||
});
|
||||
}
|
||||
|
||||
void fill_model(ModelState* s, cereal::ModelDataV2::Builder &framed, const ModelOutput &net_outputs) {
|
||||
void fill_model(cereal::ModelDataV2::Builder &framed, const ModelOutput &net_outputs, PublishState &ps) {
|
||||
const auto &best_plan = net_outputs.plans.get_best_prediction();
|
||||
std::array<float, TRAJECTORY_SIZE> plan_t;
|
||||
std::fill_n(plan_t.data(), plan_t.size(), NAN);
|
||||
|
@ -383,10 +377,10 @@ void fill_model(ModelState* s, cereal::ModelDataV2::Builder &framed, const Model
|
|||
fill_road_edges(framed, plan_t, net_outputs.road_edges);
|
||||
|
||||
// meta
|
||||
fill_meta(framed.initMeta(), net_outputs.meta);
|
||||
fill_meta(framed.initMeta(), net_outputs.meta, ps);
|
||||
|
||||
// confidence
|
||||
fill_confidence(s, framed);
|
||||
fill_confidence(framed, ps);
|
||||
|
||||
// leads
|
||||
auto leads = framed.initLeadsV3(LEAD_MHP_SELECTION);
|
||||
|
@ -407,8 +401,8 @@ void fill_model(ModelState* s, cereal::ModelDataV2::Builder &framed, const Model
|
|||
temporal_pose.setRotStd({exp(r_std.x), exp(r_std.y), exp(r_std.z)});
|
||||
}
|
||||
|
||||
void model_publish(ModelState* s, PubMaster &pm, uint32_t vipc_frame_id, uint32_t vipc_frame_id_extra, uint32_t frame_id, float frame_drop,
|
||||
const ModelOutput &net_outputs, uint64_t timestamp_eof, uint64_t timestamp_llk,
|
||||
void model_publish(PubMaster &pm, uint32_t vipc_frame_id, uint32_t vipc_frame_id_extra, uint32_t frame_id, float frame_drop,
|
||||
const ModelOutput &net_outputs, ModelState &s, PublishState &ps, uint64_t timestamp_eof, uint64_t timestamp_llk,
|
||||
float model_execution_time, const bool nav_enabled, const bool valid) {
|
||||
const uint32_t frame_age = (frame_id > vipc_frame_id) ? (frame_id - vipc_frame_id) : 0;
|
||||
MessageBuilder msg;
|
||||
|
@ -422,9 +416,9 @@ void model_publish(ModelState* s, PubMaster &pm, uint32_t vipc_frame_id, uint32_
|
|||
framed.setModelExecutionTime(model_execution_time);
|
||||
framed.setNavEnabled(nav_enabled);
|
||||
if (send_raw_pred) {
|
||||
framed.setRawPredictions((kj::ArrayPtr<const float>(s->output.data(), s->output.size())).asBytes());
|
||||
framed.setRawPredictions((kj::ArrayPtr<const float>(s.output.data(), s.output.size())).asBytes());
|
||||
}
|
||||
fill_model(s, framed, net_outputs);
|
||||
fill_model(framed, net_outputs, ps);
|
||||
pm.send("modelV2", msg);
|
||||
}
|
||||
|
||||
|
|
|
@ -39,6 +39,10 @@ constexpr int LEAD_MHP_SELECTION = 3;
|
|||
// Padding to get output shape as multiple of 4
|
||||
constexpr int PAD_SIZE = 2;
|
||||
|
||||
constexpr float FCW_THRESHOLD_5MS2_HIGH = 0.15;
|
||||
constexpr float FCW_THRESHOLD_5MS2_LOW = 0.05;
|
||||
constexpr float FCW_THRESHOLD_3MS2 = 0.7;
|
||||
|
||||
struct ModelOutputXYZ {
|
||||
float x;
|
||||
float y;
|
||||
|
@ -262,7 +266,6 @@ struct ModelState {
|
|||
ModelFrame *frame = nullptr;
|
||||
ModelFrame *wide_frame = nullptr;
|
||||
std::array<float, HISTORY_BUFFER_LEN * FEATURE_LEN> feature_buffer = {};
|
||||
std::array<float, DISENGAGE_LEN * DISENGAGE_LEN> disengage_buffer = {};
|
||||
std::array<float, NET_OUTPUT_SIZE> output = {};
|
||||
std::unique_ptr<RunModel> m;
|
||||
#ifdef DESIRE
|
||||
|
@ -280,12 +283,18 @@ struct ModelState {
|
|||
#endif
|
||||
};
|
||||
|
||||
struct PublishState {
|
||||
std::array<float, DISENGAGE_LEN * DISENGAGE_LEN> disengage_buffer = {};
|
||||
std::array<float, 5> prev_brake_5ms2_probs = {};
|
||||
std::array<float, 3> prev_brake_3ms2_probs = {};
|
||||
};
|
||||
|
||||
void model_init(ModelState* s, cl_device_id device_id, cl_context context);
|
||||
ModelOutput *model_eval_frame(ModelState* s, VisionBuf* buf, VisionBuf* buf_wide,
|
||||
const mat3 &transform, const mat3 &transform_wide, float *desire_in, bool is_rhd, float *driving_style, float *nav_features, bool prepare_only);
|
||||
void model_free(ModelState* s);
|
||||
void model_publish(ModelState* s, PubMaster &pm, uint32_t vipc_frame_id, uint32_t vipc_frame_id_extra, uint32_t frame_id, float frame_drop,
|
||||
const ModelOutput &net_outputs, uint64_t timestamp_eof, uint64_t timestamp_llk,
|
||||
void model_publish(PubMaster &pm, uint32_t vipc_frame_id, uint32_t vipc_frame_id_extra, uint32_t frame_id, float frame_drop,
|
||||
const ModelOutput &net_outputs, ModelState &s, PublishState &ps, uint64_t timestamp_eof, uint64_t timestamp_llk,
|
||||
float model_execution_time, const bool nav_enabled, const bool valid);
|
||||
void posenet_publish(PubMaster &pm, uint32_t vipc_frame_id, uint32_t vipc_dropped_frames,
|
||||
const ModelOutput &net_outputs, uint64_t timestamp_eof, const bool valid);
|
||||
|
|
Loading…
Reference in New Issue