Skip to content

feat: add visualization functions and integrate figure logging#16

Merged
dhalmazna merged 5 commits into
masterfrom
feat/visualization
Apr 28, 2026
Merged

feat: add visualization functions and integrate figure logging#16
dhalmazna merged 5 commits into
masterfrom
feat/visualization

Conversation

@dhalmazna

@dhalmazna dhalmazna commented Apr 23, 2026

Copy link
Copy Markdown
Collaborator

Context:
This PR introduces the core visual evaluation tools for the CIAO explanation pipeline. These visualizations are automatically generated and logged to MLflow as artifacts if configured.

What's Changed / Added:

  • ciao/visualization/visualization.py: Added a new visualization module containing three core plotting functions using matplotlib:
    • plot_overview: A 4-panel figure showing the original image, segmentation boundaries, a global segment-score heatmap (RdBu colormap), and the chosen replacement background.
    • plot_regions: Generates subplots for each top region, blending the original image with the replacement background specifically within the masked region.
    • plot_region_scores: Generates subplots where each identified region is tinted red (positive score/drop) or blue (negative score), proportional to its impact.
  • ciao/explainer/ciao_explainer.py: Updated ExplanationResult dataclass to carry the replacement_image tensor, providing the necessary context for the visualization functions to render the "masked" state.
  • ciao/__main__.py: Integrated the visualization suite into the main execution loop. Automatically generates and logs figures (overview.png, regions.png, region_scores.png) via mlflow.log_figure() if the run yields valid regions and figure logging is enabled.
  • ciao/data/preprocessing.py & ciao/data/replacement.py: Extracted IMAGENET_MEAN and IMAGENET_STD into centralized constants to prevent code duplication between preprocessing and the new visualization denormalization logic.
  • configs/logger/mlflow.yaml: Added the log_figures: true flag to easily toggle artifact generation.

Related Task:
XAI-29

Summary by CodeRabbit

Release Notes

  • New Features
    • Added visualization module with three figure types for explanation results: overview displaying original/segmentation/scores/replacement, regional comparisons, and regional score heatmaps.
    • Explanation results now include the replacement image for direct access.
    • Automatic visualization logging to MLflow (configurable).

@dhalmazna dhalmazna self-assigned this Apr 23, 2026
@coderabbitai

coderabbitai Bot commented Apr 23, 2026

Copy link
Copy Markdown
📝 Walkthrough

Walkthrough

The PR adds comprehensive visualization functionality for CIAO explanation results and integrates it with MLflow logging. A new visualization module provides three plotting functions to display results, while the explanation result now exposes the replacement image tensor. The main CLI conditionally generates and logs these figures to MLflow when enabled via configuration.

Changes

Cohort / File(s) Summary
Visualization Module
ciao/visualization/visualization.py, ciao/visualization/__init__.py
Introduces three new plotting functions (plot_overview, plot_regions, plot_region_scores) that visualize CIAO explanation results with segment boundaries, heatmaps, and region-specific overlays. Functions handle image denormalization, mask aggregation, and matplotlib figure rendering.
MLflow Integration
ciao/__main__.py, configs/logger/mlflow.yaml
Adds conditional figure logging to MLflow after explanation. The _log_figures helper generates and logs visualization artifacts when cfg.logger.log_figures is enabled, with explicit figure cleanup.
Explanation Result
ciao/explainer/ciao_explainer.py
Exposes the computed replacement_image tensor in ExplanationResult dataclass, allowing callers to access the obfuscated replacement image directly.

Sequence Diagram

sequenceDiagram
    participant Main
    participant Explainer
    participant Visualization
    participant MLflow

    Main->>Explainer: explain(image, ...)
    Explainer->>Explainer: compute replacement_image
    Explainer-->>Main: ExplanationResult (with replacement_image)
    
    alt log_figures enabled
        Main->>Visualization: plot_overview(result)
        Visualization-->>Main: Figure
        Main->>MLflow: log_figure(fig, 'figures/overview.png')
        Main->>Visualization: plot_regions(result)
        Visualization-->>Main: Figure
        Main->>MLflow: log_figure(fig, 'figures/regions.png')
        Main->>Visualization: plot_region_scores(result)
        Visualization-->>Main: Figure
        Main->>MLflow: log_figure(fig, 'figures/region_scores.png')
        Main->>Main: close all figures
    end
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~28 minutes

Possibly related PRs

Suggested reviewers

  • vojtech-kur
  • Adames4
  • vejtek

Poem

🐰 Plots blossom bright in Matplotlib's glow,
Regions dance with colors to let insights show,
MLflow captures each figure with care,
Replacement images float through the air,
Visualization magic is finally here! ✨

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main changes: adding visualization functions (plot_overview, plot_regions, plot_region_scores) and integrating figure logging to MLflow.
Docstring Coverage ✅ Passed Docstring coverage is 85.71% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feat/visualization

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces visualization capabilities for CIAO explanation results, adding a new module to generate overview plots, region-specific images, and score heatmaps. These visualizations are integrated into the main execution flow with support for MLflow logging. Additionally, the PR centralizes ImageNet normalization constants and updates the explanation result structure to include replacement images. The review feedback suggests enhancing the robustness of tensor-to-numpy conversions by using .detach() and recommends implementing a grid layout for subplots to handle cases with a large number of regions more effectively.

Comment thread ciao/visualization/visualization.py Outdated
Comment thread ciao/visualization/visualization.py
Comment thread ciao/visualization/visualization.py

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (4)
ciao/data/replacement.py (1)

22-23: Consider matching dtype of the input tensor.

torch.tensor(IMAGENET_MEAN, device=device) will default to float32. If input_tensor/image is ever float16/bfloat16 (mixed precision), the arithmetic at lines 26/138 will upcast or error. Minor/defensive — pass dtype=input_tensor.dtype (resp. image.dtype) to be safe and consistent with how color_tensor is built on line 129.

♻️ Proposed fix
-    imagenet_mean = torch.tensor(IMAGENET_MEAN, device=device).view(3, 1, 1)
-    imagenet_std = torch.tensor(IMAGENET_STD, device=device).view(3, 1, 1)
+    imagenet_mean = torch.tensor(IMAGENET_MEAN, dtype=input_tensor.dtype, device=device).view(3, 1, 1)
+    imagenet_std = torch.tensor(IMAGENET_STD, dtype=input_tensor.dtype, device=device).view(3, 1, 1)

And similarly in make_solid_color_replacement:

-        imagenet_mean = torch.tensor(IMAGENET_MEAN, device=image.device).view(3, 1, 1)
-        imagenet_std = torch.tensor(IMAGENET_STD, device=image.device).view(3, 1, 1)
+        imagenet_mean = torch.tensor(IMAGENET_MEAN, dtype=image.dtype, device=image.device).view(3, 1, 1)
+        imagenet_std = torch.tensor(IMAGENET_STD, dtype=image.dtype, device=image.device).view(3, 1, 1)

Also applies to: 135-136

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@ciao/data/replacement.py` around lines 22 - 23, The ImageNet mean/std tensors
are created without matching the input dtype which can cause upcasts or errors
under mixed precision; update the creation of imagenet_mean and imagenet_std to
pass dtype=input_tensor.dtype (and for the other occurrence in
make_solid_color_replacement use dtype=image.dtype) in addition to device so
they match the tensor used for arithmetic (refer to the
imagenet_mean/imagenet_std variables and the make_solid_color_replacement
function and the color_tensor built near where color_tensor is created).
ciao/__main__.py (1)

129-137: Use .get() for backward-compat on log_figures.

Accessing cfg.logger.log_figures as an attribute will raise on existing/user configs that predate this flag (e.g., overrides or custom logger configs that don't set it). Prefer a defaulted lookup so this remains opt-in and non-breaking.

♻️ Proposed fix
-                if cfg.logger.log_figures and results.regions:
+                if cfg.logger.get("log_figures", False) and results.regions:

Also note: the results.regions guard here is what prevents plot_regions/plot_region_scores from crashing on empty regions (see comment in visualization.py). Keep them in lockstep.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@ciao/__main__.py` around lines 129 - 137, Replace the direct attribute access
cfg.logger.log_figures with a safe, defaulted lookup (e.g.,
cfg.logger.get("log_figures", False)) so missing/older logger configs don't
raise; leave the existing results.regions guard in place so plot_overview,
plot_regions and plot_region_scores are only called when regions exist, and
continue to call mlflow.log_figure(...) and plt.close(fig) as before.
ciao/visualization/visualization.py (2)

83-99: Guard against empty result.regions when called directly.

plot_regions and plot_region_scores will fail if invoked with an empty regions list: plt.subplots(1, 0, ...) raises, and max(abs(s) for s in all_scores) at line 112 raises ValueError on an empty iterable. The __main__.py caller currently guards this, but as a public API in ciao.visualization it's worth an early-return or explicit ValueError so direct users get a clear contract.

♻️ Proposed fix (example)
 def plot_regions(result: ExplanationResult) -> Figure:
     """One subplot per region: region pixels replaced, rest is original."""
+    if not result.regions:
+        raise ValueError("plot_regions requires at least one region")
     img = _to_hwc(result.input_batch)
 def plot_region_scores(result: ExplanationResult) -> Figure:
     ...
+    if not result.regions:
+        raise ValueError("plot_region_scores requires at least one region")
     img = _to_hwc(result.input_batch)

Also applies to: 102-128

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@ciao/visualization/visualization.py` around lines 83 - 99, Both plot_regions
and plot_region_scores must guard against an empty result.regions list; add an
explicit early check at the top of each function (e.g., if not result.regions:)
and raise a ValueError with a clear message like "result.regions is empty" (so
callers get a clear contract) before calling plt.subplots or computing
max(abs(s) for s in all_scores); reference plot_regions, plot_region_scores,
result.regions, the plt.subplots(1, n, ...) call, and the max(abs(s) for s in
all_scores) expression when applying the fix.

37-42: Optional: vectorize segment→value lookups.

Both _region_mask and the score-map build at lines 55–57 loop over segment IDs in Python. For larger segmentations these can be vectorized with np.isin / fancy indexing, which is both shorter and faster:

♻️ Example
# _region_mask
return np.isin(segments, np.fromiter(region, dtype=segments.dtype))

# score_map
ids = np.fromiter(result.segment_scores.keys(), dtype=segs.dtype)
vals = np.fromiter(result.segment_scores.values(), dtype=np.float32)
lut = np.zeros(int(segs.max()) + 1, dtype=np.float32)
lut[ids] = vals
score_map = lut[segs]

Not a correctness concern — leave as-is if current performance is fine.

Also applies to: 55-57

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@ciao/visualization/visualization.py` around lines 37 - 42, The loops over
segment IDs should be vectorized for performance: replace the Python loop in
_region_mask to use numpy membership testing (np.isin) on segments/region, and
replace the per-segment Python loop that builds score_map from
result.segment_scores by creating an index array from the keys, a values array
from the values, build a LUT (zero-initialized with length = segs.max()+1),
assign LUT[ids] = vals and then index into segs to produce score_map; refer to
the functions/variables _region_mask, segments/segs, result.segment_scores,
ids/vals/LUT and score_map when making the changes.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@ciao/visualization/visualization.py`:
- Around line 45-46: The docstring for plot_overview is outdated: it describes a
3-panel layout but the function now creates four panels (original | segmentation
| segment-score heatmap | replacement). Update the docstring of
plot_overview(ExplanationResult) -> Figure to accurately list all four panels in
the correct order (original, segmentation, score heatmap, replacement) and
briefly state that the figure returns a 4-panel side-by-side comparison.

---

Nitpick comments:
In `@ciao/__main__.py`:
- Around line 129-137: Replace the direct attribute access
cfg.logger.log_figures with a safe, defaulted lookup (e.g.,
cfg.logger.get("log_figures", False)) so missing/older logger configs don't
raise; leave the existing results.regions guard in place so plot_overview,
plot_regions and plot_region_scores are only called when regions exist, and
continue to call mlflow.log_figure(...) and plt.close(fig) as before.

In `@ciao/data/replacement.py`:
- Around line 22-23: The ImageNet mean/std tensors are created without matching
the input dtype which can cause upcasts or errors under mixed precision; update
the creation of imagenet_mean and imagenet_std to pass dtype=input_tensor.dtype
(and for the other occurrence in make_solid_color_replacement use
dtype=image.dtype) in addition to device so they match the tensor used for
arithmetic (refer to the imagenet_mean/imagenet_std variables and the
make_solid_color_replacement function and the color_tensor built near where
color_tensor is created).

In `@ciao/visualization/visualization.py`:
- Around line 83-99: Both plot_regions and plot_region_scores must guard against
an empty result.regions list; add an explicit early check at the top of each
function (e.g., if not result.regions:) and raise a ValueError with a clear
message like "result.regions is empty" (so callers get a clear contract) before
calling plt.subplots or computing max(abs(s) for s in all_scores); reference
plot_regions, plot_region_scores, result.regions, the plt.subplots(1, n, ...)
call, and the max(abs(s) for s in all_scores) expression when applying the fix.
- Around line 37-42: The loops over segment IDs should be vectorized for
performance: replace the Python loop in _region_mask to use numpy membership
testing (np.isin) on segments/region, and replace the per-segment Python loop
that builds score_map from result.segment_scores by creating an index array from
the keys, a values array from the values, build a LUT (zero-initialized with
length = segs.max()+1), assign LUT[ids] = vals and then index into segs to
produce score_map; refer to the functions/variables _region_mask, segments/segs,
result.segment_scores, ids/vals/LUT and score_map when making the changes.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 8069b51f-69d5-439e-b34c-2cbe8ccb8e0d

📥 Commits

Reviewing files that changed from the base of the PR and between 9a6db19 and 54151c0.

📒 Files selected for processing (7)
  • ciao/__main__.py
  • ciao/data/preprocessing.py
  • ciao/data/replacement.py
  • ciao/explainer/ciao_explainer.py
  • ciao/visualization/__init__.py
  • ciao/visualization/visualization.py
  • configs/logger/mlflow.yaml

Comment thread ciao/visualization/visualization.py Outdated
@dhalmazna dhalmazna marked this pull request as ready for review April 23, 2026 16:36
Copilot AI review requested due to automatic review settings April 23, 2026 16:36

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a visualization suite for CIAO explanation results and hooks it into the main run loop to optionally log generated figures to MLflow.

Changes:

  • Introduces ciao.visualization plotting helpers (plot_overview, plot_regions, plot_region_scores) built on matplotlib.
  • Extends ExplanationResult to include replacement_image, enabling “masked/replaced” visual renderings.
  • Adds an MLflow config flag (log_figures) and logs generated figures as MLflow artifacts when enabled.

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
configs/logger/mlflow.yaml Adds log_figures toggle to control artifact generation.
ciao/visualization/visualization.py Implements plotting utilities for overview, regions, and region score tinting.
ciao/visualization/init.py Exposes visualization functions as public package API.
ciao/explainer/ciao_explainer.py Adds replacement_image to ExplanationResult and populates it in explain().
ciao/data/replacement.py Centralizes ImageNet normalization constants usage via preprocessing module.
ciao/data/preprocessing.py Defines shared IMAGENET_MEAN/STD constants and reuses them in transforms.
ciao/main.py Integrates figure generation and MLflow logging, ensuring figures are closed.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread ciao/visualization/visualization.py
Comment thread ciao/visualization/visualization.py
Comment thread ciao/visualization/visualization.py
Comment thread ciao/__main__.py Outdated
Adames4
Adames4 previously approved these changes Apr 28, 2026
Base automatically changed from feat/main-and-hydra to master April 28, 2026 16:35
@dhalmazna dhalmazna dismissed Adames4’s stale review April 28, 2026 16:35

The base branch was changed.

@dhalmazna dhalmazna requested a review from a team April 28, 2026 16:35
@dhalmazna dhalmazna force-pushed the feat/visualization branch from 7c46f49 to d8df2e0 Compare April 28, 2026 16:50
vejtek
vejtek previously approved these changes Apr 28, 2026
Adames4
Adames4 previously approved these changes Apr 28, 2026
@dhalmazna dhalmazna dismissed stale reviews from Adames4 and vejtek via c70b5e9 April 28, 2026 18:20
@dhalmazna dhalmazna force-pushed the feat/visualization branch from d8df2e0 to c70b5e9 Compare April 28, 2026 18:20

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
ciao/visualization/visualization.py (1)

49-49: Reduce shape coupling around replacement_image conversion.

On Line 49 and Line 86, unconditional unsqueeze(0) assumes replacement_image is CHW. If it is already BCHW, _to_hwc() will receive 5D and fail. Making _to_hwc() accept both CHW/BCHW avoids this fragility.

Proposed refactor
 def _to_hwc(tensor: torch.Tensor) -> np.ndarray:
     """Denormalize an image tensor to a displayable float32 [H, W, 3] array."""
-    img = tensor.detach().squeeze(0).cpu().float().numpy().transpose(1, 2, 0)
+    arr = tensor.detach().cpu().float()
+    if arr.ndim == 4:
+        if arr.shape[0] != 1:
+            raise ValueError(f"Expected batch size 1, got shape {tuple(arr.shape)}")
+        arr = arr[0]
+    if arr.ndim != 3:
+        raise ValueError(f"Expected CHW or BCHW tensor, got shape {tuple(arr.shape)}")
+    img = arr.numpy().transpose(1, 2, 0)
     return np.clip(img * _IMAGENET_STD + _IMAGENET_MEAN, 0.0, 1.0)
-    repl = _to_hwc(result.replacement_image.unsqueeze(0))
+    repl = _to_hwc(result.replacement_image)

Also applies to: 86-86

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@ciao/visualization/visualization.py` at line 49, The code currently
unconditionally calls replacement_image.unsqueeze(0) before _to_hwc, which
breaks if replacement_image is already BCHW; modify either the call sites or
_to_hwc: best fix is to update _to_hwc to accept both CHW and BCHW by checking
tensor.dim() and if dim==4 treat it as BCHW (select the first batch or iterate),
if dim==3 treat it as CHW, and raise a clear error for other ranks; then remove
forced unsqueeze(0) or make the callers (where repl is created from
replacement_image) perform a conditional unsqueeze only when dim()==3 so _to_hwc
receives a 3D CHW tensor or a 4D BCHW tensor consistently.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@ciao/__main__.py`:
- Around line 141-148: The loop calling
plot_overview/plot_regions/plot_region_scores and mlflow.log_figure can raise
and skip plt.close(fig), so wrap each iteration in a try/except/finally: call
fig = plot_fn(results) and mlflow.log_figure(...) inside try, catch and log any
exception (but do not re-raise) so a single failing plot doesn't abort the run,
and in finally always call plt.close(fig) if fig is defined; reference
plot_overview, plot_regions, plot_region_scores, mlflow.log_figure, plt.close,
and the loop variable fig to locate the code to change.

In `@ciao/visualization/visualization.py`:
- Around line 88-90: The plotting helpers currently call plt.subplots(1, n, ...)
with n = len(result.regions) which crashes when n == 0; update the functions
that compute n (where result.regions is referenced) to explicitly handle n == 0
by raising a clear ValueError (e.g. "no regions to plot") or returning an
empty/trivial figure instead of calling plt.subplots; apply this check in both
places where n is used (the block around the existing n = len(result.regions)
and the similar block later) so the code either returns a valid empty
matplotlib.Figure or raises a descriptive error before calling plt.subplots.

---

Nitpick comments:
In `@ciao/visualization/visualization.py`:
- Line 49: The code currently unconditionally calls
replacement_image.unsqueeze(0) before _to_hwc, which breaks if replacement_image
is already BCHW; modify either the call sites or _to_hwc: best fix is to update
_to_hwc to accept both CHW and BCHW by checking tensor.dim() and if dim==4 treat
it as BCHW (select the first batch or iterate), if dim==3 treat it as CHW, and
raise a clear error for other ranks; then remove forced unsqueeze(0) or make the
callers (where repl is created from replacement_image) perform a conditional
unsqueeze only when dim()==3 so _to_hwc receives a 3D CHW tensor or a 4D BCHW
tensor consistently.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: ba0943a3-ad4a-47d1-8815-6633bae68141

📥 Commits

Reviewing files that changed from the base of the PR and between 54151c0 and c70b5e9.

📒 Files selected for processing (5)
  • ciao/__main__.py
  • ciao/explainer/ciao_explainer.py
  • ciao/visualization/__init__.py
  • ciao/visualization/visualization.py
  • configs/logger/mlflow.yaml
✅ Files skipped from review due to trivial changes (1)
  • configs/logger/mlflow.yaml
🚧 Files skipped from review as they are similar to previous changes (2)
  • ciao/visualization/init.py
  • ciao/explainer/ciao_explainer.py

Comment thread ciao/__main__.py
Comment thread ciao/visualization/visualization.py
@dhalmazna dhalmazna requested a review from vejtek April 28, 2026 18:29
@dhalmazna dhalmazna requested a review from Adames4 April 28, 2026 18:29
@dhalmazna dhalmazna merged commit d31cf6d into master Apr 28, 2026
3 checks passed
@dhalmazna dhalmazna deleted the feat/visualization branch April 28, 2026 18:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants