Clean up model_publish args to simplify cython bindings (#29203)

* Clean up model_publish args to simplify cython bindings

* pass by reference

* Move FCW and model confidence queues into PublishState
This commit is contained in:
Mitchell Goff 2023-08-01 19:23:18 -07:00 committed by GitHub
parent feaad4ce42
commit 663fc0d8fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 34 additions and 30 deletions

View File

@ -63,6 +63,7 @@ void run_model(ModelState &model, VisionIpcClient &vipc_client_main, VisionIpcCl
SubMaster sm({"lateralPlan", "roadCameraState", "liveCalibration", "driverMonitoringState", "navModel"});
Params params;
PublishState ps = {};
// setup filter to track dropped frames
FirstOrderFilter frame_dropped_filter(0., 10., 1. / MODEL_FREQ);
@ -179,7 +180,7 @@ void run_model(ModelState &model, VisionIpcClient &vipc_client_main, VisionIpcCl
float model_execution_time = (mt2 - mt1) / 1000.0;
if (model_output != nullptr) {
model_publish(&model, pm, meta_main.frame_id, meta_extra.frame_id, frame_id, frame_drop_ratio, *model_output, meta_main.timestamp_eof, timestamp_llk, model_execution_time,
model_publish(pm, meta_main.frame_id, meta_extra.frame_id, frame_id, frame_drop_ratio, *model_output, model, ps, meta_main.timestamp_eof, timestamp_llk, model_execution_time,
nav_enabled, live_calib_seen);
posenet_publish(pm, meta_main.frame_id, vipc_dropped_frames, *model_output, meta_main.timestamp_eof, live_calib_seen);
}

View File

@ -13,12 +13,6 @@
#include "common/timing.h"
#include "common/swaglog.h"
constexpr float FCW_THRESHOLD_5MS2_HIGH = 0.15;
constexpr float FCW_THRESHOLD_5MS2_LOW = 0.05;
constexpr float FCW_THRESHOLD_3MS2 = 0.7;
std::array<float, 5> prev_brake_5ms2_probs = {0,0,0,0,0};
std::array<float, 3> prev_brake_3ms2_probs = {0,0,0};
// #define DUMP_YUV
@ -152,7 +146,7 @@ void fill_lead(cereal::ModelDataV2::LeadDataV3::Builder lead, const ModelOutputL
lead.setAStd(to_kj_array_ptr(lead_a_std));
}
void fill_meta(cereal::ModelDataV2::MetaData::Builder meta, const ModelOutputMeta &meta_data) {
void fill_meta(cereal::ModelDataV2::MetaData::Builder meta, const ModelOutputMeta &meta_data, PublishState &ps) {
std::array<float, DESIRE_LEN> desire_state_softmax;
softmax(meta_data.desire_state_prob.array.data(), desire_state_softmax.data(), DESIRE_LEN);
@ -174,18 +168,18 @@ void fill_meta(cereal::ModelDataV2::MetaData::Builder meta, const ModelOutputMet
//gas_pressed_sigmoid[i] = sigmoid(meta_data.disengage_prob[i].gas_pressed);
}
std::memmove(prev_brake_5ms2_probs.data(), &prev_brake_5ms2_probs[1], 4*sizeof(float));
std::memmove(prev_brake_3ms2_probs.data(), &prev_brake_3ms2_probs[1], 2*sizeof(float));
prev_brake_5ms2_probs[4] = brake_5ms2_sigmoid[0];
prev_brake_3ms2_probs[2] = brake_3ms2_sigmoid[0];
std::memmove(ps.prev_brake_5ms2_probs.data(), &ps.prev_brake_5ms2_probs[1], 4*sizeof(float));
std::memmove(ps.prev_brake_3ms2_probs.data(), &ps.prev_brake_3ms2_probs[1], 2*sizeof(float));
ps.prev_brake_5ms2_probs[4] = brake_5ms2_sigmoid[0];
ps.prev_brake_3ms2_probs[2] = brake_3ms2_sigmoid[0];
bool above_fcw_threshold = true;
for (int i=0; i<prev_brake_5ms2_probs.size(); i++) {
for (int i=0; i<ps.prev_brake_5ms2_probs.size(); i++) {
float threshold = i < 2 ? FCW_THRESHOLD_5MS2_LOW : FCW_THRESHOLD_5MS2_HIGH;
above_fcw_threshold = above_fcw_threshold && prev_brake_5ms2_probs[i] > threshold;
above_fcw_threshold = above_fcw_threshold && ps.prev_brake_5ms2_probs[i] > threshold;
}
for (int i=0; i<prev_brake_3ms2_probs.size(); i++) {
above_fcw_threshold = above_fcw_threshold && prev_brake_3ms2_probs[i] > FCW_THRESHOLD_3MS2;
for (int i=0; i<ps.prev_brake_3ms2_probs.size(); i++) {
above_fcw_threshold = above_fcw_threshold && ps.prev_brake_3ms2_probs[i] > FCW_THRESHOLD_3MS2;
}
auto disengage = meta.initDisengagePredictions();
@ -203,7 +197,7 @@ void fill_meta(cereal::ModelDataV2::MetaData::Builder meta, const ModelOutputMet
meta.setHardBrakePredicted(above_fcw_threshold);
}
void fill_confidence(ModelState* s, cereal::ModelDataV2::Builder &framed) {
void fill_confidence(cereal::ModelDataV2::Builder &framed, PublishState &ps) {
if (framed.getFrameId() % (2*MODEL_FREQ) == 0) {
// update every 2s to match predictions interval
auto dbps = framed.getMeta().getDisengagePredictions().getBrakeDisengageProbs();
@ -223,13 +217,13 @@ void fill_confidence(ModelState* s, cereal::ModelDataV2::Builder &framed) {
}
// rolling buf for 2, 4, 6, 8, 10s
std::memmove(&s->disengage_buffer[0], &s->disengage_buffer[DISENGAGE_LEN], sizeof(float) * DISENGAGE_LEN * (DISENGAGE_LEN-1));
std::memcpy(&s->disengage_buffer[DISENGAGE_LEN * (DISENGAGE_LEN-1)], &dp_ind[0], sizeof(float) * DISENGAGE_LEN);
std::memmove(&ps.disengage_buffer[0], &ps.disengage_buffer[DISENGAGE_LEN], sizeof(float) * DISENGAGE_LEN * (DISENGAGE_LEN-1));
std::memcpy(&ps.disengage_buffer[DISENGAGE_LEN * (DISENGAGE_LEN-1)], &dp_ind[0], sizeof(float) * DISENGAGE_LEN);
}
float score = 0;
for (int i = 0; i < DISENGAGE_LEN; i++) {
score += s->disengage_buffer[i*DISENGAGE_LEN+DISENGAGE_LEN-1-i] / DISENGAGE_LEN;
score += ps.disengage_buffer[i*DISENGAGE_LEN+DISENGAGE_LEN-1-i] / DISENGAGE_LEN;
}
if (score < RYG_GREEN) {
@ -355,7 +349,7 @@ void fill_road_edges(cereal::ModelDataV2::Builder &framed, const std::array<floa
});
}
void fill_model(ModelState* s, cereal::ModelDataV2::Builder &framed, const ModelOutput &net_outputs) {
void fill_model(cereal::ModelDataV2::Builder &framed, const ModelOutput &net_outputs, PublishState &ps) {
const auto &best_plan = net_outputs.plans.get_best_prediction();
std::array<float, TRAJECTORY_SIZE> plan_t;
std::fill_n(plan_t.data(), plan_t.size(), NAN);
@ -383,10 +377,10 @@ void fill_model(ModelState* s, cereal::ModelDataV2::Builder &framed, const Model
fill_road_edges(framed, plan_t, net_outputs.road_edges);
// meta
fill_meta(framed.initMeta(), net_outputs.meta);
fill_meta(framed.initMeta(), net_outputs.meta, ps);
// confidence
fill_confidence(s, framed);
fill_confidence(framed, ps);
// leads
auto leads = framed.initLeadsV3(LEAD_MHP_SELECTION);
@ -407,8 +401,8 @@ void fill_model(ModelState* s, cereal::ModelDataV2::Builder &framed, const Model
temporal_pose.setRotStd({exp(r_std.x), exp(r_std.y), exp(r_std.z)});
}
void model_publish(ModelState* s, PubMaster &pm, uint32_t vipc_frame_id, uint32_t vipc_frame_id_extra, uint32_t frame_id, float frame_drop,
const ModelOutput &net_outputs, uint64_t timestamp_eof, uint64_t timestamp_llk,
void model_publish(PubMaster &pm, uint32_t vipc_frame_id, uint32_t vipc_frame_id_extra, uint32_t frame_id, float frame_drop,
const ModelOutput &net_outputs, ModelState &s, PublishState &ps, uint64_t timestamp_eof, uint64_t timestamp_llk,
float model_execution_time, const bool nav_enabled, const bool valid) {
const uint32_t frame_age = (frame_id > vipc_frame_id) ? (frame_id - vipc_frame_id) : 0;
MessageBuilder msg;
@ -422,9 +416,9 @@ void model_publish(ModelState* s, PubMaster &pm, uint32_t vipc_frame_id, uint32_
framed.setModelExecutionTime(model_execution_time);
framed.setNavEnabled(nav_enabled);
if (send_raw_pred) {
framed.setRawPredictions((kj::ArrayPtr<const float>(s->output.data(), s->output.size())).asBytes());
framed.setRawPredictions((kj::ArrayPtr<const float>(s.output.data(), s.output.size())).asBytes());
}
fill_model(s, framed, net_outputs);
fill_model(framed, net_outputs, ps);
pm.send("modelV2", msg);
}

View File

@ -39,6 +39,10 @@ constexpr int LEAD_MHP_SELECTION = 3;
// Padding to get output shape as multiple of 4
constexpr int PAD_SIZE = 2;
constexpr float FCW_THRESHOLD_5MS2_HIGH = 0.15;
constexpr float FCW_THRESHOLD_5MS2_LOW = 0.05;
constexpr float FCW_THRESHOLD_3MS2 = 0.7;
struct ModelOutputXYZ {
float x;
float y;
@ -262,7 +266,6 @@ struct ModelState {
ModelFrame *frame = nullptr;
ModelFrame *wide_frame = nullptr;
std::array<float, HISTORY_BUFFER_LEN * FEATURE_LEN> feature_buffer = {};
std::array<float, DISENGAGE_LEN * DISENGAGE_LEN> disengage_buffer = {};
std::array<float, NET_OUTPUT_SIZE> output = {};
std::unique_ptr<RunModel> m;
#ifdef DESIRE
@ -280,12 +283,18 @@ struct ModelState {
#endif
};
struct PublishState {
std::array<float, DISENGAGE_LEN * DISENGAGE_LEN> disengage_buffer = {};
std::array<float, 5> prev_brake_5ms2_probs = {};
std::array<float, 3> prev_brake_3ms2_probs = {};
};
void model_init(ModelState* s, cl_device_id device_id, cl_context context);
ModelOutput *model_eval_frame(ModelState* s, VisionBuf* buf, VisionBuf* buf_wide,
const mat3 &transform, const mat3 &transform_wide, float *desire_in, bool is_rhd, float *driving_style, float *nav_features, bool prepare_only);
void model_free(ModelState* s);
void model_publish(ModelState* s, PubMaster &pm, uint32_t vipc_frame_id, uint32_t vipc_frame_id_extra, uint32_t frame_id, float frame_drop,
const ModelOutput &net_outputs, uint64_t timestamp_eof, uint64_t timestamp_llk,
void model_publish(PubMaster &pm, uint32_t vipc_frame_id, uint32_t vipc_frame_id_extra, uint32_t frame_id, float frame_drop,
const ModelOutput &net_outputs, ModelState &s, PublishState &ps, uint64_t timestamp_eof, uint64_t timestamp_llk,
float model_execution_time, const bool nav_enabled, const bool valid);
void posenet_publish(PubMaster &pm, uint32_t vipc_frame_id, uint32_t vipc_dropped_frames,
const ModelOutput &net_outputs, uint64_t timestamp_eof, const bool valid);