Skip to content
Draft
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Imports:
checkmate,
matrixStats (>= 0.52),
parallel,
posterior (>= 1.5.0),
posterior (>= 1.7.0),
stats
Suggests:
bayesplot (>= 1.7.0),
Expand Down
63 changes: 17 additions & 46 deletions R/psis.R
Original file line number Diff line number Diff line change
Expand Up @@ -212,62 +212,33 @@ do_psis_i <- function(log_ratios_i, tail_len_i, ...) {
S <- length(log_ratios_i)
# shift log ratios for safer exponentation
lw_i <- log_ratios_i - max(log_ratios_i)
khat <- Inf

if (enough_tail_samples(tail_len_i)) {
ord <- sort.int(lw_i, index.return = TRUE)
tail_ids <- seq(S - tail_len_i + 1, S)
lw_tail <- ord$x[tail_ids]
if (abs(max(lw_tail) - min(lw_tail)) < .Machine$double.eps / 100) {
warning(
"Can't fit generalized Pareto distribution ",
"because all tail values are the same.",
call. = FALSE
)
} else {
cutoff <- ord$x[min(tail_ids) - 1] # largest value smaller than tail values
smoothed <- psis_smooth_tail(lw_tail, cutoff)
khat <- smoothed$k
lw_i[ord$ix[tail_ids]] <- smoothed$tail
}
if (length(unique(tail(log_ratios_i, -tail_len_i))) == 1) {
Comment thread
jgabry marked this conversation as resolved.
Outdated
warning(
"Can't fit generalized Pareto distribution ",
"because all tail values are the same.",
call. = FALSE
)
}

smoothed <- posterior::ps_tail(
x = lw_i,
ndraws_tail = tail_len_i,
tail = "right",
are_log_weights = TRUE
)

lw_i <- smoothed$x
khat <- smoothed$k

# truncate at max of raw wts (i.e., 0 since max has been subtracted)
lw_i[lw_i > 0] <- 0
# shift log weights back so that the smallest log weights remain unchanged
lw_i <- lw_i + max(log_ratios_i)

list(log_weights = lw_i, pareto_k = khat)
list(log_weights = lw_i, pareto_k = if (is.na(khat)) Inf else khat)
}

#' PSIS tail smoothing for a single vector
#'
#' @noRd
#' @param x Vector of tail elements already sorted in ascending order.
#' @return A named list containing:
#' * `tail`: vector same size as `x` containing the logs of the
#' order statistics of the generalized pareto distribution.
#' * `k`: scalar shape parameter estimate.
#'
psis_smooth_tail <- function(x, cutoff) {
len <- length(x)
exp_cutoff <- exp(cutoff)

# save time not sorting since x already sorted
fit <- gpdfit(exp(x) - exp_cutoff, sort_x = FALSE)
k <- fit$k
sigma <- fit$sigma
if (is.finite(k)) {
p <- (seq_len(len) - 0.5) / len
qq <- qgpd(p, k, sigma) + exp_cutoff
tail <- log(qq)
} else {
tail <- x
}
list(tail = tail, k = k)
}


#' Calculate tail lengths to use for fitting the GPD
#'
#' The number of weights (i.e., tail length) used to fit the generalized Pareto
Expand Down
64 changes: 64 additions & 0 deletions tests/testthat/_snaps/psis.md
Original file line number Diff line number Diff line change
Expand Up @@ -4801,6 +4801,70 @@
Warning:
Not enough tail samples to fit the generalized Pareto distribution in some or all columns of matrix of log importance ratios. Skipping the following columns: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, ... [22 more not printed].
Warning:
Can't fit generalized Pareto distribution because ndraws_tail is less than 5.
Warning:
Can't fit generalized Pareto distribution because ndraws_tail is less than 5.
Warning:
Can't fit generalized Pareto distribution because ndraws_tail is less than 5.
Warning:
Can't fit generalized Pareto distribution because ndraws_tail is less than 5.
Warning:
Can't fit generalized Pareto distribution because ndraws_tail is less than 5.
Warning:
Can't fit generalized Pareto distribution because ndraws_tail is less than 5.
Warning:
Can't fit generalized Pareto distribution because ndraws_tail is less than 5.
Warning:
Can't fit generalized Pareto distribution because ndraws_tail is less than 5.
Warning:
Can't fit generalized Pareto distribution because ndraws_tail is less than 5.
Warning:
Can't fit generalized Pareto distribution because ndraws_tail is less than 5.
Warning:
Can't fit generalized Pareto distribution because ndraws_tail is less than 5.
Warning:
Can't fit generalized Pareto distribution because ndraws_tail is less than 5.
Warning:
Can't fit generalized Pareto distribution because ndraws_tail is less than 5.
Warning:
Can't fit generalized Pareto distribution because ndraws_tail is less than 5.
Warning:
Can't fit generalized Pareto distribution because ndraws_tail is less than 5.
Warning:
Can't fit generalized Pareto distribution because ndraws_tail is less than 5.
Warning:
Can't fit generalized Pareto distribution because ndraws_tail is less than 5.
Warning:
Can't fit generalized Pareto distribution because ndraws_tail is less than 5.
Comment thread
VisruthSK marked this conversation as resolved.
Outdated
Warning:
Can't fit generalized Pareto distribution because ndraws_tail is less than 5.
Warning:
Can't fit generalized Pareto distribution because ndraws_tail is less than 5.
Warning:
Can't fit generalized Pareto distribution because ndraws_tail is less than 5.
Warning:
Can't fit generalized Pareto distribution because ndraws_tail is less than 5.
Warning:
Can't fit generalized Pareto distribution because ndraws_tail is less than 5.
Warning:
Can't fit generalized Pareto distribution because ndraws_tail is less than 5.
Warning:
Can't fit generalized Pareto distribution because ndraws_tail is less than 5.
Warning:
Can't fit generalized Pareto distribution because ndraws_tail is less than 5.
Warning:
Can't fit generalized Pareto distribution because ndraws_tail is less than 5.
Warning:
Can't fit generalized Pareto distribution because ndraws_tail is less than 5.
Warning:
Can't fit generalized Pareto distribution because ndraws_tail is less than 5.
Warning:
Can't fit generalized Pareto distribution because ndraws_tail is less than 5.
Warning:
Can't fit generalized Pareto distribution because ndraws_tail is less than 5.
Warning:
Can't fit generalized Pareto distribution because ndraws_tail is less than 5.
Warning:
Some Pareto k diagnostic values are too high. See help('pareto-k-diagnostic') for details.
Output
Computed from 10 by 32 log-weights matrix.
Expand Down
10 changes: 0 additions & 10 deletions tests/testthat/test_psis.R
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,3 @@ test_that("do_psis_i throws warning if all tail values the same", {
)
expect_equal(val$pareto_k, Inf)
})

test_that("psis_smooth_tail returns original tail values if k is infinite", {
# skip on M1 Mac until we figure out why this test fails only on M1 Mac
skip_if(Sys.info()[["sysname"]] == "Darwin" && R.version$arch == "aarch64")

xx <- c(1, 2, 3, 4, 4, 4, 4, 4, 4, 4, 4)
val <- suppressWarnings(psis_smooth_tail(xx, 3))
expect_equal(val$tail, xx)
expect_equal(val$k, Inf)
})
Loading