Filet o Fish model (#34798)
* 690b01c3 seems ok * correct temporal * push * inplace * bs * what thw * is this wrong * frames are skipped * new models * simplify decimation * clean up * clean up modelframe * need attr * lint * 0 * use all samples * this should break - Revert "use all samples" This reverts commit 6c0d7943ac5fbb7ae60af1a1201e2423e4c3c105. * add lc probs * Revert "this should break - Revert "use all samples"" This reverts commit ca38c54333555266a0d2c885c28af28941841431. * Reapply "this should break - Revert "use all samples"" This reverts commit a3f0246f209f85f06b9090d9492bfba32ed8cfed. * Revert "Reapply "this should break - Revert "use all samples""" This reverts commit 7fd3d2a191b688e5ef7b4dcc8f5379e900af10f8. * new fish * e07ce1de-bdea-463e-b5bc-a38ce8d43f4f/700 --------- Co-authored-by: Comma Device <device@comma.ai>
This commit is contained in:
@@ -14,8 +14,14 @@ class ModelConstants:
|
||||
|
||||
# model inputs constants
|
||||
MODEL_FREQ = 20
|
||||
HISTORY_FREQ = 5
|
||||
HISTORY_LEN_SECONDS = 5
|
||||
TEMPORAL_SKIP = MODEL_FREQ // HISTORY_FREQ
|
||||
FULL_HISTORY_BUFFER_LEN = MODEL_FREQ * HISTORY_LEN_SECONDS
|
||||
INPUT_HISTORY_BUFFER_LEN = HISTORY_FREQ * HISTORY_LEN_SECONDS
|
||||
|
||||
FEATURE_LEN = 512
|
||||
FULL_HISTORY_BUFFER_LEN = 100
|
||||
|
||||
DESIRE_LEN = 8
|
||||
TRAFFIC_CONVENTION_LEN = 2
|
||||
LAT_PLANNER_STATE_LEN = 4
|
||||
|
||||
@@ -56,16 +56,24 @@ class ModelState:
|
||||
prev_desire: np.ndarray # for tracking the rising edge of the pulse
|
||||
|
||||
def __init__(self, context: CLContext):
|
||||
self.frames = {'input_imgs': DrivingModelFrame(context), 'big_input_imgs': DrivingModelFrame(context)}
|
||||
self.frames = {
|
||||
'input_imgs': DrivingModelFrame(context, ModelConstants.TEMPORAL_SKIP),
|
||||
'big_input_imgs': DrivingModelFrame(context, ModelConstants.TEMPORAL_SKIP)
|
||||
}
|
||||
self.prev_desire = np.zeros(ModelConstants.DESIRE_LEN, dtype=np.float32)
|
||||
|
||||
self.full_features_buffer = np.zeros((1, ModelConstants.FULL_HISTORY_BUFFER_LEN, ModelConstants.FEATURE_LEN), dtype=np.float32)
|
||||
self.full_desire = np.zeros((1, ModelConstants.FULL_HISTORY_BUFFER_LEN, ModelConstants.DESIRE_LEN), dtype=np.float32)
|
||||
self.full_prev_desired_curv = np.zeros((1, ModelConstants.FULL_HISTORY_BUFFER_LEN, ModelConstants.PREV_DESIRED_CURV_LEN), dtype=np.float32)
|
||||
self.temporal_idxs = slice(-1-(ModelConstants.TEMPORAL_SKIP*(ModelConstants.INPUT_HISTORY_BUFFER_LEN-1)), None, ModelConstants.TEMPORAL_SKIP)
|
||||
|
||||
# policy inputs
|
||||
self.numpy_inputs = {
|
||||
'desire': np.zeros((1, ModelConstants.FULL_HISTORY_BUFFER_LEN, ModelConstants.DESIRE_LEN), dtype=np.float32),
|
||||
'desire': np.zeros((1, ModelConstants.INPUT_HISTORY_BUFFER_LEN, ModelConstants.DESIRE_LEN), dtype=np.float32),
|
||||
'traffic_convention': np.zeros((1, ModelConstants.TRAFFIC_CONVENTION_LEN), dtype=np.float32),
|
||||
'lateral_control_params': np.zeros((1, ModelConstants.LATERAL_CONTROL_PARAMS_LEN), dtype=np.float32),
|
||||
'prev_desired_curv': np.zeros((1, ModelConstants.FULL_HISTORY_BUFFER_LEN, ModelConstants.PREV_DESIRED_CURV_LEN), dtype=np.float32),
|
||||
'features_buffer': np.zeros((1, ModelConstants.FULL_HISTORY_BUFFER_LEN, ModelConstants.FEATURE_LEN), dtype=np.float32),
|
||||
'prev_desired_curv': np.zeros((1, ModelConstants.INPUT_HISTORY_BUFFER_LEN, ModelConstants.PREV_DESIRED_CURV_LEN), dtype=np.float32),
|
||||
'features_buffer': np.zeros((1, ModelConstants.INPUT_HISTORY_BUFFER_LEN, ModelConstants.FEATURE_LEN), dtype=np.float32),
|
||||
}
|
||||
|
||||
with open(VISION_METADATA_PATH, 'rb') as f:
|
||||
@@ -104,8 +112,9 @@ class ModelState:
|
||||
new_desire = np.where(inputs['desire'] - self.prev_desire > .99, inputs['desire'], 0)
|
||||
self.prev_desire[:] = inputs['desire']
|
||||
|
||||
self.numpy_inputs['desire'][0,:-1] = self.numpy_inputs['desire'][0,1:]
|
||||
self.numpy_inputs['desire'][0,-1] = new_desire
|
||||
self.full_desire[0,:-1] = self.full_desire[0,1:]
|
||||
self.full_desire[0,-1] = new_desire
|
||||
self.numpy_inputs['desire'][:] = self.full_desire.reshape((1,ModelConstants.INPUT_HISTORY_BUFFER_LEN,ModelConstants.TEMPORAL_SKIP,-1)).max(axis=2)
|
||||
|
||||
self.numpy_inputs['traffic_convention'][:] = inputs['traffic_convention']
|
||||
self.numpy_inputs['lateral_control_params'][:] = inputs['lateral_control_params']
|
||||
@@ -128,15 +137,17 @@ class ModelState:
|
||||
self.vision_output = self.vision_run(**self.vision_inputs).numpy().flatten()
|
||||
vision_outputs_dict = self.parser.parse_vision_outputs(self.slice_outputs(self.vision_output, self.vision_output_slices))
|
||||
|
||||
self.numpy_inputs['features_buffer'][0,:-1] = self.numpy_inputs['features_buffer'][0,1:]
|
||||
self.numpy_inputs['features_buffer'][0,-1] = vision_outputs_dict['hidden_state'][0, :]
|
||||
self.full_features_buffer[0,:-1] = self.full_features_buffer[0,1:]
|
||||
self.full_features_buffer[0,-1] = vision_outputs_dict['hidden_state'][0, :]
|
||||
self.numpy_inputs['features_buffer'][:] = self.full_features_buffer[0, self.temporal_idxs]
|
||||
|
||||
self.policy_output = self.policy_run(**self.policy_inputs).numpy().flatten()
|
||||
policy_outputs_dict = self.parser.parse_policy_outputs(self.slice_outputs(self.policy_output, self.policy_output_slices))
|
||||
|
||||
# TODO model only uses last value now
|
||||
self.numpy_inputs['prev_desired_curv'][0,:-1] = self.numpy_inputs['prev_desired_curv'][0,1:]
|
||||
self.numpy_inputs['prev_desired_curv'][0,-1,:] = policy_outputs_dict['desired_curvature'][0, :]
|
||||
self.full_prev_desired_curv[0,:-1] = self.full_prev_desired_curv[0,1:]
|
||||
self.full_prev_desired_curv[0,-1,:] = policy_outputs_dict['desired_curvature'][0, :]
|
||||
self.numpy_inputs['prev_desired_curv'][:] = self.full_prev_desired_curv[0, self.temporal_idxs]
|
||||
|
||||
combined_outputs_dict = {**vision_outputs_dict, **policy_outputs_dict}
|
||||
if SEND_RAW_PRED:
|
||||
|
||||
@@ -5,11 +5,12 @@
|
||||
|
||||
#include "common/clutil.h"
|
||||
|
||||
DrivingModelFrame::DrivingModelFrame(cl_device_id device_id, cl_context context) : ModelFrame(device_id, context) {
|
||||
DrivingModelFrame::DrivingModelFrame(cl_device_id device_id, cl_context context, int _temporal_skip) : ModelFrame(device_id, context) {
|
||||
input_frames = std::make_unique<uint8_t[]>(buf_size);
|
||||
temporal_skip = _temporal_skip;
|
||||
input_frames_cl = CL_CHECK_ERR(clCreateBuffer(context, CL_MEM_READ_WRITE, buf_size, NULL, &err));
|
||||
img_buffer_20hz_cl = CL_CHECK_ERR(clCreateBuffer(context, CL_MEM_READ_WRITE, 2*frame_size_bytes, NULL, &err));
|
||||
region.origin = 1 * frame_size_bytes;
|
||||
img_buffer_20hz_cl = CL_CHECK_ERR(clCreateBuffer(context, CL_MEM_READ_WRITE, (temporal_skip+1)*frame_size_bytes, NULL, &err));
|
||||
region.origin = temporal_skip * frame_size_bytes;
|
||||
region.size = frame_size_bytes;
|
||||
last_img_cl = CL_CHECK_ERR(clCreateSubBuffer(img_buffer_20hz_cl, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err));
|
||||
|
||||
@@ -20,7 +21,7 @@ DrivingModelFrame::DrivingModelFrame(cl_device_id device_id, cl_context context)
|
||||
cl_mem* DrivingModelFrame::prepare(cl_mem yuv_cl, int frame_width, int frame_height, int frame_stride, int frame_uv_offset, const mat3& projection) {
|
||||
run_transform(yuv_cl, MODEL_WIDTH, MODEL_HEIGHT, frame_width, frame_height, frame_stride, frame_uv_offset, projection);
|
||||
|
||||
for (int i = 0; i < 1; i++) {
|
||||
for (int i = 0; i < temporal_skip; i++) {
|
||||
CL_CHECK(clEnqueueCopyBuffer(q, img_buffer_20hz_cl, img_buffer_20hz_cl, (i+1)*frame_size_bytes, i*frame_size_bytes, frame_size_bytes, 0, nullptr, nullptr));
|
||||
}
|
||||
loadyuv_queue(&loadyuv, q, y_cl, u_cl, v_cl, last_img_cl);
|
||||
|
||||
@@ -64,20 +64,21 @@ protected:
|
||||
|
||||
class DrivingModelFrame : public ModelFrame {
|
||||
public:
|
||||
DrivingModelFrame(cl_device_id device_id, cl_context context);
|
||||
DrivingModelFrame(cl_device_id device_id, cl_context context, int _temporal_skip);
|
||||
~DrivingModelFrame();
|
||||
cl_mem* prepare(cl_mem yuv_cl, int frame_width, int frame_height, int frame_stride, int frame_uv_offset, const mat3& projection);
|
||||
|
||||
const int MODEL_WIDTH = 512;
|
||||
const int MODEL_HEIGHT = 256;
|
||||
const int MODEL_FRAME_SIZE = MODEL_WIDTH * MODEL_HEIGHT * 3 / 2;
|
||||
const int buf_size = MODEL_FRAME_SIZE * 2;
|
||||
const int buf_size = MODEL_FRAME_SIZE * 2; // 2 frames are temporal_skip frames apart
|
||||
const size_t frame_size_bytes = MODEL_FRAME_SIZE * sizeof(uint8_t);
|
||||
|
||||
private:
|
||||
LoadYUVState loadyuv;
|
||||
cl_mem img_buffer_20hz_cl, last_img_cl, input_frames_cl;
|
||||
cl_buffer_region region;
|
||||
int temporal_skip;
|
||||
};
|
||||
|
||||
class MonitoringModelFrame : public ModelFrame {
|
||||
|
||||
@@ -20,7 +20,7 @@ cdef extern from "selfdrive/modeld/models/commonmodel.h":
|
||||
|
||||
cppclass DrivingModelFrame:
|
||||
int buf_size
|
||||
DrivingModelFrame(cl_device_id, cl_context)
|
||||
DrivingModelFrame(cl_device_id, cl_context, int)
|
||||
|
||||
cppclass MonitoringModelFrame:
|
||||
int buf_size
|
||||
|
||||
@@ -59,8 +59,8 @@ cdef class ModelFrame:
|
||||
cdef class DrivingModelFrame(ModelFrame):
|
||||
cdef cppDrivingModelFrame * _frame
|
||||
|
||||
def __cinit__(self, CLContext context):
|
||||
self._frame = new cppDrivingModelFrame(context.device_id, context.context)
|
||||
def __cinit__(self, CLContext context, int temporal_skip):
|
||||
self._frame = new cppDrivingModelFrame(context.device_id, context.context, temporal_skip)
|
||||
self.frame = <cppModelFrame*>(self._frame)
|
||||
self.buf_size = self._frame.buf_size
|
||||
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Reference in New Issue
Block a user