diff --git a/selfdrive/modeld/SConscript b/selfdrive/modeld/SConscript index 40b30b8c1bc4d0..80b764b27c20ce 100644 --- a/selfdrive/modeld/SConscript +++ b/selfdrive/modeld/SConscript @@ -56,7 +56,7 @@ compiled_flags_node = lenv.Command( mac_brew_string = f'HOME={os.path.expanduser("~")}' if arch == 'Darwin' else '' # Get model metadata -for model_name in ['driving_vision', 'driving_policy', 'dmonitoring_model']: +for model_name in ['driving_vision', 'driving_off_policy', 'driving_on_policy', 'dmonitoring_model']: fn = File(f"models/{model_name}").abspath script_files = [File(Dir("#selfdrive/modeld").File("get_model_metadata.py").abspath)] cmd = f'{tg_flags} {mac_brew_string} python3 {Dir("#selfdrive/modeld").abspath}/get_model_metadata.py {fn}.onnx' @@ -65,8 +65,8 @@ for model_name in ['driving_vision', 'driving_policy', 'dmonitoring_model']: modeld_dir = Dir("#selfdrive/modeld").abspath compile_modeld_script = [File(f"{modeld_dir}/compile_modeld.py")] compile_dm_warp_script = [File(f"{modeld_dir}/compile_dm_warp.py")] -driving_onnx_deps = [File(f"models/{m}.onnx").abspath for m in ['driving_vision', 'driving_policy']] -driving_metadata_deps = [File(f"models/{m}_metadata.pkl").abspath for m in ['driving_vision', 'driving_policy']] +driving_onnx_deps = [File(f"models/{m}.onnx").abspath for m in ['driving_vision', 'driving_off_policy', 'driving_on_policy']] +driving_metadata_deps = [File(f"models/{m}_metadata.pkl").abspath for m in ['driving_vision', 'driving_off_policy', 'driving_on_policy']] model_w, model_h = MEDMODEL_INPUT_SIZE frame_skip = ModelConstants.MODEL_RUN_FREQ // ModelConstants.MODEL_CONTEXT_FREQ @@ -75,7 +75,8 @@ for cfg in MODELD_CONFIGS: f'--model-size {model_w}x{model_h} ' f'--nv12 {",".join(str(x) for x in cfg.nv12)} ' f'--vision-onnx {File("models/driving_vision.onnx").abspath} ' - f'--policy-onnx {File("models/driving_policy.onnx").abspath} ' + f'--off-policy-onnx {File("models/driving_off_policy.onnx").abspath} ' + f'--on-policy-onnx {File("models/driving_on_policy.onnx").abspath} ' f'--output {cfg.pkl_path} --frame-skip {frame_skip}' + (' --prepare-only' if cfg.prepare_only else '')) node = lenv.Command(cfg.pkl_path, tinygrad_files + compile_modeld_script + driving_onnx_deps + driving_metadata_deps + [chunker_file, compiled_flags_node], cmd) diff --git a/selfdrive/modeld/compile_modeld.py b/selfdrive/modeld/compile_modeld.py index 61de986d564f41..d7e5efaac1b070 100755 --- a/selfdrive/modeld/compile_modeld.py +++ b/selfdrive/modeld/compile_modeld.py @@ -94,6 +94,10 @@ def make_input_queues(vision_input_shapes, policy_input_shapes, frame_skip): 'tfm': np.zeros((3, 3), dtype=np.float32), 'big_tfm': np.zeros((3, 3), dtype=np.float32), } + if 'action_t' in policy_input_shapes: + npy['action_t'] = np.zeros(policy_input_shapes['action_t'], dtype=np.float32) + if 'prev_action' in policy_input_shapes: + npy['prev_action'] = np.zeros(policy_input_shapes['prev_action'][2], dtype=np.float32) input_queues = { 'img_q': Tensor.zeros(img_buf_shape, dtype='uint8').contiguous().realize(), 'big_img_q': Tensor.zeros(img_buf_shape, dtype='uint8').contiguous().realize(), @@ -101,6 +105,9 @@ def make_input_queues(vision_input_shapes, policy_input_shapes, frame_skip): 'desire_q': Tensor.zeros(frame_skip * dp[1], dp[0], dp[2]).contiguous().realize(), **{k: Tensor(v, device='NPY').realize() for k, v in npy.items()}, } + if 'prev_action' in policy_input_shapes: + pa = policy_input_shapes['prev_action'] # (1, 25, 2) + input_queues['prev_action_q'] = Tensor.zeros(frame_skip * (pa[1] - 1) + 1, pa[0], pa[2]).contiguous().realize() return input_queues, npy @@ -117,18 +124,24 @@ def sample_desire(buf, frame_skip): return buf.reshape(-1, frame_skip, *buf.shape[1:]).max(1).flatten(0, 1).unsqueeze(0) -def make_run_policy(vision_runner, policy_runner, nv12: NV12Frame, model_w, model_h, +def make_run_policy(vision_runner, off_policy_runner, on_policy_runner, nv12: NV12Frame, model_w, model_h, vision_features_slice, frame_skip, prepare_only=False): frame_prepare = make_frame_prepare(nv12, model_w, model_h) sample_skip_fn = partial(sample_skip, frame_skip=frame_skip) sample_desire_fn = partial(sample_desire, frame_skip=frame_skip) - def run_policy(img_q, big_img_q, feat_q, desire_q, desire, traffic_convention, tfm, big_tfm, frame, big_frame): + def run_policy(img_q, big_img_q, feat_q, desire_q, desire, traffic_convention, action_t, tfm, big_tfm, frame, big_frame, + prev_action_q=None, prev_action=None): tfm = tfm.to(Device.DEFAULT) big_tfm = big_tfm.to(Device.DEFAULT) desire = desire.to(Device.DEFAULT) traffic_convention = traffic_convention.to(Device.DEFAULT) - Tensor.realize(tfm, big_tfm, desire, traffic_convention) + action_t = action_t.to(Device.DEFAULT) + to_realize = [tfm, big_tfm, desire, traffic_convention, action_t] + if prev_action is not None: + prev_action = prev_action.to(Device.DEFAULT) + to_realize.append(prev_action) + Tensor.realize(*to_realize) img = shift_and_sample(img_q, frame_prepare(frame, tfm).unsqueeze(0), sample_skip_fn) big_img = shift_and_sample(big_img_q, frame_prepare(big_frame, big_tfm).unsqueeze(0), sample_skip_fn) @@ -142,30 +155,42 @@ def run_policy(img_q, big_img_q, feat_q, desire_q, desire, traffic_convention, t feat_buf = shift_and_sample(feat_q, new_feat, sample_skip_fn) desire_buf = shift_and_sample(desire_q, desire.reshape(1, 1, -1), sample_desire_fn) - inputs = {'features_buffer': feat_buf, 'desire_pulse': desire_buf, 'traffic_convention': traffic_convention} - policy_out = next(iter(policy_runner(inputs).values())).cast('float32') - - return vision_out, policy_out + inputs = { + 'features_buffer': feat_buf, + 'desire_pulse': desire_buf, + 'traffic_convention': traffic_convention, + 'action_t': action_t, + } + if prev_action_q is not None and prev_action is not None: + inputs['prev_action'] = shift_and_sample(prev_action_q, prev_action.reshape(1, 1, -1), sample_skip_fn) + on_policy_out = next(iter(on_policy_runner(inputs).values())).cast('float32') + off_policy_out = next(iter(off_policy_runner(inputs).values())).cast('float32') + + return vision_out, on_policy_out, off_policy_out return run_policy def compile_modeld(nv12: NV12Frame, model_w, model_h, prepare_only, frame_skip, - vision_onnx, policy_onnx, pkl_path): + vision_onnx, off_policy_onnx, on_policy_onnx, pkl_path): from get_model_metadata import metadata_path_for print(f"Compiling combined policy JIT for {nv12.width}x{nv12.height} (prepare_only={prepare_only})...") vision_runner = OnnxRunner(vision_onnx) - policy_runner = OnnxRunner(policy_onnx) + off_policy_runner = OnnxRunner(off_policy_onnx) + on_policy_runner = OnnxRunner(on_policy_onnx) with open(metadata_path_for(vision_onnx), 'rb') as f: vision_metadata = pickle.load(f) vision_features_slice = vision_metadata['output_slices']['hidden_state'] vision_input_shapes = vision_metadata['input_shapes'] - with open(metadata_path_for(policy_onnx), 'rb') as f: + with open(metadata_path_for(on_policy_onnx), 'rb') as f: policy_input_shapes = pickle.load(f)['input_shapes'] + with open(metadata_path_for(off_policy_onnx), 'rb') as f: + off_policy_input_shapes = pickle.load(f)['input_shapes'] + assert policy_input_shapes == off_policy_input_shapes - _run = make_run_policy(vision_runner, policy_runner, nv12, model_w, model_h, + _run = make_run_policy(vision_runner, off_policy_runner, on_policy_runner, nv12, model_w, model_h, vision_features_slice, frame_skip, prepare_only) run_policy_jit = TinyJit(_run, prune=True) @@ -235,7 +260,8 @@ def _parse_nv12(s): p.add_argument('--nv12', type=_parse_nv12, required=True, help=f'NV12 frame layout: {",".join(NV12Frame._fields)}') p.add_argument('--vision-onnx', required=True) - p.add_argument('--policy-onnx', required=True) + p.add_argument('--off-policy-onnx', required=True) + p.add_argument('--on-policy-onnx', required=True) p.add_argument('--output', required=True) p.add_argument('--prepare-only', action='store_true') p.add_argument('--frame-skip', type=int, required=True) @@ -243,4 +269,4 @@ def _parse_nv12(s): model_w, model_h = args.model_size compile_modeld(args.nv12, model_w, model_h, args.prepare_only, args.frame_skip, - args.vision_onnx, args.policy_onnx, args.output) + args.vision_onnx, args.off_policy_onnx, args.on_policy_onnx, args.output) diff --git a/selfdrive/modeld/constants.py b/selfdrive/modeld/constants.py index ff7e1d86006e83..0fb09262d0192e 100644 --- a/selfdrive/modeld/constants.py +++ b/selfdrive/modeld/constants.py @@ -38,6 +38,7 @@ class ModelConstants: LANE_LINES_WIDTH = 2 ROAD_EDGES_WIDTH = 2 PLAN_WIDTH = 15 + ACTION_WIDTH = 2 DESIRE_PRED_WIDTH = 8 LAT_PLANNER_SOLUTION_WIDTH = 4 DESIRED_CURV_WIDTH = 1 diff --git a/selfdrive/modeld/fill_model_msg.py b/selfdrive/modeld/fill_model_msg.py index 82c4c92b1d53c7..92a2dfa58d7f3a 100644 --- a/selfdrive/modeld/fill_model_msg.py +++ b/selfdrive/modeld/fill_model_msg.py @@ -125,7 +125,10 @@ def fill_model_msg(base_msg: capnp._DynamicStructBuilder, extended_msg: capnp._D # meta meta = modelV2.meta - meta.desireState = net_output_data['desire_state'][0].reshape(-1).tolist() + if 'desire_state' in net_output_data: + meta.desireState = net_output_data['desire_state'][0].reshape(-1).tolist() + else: + meta.desireState = [0.0] * ModelConstants.DESIRE_PRED_WIDTH meta.desirePrediction = net_output_data['desire_pred'][0].reshape(-1).tolist() meta.engagedProb = net_output_data['meta'][0,Meta.ENGAGED].item() meta.init('disengagePredictions') diff --git a/selfdrive/modeld/modeld.py b/selfdrive/modeld/modeld.py index 73ed19ec943790..7041d860e382e6 100755 --- a/selfdrive/modeld/modeld.py +++ b/selfdrive/modeld/modeld.py @@ -36,7 +36,8 @@ SEND_RAW_PRED = os.getenv('SEND_RAW_PRED') VISION_METADATA_PATH = MODELS_DIR / 'driving_vision_metadata.pkl' -POLICY_METADATA_PATH = MODELS_DIR / 'driving_policy_metadata.pkl' +OFF_POLICY_METADATA_PATH = MODELS_DIR / 'driving_off_policy_metadata.pkl' +ON_POLICY_METADATA_PATH = MODELS_DIR / 'driving_on_policy_metadata.pkl' LAT_SMOOTH_SECONDS = 0.0 LONG_SMOOTH_SECONDS = 0.3 @@ -44,20 +45,12 @@ -def get_action_from_model(model_output: dict[str, np.ndarray], prev_action: log.ModelDataV2.Action, - lat_action_t: float, long_action_t: float, v_ego: float) -> log.ModelDataV2.Action: - plan = model_output['plan'][0] - desired_accel, should_stop = get_accel_from_plan(plan[:,Plan.VELOCITY][:,0], - plan[:,Plan.ACCELERATION][:,0], - ModelConstants.T_IDXS, - action_t=long_action_t) - desired_accel = smooth_value(desired_accel, prev_action.desiredAcceleration, LONG_SMOOTH_SECONDS) +def get_action_from_model(model_output: dict[str, np.ndarray], prev_action: log.ModelDataV2.Action, v_ego: float) -> log.ModelDataV2.Action: + desired_curv_unscaled, desired_accel = model_output['action'][0] + desired_curvature = desired_curv_unscaled / 100 + should_stop = (v_ego < 0.3 and desired_accel < 0.1) - desired_curvature = get_curvature_from_plan(plan[:,Plan.T_FROM_CURRENT_EULER][:,2], - plan[:,Plan.ORIENTATION_RATE][:,2], - ModelConstants.T_IDXS, - v_ego, - lat_action_t) + desired_accel = smooth_value(desired_accel, prev_action.desiredAcceleration, LONG_SMOOTH_SECONDS) if v_ego > MIN_LAT_CONTROL_SPEED: desired_curvature = smooth_value(desired_curvature, prev_action.desiredCurvature, LAT_SMOOTH_SECONDS) else: @@ -87,7 +80,11 @@ def __init__(self, cam_w: int, cam_h: int): self.vision_input_names = list(self.vision_input_shapes.keys()) self.vision_output_slices = vision_metadata['output_slices'] - with open(POLICY_METADATA_PATH, 'rb') as f: + with open(OFF_POLICY_METADATA_PATH, 'rb') as f: + off_policy_metadata = pickle.load(f) + self.off_policy_output_slices = off_policy_metadata['output_slices'] + + with open(ON_POLICY_METADATA_PATH, 'rb') as f: policy_metadata = pickle.load(f) self.policy_input_shapes = policy_metadata['input_shapes'] self.policy_output_slices = policy_metadata['output_slices'] @@ -127,6 +124,10 @@ def run(self, bufs: dict[str, VisionBuf], transforms: dict[str, np.ndarray], self.npy['desire'][:] = np.where(inputs['desire_pulse'] - self.prev_desire > .99, inputs['desire_pulse'], 0) self.prev_desire[:] = inputs['desire_pulse'] self.npy['traffic_convention'][:] = inputs['traffic_convention'] + if 'action_t' in self.npy: + self.npy['action_t'][:] = inputs['action_t'] + if 'prev_action' in self.npy: + self.npy['prev_action'][:] = inputs['prev_action'] self.npy['tfm'][:,:] = transforms['img'][:,:] self.npy['big_tfm'][:,:] = transforms['big_img'][:,:] @@ -134,18 +135,20 @@ def run(self, bufs: dict[str, VisionBuf], transforms: dict[str, np.ndarray], self.warp_enqueue(**self.input_queues, frame=self.full_frames['img'], big_frame=self.full_frames['big_img']) return None - vision_output, policy_output = self.run_policy( + vision_output, policy_output, off_policy_output = self.run_policy( **self.input_queues, frame=self.full_frames['img'], big_frame=self.full_frames['big_img'] ) vision_output = vision_output.numpy().flatten() + off_policy_output = off_policy_output.numpy().flatten() policy_output = policy_output.numpy().flatten() vision_outputs_dict = self.parser.parse_vision_outputs(self.slice_outputs(vision_output, self.vision_output_slices)) + off_policy_outputs_dict = self.parser.parse_off_policy_outputs(self.slice_outputs(off_policy_output, self.off_policy_output_slices)) policy_outputs_dict = self.parser.parse_policy_outputs(self.slice_outputs(policy_output, self.policy_output_slices)) - combined_outputs_dict = {**vision_outputs_dict, **policy_outputs_dict} + combined_outputs_dict = {**vision_outputs_dict, **off_policy_outputs_dict, **policy_outputs_dict} if SEND_RAW_PRED: - combined_outputs_dict['raw_pred'] = np.concatenate([vision_output.copy(), policy_output.copy()]) + combined_outputs_dict['raw_pred'] = np.concatenate([vision_output.copy(), policy_output.copy(), off_policy_output.copy()]) return combined_outputs_dict @@ -287,9 +290,15 @@ def main(demo=False): bufs = {name: buf_extra if 'big' in name else buf_main for name in model.vision_input_names} transforms = {name: model_transform_extra if 'big' in name else model_transform_main for name in model.vision_input_names} + frame_delay = DT_MDL # compensate for time passed since the frame was captured: current_time - timestamp_eof is 50ms on average + action_delay = DT_MDL / 2 # middle of the interval between model output (current state) and next frame (expected state) + lat_action_t = lat_delay + frame_delay + action_delay + long_action_t = long_delay + frame_delay + action_delay inputs:dict[str, np.ndarray] = { 'desire_pulse': vec_desire, 'traffic_convention': traffic_convention, + 'action_t': np.array([lat_action_t, long_action_t], dtype=np.float32), + 'prev_action': np.array([prev_action.desiredCurvature * max(1.0, v_ego)**2, prev_action.desiredAcceleration], dtype=np.float32), } mt1 = time.perf_counter() @@ -302,9 +311,7 @@ def main(demo=False): drivingdata_send = messaging.new_message('drivingModelData') posenet_send = messaging.new_message('cameraOdometry') - frame_delay = DT_MDL # compensate for time passed since the frame was captured: current_time - timestamp_eof is 50ms on average - action_delay = DT_MDL / 2 # middle of the interval between model output (current state) and next frame (expected state) - action = get_action_from_model(model_output, prev_action, lat_delay + frame_delay + action_delay, long_delay + frame_delay + action_delay, v_ego) + action = get_action_from_model(model_output, prev_action, v_ego) prev_action = action fill_model_msg(drivingdata_send, modelv2_send, model_output, action, publish_state, meta_main.frame_id, meta_extra.frame_id, frame_id, diff --git a/selfdrive/modeld/models/big_driving_policy.onnx b/selfdrive/modeld/models/big_driving_policy.onnx deleted file mode 120000 index e1b653a14a03d6..00000000000000 --- a/selfdrive/modeld/models/big_driving_policy.onnx +++ /dev/null @@ -1 +0,0 @@ -driving_policy.onnx \ No newline at end of file diff --git a/selfdrive/modeld/models/big_driving_vision.onnx b/selfdrive/modeld/models/big_driving_vision.onnx deleted file mode 120000 index 28ee71dd746e63..00000000000000 --- a/selfdrive/modeld/models/big_driving_vision.onnx +++ /dev/null @@ -1 +0,0 @@ -driving_vision.onnx \ No newline at end of file diff --git a/selfdrive/modeld/models/driving_off_policy.onnx b/selfdrive/modeld/models/driving_off_policy.onnx new file mode 100644 index 00000000000000..51b659a08ea6cd --- /dev/null +++ b/selfdrive/modeld/models/driving_off_policy.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dff39ef4705f7dc092ea6a867d030c1092d40fa124f15c36ed01ee62877a1f21 +size 18178659 diff --git a/selfdrive/modeld/models/driving_on_policy.onnx b/selfdrive/modeld/models/driving_on_policy.onnx new file mode 100644 index 00000000000000..50ffcfec4cae96 --- /dev/null +++ b/selfdrive/modeld/models/driving_on_policy.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4e7e31a20e93c6e0fd37d29c339e6668a915e3550d971a17e6560b9227b7ffcd +size 16745129 diff --git a/selfdrive/modeld/models/driving_policy.onnx b/selfdrive/modeld/models/driving_policy.onnx deleted file mode 100644 index 611ae9fe85f837..00000000000000 --- a/selfdrive/modeld/models/driving_policy.onnx +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:78477124cbf3ffe30fa951ebada8410b43c4242c6054584d656f1d329b067e15 -size 14060847 diff --git a/selfdrive/modeld/models/driving_vision.onnx b/selfdrive/modeld/models/driving_vision.onnx index 6c9fc4c84d3632..33601fc1ff2121 100644 --- a/selfdrive/modeld/models/driving_vision.onnx +++ b/selfdrive/modeld/models/driving_vision.onnx @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:ee29ee5bce84d1ce23e9ff381280de9b4e4d96d2934cd751740354884e112c66 -size 46877473 +oid sha256:50070f898ad223475fced42d12a8e67e57e7fac7b21ddc26b05e64f1ae536880 +size 23210375 diff --git a/selfdrive/modeld/parse_model_outputs.py b/selfdrive/modeld/parse_model_outputs.py index a0b45d2a981685..1a7d699e687ca7 100644 --- a/selfdrive/modeld/parse_model_outputs.py +++ b/selfdrive/modeld/parse_model_outputs.py @@ -96,29 +96,32 @@ def parse_vision_outputs(self, outs: dict[str, np.ndarray]) -> dict[str, np.ndar self.parse_mdn('pose', outs, in_N=0, out_N=0, out_shape=(ModelConstants.POSE_WIDTH,)) self.parse_mdn('wide_from_device_euler', outs, in_N=0, out_N=0, out_shape=(ModelConstants.WIDE_FROM_DEVICE_WIDTH,)) self.parse_mdn('road_transform', outs, in_N=0, out_N=0, out_shape=(ModelConstants.POSE_WIDTH,)) + self.parse_categorical_crossentropy('desire_pred', outs, out_shape=(ModelConstants.DESIRE_PRED_LEN,ModelConstants.DESIRE_PRED_WIDTH)) + self.parse_binary_crossentropy('meta', outs) + return outs + + def parse_off_policy_outputs(self, outs: dict[str, np.ndarray]) -> dict[str, np.ndarray]: + plan_mhp = self.is_mhp(outs, 'plan', ModelConstants.IDX_N * ModelConstants.PLAN_WIDTH) + plan_in_N, plan_out_N = (ModelConstants.PLAN_MHP_N, ModelConstants.PLAN_MHP_SELECTION) if plan_mhp else (0, 0) + self.parse_mdn('plan', outs, in_N=plan_in_N, out_N=plan_out_N, out_shape=(ModelConstants.IDX_N, ModelConstants.PLAN_WIDTH)) self.parse_mdn('lane_lines', outs, in_N=0, out_N=0, out_shape=(ModelConstants.NUM_LANE_LINES,ModelConstants.IDX_N,ModelConstants.LANE_LINES_WIDTH)) self.parse_mdn('road_edges', outs, in_N=0, out_N=0, out_shape=(ModelConstants.NUM_ROAD_EDGES,ModelConstants.IDX_N,ModelConstants.LANE_LINES_WIDTH)) self.parse_binary_crossentropy('lane_lines_prob', outs) - self.parse_categorical_crossentropy('desire_pred', outs, out_shape=(ModelConstants.DESIRE_PRED_LEN,ModelConstants.DESIRE_PRED_WIDTH)) - self.parse_binary_crossentropy('meta', outs) self.parse_binary_crossentropy('lead_prob', outs) lead_mhp = self.is_mhp(outs, 'lead', ModelConstants.LEAD_MHP_SELECTION * ModelConstants.LEAD_TRAJ_LEN * ModelConstants.LEAD_WIDTH) lead_in_N, lead_out_N = (ModelConstants.LEAD_MHP_N, ModelConstants.LEAD_MHP_SELECTION) if lead_mhp else (0, 0) lead_out_shape = (ModelConstants.LEAD_TRAJ_LEN, ModelConstants.LEAD_WIDTH) if lead_mhp else \ (ModelConstants.LEAD_MHP_SELECTION, ModelConstants.LEAD_TRAJ_LEN, ModelConstants.LEAD_WIDTH) self.parse_mdn('lead', outs, in_N=lead_in_N, out_N=lead_out_N, out_shape=lead_out_shape) + self.parse_categorical_crossentropy('desire_state', outs, out_shape=(ModelConstants.DESIRE_PRED_WIDTH,)) return outs def parse_policy_outputs(self, outs: dict[str, np.ndarray]) -> dict[str, np.ndarray]: - plan_mhp = self.is_mhp(outs, 'plan', ModelConstants.IDX_N * ModelConstants.PLAN_WIDTH) - plan_in_N, plan_out_N = (ModelConstants.PLAN_MHP_N, ModelConstants.PLAN_MHP_SELECTION) if plan_mhp else (0, 0) - self.parse_mdn('plan', outs, in_N=plan_in_N, out_N=plan_out_N, out_shape=(ModelConstants.IDX_N, ModelConstants.PLAN_WIDTH)) - if 'planplus' in outs: - self.parse_mdn('planplus', outs, in_N=0, out_N=0, out_shape=(ModelConstants.IDX_N, ModelConstants.PLAN_WIDTH)) - self.parse_categorical_crossentropy('desire_state', outs, out_shape=(ModelConstants.DESIRE_PRED_WIDTH,)) + self.parse_mdn('action', outs, in_N=0, out_N=0, out_shape=(ModelConstants.ACTION_WIDTH,)) return outs def parse_outputs(self, outs: dict[str, np.ndarray]) -> dict[str, np.ndarray]: outs = self.parse_vision_outputs(outs) + outs = self.parse_off_policy_outputs(outs) outs = self.parse_policy_outputs(outs) return outs