diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..3641cd5 Binary files /dev/null and b/.DS_Store differ diff --git a/CC_40D_2_1103-0917.mat b/CC_40D_2_1103-0917.mat new file mode 100644 index 0000000..c7baee9 Binary files /dev/null and b/CC_40D_2_1103-0917.mat differ diff --git a/IDS_COLORCHECK_1020-1215-1.mat b/IDS_COLORCHECK_1020-1215-1.mat new file mode 100644 index 0000000..26f142c Binary files /dev/null and b/IDS_COLORCHECK_1020-1215-1.mat differ diff --git a/clustering.py b/clustering.py new file mode 100644 index 0000000..2d4d03e --- /dev/null +++ b/clustering.py @@ -0,0 +1,172 @@ +import os +import sys +import torch +import numpy as np +import matplotlib.pyplot as plt +import scipy.io as sio + +from sklearn.cluster import KMeans +import h5py + +from hydra import initialize, compose +from omegaconf import OmegaConf + +# Ajouter le dossier courant pour que Python trouve le module 't3sc' +sys.path.append(os.path.abspath('.')) +from t3sc.models.multilayer import MultilayerModel + +# ------------------------------------------------------------------ +# ÉTAPE 1 : CHARGEMENT DU MODÈLE +# ------------------------------------------------------------------ +print("Chargement de la configuration...") +with initialize(config_path="t3sc/config", version_base=None): + cfg = compose(config_name="config", overrides=[ + "data=icvl", + "noise=constant", + "noise.params.sigma=25", + "model.ckpt=icvl_constant_25.ckpt" + ]) + +print("Instanciation du modèle PyTorch...") +model_config = OmegaConf.to_container(cfg.model, resolve=True) +kwargs = model_config.get('params', model_config) + +model = MultilayerModel(**kwargs) +model.eval() +print("Modèle chargé et prêt !") + +# ------------------------------------------------------------------ +# ÉTAPE 2 : L'ESPION (HOOK) +# ------------------------------------------------------------------ +alphas_capture = {} + +def get_activation(name): + def hook(model, input, output): + alphas_capture[name] = output.detach().cpu() + return hook + +model.layers[0].register_forward_hook(get_activation('alpha_1')) + + +# ------------------------------------------------------------------ +# ÉTAPE 3 : CHARGEMENT DE VOTRE IMAGE ICVL (Robuste) +# ------------------------------------------------------------------ +print("\nChargement de la vraie image ICVL...") +#chemin_image = "CC_40D_2_1103-0917.mat" # <-- Votre fichier à la racine +#chemin_image = "nachal_0823-1222.mat" # <-- Votre fichier à la racine +chemin_image = "nachal_0823-1214.mat" # <-- Votre fichier à la racine + +# Vérification rapide de la taille du fichier +taille_mo = os.path.getsize(chemin_image) / (1024 * 1024) +print(f"Taille du fichier : {taille_mo:.2f} Mo") +if taille_mo < 1.0: + print("ATTENTION : Le fichier pèse moins de 1 Mo. Il est très probablement corrompu ou c'est une page HTML déguisée !") + +try: + # Tentative 1 : Format MATLAB récent (HDF5) + with h5py.File(chemin_image, "r") as f: + image_np = np.array(f["rad"], dtype=np.float32) + print("Fichier lu avec succès (Format HDF5 via h5py).") + +except OSError: + # Tentative 2 : Format MATLAB classique + print("Le format n'est pas HDF5, tentative de lecture avec scipy.io...") + try: + mat_data = sio.loadmat(chemin_image) + + # On liste toutes les variables contenues dans le fichier + cles = [k for k in mat_data.keys() if not k.startswith('_')] + print(f"Variables trouvées dans le fichier : {cles}") + + if 'rad' in cles: + nom_variable = 'rad' + else: + # S'il n'y a pas 'rad', on prend la plus grande variable trouvée + nom_variable = cles[0] + print(f"La clé 'rad' est absente. Utilisation automatique de '{nom_variable}'") + + image_np = mat_data[nom_variable].astype(np.float32) + print("Fichier lu avec succès (Format classique via scipy.io).") + + except Exception as e: + print(f"Erreur fatale : Impossible de lire le fichier. Détail : {e}") + sys.exit(1) + +# --- CORRECTION DES DIMENSIONS --- +# PyTorch veut (Canaux, Hauteur, Largeur). +# Selon les datasets, l'image est parfois (H, W, Canaux) ou (Canaux, H, W) +if image_np.shape[-1] == 31 or image_np.shape[-1] > 100: + # Le dernier chiffre est le nombre de canaux (ex: 31). On doit faire tourner le cube. + image_tensor = torch.tensor(image_np, dtype=torch.float32).permute(2, 0, 1) +else: + # Le cube est déjà dans le bon sens + image_tensor = torch.tensor(image_np, dtype=torch.float32) + +# Normalisation Min-Max (entre 0 et 1) +img_min = image_tensor.min() +img_max = image_tensor.max() +image_tensor = (image_tensor - img_min) / (img_max - img_min) + +# --- PRÉPARATION DE LA VISUALISATION (Pseudo-RGB) --- +# Sécurité : on vérifie combien de canaux on a vraiment +nb_canaux = image_tensor.shape[0] +if nb_canaux >= 31: + R = image_tensor[27, :, :].numpy() + G = image_tensor[14, :, :].numpy() + B = image_tensor[8, :, :].numpy() +else: + # Si c'est un autre dataset avec moins de canaux, on prend le début, milieu et fin + R = image_tensor[int(nb_canaux*0.9), :, :].numpy() + G = image_tensor[int(nb_canaux*0.5), :, :].numpy() + B = image_tensor[int(nb_canaux*0.1), :, :].numpy() + +image_rgb_visu = np.stack([R, G, B], axis=-1) +# Pour que l'image soit bien lumineuse à l'écran, on peut booster un peu le contraste +image_rgb_visu = np.clip(image_rgb_visu * 1.5, 0, 1) + +# Format pour le réseau : (Batch, Canaux, Hauteur, Largeur) +image_tensor = image_tensor.unsqueeze(0) +print(f"Image prête pour le réseau ! Taille : {image_tensor.shape}") + +# ------------------------------------------------------------------ +# ÉTAPE 4 : PASSAGE DANS LE RÉSEAU +# ------------------------------------------------------------------ +print("Extraction des caractéristiques (Alphas)...") +with torch.no_grad(): + _ = model.encode(image_tensor, img_id=None, sigmas=None, ssl_idx=None) + +alpha_1 = alphas_capture['alpha_1'].squeeze(0) # (Nombre_atomes, H, W) +nb_atomes, h, w = alpha_1.shape + +# ------------------------------------------------------------------ +# ÉTAPE 5 : CLUSTERING K-MEANS +# ------------------------------------------------------------------ +print("Lancement du clustering...") +alpha_features = alpha_1.permute(1, 2, 0).numpy() +pixels_flat = alpha_features.reshape(h * w, nb_atomes) + +n_clusters = 2 +kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10) +labels_flat = kmeans.fit_predict(pixels_flat) +segmentation_map = labels_flat.reshape(h, w) + +# ------------------------------------------------------------------ +# ÉTAPE 6 : VISUALISATION COMPLÈTE +# ------------------------------------------------------------------ +plt.figure(figsize=(12, 6)) + +# Affichage de l'image originale (Pseudo-RGB) +plt.subplot(1, 2, 1) +plt.imshow(image_rgb_visu) +plt.title("Image Originale (Pseudo-RGB : Bandes 28, 15, 9)") +plt.axis('off') + +# Affichage de la carte de clustering +plt.subplot(1, 2, 2) +plt.imshow(segmentation_map, cmap='Set1') +plt.title(f"Segmentation via les Alphas ({n_clusters} clusters)") +plt.colorbar(label="Numéro du cluster (Matériau)") +plt.axis('off') + +plt.tight_layout() +plt.show() \ No newline at end of file diff --git a/icvl_constant_25.ckpt b/icvl_constant_25.ckpt new file mode 100644 index 0000000..2d5bc3c Binary files /dev/null and b/icvl_constant_25.ckpt differ diff --git a/nachal_0823-1214.mat b/nachal_0823-1214.mat new file mode 100644 index 0000000..228a9d3 Binary files /dev/null and b/nachal_0823-1214.mat differ diff --git a/nachal_0823-1222.mat b/nachal_0823-1222.mat new file mode 100644 index 0000000..9b34dd4 Binary files /dev/null and b/nachal_0823-1222.mat differ diff --git a/t3sc/models/multilayer.py b/t3sc/models/multilayer.py index b32fd93..ec651a8 100644 --- a/t3sc/models/multilayer.py +++ b/t3sc/models/multilayer.py @@ -35,7 +35,7 @@ def __init__( self.ckpt = ckpt if self.ckpt is not None: logger.info(f"Loading ckpt {self.ckpt!r}") - d = torch.load(to_absolute_path(self.ckpt)) + d = torch.load(to_absolute_path(self.ckpt), map_location='cpu') self.load_state_dict(d["state_dict"]) def init_layers(self): diff --git a/visualisation_dataset.py b/visualisation_dataset.py new file mode 100644 index 0000000..104e6ac --- /dev/null +++ b/visualisation_dataset.py @@ -0,0 +1,101 @@ +import os +import glob +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.widgets import Slider +import h5py +import scipy.io as sio +import kagglehub + +# --------------------------------------------------------- +# 1. RÉCUPÉRATION DU DATASET +# --------------------------------------------------------- +print("Vérification du dataset Kaggle...") +path = kagglehub.dataset_download("simhadrisadaram/icvl-test-512") +mat_files = glob.glob(os.path.join(path, "**", "*.mat"), recursive=True) +mat_files.sort() # Pour avoir un ordre logique + +if len(mat_files) == 0: + print("Aucun fichier .mat n'a été trouvé.") + exit() + +print(f"{len(mat_files)} images trouvées. Lancement de l'interface...") + +# --------------------------------------------------------- +# 2. FONCTION DE CHARGEMENT ROBUSTE +# --------------------------------------------------------- +def load_pseudo_rgb(file_path): + """Charge un .mat à la volée et renvoie une image RGB.""" + try: + with h5py.File(file_path, "r") as f: + keys = list(f.keys()) + key = 'rad' if 'rad' in keys else keys[0] + image_np = np.array(f[key], dtype=np.float32) + except OSError: + mat_data = sio.loadmat(file_path) + keys = [k for k in mat_data.keys() if not k.startswith('_')] + key = 'rad' if 'rad' in keys else keys[0] + image_np = mat_data[key].astype(np.float32) + + if image_np.shape[-1] == 31 or image_np.shape[-1] > 100: + image_np = np.transpose(image_np, (2, 0, 1)) + + nb_canaux = image_np.shape[0] + + if nb_canaux >= 31: + R, G, B = image_np[27, :, :], image_np[14, :, :], image_np[8, :, :] + else: + R = image_np[int(nb_canaux * 0.9), :, :] + G = image_np[int(nb_canaux * 0.5), :, :] + B = image_np[int(nb_canaux * 0.1), :, :] + + rgb = np.stack([R, G, B], axis=-1) + + rgb_min, rgb_max = rgb.min(), rgb.max() + if rgb_max > rgb_min: + rgb = (rgb - rgb_min) / (rgb_max - rgb_min) + + return np.clip(rgb * 1.5, 0, 1) + +# --------------------------------------------------------- +# 3. INTERFACE GRAPHIQUE INTERACTIVE +# --------------------------------------------------------- +# Création de la figure avec un peu d'espace en bas pour le slider +fig, ax = plt.subplots(figsize=(8, 8)) +plt.subplots_adjust(bottom=0.2) + +# Premier affichage (Image 0) +current_idx = 0 +rgb_initial = load_pseudo_rgb(mat_files[current_idx]) +img_display = ax.imshow(rgb_initial) +ax.set_title(f"[{current_idx + 1}/{len(mat_files)}] - {os.path.basename(mat_files[current_idx])}") +ax.axis('off') + +# Création de l'axe et du Slider en bas de la fenêtre +ax_slider = plt.axes([0.15, 0.05, 0.7, 0.03]) +slider = Slider( + ax=ax_slider, + label='Image n°', + valmin=0, + valmax=len(mat_files) - 1, + valinit=current_idx, + valstep=1 +) + +# Fonction appelée à chaque fois qu'on bouge le curseur +def update(val): + idx = int(slider.val) + try: + # On charge la nouvelle image + new_rgb = load_pseudo_rgb(mat_files[idx]) + # On met à jour l'affichage sans recréer toute la fenêtre + img_display.set_data(new_rgb) + ax.set_title(f"[{idx + 1}/{len(mat_files)}] - {os.path.basename(mat_files[idx])}") + fig.canvas.draw_idle() + except Exception as e: + print(f"Erreur lors du chargement de l'image {idx}: {e}") + +# Lier le slider à la fonction de mise à jour +slider.on_changed(update) + +plt.show() \ No newline at end of file