From 9b21061489f02dea9191a5ee4c4790de16d1e85b Mon Sep 17 00:00:00 2001 From: Mikael Simard Date: Tue, 26 Sep 2023 18:35:57 +0100 Subject: [PATCH] Minor changes Corrected dataloader issues, updated the inference code (needs uniformisation with DigitalPathologyAI), added the ColourAugment scripts for data augmentation, adjusted transforms at the training stage. --- Dataloader/Dataloader.py | 33 +++-- Inference/Preprocess.py | 101 ++++++------- QA/Normalization/Colour/ColourAugment.py | 138 ++++++++++++++++++ .../__pycache__/ColourAugment.cpython-39.pyc | Bin 0 -> 4462 bytes Training/Preprocess.py | 82 +++++------ Utils/PreprocessingTools.py | 2 +- 6 files changed, 238 insertions(+), 118 deletions(-) create mode 100644 QA/Normalization/Colour/ColourAugment.py create mode 100644 QA/Normalization/Colour/__pycache__/ColourAugment.cpython-39.pyc diff --git a/Dataloader/Dataloader.py b/Dataloader/Dataloader.py index f03f674..b63d8b0 100644 --- a/Dataloader/Dataloader.py +++ b/Dataloader/Dataloader.py @@ -11,6 +11,7 @@ from pathlib import Path from QA.StainNormalization import ColourNorm from Utils import npyExportTools +from PIL import Image class DataGenerator(torch.utils.data.Dataset): @@ -33,7 +34,14 @@ def __getitem__(self, id): # load image svs_path = self.tile_dataset['SVS_PATH'].iloc[id] svs_file = openslide.open_slide(svs_path) - data = np.array(svs_file.read_region([self.tile_dataset["coords_x"].iloc[id], self.tile_dataset["coords_y"].iloc[id]], self.vis, self.dim).convert("RGB")) + + # Todo: uniformise with DigitalPathologyAI framework. This is old and has not been changed recently. + try: + data = np.array(svs_file.read_region([self.tile_dataset["coords_x"].iloc[id], self.tile_dataset["coords_y"].iloc[id]], self.vis, self.dim).convert("RGB")) + except Exception as e: + print("An error '{}' occurred with SVS '{}'; replacing patch by zeroes.".format(e, svs_path)) + data = Image.new('RGB', self.dim, (0, 0, 0)) + for transform_step in self.transform.transforms: if(isinstance(transform_step,ColourNorm.Macenko)): HE, maxC = ColourNorm.Macenko().find_HE(data, get_maxC=True) @@ -114,13 +122,13 @@ def __init__(self, tile_dataset, train_transform=None, val_transform=None, batch self.test_data = DataGenerator(tile_dataset_test , transform=val_transform , target=target, **kwargs) def train_dataloader(self): - return DataLoader(self.train_data, batch_size=self.batch_size, num_workers=60, pin_memory=True, shuffle=True) + return DataLoader(self.train_data, batch_size=self.batch_size, num_workers=24, pin_memory=True, shuffle=True) def val_dataloader(self): - return DataLoader(self.val_data, batch_size=self.batch_size, num_workers=60, pin_memory=True) + return DataLoader(self.val_data, batch_size=self.batch_size, num_workers=24, pin_memory=True) def test_dataloader(self): - return DataLoader(self.test_data, batch_size=self.batch_size, num_workers=60, pin_memory=True) + return DataLoader(self.test_data, batch_size=self.batch_size, num_workers=24, pin_memory=True) def LoadFileParameter(config, dataset): @@ -225,15 +233,16 @@ def QueryImageFromCriteria(config, **kwargs): result = conn.getQueryService().projection(query, params, {"omero.group": "-1"}) df_criteria = pd.DataFrame() - for row in result: ## Transform the results into a panda dataframe for each found match - temp = pd.DataFrame([[row[0].val, Path(row[1].val).stem, row[2].val, *row[3].val.getMapValueAsMap().values()]], - columns=["id_omero", "id_external", "Size", *row[3].val.getMapValueAsMap().keys()]) - df_criteria = pd.concat([df_criteria, temp]) - - df_criteria['SVS_PATH'] = [os.path.join(config['DATA']['SVS_Folder'], image_id+'.svs') for image_id in df_criteria['id_external']] - df_criteria['NPY_PATH'] = [os.path.join(config['DATA']['SVS_Folder'], 'patches', image_id + '.npy') for image_id in df_criteria['id_external']] + if len(result)>0: + for row in result: ## Transform the results into a panda dataframe for each found match + temp = pd.DataFrame([[row[0].val, Path(row[1].val).stem, row[2].val, *row[3].val.getMapValueAsMap().values()]], + columns=["id_omero", "id_external", "Size", *row[3].val.getMapValueAsMap().keys()]) + df_criteria = pd.concat([df_criteria, temp]) + + df_criteria['SVS_PATH'] = [os.path.join(config['DATA']['SVS_Folder'], image_id+'.svs') for image_id in df_criteria['id_external']] + df_criteria['NPY_PATH'] = [os.path.join(config['DATA']['SVS_Folder'], 'patches', image_id + '.npy') for image_id in df_criteria['id_external']] - df = pd.concat([df, df_criteria]) + df = pd.concat([df, df_criteria]) conn.close() return df diff --git a/Inference/Preprocess.py b/Inference/Preprocess.py index d8d1550..d470c69 100644 --- a/Inference/Preprocess.py +++ b/Inference/Preprocess.py @@ -1,59 +1,61 @@ -import sys -sys.path.insert(0,'/home/cacof1/Software/DigitalPathologyPreprocessing/') +from Dataloader.Dataloader import * from Utils.PreprocessingTools import Preprocessor import toml +from torch import cuda +from Utils import GetInfo +from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, TQDMProgressBar import pytorch_lightning as pl +from Utils import MultiGPUTools from torchvision import transforms -from QA.StainNormalization import ColourNorm +import torch +from QA.StainNormalization import ColourAugment from Model.ConvNet import ConvNet -from Dataloader.Dataloader import * -from Utils import MultiGPUTools +import datetime +import multiprocessing as mp -n_gpus = torch.cuda.device_count() # could go into config file config = toml.load(sys.argv[1]) +n_gpus = 1 #cuda.device_count() -######################################################################################################################## -# 1. Download all relevant files based on the configuration file -SVS_dataset = QueryImageFromCriteria(config) -SVS_dataset = SVS_dataset.reset_index() -print(SVS_dataset) +######################################################################################################################## +# 1. Download all relevant ROI based on the configuration file +SVS_dataset = QueryImageFromCriteria(config).reset_index() SynchronizeSVS(config, SVS_dataset) -SVS_dataset = SVS_dataset.iloc[0:1] - -print(SVS_dataset) - ######################################################################################################################## -# 2. Pre-processing: create npy files -print('Getting tiles without background') +# 2. Pre-processing: create tile_dataset from annotations list + preprocessor = Preprocessor(config) tile_dataset = preprocessor.getAllTiles(SVS_dataset, background_fraction_threshold=0.7) -print('hrtr') + ######################################################################################################################## -# 3. Model + dataloader +# 3. Model -# Pad tile_dataset such that the final batch size can be divided by n_gpus. -n_pad = MultiGPUTools.pad_size(len(tile_dataset), n_gpus, config['BASEMODEL']['Batch_Size']) -tile_dataset = MultiGPUTools.pad_dataframe(tile_dataset, n_pad) -pl.seed_everything(config['ADVANCEDMODEL']['Random_Seed'], workers=True) +# Data transformation val_transform = transforms.Compose([ - transforms.ToTensor(), # this also normalizes to [0,1]. - #ColourNorm.Macenko(), + transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) +# Pad tile_dataset such that the final batch size can be divided by n_gpus. +n_pad = MultiGPUTools.pad_size(len(tile_dataset), n_gpus, config['BASEMODEL']['Batch_Size']) +tile_dataset = MultiGPUTools.pad_dataframe(tile_dataset, n_pad) + data = DataLoader(DataGenerator(tile_dataset, transform=val_transform, target=config['DATA']['Label'], inference=True), batch_size=config['BASEMODEL']['Batch_Size'], - num_workers=60, + num_workers=int(.8 * mp.Pool()._processes), persistent_workers=True, shuffle=False, - #prefetch_factor = 10, pin_memory=True) -trainer = pl.Trainer(gpus=n_gpus, strategy='bagua', benchmark=False, precision=config['BASEMODEL']['Precision'])#, - #callbacks=[pl.callbacks.TQDMProgressBar(refresh_rate=1)]) +trainer = pl.Trainer(devices=n_gpus, + accelerator="gpu", + strategy=pl.strategies.DDPStrategy(timeout=datetime.timedelta(seconds=10800)), + benchmark=False, + precision=config['BASEMODEL']['Precision'], + callbacks=[TQDMProgressBar(refresh_rate=1)]) model = ConvNet.load_from_checkpoint(config['CHECKPOINT']['Model_Save_Path']) model.eval() @@ -69,8 +71,14 @@ tile_dataset = tile_dataset.iloc[:-n_pad] predicted_classes_prob = predicted_classes_prob[:-n_pad] + ######################################################################################################################## -# 5. Save +# 5. Save (separate from Omero upload in case it crashes) + +# Start by saving the dataframe in case there is a failure in the code below - by experience on multiple +# gpus the code hangs here sometimes... +tile_dataset = tile_dataset.drop(columns="SVS_PATH") # drop SVS path before saving. +tile_dataset.to_csv('full_tile_dataset_after_inference.csv') tissue_names = model.LabelEncoder.inverse_transform(np.arange(predicted_classes_prob.shape[1])) for tissue_no, tissue_name in enumerate(tissue_names): @@ -78,34 +86,13 @@ tile_dataset['prob_' + config['DATA']['Label'] + '_' + tissue_name] = predicted_classes_prob[:, tissue_no] tile_dataset = tile_dataset.fillna(0) -# todo: remove both following lines once the SaveFileParameter below works. for id_external, df_split in tile_dataset.groupby(tile_dataset.id_external): - npy_file = SaveFileParameter(config, df_split, id_external) + npy_file = SaveFileParameter(config, df_split, str(id_external)) + +print('Npy export complete - now uploading onto Omero...') + + + -######################################################################################################################## -# 6. Send back to OMERO -conn = connect(config['OMERO']['Host'], config['OMERO']['User'], config['OMERO']['Pw']) -conn.SERVICE_OPTS.setOmeroGroup('-1') -for id_external, df_split in tile_dataset.groupby(tile_dataset.id_external): - image = conn.getObject("Image", SVS_dataset.loc[SVS_dataset["id_internal"] == id_external].iloc[0]['id_omero']) - group_id = image.getDetails().getGroup().getId() - conn.SERVICE_OPTS.setOmeroGroup(group_id) - print("Current group: ", group_id) - npy_file = SaveFileParameter(config, df_split, id_external) - print("\nCreating an OriginalFile and FileAnnotation") - file_ann = conn.createFileAnnfromLocalFile(npy_file, mimetype="text/plain", desc=None) - print("Attaching FileAnnotation to Dataset: ", "File ID:", file_ann.getId(), ",", file_ann.getFile().getName(), - "Size:", file_ann.getFile().getSize()) - - ## delete because Omero methods are moronic - to_delete = [] - for ann in image.listAnnotations(): - if isinstance(ann, omero.gateway.FileAnnotationWrapper): to_delete.append(ann.id) - conn.deleteObjects('Annotation', to_delete, wait=True) - if len(to_delete)>0: image.linkAnnotation(file_ann) # link it to dataset. - - print('{}.npy uploaded'.format(id_external)) - -conn.close() diff --git a/QA/Normalization/Colour/ColourAugment.py b/QA/Normalization/Colour/ColourAugment.py new file mode 100644 index 0000000..ccb0b83 --- /dev/null +++ b/QA/Normalization/Colour/ColourAugment.py @@ -0,0 +1,138 @@ +import torch.nn as nn +import torch +import numpy as np +from matplotlib import pyplot as plt +from PIL import Image +from skimage import data, color + + +def random_uniform(r1, r2): + return (r1 - r2) * torch.rand(3) + r2 + + +def make_3d(x): + return x.unsqueeze(-1).unsqueeze(-1) + + +class ColourAugment(nn.Module): + """ + Colour augmentation based on: + (1) A. C. Ruifrok and D. A. Johnston, “Quantification of histochemical staining by color deconvolution”. + (2) the scikit-learn codes rgb2hed and hed2rgb (reimplemented here for torch tensors). + (3) DOI: 10.1109/TMI.2018.2820199 for the perturbation scheme. + """ + + def __init__(self, sigma=0.05, mode='uniform'): + super(ColourAugment, self).__init__() + + # In Ruifrok and Johnston's original paper, they do a 3-stain deconvolution with the last + # one as DAB, with a stain vector [0.27, 0.57, 0.78]. In our case, we do H&E only. The + # third vector can be calculated as the one orthogonal to the H and E vectors, calculable + # with a cross product between the H and E stain vectors: + H_stain_vector = [0.65, 0.70, 0.29] + E_stain_vector = [0.07, 0.99, 0.11] + residual = list(np.cross(np.array(H_stain_vector), np.array(E_stain_vector))) + + # Some references on the residual (although it's basic linear algebra) + # https://blog.bham.ac.uk/intellimic/g-landini-software/colour-deconvolution-2/, + # https://forum.image.sc/t/on-the-math-behind-colour-deconvolution-ruifrok-and-johnston-2001/66325 + + self.rgb_from_hed = torch.tensor([H_stain_vector, E_stain_vector, residual], dtype=torch.float32) + self.hed_from_rgb = torch.linalg.inv(self.rgb_from_hed) + self.sigma = sigma + self.mode = mode + + def rgb_to_stain(self, img, conv_matrix): + + c, h, w = img.shape + img = img.reshape(img.shape[0], -1) # collapse (C, H, W) to (C, H*W) + torch.maximum(img, torch.tensor(1E-6), out=img) # avoiding log artifacts + log_adjust = torch.log(torch.tensor(1E-6)) # used to compensate the sum above + stains = conv_matrix @ (torch.log(img) / log_adjust) + + return torch.maximum(stains, torch.tensor(0), out=stains).reshape(c, h, w) + + def stain_to_rgb(self, stains, conv_matrix): + + c, h, w = stains.shape + stains = stains.reshape(stains.shape[0], -1) # collapse (C, H, W) to (C, H*W) + log_adjust = -torch.log( + torch.tensor(1E-6)) # log_adjust here is used to compensate the sum within separate_stains(). + log_rgb = -conv_matrix @ (stains * log_adjust) + rgb = torch.exp(log_rgb) + return torch.clamp(rgb, min=0, max=1).reshape(c, h, w) + + def forward(self, img): # rgb -> he + # input img: float32 torch tensor(intensity ranging[0, 1]) of size (c, h, w) + # output: colour-normalised float32 torch tensor of the same size and range, with colours perturbed. + alpha, beta = torch.tensor(1.0), torch.tensor(0.0) + + conv_matrix_forward = torch.transpose(self.hed_from_rgb, 0, 1) + stains = self.rgb_to_stain(img=img, conv_matrix=conv_matrix_forward) + + if self.mode == 'uniform': + alpha = make_3d(random_uniform(r1=1 - self.sigma, r2=1 + self.sigma)) + beta = make_3d(random_uniform(r1=-self.sigma, r2=self.sigma)) + elif self.mode == 'normal': + alpha = make_3d(torch.normal(mean=1.0, std=torch.tensor([self.sigma, self.sigma, self.sigma]))) + beta = make_3d(torch.normal(mean=0.0, std=torch.tensor([self.sigma, self.sigma, self.sigma]))) + + stains_perturbed = alpha * stains + beta + conv_matrix_backward = torch.transpose(self.rgb_from_hed, 0, 1) + rgb_perturbed = self.stain_to_rgb(stains_perturbed, conv_matrix_backward) + + return rgb_perturbed + + def backward(self, stain): + + conv_matrix_backward = torch.transpose(self.rgb_from_hed, 0, 1) + + return self.stain_to_rgb(stains=stain, conv_matrix=conv_matrix_backward) + + +######################################################################################################################## + + +if __name__ == '__main__': + + # Cloud image (but not a good example as it's Hematoxylin & DAB (no Eosin)). + ihc_rgb = data.immunohistochemistry() / 255.0 + img = torch.permute(torch.tensor(ihc_rgb, dtype=torch.float32), (2, 0, 1)) + + # get stains + m = ColourAugment(sigma=0.005, mode='uniform') + c_f = torch.transpose(m.hed_from_rgb, 0, 1) + c_b = torch.transpose(m.rgb_from_hed, 0, 1) + ihc_hed = torch.permute(m.rgb_to_stain(img=img, conv_matrix=c_f), (1, 2, 0)) + null = torch.zeros_like(ihc_hed[:, :, 0]) + ihc_h = m.stain_to_rgb(stains=torch.permute(torch.stack((ihc_hed[:, :, 0], null, null), dim=-1), (2, 0, 1)), + conv_matrix=c_b) + ihc_e = m.stain_to_rgb(stains=torch.permute(torch.stack((null, ihc_hed[:, :, 1], null), dim=-1), (2, 0, 1)), + conv_matrix=c_b) + ihc_d = m.stain_to_rgb(stains=torch.permute(torch.stack((null, null, ihc_hed[:, :, 2]), dim=-1), (2, 0, 1)), + conv_matrix=c_b) + + # Display the decomposition of the target image + fig, axes = plt.subplots(2, 2, figsize=(7, 6), sharex=True, sharey=True) + ax = axes.ravel() + ax[0].imshow(img.numpy().transpose(1, 2, 0)) + ax[0].set_title("Original image") + ax[1].imshow(ihc_h.numpy().transpose(1, 2, 0)) + ax[1].set_title("Hematoxylin") + ax[2].imshow(ihc_e.numpy().transpose(1, 2, 0)) + ax[2].set_title("Eosin") # Note that there is no Eosin stain in this image + ax[3].imshow(ihc_d.numpy().transpose(1, 2, 0)) + ax[3].set_title("Residual") + for a in ax.ravel(): + a.axis('off') + fig.tight_layout() + + # Display colour augmentation examples of the target image + fig, axes = plt.subplots(5, 5, figsize=(7, 7), sharex=True, sharey=True) + AX = axes.ravel() + for j in range(25): + img_CA = torch.permute(m.forward(img=img), (1, 2, 0)) + AX[j].imshow(img_CA.numpy()) + for a in AX.ravel(): + a.axis('off') + fig.tight_layout() \ No newline at end of file diff --git a/QA/Normalization/Colour/__pycache__/ColourAugment.cpython-39.pyc b/QA/Normalization/Colour/__pycache__/ColourAugment.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f392bfe52caef3d74044d73274d12b184cd806b8 GIT binary patch literal 4462 zcmai1TWlOx8J;t(L8I}W z&wtMO|8x28>yD3CHMF1Iy2d}8(X>C%U~i+%+~^DGnV1Ht*bf1E}+-_ zVo>9>JUV2}%$9b8X=EH^dWs zY*qhQm;VxD!-d`su`k#0|m20y|Z{E02yo&n9QnCTwJ zT_#3`eIg&9%*RivJ!ZP6#I!iUk79>O>~I>hmN?BBXCLS})K-jHZN%&|*zGCBrFYLB zMkCIMGb;u^#;3CTh$s0pZgOrP-oXn!~3@n*EG;hHJCh@RFDw8Y68u%O^mqjf@dAnXg^uC-c#SqC9sfC9PfN zPv(>}it_xSl(cr4pUNpuE6NLpQqtOGZe^5o-<#UE8upx+8_x4HH2bmsPFGwUUKwe) zExq(9{LUWw(gm@|&pl+=;gWqd*Y~qxnxEfdO54wg1#wl~x8C!k1X%HQ;I>6mPYc{l+_cz={7ANRBFr1%1$;N~#Zv$SZNO%< zL>t29Mr;zJX$`yUcedeV=+k0?EjnpIx*>n0r@CCgxA;htHLU;Bp6^6~nCI>n0sspuK)qik^&GB96VVJ^$wQ`8$yeT;JPp6E6zqZ-A&T^X_`T z9f&ZQ>#e7CIx-5Je(0@$p+Z+yZ(*_yyu|=^8sT(d^{xhxfi^^h7_kgyFZu!vQmXXB z_?^BG8v+BirXG@kaWJOkz+Dy2RZfGQ^Cp0|@a=0PM{4m`0fc5Vt|107+u-#NwH3w- zy!d_v(Wu1Bh%QBL@(TJAuktbUWnSas=*cEwpUfax>W| zsaeZzEVvbgFDZ(P3r*|#oONT)y4UwsWVC8QcGk@~ntUzlgmDsum#k0U``(*;0bFy~uw;(8&}WG$~-%6u#?T2c5Q^82K$Pv86gTz24M(@HwRid){Q zmt6LRD?>1F5nHmoyx76r=mzK)F}5yB;RQWkkZ!Tm5mH!Cs-=v}N<LPayffSOm7g`7N33hJJcR6 zrRCguT4WtNtpucukLF?tkZ1`TXb<(FPJZ2n<@|tc8t8TOg~Uh-!y>-Y08vWyg%KO+ z57}lhDR<3*F(?d*Z3B3P8(aE7bE?3{hP8p#t&Xs>Gd>_#AJEznP99dc(=^(@`8Bma zF11%w`_oeUqhI{};{A_)yM)ERy!q2pAN={pOYKi?{eENelfNysoA+t?=Sz>^>O5KR ziDn}$#(juE9>ptC({Z4@#BtJs53Qzp*h3JIQ5>g5S4wwX(yP-_ChxSo;zw?Bbuq0$ zMGmYcaGGYHJfPw@E#Hk}$9e2~au&PL7n2KXtX?vi$xfDLSc8?=#?(Id znA>^mKGmAwE8uUFnO=GjLSsbt-qcLZ*Q0++HWQpE227to?hiV6K+AV`= z7#DZuMr^Y*&}n9Xa55;J&>GriS@8jak7monhrN{{j0~u^ct=u9l1BhO+xqKIhX450 zkFFu!7*Rig21}}_Y$-oN`B8f_eRcl=ova7n3NhyIPVr?4Rz-?ujV($4TfWb5>&2lngbJl zQoYm0v(Y|Z$`R;6dCm)%e5A`O=rUy;Ba9@LmO@oPP@xkD7d8+lyjjaMuasU%P&>rE zC>HW8agnA}h?11aRr!#Gt(R$N5x54hV<2=UBBtN#Al5C5#I=v^@jVBL3v~`xc4Rj) zpL6nDg!nnRmv-50trD?4PCm6ubkMHK`JUHta(o;>Gp6(gtLs&ElFh)R&arbavIeW^ z8^`whYR;~YscP(Kxprvv|h{m^^N))5>+ 0: - train_transform = transforms.Compose([ - transforms.ToTensor(), ## Convert to [0,1] by dividing by 255 - ColourNorm.Macenko(), - transforms.ToPILImage(), - transforms.RandAugment(num_ops=config['AUGMENTATION']['Rand_Operations'], - magnitude=config['AUGMENTATION']['Rand_Magnitude']), - transforms.ToTensor(), ## Convert to [0,1] by dividing by 255 - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), - ]) +pl.seed_everything(config['ADVANCEDMODEL']['Random_Seed'], workers=True) -else: - train_transform = transforms.Compose([ - transforms.ToTensor(), # this also normalizes to [0,1]., - ColourNorm.Macenko(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), - ]) +# Data transformation +train_transform = transforms.Compose([ + transforms.ToTensor(), + ColourAugment.ColourAugment(sigma=config['AUGMENTATION']['Colour_Sigma'], mode=config['AUGMENTATION']['Colour_Mode']), + transforms.ToPILImage(), + transforms.RandomHorizontalFlip(p=0.4), + transforms.RandomVerticalFlip(p=0.4), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), +]) -# transforms: colour norm only on validation set val_transform = transforms.Compose([ - transforms.ToTensor(), # this also normalizes to [0,1]. - ColourNorm.Macenko(), + transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) @@ -93,9 +79,10 @@ label_encoder.fit(tile_dataset[config['DATA']['Label']]) # Load model and train -print("N GPUs: ",torch.cuda.device_count()) -trainer = pl.Trainer(gpus=torch.cuda.device_count(), # could go into config file - strategy='bagua', +print("N GPUs: ", torch.cuda.device_count()) +trainer = pl.Trainer(devices=torch.cuda.device_count(), # could go into config file + accelerator="gpu", + strategy=pl.strategies.DDPStrategy(timeout=datetime.timedelta(seconds=10800)), benchmark=True, max_epochs=config['ADVANCEDMODEL']['Max_Epochs'], precision=config['BASEMODEL']['Precision'], @@ -125,14 +112,13 @@ # Load model and train/validate trainer.fit(model, data) -## Test +# Test trainer.test(model, data.test_dataloader()) -## Write config file in logging folder for safekeeping -with open(logger.log_dir+"/Config.ini", "w+") as toml_file: +# Write config file in logging folder for safekeeping +with open(logger.log_dir + "/Config.ini", "w+") as toml_file: toml.dump(config, toml_file) toml_file.write("Train transform:\n") toml_file.write(str(train_transform)) toml_file.write("Val/Test transform:\n") - toml_file.write(str(val_transform)) - + toml_file.write(str(val_transform)) \ No newline at end of file diff --git a/Utils/PreprocessingTools.py b/Utils/PreprocessingTools.py index 210994a..f6274d0 100644 --- a/Utils/PreprocessingTools.py +++ b/Utils/PreprocessingTools.py @@ -267,7 +267,7 @@ def getTilesFromAnnotations(self, dataset): print(df_final.shape) return df - def getAllTiles(self, dataset, background_fraction_threshold=0): + def getAllTiles(self, dataset, background_fraction_threshold=0.7): df = pd.DataFrame() for idx, row in dataset.iterrows():