Source code for atomai.models.sam

import numpy as np
import cv2
import pandas as pd
import matplotlib.pyplot as plt
import torch
import os
import urllib.request


[docs]class ParticleAnalyzer: """ A class to encapsulate an end-to-end particle segmentation and analysis workflow using the Segment Anything Model (SAM). This class handles: - Automatic downloading of SAM model checkpoints. - Image pre-processing, including normalization and optional contrast enhancement. - Running SAM with preset or custom parameters. - Advanced post-processing to filter masks by area and shape, and to remove duplicates. - Extraction of detailed properties for each detected particle. - Conversion of results to a pandas DataFrame and visualization of results. Example: >>> # 1. Initialize the analyzer (downloads model if needed) >>> analyzer = ParticleAnalyzer(model_type="vit_h") >>> >>> # 2. Load image and run the analysis >>> image = np.load(IMAGE_PATH) >>> result = analyzer.analyze(image) >>> >>> # 3. Print summary and visualize results >>> print(f"Found {result['total_count']} particles.") >>> df = ParticleAnalyzer.particles_to_dataframe(result) >>> print(df.head()) >>> >>> # This will generate and show a side-by-side plot >>> ParticleAnalyzer.visualize_particles( ... result, ... original_image_for_plot=image, ... show_plot=True ... ) """ def __init__(self, checkpoint_path=None, model_type="vit_h", device="auto"): """ Initializes the ParticleAnalyzer by loading the SAM model. If the model checkpoint is not found, it will be downloaded automatically. Args: checkpoint_path (str, optional): Path to the SAM model checkpoint file. If None, a default path will be used. model_type (str): The type of SAM model (e.g., "vit_h", "vit_l", "vit_b"). device (str): The device to run the model on ("auto", "cuda", "cpu"). """ print("Initializing Particle Analyzer...") self.device = self._get_device(device) # Determine the final checkpoint path and download if necessary final_checkpoint_path = self._download_model_if_needed(checkpoint_path, model_type) self.sam_model = self._load_model(final_checkpoint_path, model_type) print(f"SAM model loaded successfully on device: {self.device}") def _get_device(self, device): """Determines the appropriate device for PyTorch.""" if device == "auto": return "cuda" if torch.cuda.is_available() else "cpu" return device def _download_model_if_needed(self, checkpoint_path, model_type): """Checks for the model checkpoint and downloads it if it doesn't exist.""" model_urls = { "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" } if checkpoint_path is None: # Create a default path if none is provided checkpoint_dir = "./checkpoints" os.makedirs(checkpoint_dir, exist_ok=True) checkpoint_path = os.path.join(checkpoint_dir, f"sam_{model_type}.pth") if not os.path.exists(checkpoint_path): url = model_urls.get(model_type) if url is None: raise ValueError(f"Unknown model type: '{model_type}'. Cannot download.") print(f"SAM model checkpoint not found at '{checkpoint_path}'.") print(f"Downloading model for '{model_type}' from {url}...") urllib.request.urlretrieve(url, checkpoint_path) print(f"Download complete. Model saved to '{checkpoint_path}'.") return checkpoint_path def _load_model(self, checkpoint_path, model_type): """Loads the SAM model from a checkpoint and moves it to the device.""" try: from segment_anything import sam_model_registry except ImportError: raise ImportError( "The 'segment-anything' package is required to use this feature.\n" "Please install it directly from the official repository:\n\n" "pip install git+https://github.com/facebookresearch/segment-anything.git" ) try: sam = sam_model_registry[model_type](checkpoint=checkpoint_path) sam.to(device=self.device) return sam except Exception as e: print(f"Error loading SAM model from '{checkpoint_path}': {e}") raise
[docs] def analyze(self, image_array, params=None): """ Runs the full analysis pipeline on a given image using a set of parameters. Args: image_array (np.array): The input 2D grayscale image. params (dict, optional): A dictionary of parameters controlling the analysis. If None, a set of default parameters will be used. """ # If no parameters are provided, use a default set for baseline analysis. if params is None: print("No parameters provided. Using default analysis settings.") params = { "use_clahe": False, "sam_parameters": "default", "min_area": 500, "max_area": 50000, "use_pruning": False, "pruning_iou_threshold": 0.5 } # 1. Pre-process the image processed_image = self._preprocess_image(image_array, params.get("use_clahe", False)) image_rgb = cv2.cvtColor(processed_image, cv2.COLOR_GRAY2RGB) # 2. Generate masks with SAM using specified parameters all_masks = self._run_sam(image_rgb, params.get("sam_parameters", "default")) print(f"Generated {len(all_masks)} raw masks.") # 3. Filter and prune masks final_masks_info = self._filter_and_prune(all_masks, params) print(f"Kept {len(final_masks_info)} masks after filtering and pruning.") # 4. Extract properties from final masks particles = [] for i, mask in enumerate(final_masks_info): particle_info = self._extract_particle_properties(mask, processed_image, i + 1) particles.append(particle_info) # Sort by area for consistent ordering particles = sorted(particles, key=lambda x: x['area'], reverse=True) # Reassign IDs after sorting for i, particle in enumerate(particles): particle['id'] = i + 1 return { 'particles': particles, 'original_image': processed_image, 'rgb_image': image_rgb, 'total_count': len(particles) }
def _preprocess_image(self, image_array, use_clahe): """Normalizes image to uint8 and optionally applies CLAHE.""" # Normalize to uint8 if image_array.dtype != np.uint8: if image_array.max() <= 1.0 and image_array.min() >= 0.0: image_array = (image_array * 255).astype(np.uint8) else: min_val, max_val = image_array.min(), image_array.max() if max_val > min_val: image_array = ((image_array - min_val) / (max_val - min_val) * 255).astype(np.uint8) else: image_array = np.zeros_like(image_array, dtype=np.uint8) # Apply CLAHE if requested if use_clahe: print("Applying CLAHE...") clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) image_array = clahe.apply(image_array) return image_array def _run_sam(self, image_rgb, preset_name): """Initializes and runs the SAM mask generator based on a preset.""" try: from segment_anything import SamAutomaticMaskGenerator except ImportError: raise ImportError( "The 'segment-anything' package is required to use this feature.\n" "Please install it directly from the official repository:\n\n" "pip install git+https://github.com/facebookresearch/segment-anything.git" ) sam_param_presets = { "default": {}, "sensitive": { "points_per_side": 96, "pred_iou_thresh": 0.80, "stability_score_thresh": 0.85, }, "ultra-permissive": { "points_per_side": 96, "pred_iou_thresh": 0.60, "stability_score_thresh": 0.70, } } sam_params = sam_param_presets.get(preset_name, {}) print(f"Running SAM with preset: '{preset_name}'") mask_generator = SamAutomaticMaskGenerator(self.sam_model, **sam_params) return mask_generator.generate(image_rgb) def _filter_and_prune(self, masks, params): """Applies area filtering and optional shape-based pruning.""" min_area = params.get("min_area", 0) max_area = params.get("max_area", float('inf')) # Area filtering area_filtered_masks = [m for m in masks if min_area <= m['area'] <= max_area] if params.get("use_pruning", False): print("Applying shape-based pruning...") iou_threshold = params.get("pruning_iou_threshold", 0.5) return self._prune_by_shape_and_iou(area_filtered_masks, iou_threshold) else: return area_filtered_masks def _extract_particle_properties(self, mask, image, particle_id): """Extracts detailed properties for a single particle mask.""" binary_mask = mask['segmentation'] area = mask['area'] y_coords, x_coords = np.where(binary_mask) centroid = (np.mean(x_coords), np.mean(y_coords)) particle_pixels = image[binary_mask] perimeter = self._calculate_perimeter(binary_mask) return { 'id': particle_id, 'area': area, 'centroid': centroid, 'bbox': mask['bbox'], 'mean_intensity': np.mean(particle_pixels), 'std_intensity': np.std(particle_pixels), 'min_intensity': np.min(particle_pixels), 'max_intensity': np.max(particle_pixels), 'perimeter': perimeter, 'circularity': 4 * np.pi * area / (perimeter ** 2) if perimeter > 0 else 0, 'equiv_diameter': 2 * np.sqrt(area / np.pi), 'aspect_ratio': mask['bbox'][3] / mask['bbox'][2] if mask['bbox'][2] > 0 else 1, 'solidity': mask.get('solidity', self._calculate_solidity(mask)), # Use pre-calculated solidity if available 'mask': binary_mask } def _calculate_perimeter(self, binary_mask): contours, _ = cv2.findContours(binary_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) return cv2.arcLength(contours[0], True) if contours else 0 def _calculate_solidity(self, mask): binary_mask = mask['segmentation'].astype(np.uint8) contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if not contours: return 0 cnt = contours[0] area = cv2.contourArea(cnt) hull = cv2.convexHull(cnt) hull_area = cv2.contourArea(hull) return area / hull_area if hull_area > 0 else 0 def _calculate_iou(self, mask1, mask2): bbox1, bbox2 = mask1['bbox'], mask2['bbox'] x_left = max(bbox1[0], bbox2[0]) y_top = max(bbox1[1], bbox2[1]) x_right = min(bbox1[0] + bbox1[2], bbox2[0] + bbox2[2]) y_bottom = min(bbox1[1] + bbox1[3], bbox2[1] + bbox2[3]) if x_right < x_left or y_bottom < y_top: return 0.0 intersection_area = (x_right - x_left) * (y_bottom - y_top) area1, area2 = bbox1[2] * bbox1[3], bbox2[2] * bbox2[3] union_area = area1 + area2 - intersection_area return intersection_area / union_area if union_area > 0 else 0.0 def _prune_by_shape_and_iou(self, masks, iou_threshold): """Prunes masks based on a goodness score and IoU.""" if not masks: return [] for m in masks: m['solidity'] = self._calculate_solidity(m) m['score'] = m['area'] * (m['solidity'] ** 2) sorted_masks = sorted(masks, key=lambda x: x['score'], reverse=True) pruned_masks = [] for mask in sorted_masks: is_duplicate = any(self._calculate_iou(mask, kept_mask) > iou_threshold for kept_mask in pruned_masks) if not is_duplicate: pruned_masks.append(mask) return pruned_masks
[docs] @staticmethod def particles_to_dataframe(result): """Converts the 'particles' list from the result into a pandas DataFrame.""" particles = result.get('particles', []) if not particles: return pd.DataFrame() data = [] for p in particles: row = {k: v for k, v in p.items() if k != 'mask'} row['centroid_x'], row['centroid_y'] = p['centroid'] row['bbox_x'], row['bbox_y'], row['bbox_width'], row['bbox_height'] = p['bbox'] del row['centroid'], row['bbox'] data.append(row) return pd.DataFrame(data)
[docs] @staticmethod def visualize_particles(result, original_image_for_plot=None, show_plot=False, show_labels=True, show_centroids=True): """ Creates an RGB image visualizing the detected particles and optionally displays a plot. Args: result (dict): The output dictionary from the analyze method. original_image_for_plot (np.array, optional): The raw, unprocessed image for side-by-side comparison. If None, the processed image from the result is used. show_plot (bool): If True, displays a matplotlib plot comparing original and segmented images. show_labels (bool): If True, shows particle ID labels on the overlay. show_centroids (bool): If True, shows particle centroids on the overlay. Returns: np.array: The RGB overlay image with particles drawn on it. """ overlay = result['rgb_image'].copy() for particle in result.get('particles', []): contours, _ = cv2.findContours(particle['mask'].astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) cv2.drawContours(overlay, contours, -1, (255, 0, 0), 2) cx, cy = int(particle['centroid'][0]), int(particle['centroid'][1]) if show_centroids: cv2.circle(overlay, (cx, cy), 5, (0, 255, 0), -1) if show_labels: cv2.putText(overlay, str(particle['id']), (cx + 5, cy + 5), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 0), 2) if show_plot: fig, axes = plt.subplots(1, 2, figsize=(16, 8)) # Use the provided original image for the 'before' plot, otherwise use the processed one from results display_image = original_image_for_plot if original_image_for_plot is not None else result['original_image'] axes[0].imshow(display_image, cmap='gray') axes[0].set_title('Original Input') axes[1].imshow(overlay) axes[1].set_title(f"Detected Particles (n={result['total_count']})") for ax in axes: ax.set_axis_off() plt.tight_layout() plt.show() return overlay