Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 43 additions & 3 deletions contrastive-pretraining/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -527,12 +527,23 @@ Column names must match the pathology labels in `--pathologies_file`. The pre-co

A supervised baseline that freezes the pretrained MR-RATE encoder and trains only a linear classifier on top of its pooled visual features. This is the standard probe used to measure how well the contrastive representation captures the downstream pathology labels, separate from the zero-shot prompt quality.

The probe is two scripts run in order:
The probe is two (optionally three) scripts run in order:

1. **`extract_features.py`** — encode every subject once with the frozen encoder and dump features to disk. This is the slow step; it runs the same model + preprocessing path as `inference.py`. Run it once per split.
2. **`linear_probe.py`** — train `nn.Linear(dim_latent, num_classes)` with `BCEWithLogitsLoss` on the cached features, pick the best epoch by val mean-AUROC, report test metrics through `eval.evaluate_internal` (identical AUROC pipeline to inference).
3. **`relabel_features.py`** *(optional)* — swap in a different label set on top of already-cached features **without re-encoding**. See [Reusing features for a different label set](#reusing-features-for-a-different-label-set-no-re-extract).

Why precompute features: the 3D encoder is the expensive part. Once it is frozen, every training epoch sees the same features, so running the encoder once and then training the linear head on cached `.npy` files is orders of magnitude faster than re-encoding each epoch — the standard CLIP / SimCLR / DINO linear-probe recipe.
Why precompute features: the 3D encoder is the expensive part. Once it is frozen, every training epoch sees the same features, so running the encoder once and then training the linear head on cached `.npy` files is orders of magnitude faster than re-encoding each epoch — the standard CLIP / SimCLR / DINO linear-probe recipe. The corollary: **the labels are not part of the features**, so changing the label set never requires re-extraction — only the cheap `linear_probe.py` step is repeated (see step 3).

### Label sets

The class count is **derived from the labels CSV at runtime** — `extract_features.py` sets `num_classes = len(label_columns)` and writes `label_names.json` from the CSV header, and `linear_probe.py` sizes `nn.Linear(dim, n_classes)` from that. Nothing is hardcoded, so any labels CSV (`study_uid` + binary class columns) is a drop-in `--labels_file` with no code change. The ready-made set lives under `scripts/eval_labels/`:

| Folder | Classes | Ground-truth rule |
|--------|--------:|-------------------|
| `splits_merged_majority/` (`mrrate_merged_labels.csv`) | 14 | 3-model majority (Claude Opus 4.7 + GPT-5.5 + Nemotron-3 Super 120B; positive when ≥2 of the available votes agree), then collapsed into the neuroradiologist's clinical groups (8 pathophysiology `PP_*` + 6 imaging-phenotype `BP_*`) |

It pairs a labels CSV (`study_uid` + 14 binary class columns) with a `splits.csv` (`study_uid,split`). Reproduce it with `scripts/eval_labels/build_merged_group_labels.py --source majority` (the script also supports `--source {raw,csv32}` for 2-model agreement variants if you need them).

### Step 1 — Cache features

Expand Down Expand Up @@ -590,9 +601,38 @@ Outputs in `--results_dir`:

The best epoch by val mean-AUROC is restored before test evaluation. Single-class columns (all 0 or all 1 in the test split) are gracefully reported as `NaN` and excluded from the macro mean.

### Reusing features for a different label set (no re-extract)

Because labels are not baked into the encoder features, you only ever run `extract_features.py` **once**. To probe a different label CSV — a different ground-truth rule, or a different grouping — relabel the cached features instead of re-encoding:

```bash
# (once) cache features
for SPLIT in train val test; do
python scripts/extract_features.py \
--encoder vjepa2 --fusion_mode late --pooling_strategy simple_attn \
--weights_path ./mr_rate_results/MrRate.5000.pt \
--data_folder /path/to/data --jsonl_file /path/to/reports.jsonl \
--labels_file scripts/eval_labels/splits_merged_majority/mrrate_merged_labels.csv \
--splits_csv scripts/eval_labels/splits_merged_majority/splits.csv \
--split $SPLIT --normalizer zscore --out_dir ./lp_features
done

# train the head on those features
python scripts/linear_probe.py --features_dir ./lp_features --results_dir ./lp_results

# later: probe ANY other labels CSV on the SAME features — instant, no encoder pass
python scripts/relabel_features.py \
--features_dir ./lp_features \
--labels_file /path/to/other_labels.csv \
--out_dir ./lp_features_other
python scripts/linear_probe.py --features_dir ./lp_features_other --results_dir ./lp_results_other
```

`relabel_features.py` symlinks `features_<split>.npy` + `subject_ids_<split>.txt` from `--features_dir` and rebuilds only `labels_<split>.npy` (in the exact subject order of the source) and `label_names.json` from `--labels_file`. If a cached subject is missing from the new labels CSV it **errors out** rather than silently misaligning rows. Use `--copy` to copy the feature files instead of symlinking (e.g. to move the dir to another machine).

### Notes

- Both scripts are single-process. The linear head is `Linear(512, 32) ≈ 16K params` over ~180 MB of cached features DDP overhead would dominate. For feature extraction, shard externally by running one job per split, or by splitting `splits.csv` into chunks and concatenating the resulting `.npy` files.
- Both scripts are single-process. The linear head is tiny — `Linear(512, num_classes)`, e.g. ~16K params for 32 classes or ~7K for the 14 merged groups — over ~180 MB of cached features, so DDP overhead would dominate. For feature extraction, shard externally by running one job per split, or by splitting `splits.csv` into chunks and concatenating the resulting `.npy` files.
- `--encoder`, `--fusion_mode`, `--pooling_strategy`, and `--dim_latent` must match the values used when the checkpoint was trained; if they don't, `_load_and_verify` will surface it loudly instead of silently loading garbage.

## Testing
Expand Down
7 changes: 7 additions & 0 deletions contrastive-pretraining/scripts/eval_labels/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Large derived / convenience CSVs — the linear probe only consumes
# mrrate_*labels.csv + splits.csv. These are regenerable with
# build_merged_group_labels.py, so they are kept out of git to avoid bloat.
**/all.csv
**/train.csv
**/val.csv
**/test.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,296 @@
"""Build the neuroradiologist's merged-group labels in a SINGLE combined CSV,
recomputing the Opus(Claude) AND GPT(OpenAI) agreement from the RAW prediction
JSONs over ALL 37 pathologies (so the 5 pathologies dropped in
build_agreement_splits.py are recovered, not ignored).

Two grouping schemes (from Bene, neuroradiology):
1) Pathophysiologie -> 8 groups, columns prefixed "PP_"
2) Bildphaenotyp -> 6 groups, columns prefixed "BP_"

A study is positive for a group iff it is positive for ANY member pathology
(logical OR). Per-pathology agreement uses the same strict AND rule as
build_agreement_splits.py:
both 1 -> 1 | any 0 -> 0 | one side missing -> present side | both missing -> 0(blank)

The two pathologies ungrouped in BOTH schemes (Empty sella syndrome,
Hyperostosis of skull) are intentionally excluded.

Inputs (same sources as build_agreement_splits.py, but NOTHING is dropped):
eval_set_predictions_5k.json -> claude_labels + gpt_labels
../../../remaining_eval/eval_set_predictions_chunk_0{0..4}.json
-> nvidia_opus47_labels + nvidia_gpt55_labels
splits_hf.csv -> study_uid -> split

Outputs (under splits_merged/):
mrrate_merged_labels.csv -- study_uid + 14 binary group cols
pathologies.json -- same 14 names, ready for --pathologies_file
group_definitions.json -- group -> members (all present now)
splits.csv -- batch_id,patient_uid,study_uid,split (kept UIDs only)
"""
from __future__ import annotations

import argparse
import csv
import json
from pathlib import Path

HERE = Path(__file__).parent
PRED_5K = HERE / "eval_set_predictions_5k.json"
CHUNK_DIR = Path("/hnvme/workspace/b180dc51-sezgin/MR-RATE/remaining_eval")
# 92k chunks augmented with a 3rd model (Nemotron) for the "majority" mode.
CHUNK_DIR_NEM = Path("/hnvme/workspace/b180dc51-sezgin/remaining_eval_with_nemotron-2")
SPLITS_CSV = HERE / "splits_hf.csv"
# 32-column agreement labels (5 pathologies dropped) for the legacy "csv32" mode.
CSV32_LABELS = HERE / "splits_agreement" / "mrrate_labels.csv"
CSV32_SPLITS = HERE / "splits_agreement" / "splits.csv"

PP_GROUPS = [
("PP_Cerebrovascular", ["Cerebral infarction", "Cerebral hemorrhage", "Lacunar infarct",
"Silent micro-hemorrhage of brain", "Cavernous hemangioma",
"Subdural intracranial hemorrhage", "Intracranial aneurysm", "Watershed infarct"]),
("PP_Neoplastic", ["Metastatic malignant neoplasm to brain", "Intracranial meningioma",
"Glioma", "Pituitary adenoma", "Schwannoma"]),
("PP_Neurodegenerative", ["Cerebral atrophy", "Ventriculomegaly", "Cerebellar degeneration"]),
("PP_Spinal", ["Herniation of nucleus pulposus", "Spinal cord compression",
"Foraminal Spinal Stenosis", "Spinal stenosis", "Hemangioma of vertebral column"]),
("PP_Cystic_developmental", ["Arachnoid cyst", "Cyst of pineal gland",
"Structure of cave of septum pellucidum", "Mega cisterna magna", "Chiari malformation",
"Rathke's pouch cyst", "Choroid plexus cyst", "Lipoma of brain"]),
("PP_Infectious", ["Mastoiditis", "Chronic mastoiditis"]),
("PP_Inflammatory", ["Demyelinating disease of central nervous system"]),
("PP_Unspecific_bucket", ["Gliosis", "Cerebral edema", "Encephalomalacia"]),
]

BP_GROUPS = [
("BP_Atrophies", ["Cerebral atrophy", "Ventriculomegaly", "Cerebellar degeneration"]),
("BP_Contrast_enhancing_intracranial", ["Metastatic malignant neoplasm to brain",
"Intracranial meningioma", "Glioma", "Pituitary adenoma", "Schwannoma"]),
("BP_Infectious_lesions", ["Mastoiditis", "Chronic mastoiditis"]),
("BP_Edematous_lesions", ["Cerebral infarction", "Lacunar infarct", "Watershed infarct",
"Demyelinating disease of central nervous system"]),
("BP_Hemorrhagic_lesions", ["Cerebral hemorrhage", "Silent micro-hemorrhage of brain",
"Cavernous hemangioma", "Encephalomalacia"]),
("BP_Cystic_lesions", ["Arachnoid cyst", "Cyst of pineal gland",
"Structure of cave of septum pellucidum", "Mega cisterna magna",
"Rathke's pouch cyst", "Choroid plexus cyst"]),
]

ALL_GROUPS = PP_GROUPS + BP_GROUPS


def agreement(a: dict | None, b: dict | None, p: str) -> int:
"""Strict AND per pathology. Returns 1 (agree positive) else 0.

both1 -> 1 | any0 -> 0 | one missing -> present side | both missing -> 0.
"""
a = a or {}
b = b or {}
x, y = a.get(p), b.get(p)
if x is None and y is None:
return 0
if x is None:
return 1 if y == 1 else 0
if y is None:
return 1 if x == 1 else 0
return 1 if (x == 1 and y == 1) else 0


def load_split_map(path: Path) -> dict[str, dict]:
m: dict[str, dict] = {}
with path.open() as f:
for r in csv.DictReader(f):
m[r["study_uid"]] = {
"split": r["split"],
"batch_id": r.get("batch_id", ""),
"patient_uid": r.get("patient_uid", ""),
}
return m


def write_outputs(out_dir: Path, rows: list[tuple[str, list[int]]], split_map: dict) -> None:
"""rows = [(study_uid, [group values...])]; writes the 4 artifacts."""
out_dir.mkdir(exist_ok=True)
out_cols = [name for name, _ in ALL_GROUPS]
counts = {c: 0 for c in out_cols}
kept_uids: set[str] = set()
n_split = {"train": 0, "val": 0, "test": 0}

out_path = out_dir / "mrrate_merged_labels.csv"
per_split: dict[str, list[list]] = {"train": [], "val": [], "test": []}
with open(out_path, "w", newline="") as f:
w = csv.writer(f)
w.writerow(["study_uid"] + out_cols)
for uid, vals in rows:
w.writerow([uid] + vals)
kept_uids.add(uid)
split = split_map[uid]["split"]
n_split[split] += 1
per_split[split].append([uid] + vals)
for name, v in zip(out_cols, vals):
counts[name] += v

# Per-split CSVs (train.csv / val.csv / test.csv), same columns as the
# combined labels file.
for split, srows in per_split.items():
fname = "val.csv" if split == "val" else f"{split}.csv"
with open(out_dir / fname, "w", newline="") as f:
w = csv.writer(f)
w.writerow(["study_uid"] + out_cols)
w.writerows(srows)

paths_json = {}
for name, _ in ALL_GROUPS:
label = name.replace("PP_", "").replace("BP_", "").replace("_", " ")
paths_json[name] = {"positive": f"There is {label}", "negative": f"There is no {label}"}
(out_dir / "pathologies.json").write_text(json.dumps({"pathologies": paths_json}, indent=2))
(out_dir / "group_definitions.json").write_text(
json.dumps({name: ms for name, ms in ALL_GROUPS}, indent=2)
)
with open(out_dir / "splits.csv", "w", newline="") as f:
w = csv.writer(f)
w.writerow(["batch_id", "patient_uid", "study_uid", "split"])
for uid, meta in split_map.items():
if uid in kept_uids:
w.writerow([meta["batch_id"], meta["patient_uid"], uid, meta["split"]])

n = len(kept_uids)
print(f"\nWrote {out_path} ({n:,} studies, {len(out_cols)} label columns)")
print(f" kept per split: {n_split}")
print(f"{'column':40s} {'pos':>7s} {'%':>6s}")
for name, _ in ALL_GROUPS:
print(f"{name:40s} {counts[name]:7d} {100*counts[name]/n:5.1f}")


def build_from_csv32(out_dir: Path) -> None:
"""Legacy mode: OR-collapse the existing 32-column agreement labels.
The 5 dropped pathologies are absent and contribute nothing."""
split_map = load_split_map(CSV32_SPLITS)
with open(CSV32_LABELS) as f:
reader = csv.DictReader(f)
present = set(reader.fieldnames) - {"study_uid"}
rows = []
for row in reader:
uid = row["study_uid"]
if uid not in split_map:
continue
vals = []
for _, ms in ALL_GROUPS:
used = [m for m in ms if m in present]
vals.append(1 if any(float(row[m]) > 0 for m in used) else 0)
rows.append((uid, vals))
write_outputs(out_dir, rows, split_map)


def build_from_raw(out_dir: Path) -> None:
split_map = load_split_map(SPLITS_CSV)

pred_5k = json.loads(PRED_5K.read_text())
paths_all = pred_5k["pathologies"] # 37, in canonical order

# Validate every group member is a real pathology.
members = {m for _, ms in ALL_GROUPS for m in ms}
unknown = members - set(paths_all)
if unknown:
raise RuntimeError(f"Unknown group members: {sorted(unknown)}")

# study_uid -> (claude/opus labels, gpt labels). 5k wins over chunks.
pred: dict[str, tuple[dict, dict]] = {}
for it in pred_5k["items"]:
pred[it["study_uid"]] = (it.get("claude_labels") or {}, it.get("gpt_labels") or {})
print(f"loaded 5k preds: {len(pred):,}")

for i in range(5):
chunk = json.loads((CHUNK_DIR / f"eval_set_predictions_chunk_0{i}.json").read_text())
assert chunk["pathologies"] == paths_all, f"pathology mismatch in chunk {i}"
for it in chunk["items"]:
uid = it["study_uid"]
if uid in pred:
continue
pred[uid] = (it.get("nvidia_opus47_labels") or {}, it.get("nvidia_gpt55_labels") or {})
print(f"loaded chunk_0{i}: cumulative {len(pred):,}")

rows = []
no_pred = {"train": 0, "val": 0, "test": 0}
for uid, meta in split_map.items():
if uid not in pred:
no_pred[meta["split"]] += 1
continue
a, b = pred[uid]
agree = {p: agreement(a, b, p) for p in members}
vals = [1 if any(agree[m] for m in ms) else 0 for _, ms in ALL_GROUPS]
rows.append((uid, vals))
print(f"no-pred skipped: {no_pred}")
write_outputs(out_dir, rows, split_map)


def majority(votes: list, p: str) -> int:
"""Strict majority of PRESENT votes for pathology p.
pos*2 > present -> 1. 3 present -> need 2; 2 present -> need both; 1 -> need it."""
present = [d.get(p) for d in votes if (d or {}).get(p) is not None]
if not present:
return 0
pos = sum(1 for v in present if v == 1)
return 1 if pos * 2 > len(present) else 0


def build_majority(out_dir: Path) -> None:
"""3-way majority of Claude(Opus) + GPT + Nemotron, then OR into groups.
5k uses claude/gpt/nemotron; 92k uses the nemotron-augmented chunks."""
split_map = load_split_map(SPLITS_CSV)
pred_5k = json.loads(PRED_5K.read_text())
paths_all = pred_5k["pathologies"]
members = {m for _, ms in ALL_GROUPS for m in ms}

# study_uid -> (claude/opus, gpt, nemotron). 5k wins over chunks.
pred: dict[str, tuple[dict, dict, dict]] = {}
n_models = {0: 0, 1: 0, 2: 0, 3: 0}
for it in pred_5k["items"]:
pred[it["study_uid"]] = (
it.get("claude_labels") or {},
it.get("gpt_labels") or {},
it.get("nemotron_labels") or {},
)
print(f"loaded 5k preds: {len(pred):,}")

for i in range(5):
chunk = json.loads((CHUNK_DIR_NEM / f"eval_set_predictions_chunk_0{i}.json").read_text())
assert chunk["pathologies"] == paths_all, f"pathology mismatch in chunk {i}"
for it in chunk["items"]:
uid = it["study_uid"]
if uid in pred:
continue
pred[uid] = (
it.get("nvidia_opus47_labels") or {},
it.get("nvidia_gpt55_labels") or {},
it.get("nvidia_nemotron3_super_v3_labels") or {},
)
print(f"loaded chunk_0{i}: cumulative {len(pred):,}")

rows = []
no_pred = {"train": 0, "val": 0, "test": 0}
for uid, meta in split_map.items():
if uid not in pred:
no_pred[meta["split"]] += 1
continue
votes = pred[uid]
n_models[sum(1 for d in votes if d)] += 1
maj = {p: majority(votes, p) for p in members}
vals = [1 if any(maj[m] for m in ms) else 0 for _, ms in ALL_GROUPS]
rows.append((uid, vals))
print(f"no-pred skipped: {no_pred}")
print(f"models available per study: {n_models} (3 = full 3-way majority)")
write_outputs(out_dir, rows, split_map)


if __name__ == "__main__":
ap = argparse.ArgumentParser(description=__doc__)
ap.add_argument("--source", choices=["raw", "csv32", "majority"], default="raw",
help="raw = Claude AND GPT agreement over all 37 pathologies (default); "
"csv32 = collapse the existing 32-col labels (5 paths stay dropped); "
"majority = 3-way majority of Claude + GPT + Nemotron.")
ap.add_argument("--out-dir", default=None, help="output folder (default per source).")
a = ap.parse_args()
defaults = {"raw": "splits_merged", "csv32": "splits_merged_32col",
"majority": "splits_merged_majority"}
out = Path(a.out_dir) if a.out_dir else HERE / defaults[a.source]
{"raw": build_from_raw, "csv32": build_from_csv32, "majority": build_majority}[a.source](out)
Loading
Loading