diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..1b16b82 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,213 @@ +# Contributing to Prompt-CAM + +Thank you for your interest in Prompt-CAM! This guide explains how to use the +repository and how to contribute to it. No advanced Python knowledge is required +to run the visualization tools — we've designed the workflow so that everyone can +explore the model's interpretability features. + +--- + +## Table of Contents + +1. [Using the Tools Without Python Knowledge](#1-using-the-tools-without-python-knowledge) +2. [Environment Setup](#2-environment-setup) +3. [Downloading Checkpoints](#3-downloading-checkpoints) +4. [Running the Web App (Recommended for New Users)](#4-running-the-web-app-recommended-for-new-users) +5. [Running the Demo Notebook](#5-running-the-demo-notebook) +6. [Running Visualization from the Command Line](#6-running-visualization-from-the-command-line) +7. [Training a New Model](#7-training-a-new-model) +8. [Extending the Codebase](#8-extending-the-codebase) +9. [Reporting Issues](#9-reporting-issues) +10. [Code Style](#10-code-style) + +--- + +## 1. Using the Tools Without Python Knowledge + +You do **not** need to write or edit any Python code to use the visualization +tools. Choose one of the following entry points: + +| Entry point | Python knowledge needed | What it does | +|---|---|---| +| **Google Colab** [![Colab](https://img.shields.io/badge/Google_Colab-blue)](https://colab.research.google.com/drive/1co1P5LXSVb-g0hqv8Selfjq4WGxSpIFe?usp=sharing) | None (runs in the browser) | Interactive demo in the cloud — no local install needed | +| **`python app.py`** (web app) | None after setup | Point-and-click web interface for visualization | +| **`python visualize.py`** (CLI) | Basic command line | Visualize an entire class from a dataset | +| **`demo.ipynb`** (notebook) | Basic Jupyter skills | Interactive Python notebook | + +For first-time users we recommend starting with the **Google Colab** link or +the **web app** described in [Section 4](#4-running-the-web-app-recommended-for-new-users). + +--- + +## 2. Environment Setup + +Perform these steps once on your local machine: + +```bash +# 1. Create and activate a conda environment +conda create -n prompt_cam python=3.10 +conda activate prompt_cam + +# 2. Install all dependencies +source env_setup.sh +``` + +> **No conda?** You can also create a plain virtual environment: +> ```bash +> python -m venv prompt_cam_env +> source prompt_cam_env/bin/activate # Windows: prompt_cam_env\Scripts\activate +> bash env_setup.sh +> ``` + +--- + +## 3. Downloading Checkpoints + +Pre-trained checkpoints are hosted on Google Drive. Download the checkpoint for +the model/dataset pair you want to use and place it at: + +``` +checkpoints/{model}/{dataset}/model.pt +``` + +For example, the DINO + CUB checkpoint goes in `checkpoints/dino/cub/model.pt`. + +Available checkpoints: + +| Backbone | Dataset | Accuracy (top-1) | Download | +|---|---|---|---| +| dino | CUB-200 | 73.2% | [Google Drive](https://drive.google.com/drive/folders/1UmHdGx4OtWCQ1GhHCrBArQeeX14FqwyY?usp=sharing) | +| dino | Stanford Cars | 83.2% | [Google Drive](https://drive.google.com/drive/folders/1UmHdGx4OtWCQ1GhHCrBArQeeX14FqwyY?usp=sharing) | +| dino | Stanford Dogs | 81.1% | [Google Drive](https://drive.google.com/drive/folders/1UmHdGx4OtWCQ1GhHCrBArQeeX14FqwyY?usp=sharing) | +| dino | Oxford Pets | 91.3% | [Google Drive](https://drive.google.com/drive/folders/1UmHdGx4OtWCQ1GhHCrBArQeeX14FqwyY?usp=sharing) | +| dino | Birds-525 | 98.8% | [Google Drive](https://drive.google.com/drive/folders/1UmHdGx4OtWCQ1GhHCrBArQeeX14FqwyY?usp=sharing) | +| dinov2 | CUB-200 | 74.1% | [Google Drive](https://drive.google.com/drive/folders/1UmHdGx4OtWCQ1GhHCrBArQeeX14FqwyY?usp=sharing) | +| dinov2 | Stanford Dogs | 81.3% | [Google Drive](https://drive.google.com/drive/folders/1UmHdGx4OtWCQ1GhHCrBArQeeX14FqwyY?usp=sharing) | +| dinov2 | Oxford Pets | 92.7% | [Google Drive](https://drive.google.com/drive/folders/1UmHdGx4OtWCQ1GhHCrBArQeeX14FqwyY?usp=sharing) | + +--- + +## 4. Running the Web App (Recommended for New Users) + +After completing [Section 2](#2-environment-setup) and [Section 3](#3-downloading-checkpoints): + +```bash +conda activate prompt_cam +python app.py +``` + +Open the URL displayed in your terminal (typically `http://localhost:7860`). + +### Step-by-step walkthrough + +1. **Select a Model / Dataset** from the drop-down (e.g., *DINO / CUB-200*). +2. **Upload a checkpoint** — click the file picker and select the `.pt` file you + downloaded in Step 3. +3. **Choose an image** — either upload your own photo or select one of the + built-in sample images (CUB birds) and click *Load sample →*. +4. **Set the target class index** — this is the class whose traits you want to + visualize. For example, in CUB-200, class `97` is *Scott's Oriole*. + - Class numbering starts at **0**. + - The total number of classes per dataset is shown in the drop-down label. +5. **Set the number of traits** (attention heads) to highlight — 3 to 4 is a + good starting point. +6. Click **▶ Run Visualization**. + +The output panel shows the original image alongside the top-ranked trait +heatmaps overlaid on the image. + +--- + +## 5. Running the Demo Notebook + +If you prefer an interactive notebook environment: + +```bash +conda activate prompt_cam +jupyter notebook demo.ipynb +``` + +The notebook walks through: +- Loading a model checkpoint +- Visualizing traits for a single image +- Comparing how different classes "see" the same image +- Trait manipulation examples + +Edit the variables in the **first cell** to point to your chosen model, dataset +config, and checkpoint path. + +--- + +## 6. Running Visualization from the Command Line + +```bash +conda activate prompt_cam +python visualize.py \ + --config ./experiment/config/prompt_cam/dino/cub/args.yaml \ + --checkpoint ./checkpoints/dino/cub/model.pt \ + --vis_cls 23 +``` + +Key arguments: + +| Argument | Default | Description | +|---|---|---| +| `--config` | — | Path to the YAML config file for the model/dataset | +| `--checkpoint` | — | Path to the trained `.pt` checkpoint | +| `--vis_cls` | `23` | Class index to visualize | +| `--top_traits` | `4` | Number of top attention heads to highlight | +| `--nmbr_samples` | `10` | Number of test images to process | +| `--vis_outdir` | `./visualization` | Output directory for saved images | + +Output images are saved to `visualization/{model}/{dataset}/class_{N}/`. + +--- + +## 7. Training a New Model + +Training requires a GPU and the full dataset. See the [README Training +section](README.md#fire-training) for detailed instructions. + +--- + +## 8. Extending the Codebase + +### Adding a new dataset + +1. Prepare your data as described in [Data Preparation](README.md#data-preparation). +2. Create a new dataset file in `data/dataset/` modelled on + [`data/dataset/cub.py`](data/dataset/cub.py). +3. Register it in [`experiment/build_loader.py`](experiment/build_loader.py). +4. Create a config file at + `experiment/config/prompt_cam/{model}/{dataset}/args.yaml` (copy and adapt + an existing one). + +### Adding a new backbone + +1. Modify `get_base_model()` in [`experiment/build_model.py`](experiment/build_model.py). +2. Register the architecture in [`model/vision_transformer.py`](model/vision_transformer.py). +3. Add the new `--pretrained_weights` and `--model` choices to `setup_parser()` + in [`main.py`](main.py). + +--- + +## 9. Reporting Issues + +Found a bug or have a question? Please [open an issue](https://github.com/Imageomics/Prompt_CAM/issues/new) and include: + +- Your operating system and Python version (`python --version`) +- The exact command or steps you ran +- The full error message / stack trace (if any) +- Which model/dataset combination you were using + +--- + +## 10. Code Style + +- Python 3.10+ +- Follow [PEP 8](https://peps.python.org/pep-0008/) for formatting. +- Use descriptive variable names; avoid single-letter names outside loop + counters. +- Add docstrings to new public functions and classes. +- Do not commit large binary files (checkpoints, datasets) — these belong on + external storage (Google Drive, Hugging Face Hub, etc.). diff --git a/README.md b/README.md index 66ae672..f26c01f 100644 --- a/README.md +++ b/README.md @@ -29,10 +29,16 @@ Witness the important traits of different class through the lens of Prompt-CAM w 👉 Try our demo locally in [![](https://img.shields.io/badge/notebook-orange )](demo.ipynb) -- Setup the [envoiroment](#environment-setup) +- Setup the [environment](#environment-setup) - download the pre-trained model from link below! - run the demo. +👉 Or launch the interactive **web app** (no Python editing required): +```bash +python app.py +``` +Then open the URL shown in your terminal (e.g. `http://localhost:7860`). + 👉 You can extend this code base to include: [New datasets](#to-add-a-new-dataset) and [New backbones](#to-add-a-new-backbone) diff --git a/app.py b/app.py new file mode 100644 index 0000000..0e0ac09 --- /dev/null +++ b/app.py @@ -0,0 +1,321 @@ +""" +Prompt-CAM Interactive Visualization App + +Run this app to explore Prompt-CAM visualizations through a web interface +without writing any Python code: + + python app.py + +Then open the URL shown in your terminal (typically http://localhost:7860). +""" + +import io +import os + +import matplotlib +matplotlib.use("Agg") # Non-interactive backend – must be set before pyplot import +import matplotlib.pyplot as plt + +import gradio as gr +import torch +from dotwiz import DotWiz +from PIL import Image + +from experiment.build_model import get_model +from utils.misc import load_yaml, set_seed +from utils.visual_utils import load_image, prune_and_plot_ranked_heads + +# --------------------------------------------------------------------------- +# Static configuration +# --------------------------------------------------------------------------- + +# Map human-readable labels to YAML config paths +CONFIGS = { + "DINO / CUB-200 (200 bird species)": "experiment/config/prompt_cam/dino/cub/args.yaml", + "DINO / Stanford Cars (196 classes)": "experiment/config/prompt_cam/dino/car/args.yaml", + "DINO / Stanford Dogs (120 breeds)": "experiment/config/prompt_cam/dino/dog/args.yaml", + "DINO / Oxford Pets (37 breeds)": "experiment/config/prompt_cam/dino/pet/args.yaml", + "DINO / Birds-525 (525 species)": "experiment/config/prompt_cam/dino/birds_525/args.yaml", + "DINOv2 / CUB-200 (200 bird species)": "experiment/config/prompt_cam/dinov2/cub/args.yaml", + "DINOv2 / Stanford Dogs (120 breeds)": "experiment/config/prompt_cam/dinov2/dog/args.yaml", + "DINOv2 / Oxford Pets (37 breeds)": "experiment/config/prompt_cam/dinov2/pet/args.yaml", +} + +# Dataset class counts (for informational display) +DATASET_CLASS_COUNTS = { + "cub": 200, + "car": 196, + "dog": 120, + "pet": 37, + "birds_525": 525, +} + +# Sample images bundled with the repository (CUB dataset, DINO backbone) +SAMPLE_IMAGES = { + "Scott's Oriole (class 97)": ("samples/Scott_Oriole.jpg", 97), + "Baltimore Oriole (class 94)": ("samples/Baltimore_Oriole.jpg", 94), + "Orchard Oriole (class 96)": ("samples/Orchard_Oriole.jpg", 96), + "Rusty Blackbird (class 10)": ("samples/rusty_Blackbird.jpg", 10), + "Brewer's Blackbird (class 11)": ("samples/Brewer_Blackbird.jpg", 11), + "Yellow-headed Blackbird (class 195)": ("samples/yellow_headed_blackbird.jpg", 195), +} + +# --------------------------------------------------------------------------- +# Core visualization function +# --------------------------------------------------------------------------- + +def visualize_traits( + config_label: str, + checkpoint_path: str, + image_input, + target_class: int, + top_traits: int, +): + """Load a checkpoint, run Prompt-CAM on *image_input*, and return the plot. + + Parameters + ---------- + config_label: + Human-readable label from the CONFIGS dictionary. + checkpoint_path: + Path (str) to a saved ``.pt`` checkpoint produced by ``main.py``. + image_input: + Either a PIL Image (from the upload widget) or a file path string. + target_class: + Class index whose prompts should be visualised. + top_traits: + Number of top attention heads (traits) to highlight. + + Returns + ------- + result_image : PIL.Image | None + The visualisation plot, or None on failure. + status_message : str + A short success / error message to display in the UI. + """ + if checkpoint_path is None: + return None, "❌ Please upload a model checkpoint (.pt file)." + if image_input is None: + return None, "❌ Please upload an image or select a sample image." + if config_label not in CONFIGS: + return None, "❌ Invalid model/dataset selection." + + config_path = CONFIGS[config_label] + if not os.path.exists(config_path): + return None, f"❌ Config file not found: {config_path}" + + try: + # Build args from the YAML config + yaml_config = load_yaml(config_path) + args = DotWiz(yaml_config) + + args.checkpoint = checkpoint_path + args.vis_cls = int(target_class) + args.top_traits = int(top_traits) + args.test_batch_size = 1 + args.random_seed = 42 + set_seed(args.random_seed) + + # Load the model (visualize=True skips loading raw pre-trained weights) + model, _, _ = get_model(args, visualize=True) + state = torch.load(args.checkpoint, map_location=args.device) + model.load_state_dict(state["model_state_dict"]) + model.eval() + + # Prepare the input image + if isinstance(image_input, str): + # File path (e.g. from sample image selection) + sample = load_image(image_input) + else: + # PIL Image from the upload widget – save temporarily + tmp_path = "/tmp/_prompt_cam_input.jpg" + image_input.save(tmp_path) + sample = load_image(tmp_path) + sample = sample.to(args.device, non_blocking=True) + + # Run Prompt-CAM and capture the matplotlib figure + plt.close("all") + with torch.no_grad(): + prune_and_plot_ranked_heads(model, sample, int(target_class), args) + fig = plt.gcf() + + buf = io.BytesIO() + fig.savefig(buf, format="png", bbox_inches="tight", dpi=150) + buf.seek(0) + result_image = Image.open(buf).copy() + plt.close("all") + + return result_image, "✅ Visualisation complete!" + + except FileNotFoundError as exc: + return None, f"❌ File not found: {exc}" + except KeyError as exc: + return None, ( + f"❌ Checkpoint is missing key {exc}. " + "Make sure you are using a checkpoint saved by this codebase." + ) + except Exception as exc: # noqa: BLE001 + return None, f"❌ Unexpected error: {exc}" + + +# --------------------------------------------------------------------------- +# Gradio interface helpers +# --------------------------------------------------------------------------- + +def load_sample(sample_label: str): + """Return the image path and suggested class index for a sample image.""" + if sample_label in SAMPLE_IMAGES: + path, cls = SAMPLE_IMAGES[sample_label] + if os.path.exists(path): + return Image.open(path), cls + return None, 0 + + +# --------------------------------------------------------------------------- +# Build the Gradio UI +# --------------------------------------------------------------------------- + +def build_interface(): + with gr.Blocks(title="Prompt-CAM Visualizer") as demo: + gr.Markdown( + """ +# 🔍 Prompt-CAM Interactive Visualizer + +**Prompt-CAM** makes Vision Transformers interpretable for fine-grained analysis. +This app lets you explore *which traits* the model focuses on for any class — +no Python knowledge required! + +### Quick start +1. Select a **Model / Dataset** combination. +2. Upload the matching **checkpoint** (`.pt` file downloaded from the links in the README). +3. Upload an **image** or pick one of the built-in sample images. +4. Set the **target class index** and **number of traits** to show. +5. Click **▶ Run Visualization**. +""" + ) + + with gr.Row(): + # ---- Left column: inputs ---------------------------------------- + with gr.Column(scale=1): + gr.Markdown("### ⚙️ Model & Checkpoint") + config_dropdown = gr.Dropdown( + choices=list(CONFIGS.keys()), + value="DINO / CUB-200 (200 bird species)", + label="Model / Dataset", + info="Choose the backbone and dataset that matches your checkpoint.", + ) + checkpoint_upload = gr.File( + label="Checkpoint file (.pt)", + file_types=[".pt"], + type="filepath", + ) + + gr.Markdown("### 🖼️ Input Image") + with gr.Tab("Upload your own image"): + image_upload = gr.Image( + label="Upload image", + type="pil", + ) + with gr.Tab("Use a sample image"): + sample_dropdown = gr.Dropdown( + choices=list(SAMPLE_IMAGES.keys()), + value=list(SAMPLE_IMAGES.keys())[0], + label="Sample image (CUB dataset)", + ) + sample_preview = gr.Image( + label="Preview", + type="pil", + interactive=False, + ) + load_sample_btn = gr.Button("Load sample →", size="sm") + + gr.Markdown("### 🎛️ Visualisation Parameters") + target_class = gr.Number( + label="Target class index", + value=97, + precision=0, + info=( + "Zero-based class index to visualise. " + "For CUB-200: 0–199, Stanford Cars: 0–195, etc." + ), + ) + top_traits = gr.Slider( + minimum=1, + maximum=12, + step=1, + value=3, + label="Top traits to show", + info="Number of most important attention heads to highlight.", + ) + + run_btn = gr.Button("▶ Run Visualization", variant="primary") + + # ---- Right column: outputs -------------------------------------- + with gr.Column(scale=1): + gr.Markdown("### 📊 Results") + output_image = gr.Image( + label="Trait visualisation", + type="pil", + interactive=False, + ) + status_text = gr.Textbox( + label="Status", + interactive=False, + lines=2, + ) + + # ---- Wire up callbacks ------------------------------------------ + load_sample_btn.click( + fn=load_sample, + inputs=[sample_dropdown], + outputs=[sample_preview, target_class], + ) + + # Visualisation can use either the uploaded image or the sample preview + def run_with_either_image( + config_label, checkpoint_path, uploaded_img, sample_img, target_cls, top_k + ): + image = uploaded_img if uploaded_img is not None else sample_img + return visualize_traits(config_label, checkpoint_path, image, target_cls, top_k) + + run_btn.click( + fn=run_with_either_image, + inputs=[ + config_dropdown, + checkpoint_upload, + image_upload, + sample_preview, + target_class, + top_traits, + ], + outputs=[output_image, status_text], + ) + + gr.Markdown( + """ +--- +### 📖 Tips + +- **Class index** – class numbering starts at **0**. For CUB-200, class 0 is + *001.Black_footed_Albatross*, class 97 is *Scott's Oriole*, etc. +- **Checkpoint** – download checkpoints from the Google Drive links in the + [README](README.md) and place them in `checkpoints/{model}/{dataset}/`. +- **GPU vs CPU** – a GPU speeds up inference considerably, but the app works on + CPU as well (expect slower runtimes). +- **Sample images** – the bundled samples are from the CUB-200 validation set + and work best with a CUB checkpoint. + +See the [README](README.md) for full documentation and training instructions. +""" + ) + + return demo + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + interface = build_interface() + interface.launch(share=False) diff --git a/env_setup.sh b/env_setup.sh index 742458a..b28f9f5 100644 --- a/env_setup.sh +++ b/env_setup.sh @@ -28,3 +28,4 @@ pip install ftfy regex tqdm --no-cache-dir pip install pandas --no-cache-dir pip install matplotlib --no-cache-dir pip install ipykernel --no-cache-dir +pip install gradio --no-cache-dir diff --git a/experiment/build_model.py b/experiment/build_model.py index 3dca657..e57b187 100644 --- a/experiment/build_model.py +++ b/experiment/build_model.py @@ -37,7 +37,7 @@ def get_model(params,visualize=False): model_grad_params_no_head = log_model_info(model, logger) - model = model.cuda(device=params.device) + model = model.to(params.device) return model, tune_parameters, model_grad_params_no_head