Skip to content

Add CaloClusterGNN training pipeline#7

Open
zwl0331 wants to merge 3 commits into
Mu2e:mainfrom
zwl0331:add-calo-cluster-gnn
Open

Add CaloClusterGNN training pipeline#7
zwl0331 wants to merge 3 commits into
Mu2e:mainfrom
zwl0331:add-calo-cluster-gnn

Conversation

@zwl0331
Copy link
Copy Markdown

@zwl0331 zwl0331 commented May 4, 2026

Summary

Adds a new CaloClusterGNN/ subdirectory containing the full training
pipeline for a Graph Neural Network calorimeter-clustering algorithm,
intended to run alongside the existing seed+BFS CaloClusterMaker in
Mu2e Offline. The deployed recipe is CCN+BFS10: CaloClusterNet
edge classifier followed by BFS-style traversal at ExpandCut = 10 MeV.

Layout follows the convention set by TrkQual/ and TrkPID/ — one
top-level subdirectory per algorithm, self-contained.

What's in here

CaloClusterGNN/
├── README.md            retrainer-facing docs                                                                                                              
├── setup_env.sh         wraps setupmu2e-art.sh + ana 2.6.1
├── src/                                                                                                                                                    
│   ├── data/            graph builder, calo-entrant truth labels,
│   │                    normalisation, packed dataset                                                                                                      
│   ├── geometry/        crystalId -> (x, y, disk) loader
│   ├── models/          SimpleEdgeNet, CaloClusterNet, layers, heads,                                                                                      
│   │                    deploy wrappers for ONNX export                                                                                                    
│   ├── training/        losses, metrics, trainer                                                                                                           
│   └── inference/       cluster_reco + postprocess (used by train-time                                                                                     
│                        evaluation scripts)                                                                                                                
├── scripts/             build/pack/train/tune/evaluate pipeline,                                                                                           
│                        ONNX export + parity validation,                                                                                                   
│                        failure audits, cluster-physics evaluation,                                                                                        
│                        Run1B no-field generalisation evaluation                                                                                           
├── configs/             5 YAML configs (one per training run)                                                                                              
├── tests/               110 unit tests (4 conditionally skipped on a                                                                                       
│                        fresh checkout for missing checkpoints)                                                                                            
├── splits/              frozen 35/7/8 v2 file split                                                                                                        
└── data/                small geometry CSVs (crystal_geometry.csv etc.)                                                                                    

Two model classes train from the same pipeline:

Model Params Frozen tau_edge Use
SimpleEdgeNet 215 K 0.26 Reference / A-B comparison
CaloClusterNet 676 K 0.20 Production model (CCN+BFS10)

Both share the input graph (one per calorimeter disk per event,
6 node features + 8 edge features) and z-score normalisation, so
swapping models in deployment is config-only.

Headline result

On the MDC2025 mixed-pileup test set (276,688 events, 481,543
disk-graphs), CCN+BFS10 beats BFS on every downstream-relevant
cluster-physics metric for E_reco >= 50 MeV clusters (those that
matter for track finding):

Metric BFS CCN+BFS10 Change
Mean abs(dE) / MeV 0.839 0.616 -27%
95th-pct abs(dE) / MeV 3.520 2.338 -34%
Mean centroid dr / mm 1.589 1.292 -19%
95th-pct dr / mm 3.606 2.294 -36%

In the 95-110 MeV signal region (47,279 clusters), mean abs(dE)
drops from 0.368 (BFS) to 0.210 (-43%) and mean dr drops from
0.559 mm to 0.460 mm (-18%).

Reproducibility

After source setup_env.sh:

# 1. Build per-disk graphs from EventNtuple ROOT files (~10 min CPU).
bash scripts/build_all_graphs.sh                                                                                                                            
python3 scripts/pack_graphs.py
                                                                                                                                                            
# 2. Train (CCN; production model).                    
python3 scripts/train_gnn.py \                                                                                                                              
    --config configs/calo_cluster_net.yaml \
    --device cuda --run-name calo_cluster_net_v2_stage1                                                                                                     
                                                                                                                                                            
# 3. Tune the edge threshold on val.                                                                                                                        
python3 scripts/tune_threshold.py \                                                                                                                         
    --config configs/calo_cluster_net.yaml \                                                                                                                
    --checkpoint outputs/runs/calo_cluster_net_v2_stage1/checkpoints/best_model.pt                                                                          
                                                                                                                                                            
# 4. Evaluate once on test.                            
OMP_NUM_THREADS=4 PYTHONUNBUFFERED=1 python3 -u scripts/evaluate_test.py                                                                                    
OMP_NUM_THREADS=4 PYTHONUNBUFFERED=1 python3 -u scripts/evaluate_cluster_physics.py                                                                         
                                                                                                                                                            
# 5. Export to ONNX for deployment.                                                                                                                         
python3 scripts/export_onnx.py --model ccn   # also --model sen                                                                                             
python3 scripts/export_norm_stats.py                                                                                                                        
python3 scripts/validate_onnx.py --model ccn                                                                                                                

Frozen hyperparameters and the exact recipe values live in configs/
and are documented in the README.

Coordinated PRs

  • Store calo SimParticle ancestor chain in SimInfo::ancestorSimIds EventNtuple#366 — adds calomcsim.ancestorSimIds to
    SimInfo. The v2 training data uses calo-entrant ancestor truth,
    which requires this branch. Link this PR into the EventNtuple PR
    once it has a number.
  • Add GNN calorimeter clustering (split-module design) Offline#1823 — the C++ inference modules
    (CaloHitGraphMaker, CaloClusterMakerGNN) under
    Offline/CaloCluster/. Loads the .onnx exported by this repo
    via art::ConfigFileLookupPolicy, asserts metadata_props
    agreement (model_version, node_features, edge_features)
    against FHiCL, and emits CaloClusterCollection under instance
    name "GNN" so existing BFS-reading analyses keep working.
    C++↔Python parity has been validated byte-exactly on the val
    split (100/100 disk-graphs, 8,502 hits) using a parity-dump
    analyzer + Python comparison harness.

Tests

$ source setup_env.sh
$ python3 -m unittest discover -s tests -p "test_*.py" -v
...                                                                                                                                                         
Ran 110 tests in 0.16s
OK (skipped=4)                                                                                                                                              

The 4 skipped tests are conditional — they exercise loading a real
trained checkpoint or the exported .onnx and self-skip with a clear
message when those files aren't in the local checkout (the case for a
fresh clone).

Acknowledgement

Implementation, refactoring, and documentation drafting in this
subdirectory were assisted by Anthropic's Claude (Claude Code). All
scientific decisions, hyperparameter choices, validation results, and
the v1→v2 truth-definition campaign are my own work.

zwl0331 added 3 commits May 4, 2026 15:13
Adds a CaloClusterGNN/ subdirectory containing the full training
pipeline for the GNN calorimeter-clustering algorithm intended as a
parallel to the existing seed+BFS CaloClusterMaker in Mu2e Offline.
The deployed recipe is "CCN+BFS10": CaloClusterNet edge classifier +
BFS-style traversal with ExpandCut = 10 MeV.

Layout (modelled on TrkQual/, but Python-package shaped):

  CaloClusterGNN/
    README.md                   how to retrain, frozen hyperparams,
                                deployment cross-link
    setup_env.sh                wraps setupmu2e-art.sh + ana 2.6.1
    src/
      data/                     graph builder, calo-entrant truth labels,
                                normalisation, packed dataset
      geometry/                 crystalId -> (x, y, disk) loader
      models/                   SimpleEdgeNet, CaloClusterNet, layers, heads
      training/                 losses, metrics, trainer
      inference/                cluster reconstruction (cluster_reco.py),
                                postprocess (kept here so train-time eval
                                scripts work end to end)
    scripts/                    build/pack/train/tune/evaluate pipeline,
                                failure audits, cluster-physics eval, ancestry
                                validation, run1B no-field eval, plotting
    configs/                    five YAML configs (one per training run)
    tests/                      88 unit tests covering all of src/ above
    splits/                     frozen 35/7/8 v2 split file lists
    data/                       crystal_geometry.csv + crystal_neighbors.csv +
                                crystal_map_raw.csv (small lookup tables)

What does NOT live here:
* The deployment-side ONNX export / parity scripts (export_onnx.py,
  export_norm_stats.py, validate_onnx.py, dump_parity_payloads.py,
  compare_parity_dump.py) and the deploy wrappers
  (calo_cluster_net_deploy.py, simple_edge_net_deploy.py) -- those
  belong with the Mu2e/Offline integration PR, not the training repo.
* The `.onnx` artifacts themselves (shipped via Mu2e data area, not
  versioned in MLTrain -- same convention TrkQual follows).
* Large run outputs and processed graphs (regenerable from
  EventNtuple ROOT files via scripts/build_all_graphs.sh).

The v2 training data requires the `calomcsim.ancestorSimIds` branch
added in Mu2e/EventNtuple (PR pending). README cross-links there
once the EventNtuple PR has a number.

Test suite: 88/88 passing in this layout via
`python3 -m unittest discover -s tests -p "test_*.py" -v` after
`source setup_env.sh`.
Both trained models in CaloClusterGNN/ now have a complete training-
to-ONNX path inside MLTrain (consistent with the TrkQual pattern of
shipping conversion scripts alongside training).

New / restored:
* src/models/calo_cluster_net_deploy.py   tensor-API wrapper around
  CaloClusterNet (no PyG Data, no node-saliency head); used by ONNX
  export so torch.onnx.export can trace it.
* src/models/simple_edge_net_deploy.py    same shape for SimpleEdgeNet.
  No node head to bypass, so it's a thin pass-through.
* scripts/export_onnx.py                  --model {ccn,sen} flag with
  per-model presets (checkpoint, output path, model_version). Stamps
  metadata_props {model_version, node_features, edge_features} into
  the .onnx after export.
* scripts/export_norm_stats.py            writes the train-split z-score
  stats next to the .onnx as a flat JSON sidecar so the C++ side
  doesn't need a LibTorch dep to read 28 floats.
* scripts/validate_onnx.py                --model flag with per-model
  preset for tau_edge and tolerance. Asserts:
    - max abs-diff edge_logits within tol on the full val split
    - zero per-edge threshold flips at tau_edge (proxy for cluster-
      reco byte-equivalence with the deployed C++ pipeline)
* tests/test_calo_cluster_net_deploy.py   (9 tests)
* tests/test_export_onnx.py               (5 tests)
* tests/test_export_norm_stats.py         (8 tests)

README extended with an "Exporting a Trained Model to ONNX" section
that documents the full chain for both models, the
metadata_props deployment contract, and the per-model frozen
tau_edge/tol values used by validate_onnx.py.

Test count goes from 88 to 110 (4 conditionally skipped on a fresh
checkout when no trained checkpoint is present locally; this is by
design and the skip messages name the missing file).

Also acknowledges Claude assistance in README.
@oksuzian
Copy link
Copy Markdown
Contributor

PR #7 Review: Add CaloClusterGNN training pipeline

Scope: +18,029 lines — full training pipeline added under CaloClusterGNN/. This is a large new subdirectory (data builder, two GNN models, training loop, ONNX export, evaluation, tests). Below are the concrete issues I found while reading the diff.


🔴 Bugs / Correctness Issues

1. Hardcoded absolute path in committed shell script

CaloClusterGNN/scripts/build_all_graphs.sh starts with:

cd /exp/mu2e/app/users/wzhou2/projects/calorimeter/GNN
source setup_env.sh
ROOT_DIR=/exp/mu2e/data/users/wzhou2/GNN/root_files_v2

Anyone else who clones the repo and follows the README's instructions (bash scripts/build_all_graphs.sh) will cd into a user-specific path that doesn't exist. Same issue with ROOT_DIR. Both should be configurable (env var or CLI args), or at minimum default to $(dirname "$0")/.. and a documented data location.

The same hardcoded /exp/mu2e/data/users/wzhou2/... paths leak into:

  • splits/{train,val,test}_files.txt (50 lines of wzhou2 paths committed)
  • scripts/build_graphs.py line ~39: pattern = "/exp/mu2e/data/users/wzhou2/GNN/root_files_v2/*.root" for the --split all branch
  • scripts/evaluate_cluster_physics.py default --root-dir
  • scripts/analysis_for_sophie.py (no CLI override at all)

These should at least be parameterized; ideally the split files should hold filenames only and the root directory should be a config/CLI option.

2. extract_events_from_file references undefined sys and dead has_pos branch unreliable

In src/data/dataset.py:

has_pos = "calohits.crystalPos_.fCoordinates.fX" in arrays.fields
if has_pos:
    xs = np.array(arrays["calohits.crystalPos_.fCoordinates.fX"][ev], ...)
    ys = np.array(arrays["calohits.crystalPos_.fCoordinates.fY"][ev], ...)

But that branch is unreachable: the branch is unconditionally added to _BRANCHES and passed to tree.arrays(_BRANCHES, ...) — if the file doesn't have it, uproot will raise before this check runs. Either drop the optional handling, or remove the position branches from _BRANCHES and try-load them.

3. _cap_degree is asymmetric — produces a directed, non-symmetric edge set

build_graph returns symmetric pairs via query_pairs (adds both directions), but _cap_degree keeps the k_max nearest per source node only:

for node in range(n):
    node_mask = src == node
    if node_mask.sum() <= k_max:
        continue
    ...
    drop = node_indices[sorted_idx[k_max:]]
    keep[drop] = False

If node i has high degree and gets its outgoing edge to j pruned, the reverse edge j→i may survive. The resulting edge_index no longer represents an undirected graph (which the rest of the code, symmetrize_edge_scores, MC truth labeling, and message passing all assume). Suggest pruning by undirected pair and re-emitting both directions, or doing the kNN cap once on the undirected pair set.

4. _deduplicate can crash on empty dst.max()

combined = edge_index[0] * (edge_index[1].max() + 1) + edge_index[1]

build_graph guards if not src_list: before calling, so it's safe today, but more importantly: edge_index[1].max() + 1 is (max_node + 1), but src can have values up to n-1 even if dst.max() < n-1. The encoding src * (max_dst + 1) + dst will collide whenever src ≥ max_dst + 1. Use n_nodes (or max(src.max(), dst.max()) + 1) as the multiplier, not dst.max() + 1.

5. Cluster time definition is inconsistent across files

  • src/inference/postprocess.py defines cluster time as energy-weighted mean: ct = np.dot(w, t)
  • scripts/evaluate_cluster_physics.py defines it as seed-hit time: seed_time = float(t[np.argmax(e)]) with the comment "Offline convention"

Both are reasonable, but the evaluation script and the deployed postprocess MUST agree, or the residual numbers in the README are not comparable to what the Offline C++ side will see. Pick one and document it.

6. reconstruct_clusters cleanup loops break under bfs_expand_cut

With bfs_expand_cut, _bfs_expand_cut is called before min_hits / min_energy cleanup, and every node gets a label ≥ 0. The cleanup loop then sets cluster_labels[mask] = -1 for clusters failing the cuts — fine. But the docstring says BFS "preserves completeness (all hits join clusters)", which conflicts with this cleanup pass. Either skip cleanup when bfs_expand_cut is set, or update the docstring.

7. node_saliency_metrics interprets labels inconsistently with the rest of the code

targets = (y_node >= 0).cpu().numpy().astype(int)

But y_node is built in dataset.py as 0/1 (singleton/multi-hit), never -1. So y_node >= 0 is always 1 → trivially everything is "signal", and the metric is meaningless. The metric must use (y_node == 1).

8. compute_class_weights ignores pos_weight=None callsite in trainer

Not visible in the diff snippet I have, but worth verifying: the pos_weight is described in the docstring as neg/pos, which upweights positives. The config for Stage 1 has lambda_edge=1.0, neg_pos_ratio=5. If the trainer applies pos_weight = neg/pos AND subsamples negatives 5:1, the positive class will be heavily overweighted. Please confirm the trainer doesn't double-count.


🟡 Maintainability / Polish

9. Two truth-label modules with overlapping names

Both src/data/truth_labels.py and src/data/truth_labels_primary.py are added, but dataset.py only imports truth_labels_primary. If truth_labels.py (the legacy version) is unused, remove it. If it's kept for a v1 reproduction path, document that in its module docstring.

10. CSV files committed with CRLF line endings

crystal_geometry.csv and crystal_neighbors.csv have \r\n line endings (visible as \r\n in the patch). The README and Python code parse them fine, but mixing CRLF into a repo that's otherwise LF will trigger spurious diffs. Add a .gitattributes entry (*.csv text eol=lf) or convert in place.

11. crystal_map_raw.csv is 2741 lines and only "source of truth"

The geometry/README.md says crystal_geometry.csv and crystal_neighbors.csv are derived from it, but the derivation script isn't in this PR. Either include the derivation script or drop the raw file (it's never read by any code in the PR).

12. analysis_for_sophie.py is checkpoint-dependent and has no CLI

The script hardcodes:

"outputs/runs/simple_edge_net_v2/checkpoints/best_model.pt"
"outputs/runs/calo_cluster_net_v2_stage1/checkpoints/best_model.pt"

…and data/processed/test.pt. It will fail-on-import for anyone without those files. Either make paths CLI-configurable or gate behind argparse with clear errors. The filename also bakes in a colleague's name — consider renaming to e.g. analysis_pileup_comparison.py.

13. README references missing scripts/docs

The README mentions:

  • scripts/make_splits.py (not in PR)
  • scripts/pack_graphs.py (not in PR)
  • scripts/tune_threshold.py (not in PR)
  • scripts/evaluate_test.py (not in PR)
  • scripts/failure_audit.py (not in PR)
  • scripts/validate_onnx.py (not in PR)
  • scripts/smoke_test_env.py (not in PR)
  • scripts/export_norm_stats.py (referenced by tests, but file itself not in this PR's file list)
  • scripts/export_onnx.py (referenced by tests, same)
  • setup_env.sh (not in PR)
  • docs/onnx_deployment.md (not in PR)

These create dangling links in docs and a non-functional pipeline end-to-end. Either include them in this PR or mark them clearly as TODO/future-work.

14. Tests import scripts.export_onnx / scripts.export_norm_stats that aren't in the PR

tests/test_export_onnx.py and tests/test_export_norm_stats.py will fail collection with ImportError. CI will go red on day one. Either land the script files in this PR or remove the tests until they do.

15. pack_graphs.py referenced but the packed file loader exists

dataset.py has logic for packed_path (train.pt / val.pt / test.pt) but the script that produces these isn't in the PR. Add it or document how to produce them.

16. README "Versions and Provenance" table has TBD for both commits

Once this PR merges, the commit SHAs are known — fill them in before / after merge so the table is useful.

17. min_energy_mev=10.0 in default.yaml but 0.0 in analysis_for_sophie.py

Different scripts pass different defaults for the same physics cut. Centralize this (read from config) so the analysis matches deployment.

18. Acknowledgements section discloses Claude usage

This is fine and laudable — just flagging that the Mu2e collaboration may want a policy statement. Not a code issue.


🟢 Things I like

  • Clear module-level docstrings on every file.
  • Sensible separation: data/, geometry/, models/, inference/, training/.
  • Deploy wrappers (*_deploy.py) explicitly designed for ONNX tracing — good idea.
  • Unit tests for graph builder cover empty / single-hit / time-filter / distance-filter edge cases.
  • Class-balanced loss with edge mask is properly implemented (modulo the pos_weight question above).
  • README is unusually thorough; documents the deployed recipe (CCN+BFS10 at 10 MeV) and the physics metrics with concrete numbers.

Suggested blockers before merge

  1. Fix CI: either add the missing scripts/export_onnx.py and scripts/export_norm_stats.py, or delete the tests that import them.
  2. De-hardcode /exp/mu2e/.../wzhou2/... from build_all_graphs.sh, build_graphs.py --split all, evaluate_cluster_physics.py, and analysis_for_sophie.py.
  3. Fix _cap_degree symmetry bug (item 3) — this affects training data and is silent.
  4. Fix node_saliency_metrics label interpretation (item 7) — the metric is currently uninformative.
  5. Resolve the time-definition mismatch (item 5) so quoted physics residuals are reproducible against the C++ side.

Happy to dig deeper into any specific file if useful.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants