@@ -106,7 +106,7 @@ class ModelState:
|
||||
|
||||
# policy inputs
|
||||
self.numpy_inputs = {
|
||||
'desire': np.zeros((1, ModelConstants.INPUT_HISTORY_BUFFER_LEN, ModelConstants.DESIRE_LEN), dtype=np.float32),
|
||||
'desire_pulse': np.zeros((1, ModelConstants.INPUT_HISTORY_BUFFER_LEN, ModelConstants.DESIRE_LEN), dtype=np.float32),
|
||||
'traffic_convention': np.zeros((1, ModelConstants.TRAFFIC_CONVENTION_LEN), dtype=np.float32),
|
||||
'features_buffer': np.zeros((1, ModelConstants.INPUT_HISTORY_BUFFER_LEN, ModelConstants.FEATURE_LEN), dtype=np.float32),
|
||||
}
|
||||
@@ -131,13 +131,13 @@ class ModelState:
|
||||
def run(self, bufs: dict[str, VisionBuf], transforms: dict[str, np.ndarray],
|
||||
inputs: dict[str, np.ndarray], prepare_only: bool) -> dict[str, np.ndarray] | None:
|
||||
# Model decides when action is completed, so desire input is just a pulse triggered on rising edge
|
||||
inputs['desire'][0] = 0
|
||||
new_desire = np.where(inputs['desire'] - self.prev_desire > .99, inputs['desire'], 0)
|
||||
self.prev_desire[:] = inputs['desire']
|
||||
inputs['desire_pulse'][0] = 0
|
||||
new_desire = np.where(inputs['desire_pulse'] - self.prev_desire > .99, inputs['desire_pulse'], 0)
|
||||
self.prev_desire[:] = inputs['desire_pulse']
|
||||
|
||||
self.full_desire[0,:-1] = self.full_desire[0,1:]
|
||||
self.full_desire[0,-1] = new_desire
|
||||
self.numpy_inputs['desire'][:] = self.full_desire.reshape((1,ModelConstants.INPUT_HISTORY_BUFFER_LEN,ModelConstants.TEMPORAL_SKIP,-1)).max(axis=2)
|
||||
self.numpy_inputs['desire_pulse'][:] = self.full_desire.reshape((1,ModelConstants.INPUT_HISTORY_BUFFER_LEN,ModelConstants.TEMPORAL_SKIP,-1)).max(axis=2)
|
||||
|
||||
self.numpy_inputs['traffic_convention'][:] = inputs['traffic_convention']
|
||||
imgs_cl = {name: self.frames[name].prepare(bufs[name], transforms[name].flatten()) for name in self.vision_input_names}
|
||||
@@ -313,7 +313,7 @@ def main(demo=False):
|
||||
bufs = {name: buf_extra if 'big' in name else buf_main for name in model.vision_input_names}
|
||||
transforms = {name: model_transform_extra if 'big' in name else model_transform_main for name in model.vision_input_names}
|
||||
inputs:dict[str, np.ndarray] = {
|
||||
'desire': vec_desire,
|
||||
'desire_pulse': vec_desire,
|
||||
'traffic_convention': traffic_convention,
|
||||
}
|
||||
|
||||
|
||||
Binary file not shown.
Reference in New Issue
Block a user