feat: add visualization functions and integrate figure logging#16
Conversation
📝 WalkthroughWalkthroughThe PR adds comprehensive visualization functionality for CIAO explanation results and integrates it with MLflow logging. A new visualization module provides three plotting functions to display results, while the explanation result now exposes the replacement image tensor. The main CLI conditionally generates and logs these figures to MLflow when enabled via configuration. Changes
Sequence DiagramsequenceDiagram
participant Main
participant Explainer
participant Visualization
participant MLflow
Main->>Explainer: explain(image, ...)
Explainer->>Explainer: compute replacement_image
Explainer-->>Main: ExplanationResult (with replacement_image)
alt log_figures enabled
Main->>Visualization: plot_overview(result)
Visualization-->>Main: Figure
Main->>MLflow: log_figure(fig, 'figures/overview.png')
Main->>Visualization: plot_regions(result)
Visualization-->>Main: Figure
Main->>MLflow: log_figure(fig, 'figures/regions.png')
Main->>Visualization: plot_region_scores(result)
Visualization-->>Main: Figure
Main->>MLflow: log_figure(fig, 'figures/region_scores.png')
Main->>Main: close all figures
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~28 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces visualization capabilities for CIAO explanation results, adding a new module to generate overview plots, region-specific images, and score heatmaps. These visualizations are integrated into the main execution flow with support for MLflow logging. Additionally, the PR centralizes ImageNet normalization constants and updates the explanation result structure to include replacement images. The review feedback suggests enhancing the robustness of tensor-to-numpy conversions by using .detach() and recommends implementing a grid layout for subplots to handle cases with a large number of regions more effectively.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (4)
ciao/data/replacement.py (1)
22-23: Consider matching dtype of the input tensor.
torch.tensor(IMAGENET_MEAN, device=device)will default tofloat32. Ifinput_tensor/imageis everfloat16/bfloat16(mixed precision), the arithmetic at lines 26/138 will upcast or error. Minor/defensive — passdtype=input_tensor.dtype(resp.image.dtype) to be safe and consistent with howcolor_tensoris built on line 129.♻️ Proposed fix
- imagenet_mean = torch.tensor(IMAGENET_MEAN, device=device).view(3, 1, 1) - imagenet_std = torch.tensor(IMAGENET_STD, device=device).view(3, 1, 1) + imagenet_mean = torch.tensor(IMAGENET_MEAN, dtype=input_tensor.dtype, device=device).view(3, 1, 1) + imagenet_std = torch.tensor(IMAGENET_STD, dtype=input_tensor.dtype, device=device).view(3, 1, 1)And similarly in
make_solid_color_replacement:- imagenet_mean = torch.tensor(IMAGENET_MEAN, device=image.device).view(3, 1, 1) - imagenet_std = torch.tensor(IMAGENET_STD, device=image.device).view(3, 1, 1) + imagenet_mean = torch.tensor(IMAGENET_MEAN, dtype=image.dtype, device=image.device).view(3, 1, 1) + imagenet_std = torch.tensor(IMAGENET_STD, dtype=image.dtype, device=image.device).view(3, 1, 1)Also applies to: 135-136
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@ciao/data/replacement.py` around lines 22 - 23, The ImageNet mean/std tensors are created without matching the input dtype which can cause upcasts or errors under mixed precision; update the creation of imagenet_mean and imagenet_std to pass dtype=input_tensor.dtype (and for the other occurrence in make_solid_color_replacement use dtype=image.dtype) in addition to device so they match the tensor used for arithmetic (refer to the imagenet_mean/imagenet_std variables and the make_solid_color_replacement function and the color_tensor built near where color_tensor is created).ciao/__main__.py (1)
129-137: Use.get()for backward-compat onlog_figures.Accessing
cfg.logger.log_figuresas an attribute will raise on existing/user configs that predate this flag (e.g., overrides or custom logger configs that don't set it). Prefer a defaulted lookup so this remains opt-in and non-breaking.♻️ Proposed fix
- if cfg.logger.log_figures and results.regions: + if cfg.logger.get("log_figures", False) and results.regions:Also note: the
results.regionsguard here is what preventsplot_regions/plot_region_scoresfrom crashing on empty regions (see comment invisualization.py). Keep them in lockstep.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@ciao/__main__.py` around lines 129 - 137, Replace the direct attribute access cfg.logger.log_figures with a safe, defaulted lookup (e.g., cfg.logger.get("log_figures", False)) so missing/older logger configs don't raise; leave the existing results.regions guard in place so plot_overview, plot_regions and plot_region_scores are only called when regions exist, and continue to call mlflow.log_figure(...) and plt.close(fig) as before.ciao/visualization/visualization.py (2)
83-99: Guard against emptyresult.regionswhen called directly.
plot_regionsandplot_region_scoreswill fail if invoked with an empty regions list:plt.subplots(1, 0, ...)raises, andmax(abs(s) for s in all_scores)at line 112 raisesValueErroron an empty iterable. The__main__.pycaller currently guards this, but as a public API inciao.visualizationit's worth an early-return or explicitValueErrorso direct users get a clear contract.♻️ Proposed fix (example)
def plot_regions(result: ExplanationResult) -> Figure: """One subplot per region: region pixels replaced, rest is original.""" + if not result.regions: + raise ValueError("plot_regions requires at least one region") img = _to_hwc(result.input_batch)def plot_region_scores(result: ExplanationResult) -> Figure: ... + if not result.regions: + raise ValueError("plot_region_scores requires at least one region") img = _to_hwc(result.input_batch)Also applies to: 102-128
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@ciao/visualization/visualization.py` around lines 83 - 99, Both plot_regions and plot_region_scores must guard against an empty result.regions list; add an explicit early check at the top of each function (e.g., if not result.regions:) and raise a ValueError with a clear message like "result.regions is empty" (so callers get a clear contract) before calling plt.subplots or computing max(abs(s) for s in all_scores); reference plot_regions, plot_region_scores, result.regions, the plt.subplots(1, n, ...) call, and the max(abs(s) for s in all_scores) expression when applying the fix.
37-42: Optional: vectorize segment→value lookups.Both
_region_maskand the score-map build at lines 55–57 loop over segment IDs in Python. For larger segmentations these can be vectorized withnp.isin/ fancy indexing, which is both shorter and faster:♻️ Example
# _region_mask return np.isin(segments, np.fromiter(region, dtype=segments.dtype)) # score_map ids = np.fromiter(result.segment_scores.keys(), dtype=segs.dtype) vals = np.fromiter(result.segment_scores.values(), dtype=np.float32) lut = np.zeros(int(segs.max()) + 1, dtype=np.float32) lut[ids] = vals score_map = lut[segs]Not a correctness concern — leave as-is if current performance is fine.
Also applies to: 55-57
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@ciao/visualization/visualization.py` around lines 37 - 42, The loops over segment IDs should be vectorized for performance: replace the Python loop in _region_mask to use numpy membership testing (np.isin) on segments/region, and replace the per-segment Python loop that builds score_map from result.segment_scores by creating an index array from the keys, a values array from the values, build a LUT (zero-initialized with length = segs.max()+1), assign LUT[ids] = vals and then index into segs to produce score_map; refer to the functions/variables _region_mask, segments/segs, result.segment_scores, ids/vals/LUT and score_map when making the changes.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@ciao/visualization/visualization.py`:
- Around line 45-46: The docstring for plot_overview is outdated: it describes a
3-panel layout but the function now creates four panels (original | segmentation
| segment-score heatmap | replacement). Update the docstring of
plot_overview(ExplanationResult) -> Figure to accurately list all four panels in
the correct order (original, segmentation, score heatmap, replacement) and
briefly state that the figure returns a 4-panel side-by-side comparison.
---
Nitpick comments:
In `@ciao/__main__.py`:
- Around line 129-137: Replace the direct attribute access
cfg.logger.log_figures with a safe, defaulted lookup (e.g.,
cfg.logger.get("log_figures", False)) so missing/older logger configs don't
raise; leave the existing results.regions guard in place so plot_overview,
plot_regions and plot_region_scores are only called when regions exist, and
continue to call mlflow.log_figure(...) and plt.close(fig) as before.
In `@ciao/data/replacement.py`:
- Around line 22-23: The ImageNet mean/std tensors are created without matching
the input dtype which can cause upcasts or errors under mixed precision; update
the creation of imagenet_mean and imagenet_std to pass dtype=input_tensor.dtype
(and for the other occurrence in make_solid_color_replacement use
dtype=image.dtype) in addition to device so they match the tensor used for
arithmetic (refer to the imagenet_mean/imagenet_std variables and the
make_solid_color_replacement function and the color_tensor built near where
color_tensor is created).
In `@ciao/visualization/visualization.py`:
- Around line 83-99: Both plot_regions and plot_region_scores must guard against
an empty result.regions list; add an explicit early check at the top of each
function (e.g., if not result.regions:) and raise a ValueError with a clear
message like "result.regions is empty" (so callers get a clear contract) before
calling plt.subplots or computing max(abs(s) for s in all_scores); reference
plot_regions, plot_region_scores, result.regions, the plt.subplots(1, n, ...)
call, and the max(abs(s) for s in all_scores) expression when applying the fix.
- Around line 37-42: The loops over segment IDs should be vectorized for
performance: replace the Python loop in _region_mask to use numpy membership
testing (np.isin) on segments/region, and replace the per-segment Python loop
that builds score_map from result.segment_scores by creating an index array from
the keys, a values array from the values, build a LUT (zero-initialized with
length = segs.max()+1), assign LUT[ids] = vals and then index into segs to
produce score_map; refer to the functions/variables _region_mask, segments/segs,
result.segment_scores, ids/vals/LUT and score_map when making the changes.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 8069b51f-69d5-439e-b34c-2cbe8ccb8e0d
📒 Files selected for processing (7)
ciao/__main__.pyciao/data/preprocessing.pyciao/data/replacement.pyciao/explainer/ciao_explainer.pyciao/visualization/__init__.pyciao/visualization/visualization.pyconfigs/logger/mlflow.yaml
There was a problem hiding this comment.
Pull request overview
Adds a visualization suite for CIAO explanation results and hooks it into the main run loop to optionally log generated figures to MLflow.
Changes:
- Introduces
ciao.visualizationplotting helpers (plot_overview,plot_regions,plot_region_scores) built onmatplotlib. - Extends
ExplanationResultto includereplacement_image, enabling “masked/replaced” visual renderings. - Adds an MLflow config flag (
log_figures) and logs generated figures as MLflow artifacts when enabled.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| configs/logger/mlflow.yaml | Adds log_figures toggle to control artifact generation. |
| ciao/visualization/visualization.py | Implements plotting utilities for overview, regions, and region score tinting. |
| ciao/visualization/init.py | Exposes visualization functions as public package API. |
| ciao/explainer/ciao_explainer.py | Adds replacement_image to ExplanationResult and populates it in explain(). |
| ciao/data/replacement.py | Centralizes ImageNet normalization constants usage via preprocessing module. |
| ciao/data/preprocessing.py | Defines shared IMAGENET_MEAN/STD constants and reuses them in transforms. |
| ciao/main.py | Integrates figure generation and MLflow logging, ensuring figures are closed. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
c525361 to
f848a80
Compare
f848a80 to
7c46f49
Compare
7c46f49 to
d8df2e0
Compare
d8df2e0 to
c70b5e9
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
ciao/visualization/visualization.py (1)
49-49: Reduce shape coupling aroundreplacement_imageconversion.On Line 49 and Line 86, unconditional
unsqueeze(0)assumesreplacement_imageis CHW. If it is already BCHW,_to_hwc()will receive 5D and fail. Making_to_hwc()accept both CHW/BCHW avoids this fragility.Proposed refactor
def _to_hwc(tensor: torch.Tensor) -> np.ndarray: """Denormalize an image tensor to a displayable float32 [H, W, 3] array.""" - img = tensor.detach().squeeze(0).cpu().float().numpy().transpose(1, 2, 0) + arr = tensor.detach().cpu().float() + if arr.ndim == 4: + if arr.shape[0] != 1: + raise ValueError(f"Expected batch size 1, got shape {tuple(arr.shape)}") + arr = arr[0] + if arr.ndim != 3: + raise ValueError(f"Expected CHW or BCHW tensor, got shape {tuple(arr.shape)}") + img = arr.numpy().transpose(1, 2, 0) return np.clip(img * _IMAGENET_STD + _IMAGENET_MEAN, 0.0, 1.0)- repl = _to_hwc(result.replacement_image.unsqueeze(0)) + repl = _to_hwc(result.replacement_image)Also applies to: 86-86
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@ciao/visualization/visualization.py` at line 49, The code currently unconditionally calls replacement_image.unsqueeze(0) before _to_hwc, which breaks if replacement_image is already BCHW; modify either the call sites or _to_hwc: best fix is to update _to_hwc to accept both CHW and BCHW by checking tensor.dim() and if dim==4 treat it as BCHW (select the first batch or iterate), if dim==3 treat it as CHW, and raise a clear error for other ranks; then remove forced unsqueeze(0) or make the callers (where repl is created from replacement_image) perform a conditional unsqueeze only when dim()==3 so _to_hwc receives a 3D CHW tensor or a 4D BCHW tensor consistently.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@ciao/__main__.py`:
- Around line 141-148: The loop calling
plot_overview/plot_regions/plot_region_scores and mlflow.log_figure can raise
and skip plt.close(fig), so wrap each iteration in a try/except/finally: call
fig = plot_fn(results) and mlflow.log_figure(...) inside try, catch and log any
exception (but do not re-raise) so a single failing plot doesn't abort the run,
and in finally always call plt.close(fig) if fig is defined; reference
plot_overview, plot_regions, plot_region_scores, mlflow.log_figure, plt.close,
and the loop variable fig to locate the code to change.
In `@ciao/visualization/visualization.py`:
- Around line 88-90: The plotting helpers currently call plt.subplots(1, n, ...)
with n = len(result.regions) which crashes when n == 0; update the functions
that compute n (where result.regions is referenced) to explicitly handle n == 0
by raising a clear ValueError (e.g. "no regions to plot") or returning an
empty/trivial figure instead of calling plt.subplots; apply this check in both
places where n is used (the block around the existing n = len(result.regions)
and the similar block later) so the code either returns a valid empty
matplotlib.Figure or raises a descriptive error before calling plt.subplots.
---
Nitpick comments:
In `@ciao/visualization/visualization.py`:
- Line 49: The code currently unconditionally calls
replacement_image.unsqueeze(0) before _to_hwc, which breaks if replacement_image
is already BCHW; modify either the call sites or _to_hwc: best fix is to update
_to_hwc to accept both CHW and BCHW by checking tensor.dim() and if dim==4 treat
it as BCHW (select the first batch or iterate), if dim==3 treat it as CHW, and
raise a clear error for other ranks; then remove forced unsqueeze(0) or make the
callers (where repl is created from replacement_image) perform a conditional
unsqueeze only when dim()==3 so _to_hwc receives a 3D CHW tensor or a 4D BCHW
tensor consistently.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: ba0943a3-ad4a-47d1-8815-6633bae68141
📒 Files selected for processing (5)
ciao/__main__.pyciao/explainer/ciao_explainer.pyciao/visualization/__init__.pyciao/visualization/visualization.pyconfigs/logger/mlflow.yaml
✅ Files skipped from review due to trivial changes (1)
- configs/logger/mlflow.yaml
🚧 Files skipped from review as they are similar to previous changes (2)
- ciao/visualization/init.py
- ciao/explainer/ciao_explainer.py
Context:
This PR introduces the core visual evaluation tools for the CIAO explanation pipeline. These visualizations are automatically generated and logged to MLflow as artifacts if configured.
What's Changed / Added:
ciao/visualization/visualization.py: Added a new visualization module containing three core plotting functions usingmatplotlib:plot_overview: A 4-panel figure showing the original image, segmentation boundaries, a global segment-score heatmap (RdBucolormap), and the chosen replacement background.plot_regions: Generates subplots for each top region, blending the original image with the replacement background specifically within the masked region.plot_region_scores: Generates subplots where each identified region is tinted red (positive score/drop) or blue (negative score), proportional to its impact.ciao/explainer/ciao_explainer.py: UpdatedExplanationResultdataclass to carry thereplacement_imagetensor, providing the necessary context for the visualization functions to render the "masked" state.ciao/__main__.py: Integrated the visualization suite into the main execution loop. Automatically generates and logs figures (overview.png,regions.png,region_scores.png) viamlflow.log_figure()if the run yields valid regions and figure logging is enabled.ciao/data/preprocessing.py&ciao/data/replacement.py: ExtractedIMAGENET_MEANandIMAGENET_STDinto centralized constants to prevent code duplication between preprocessing and the new visualization denormalization logic.configs/logger/mlflow.yaml: Added thelog_figures: trueflag to easily toggle artifact generation.Related Task:
XAI-29
Summary by CodeRabbit
Release Notes