Skip to content

Maybe something wrong in model.py #3

@K-King6

Description

@K-King6

cur_t_rnn, hc_t = self.capturer_t(rnn_input_his_concat, rnn_input_cur_concat, his_mask, cur_mask, mask_batch[1:])
if self.cat_contained:
cur_c_rnn, hc_c = self.capturer_c(rnn_input_his_concat, rnn_input_cur_concat, his_mask, cur_mask, mask_batch[1:], hc_t)
cur_l_rnn, hc_l = self.capturer_l(rnn_input_his_concat, rnn_input_cur_concat, his_mask, cur_mask, mask_batch[1:], hc_c)

        # 4) tower, t,c,l
        # CMTL
        hc_t, hc_c, hc_l = hc_t.squeeze(), hc_c.squeeze(), hc_l.squeeze()
        
        c_pred = self.fc_c(hc_c) 
        c_trans = self.label_trans_c(c_pred.clone())
        t_pred = self.fc_t(torch.cat((hc_t, c_trans), dim=-1)) 
        t_trans = self.label_trans_t(t_pred.clone())
        l_pred = self.fc_l(torch.cat((hc_l, t_trans), dim=-1))

You first calculate hc_t and use it to calculate hc_c,but you then first calculate c_trans and use it to calculate t_pred, it seems not consistent and may make your result worse.
Snipaste_2024-06-20_20-27-54

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions