Skip to content

Some similar tests (as per SEQTaRget) and a few optimizations#51

Merged
remlapmot merged 26 commits into
mainfrom
devel-2026-06-04
Jun 9, 2026
Merged

Some similar tests (as per SEQTaRget) and a few optimizations#51
remlapmot merged 26 commits into
mainfrom
devel-2026-06-04

Conversation

@remlapmot

@remlapmot remlapmot commented Jun 8, 2026

Copy link
Copy Markdown
Contributor

This,

  • Bump a worflow in CI
  • Warn when numerator and denominator weight models use identical covariates
  • Adds real tests for several options (like just implemented for SEQTaRget)
  • Adds an additional attempt to fix the readthedocs navbar height
  • Adds warm starts to the glum bootstrap fits
  • Cache the patsy design_info across bootstrap outcome fits
  • Use integer IDs through the bootstrap path
  • Skip the pl.from_pandas round-trip in the weighted fit block
  • Use a fixed default seed when none is supplied

These have taken the render time of our short course practical from about 7:30 mins to approx 6:00 mins (which is still slower than the R version - just over 4 mins).

@remlapmot remlapmot requested a review from ryan-odea June 8, 2026 12:24

@ryan-odea ryan-odea left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One stylistic comment - we might want to control how the class is validated rather than popping everything into __init__. Dataclasses has a __post_init__ method natively. Perhaps we can integrate something similar into our primary class.

Comment thread pySEQTarget/helpers/_bootstrap.py
Comment thread pySEQTarget/SEQuential.py Outdated
@remlapmot

Copy link
Copy Markdown
Contributor Author

Thanks - will come back to this tomorrow.

Just added another small thing to make this more like R - R resets its seed after a run in none is specified - have added that to this branch.

remlapmot and others added 24 commits June 9, 2026 10:18
…iates

When weighted is True and method is not ITT, warn if the resolved numerator and denominator formulas are identical. In that case the stabilized weights all equal 1 (i.e. no weighting), which is almost always a typo in the denominator - the denominator should usually include the time-varying confounders that the numerator omits.

The check runs in SEQuential.__init__ after the numerator/denominator defaults are filled in by _numerator()/_denominator(), so it catches both user-supplied identical strings and the (rare) case where the defaults coincide. It is gated on method != "ITT" because under ITT no treatment-weight models are fit, so the warning would be misleading there.

Add UserWarning tests for the positive case (identical -> warns), differing num/denom, ITT (no warning even if identical), and weighted=False (no warning).
Add tests asserting selection_random=True keeps all treated trial-starts (tx_init_bas == 1) and subsamples control trial-starts (tx_init_bas == 0) to int(selection_sample * N_controls), and that the same seed produces an identical expanded DT.
Add a test verifying that weight_min/weight_max actually truncate the weight vector reaching the outcome GLM (it's applied inside _outcome_fit and doesn't show up in self.DT or weight_stats, so it has to be checked via the fit). Two clamp bands entirely above the real weight range (SEQdata weights span ~0.5-2) collapse every weight to a constant, and a GLM is invariant to a uniform scaling of its weights, so two such fits must be identical while a genuinely varying-weight fit must differ.
Add a test verifying that weight_p99=True is equivalent to explicitly setting weight_min/weight_max to the p01/p99 values reported in weight_stats, and that it differs from an untruncated weighted fit. _weight_stats computes those percentiles from the unclipped weight column and then mutates self.weight_min/self.weight_max when weight_p99=True, so the explicit run must produce identical outcome-model coefficients.
Add a test verifying that followup_include and trial_include each add or drop their corresponding outcome-model terms (and their squares) in the fitted coefficient names. These flags control whether the follow-up/trial terms enter the outcome-model formula, so the effect is directly observable in the params index of the fitted outcome model.
Add a test verifying that followup_class=True encodes follow-up as a categorical covariate: the outcome model loses the linear followup and followup_sq coefficients and instead gains one patsy dummy 'followup[T.<n>]' per non-reference follow-up level, with the dummy count equal to n_unique(followup) - 1. Set followup_include=False since it is exclusive with followup_class.
Add a test verifying that weight_lag_condition=True (default) restricts each treatment arm's denominator weight model to its own tx_lag stratum (per-arm nobs differ and partition the full data), while weight_lag_condition=False fits both arms on the full data (equal nobs per arm). Uses statsmodels' .nobs on each per-arm denominator model.
Add a test verifying that followup_min/followup_max actually filter the expanded data to the requested [followup_min, followup_max] window. Use expand_only=True so the DT is returned without any fit step, making the clamp directly visible on the polars frame: the unrestricted expansion spans past [3, 10] while the restricted run is clamped to exactly that interval and has fewer rows.
Add a test verifying that weight_eligible_colnames actually subsets each treatment arm's weight model to rows where its named indicator column is 1 (_get_subset_for_level). Adds a balanced welig column (N > median(N)) to the input data, runs with and without weight_eligible_colnames=["welig", "welig"], and asserts that both arms' denominator-model nobs drop below the unfiltered baseline.
_outcome_fit already caches the main fit's coefficients in (m.params.values, list(m.model.exog_names)) and replays them as start_params on bootstrap replicates - but the glum branch was dropping the argument on the floor, so every bootstrap outcome fit was paying full cold-start IRLS even though glum's GeneralizedLinearRegressor accepts a start_params array of shape (n_features + 1,) with the intercept first - the same shape statsmodels' params has.

Pass start_params through from _outcome_fit to _fit_glum, and inside _fit_glum apply it as the regressor's init only when the cached (values, names) tuple aligns column-for-column with the patsy design matrix just built for this replicate. A bootstrap resample can drop a categorical level and shift the column structure, in which case the cached coefs are meaningless and using them as init would derail coordinate descent - that mirrors the existing guard on the statsmodels path. Cold-start behaviour is unchanged (start_params=None continues to leave glum at its default init).
_fit_glum now accepts an optional design_cache dict keyed by formula. On a hit the cached (y_design_info, X_design_info) are re-applied to the bootstrap replicate via patsy.build_design_matrices, skipping the formula parse and the model.frame-style rebuild that patsy.dmatrices does on the first call. On a miss patsy.dmatrices runs as before and the result is stored. Cold-start behaviour with design_cache=None is unchanged.

_outcome_fit initialises self._patsy_design_cache = {} on the main-fit pass and reuses the same dict for every bootstrap replicate, mirroring the lifecycle of _outcome_start_params. The same caching pattern is already used on the predict path in _survival_pred.py (_build_design_matrix(outcome_dinfo, data)); this brings the fit path in line.

A useful side effect: the cached dinfo freezes the categorical column structure to the main fit's columns, so a bootstrap resample that drops a categorical level still produces a design matrix with the same columns (the missing dummy column becomes all-zero rows). That makes the warm-start guard from the previous commit trivially satisfied for every replicate and lets glum use the cached coefs as init unconditionally.

Add tests for the glum design-info cache

- test_glum_design_cache_avoids_reparsing_on_bootstrap: spy on patsy.dmatrices and patsy.build_design_matrices, run an ITT bootstrap so only the outcome formula reaches _fit_glum, assert the outcome formula is parsed exactly once on the main fit and that build_design_matrices is called for every bootstrap replicate, and confirm the cache survives on self._patsy_design_cache
- test_glum_design_cache_matches_no_cache_outcome_coefs: run the bootstrap pipeline with the cache enabled (default) and with it forcibly disabled (by monkeypatching _fit_glum to drop design_cache), assert per-replicate outcome coefficients agree within glum's coordinate-descent tolerance (rel=1e-2, abs=1e-4)
expand() previously cast id_col to Utf8 on both self.DT and self.data, which made every downstream join, groupby and over() partition take the string-hashing path - 3-5x slower per op than int64 on polars. The cast was there solely to support the bootstrap-resample ID, which used string concatenation ("{orig_id}_{replicate}") so that each duplicated subject row could be distinguished and then later recovered via str.replace in _weight_bind.

Drop the Utf8 cast in expand() so the user's native ID dtype (Int64 for the bundled data) flows through. _prepare_boot_data now builds the resampled ID as orig_id * id_mult + replicate in Int64 arithmetic, where id_mult = max(id_counts) + 1 makes the (orig, rep) pair bijective into a single int. The multiplier is stashed on self._boot_id_mult so _weight_bind can recover the original ID via `id // id_mult` to join back to the un-resampled WDT. Non-integer ID columns (e.g. user-supplied string IDs) fall back to the original "{orig}_{rep}" string-concat path and the existing str.replace recovery, gated on dtype.

Result on the bundled SEQdata censoring + 20-bootstrap workload: bootstrap iteration rate ~14 -> ~20.6 it/s, fit time 1.52 -> 1.04s (1.46x). The win scales with frame size.

Tests verify (1) expand() preserves Int64 dtype on DT and data, (2) the integer-ID path produces resampled IDs whose (orig // mult, orig - rec * mult) decomposition is uniquely encoded, and (3) the string-ID fallback still works.
The weighted fit block in SEQuential.fit() was: WDT = WDT.to_pandas(); ...fits...; WDT = pl.from_pandas(WDT); _weight_predict(WDT). The weight-fit helpers (_fit_LTFU, _fit_visit, _fit_numerator, _fit_denominator) take a pandas frame because they use pandas-style indexing and pass the frame to glum/statsmodels, but they store the fitted models on `self` rather than mutating WDT - the pl.from_pandas() rebuild was reading data that hadn't changed.

Hold both WDT_pl (polars) and WDT_pd (pandas) instead. Pass WDT_pd to the fits, then drop it and pass the original WDT_pl to _weight_predict / _weight_bind. Eliminates one big polars-from-pandas conversion per replicate; pandas->polars is the slower direction so this is the larger half of the round-trip. The absolute saving scales with WDT size - small on SEQdata (a few tens of ms), more like 100-200ms per replicate on the short-course practical's expanded data.

Add a guard test that spies on _weight_predict and asserts it receives a polars.DataFrame, so a regression that re-introduced the round-trip would fail loudly.
The hazard ratio is estimated by g-formula Monte-Carlo simulation (rng.binomial in _hazard_handler, then a Cox fit), so it depends on the RNG. When no seed was given, __init__ set self._rng = np.random — the global, never-reseeded generator — so the hazard estimate changed
silently between otherwise identical runs, even with selection_random disabled and bootstrap_nboot = 0.

Mirror SEQTaRget (R), where SEQopts captures .Random.seed (fixed in a fresh process) so an unseeded run is deterministic across runs. Fall back to a fixed default seed (_DEFAULT_SEED = 0) instead of the global
np.random, record it on self.seed, and build self._rng from it. The `if self.seed is not None` reseed guards in _hazard.py and _bootstrap.py now always fire, making hazard and bootstrap estimates reproducible.

Add tests covering the unseeded path in tests/test_reproducibility.py:
two unseeded runs are deterministic, the seed is concrete and stable
before/after a run, and the recorded seed reproduces the hazard ratio.

@ryan-odea ryan-odea left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

@remlapmot

Copy link
Copy Markdown
Contributor Author

Brill - thanks Ryan.

Added tests on Linux (admittedly I forgot that the GHA Linux runners don't have GPUs - but I think JAX tests still run on CPU). Guarded the JAX tests to run just when it's installed.

@remlapmot remlapmot merged commit cfa35ef into main Jun 9, 2026
11 checks passed
@remlapmot remlapmot deleted the devel-2026-06-04 branch June 9, 2026 09:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants