Skip to content

image segmentation with multiple classes #12

@karliesama

Description

@karliesama

Thanks for the amazing project! My dataset is multiple classes segmentation. Mask image is (H, W), which H is height, W is width, each pixel is an integer representing the class. For example, tree: 0, ... car: 8, sky: 9. The mask looks like [[0,3,9],[3,4,5]].
The number of my classes is 10.

I'm wondering how to train on this dataset. Should it be like

def get_model(num_classes=10):
    model = models.segmentation.deeplabv3_resnet101(pretrained=True, progress=True)
    model.classifier = DeepLabHead(2048, num_classes=10)
    model.train()
    return model

However, the prediction size seems to be wrong.
The y_pred is torch.Size([8, 38, 256, 456]) but y_truth is torch.Size([8, 256, 456]), 8 is the batch size, 256 is H, 456 is W.
y_pred = model(inputs)['out']

The size doesn't match to feed into loss function. Moreover, the y_pred has float number for every element but I expect it should be the number representing class like 0,1,2,3.

May I ask how to deal with it? Thanks a lot for helping!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions