Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Generated by roxygen2: do not edit by hand

export(apply_target_location_exclusions)
export(assert_data_up_to_date)
export(check_authorized_users)
export(check_changes_for_autoapproval)
export(count_designated_models)
export(excluded_locations)
export(generate_hub_baseline)
export(generate_hub_ensemble)
export(generate_oracle_output)
Expand All @@ -23,7 +23,6 @@ export(get_nssp_col_name)
export(get_round_ids_vec)
export(get_target_data_type)
export(get_target_label)
export(included_locations)
export(is_ed_target)
export(is_hosp_target)
export(summarize_ref_date_forecasts)
Expand Down
75 changes: 1 addition & 74 deletions R/constants.R
Original file line number Diff line number Diff line change
@@ -1,74 +1 @@
#' Two digits FIPS codes for locations excluded from Hubs'
#' target data.
#'
#' Excludes Virgin Islands (78), Northern Mariana
#' Islands (69), Guam (66), American Samoa (60), and Minor
#' Outlying Islands (74).
#'
#' @export
excluded_locations <- c("78", "74", "69", "66", "60")

#' Two digits FIPS codes for locations included in Hubs'
#' target data.
#'
#' Includes 50 states, US national, DC, and Puerto Rico
#' (PR). Excludes Virgin Islands (78), Northern Mariana
#' Islands (69), Guam (66), American Samoa (60), and Minor
#' Outlying Islands (74).
#'
#' @export
included_locations <- c(
"01",
"02",
"04",
"05",
"06",
"08",
"09",
"10",
"11",
"12",
"13",
"15",
"16",
"17",
"18",
"19",
"20",
"21",
"22",
"23",
"24",
"25",
"26",
"27",
"28",
"29",
"30",
"31",
"32",
"33",
"34",
"35",
"36",
"37",
"38",
"39",
"40",
"41",
"42",
"44",
"45",
"46",
"47",
"48",
"49",
"50",
"51",
"53",
"54",
"55",
"56",
"72",
"US"
)
# constants used across hubhelpr functions.
176 changes: 176 additions & 0 deletions R/location_exclusions.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
#' Normalize excluded locations to a named list.
#'
#' Converts a character vector or named list of excluded
#' locations into a consistent named list format.
#' Validates that all abbreviations are valid US
#' state/territory abbreviations.
#'
#' @param excluded_locations NULL, character vector, or
#' named list of character vectors.
#'
#' @return Named list of character vectors, or NULL if
#' input is NULL or zero-length.
#' @noRd
normalize_excluded_locations <- function(excluded_locations) {
if (is.null(excluded_locations) || length(excluded_locations) == 0) {
return(NULL)
}
if (is.character(excluded_locations)) {
assert_valid_location_abbrs(excluded_locations)
return(list("all" = excluded_locations))
}
if (is.list(excluded_locations)) {
purrr::walk(excluded_locations, function(x) {
checkmate::assert_character(
x,
.var.name = "excluded_locations list values"
)
assert_valid_location_abbrs(x)
})
return(excluded_locations)
}
cli::cli_abort(
"{.arg excluded_locations} must be NULL, a character vector, or a named list."
)
}


#' Assert that location abbreviations are valid.
#'
#' Checks that all provided abbreviations are present
#' in the US location table (from forecasttools).
#' Errors with a message listing any invalid
#' abbreviations.
#'
#' @param abbrs Character vector of abbreviations to
#' validate.
#'
#' @return Invisible NULL. Called for side effects.
#' @noRd
assert_valid_location_abbrs <- function(abbrs) {
valid_abbrs <- forecasttools::us_location_table$abbr
invalid <- setdiff(abbrs, valid_abbrs)
if (length(invalid) > 0) {
cli::cli_abort(
"{.arg excluded_locations} contains invalid abbreviation{?s}: {.val {invalid}}."
)
}
}


#' Get excluded abbreviations for a specific target.
#'
#' Extracts the abbreviations that should be excluded
#' for a given target from a normalized exclusion list,
#' combining global ("all") exclusions with any
#' target-specific ones.
#'
#' @param normalized Named list as returned by
#' `normalize_excluded_locations()`.
#' @param target Character, the target name.
#'
#' @return Character vector of unique abbreviations to
#' exclude for this target.
#' @noRd
get_target_exclusions <- function(normalized, target) {
unique(c(normalized[["all"]], normalized[[target]]))
}


#' Apply target-specific location exclusions to a data
#' frame.
#'
#' Removes rows from a data frame based on
#' target-specific excluded location abbreviations.
#' Supports uniform exclusions (character vector applied
#' to all targets) and target-specific exclusions (named
#' list with target names as keys). Validates target
#' names against the targets present in the data.
#' Filters on the "target" and "location" columns via
#' anti-join.
#'
#' @param data Data frame with "target" and "location"
#' columns.
#' @param excluded_locations NULL, character vector, or
#' named list of US state/territory abbreviations to
#' exclude. If a character vector, locations are
#' excluded across all targets. If a named list, names
#' should be target names (or "all" for global
#' exclusions) mapping to character vectors of
#' abbreviations.
#'
#' @return Data frame with excluded rows removed.
#' @export
apply_target_location_exclusions <- function(
data,
excluded_locations
) {
normalized <- normalize_excluded_locations(excluded_locations)
if (is.null(normalized)) {
return(data)
}

data_targets <- unique(data$target)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not hub supported targets here? as before?

named_targets <- setdiff(names(normalized), "all")
unmatched <- setdiff(named_targets, data_targets)
if (length(unmatched) > 0) {
cli::cli_warn(
"{.arg excluded_locations} contains target{?s} not in data: {.val {unmatched}}."
)
}

exclusion_df <- purrr::map_df(data_targets, \(tgt) {
excl_abbrs <- get_target_exclusions(normalized, tgt)
if (length(excl_abbrs) == 0) {
return(tibble::tibble(target = character(), location = character()))
}
tibble::tibble(
target = tgt,
location = forecasttools::us_location_recode(excl_abbrs, "abbr", "hub")
)
})
Comment on lines +122 to +131
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will implicitly handle empty rows

Suggested change
exclusion_df <- purrr::map_df(data_targets, \(tgt) {
excl_abbrs <- get_target_exclusions(normalized, tgt)
if (length(excl_abbrs) == 0) {
return(tibble::tibble(target = character(), location = character()))
}
tibble::tibble(
target = tgt,
location = forecasttools::us_location_recode(excl_abbrs, "abbr", "hub")
)
})
exclusion_df <- exclusion_df <- dplyr::tibble(target = data_targets) |>
dplyr::mutate(
location = purrr::map(
target,
\(tgt) forecasttools::us_location_recode(
get_target_exclusions(normalized, tgt),
"abbr",
"hub"
)
)
) |>
tidyr::unnest_longer(location)


dplyr::anti_join(
data,
exclusion_df,
by = c("target", "location")
)
}


#' Filter data to included locations only.
#'
#' Only keeps rows where location is in the set of
#' valid US locations minus any excluded locations for
#' that target.
#'
#' @param data Data frame with "target" and "location"
#' columns.
#' @param excluded_locations NULL, character vector, or
#' named list of US state/territory abbreviations to
#' exclude.
#'
#' @return Data frame filtered to included locations.
#' @noRd
filter_to_included_locations <- function(
data,
excluded_locations
) {
normalized <- normalize_excluded_locations(excluded_locations)
all_valid_codes <- forecasttools::us_location_table$code

purrr::map_df(unique(data$target), \(tgt) {
if (!is.null(normalized)) {
excl_abbrs <- get_target_exclusions(normalized, tgt)
excl_codes <- forecasttools::us_location_recode(
excl_abbrs,
"abbr",
"hub"
)
included_codes <- setdiff(all_valid_codes, excl_codes)
} else {
included_codes <- all_valid_codes
}
dplyr::filter(data, .data$target == tgt, .data$location %in% included_codes)
})
}
Comment on lines +155 to +176
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a mismatch here between function name and argument. Name of the function makes it easy to misinterpret the argument.
I think rename the function name to filter_to_expected_location that takes in expected locations (default: forecasttools::us_location_table$code) and excluded locations (default: NULL).

Then create expected_df and exclusion_df

  expected_df <- tidyr::crossing(
    target = get_hub_supported_targets(),
    location = forecasttools::us_location_table$code
  )

  exclusion_df <- exclusion_df <- dplyr::tibble(target = data_targets) |>
    dplyr::mutate(
      location = purrr::map(
        target,
        \(tgt) forecasttools::us_location_recode(
          get_target_exclusions(normalized, tgt),
          "abbr",
          "hub"
        )
      )
    ) |>
    tidyr::unnest_longer(location)

  expected_target_location_df <- dplyr::anti_join(
    expected_df, exclusion_df
  )
  filtered <- dplyr::inner_join(
    data,
    expected_target_location_df,
    by = c("target", "location")
  )

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't love this, but I think it's clearer than the current approach and has same approach as apply_exclusion (https://github.com/CDCgov/hubhelpr/pull/178/files#r2990093161)
Open to other ideas if you have any?

93 changes: 11 additions & 82 deletions R/summarize_ref_date_forecasts.R
Original file line number Diff line number Diff line change
@@ -1,72 +1,3 @@
#' Normalize excluded locations to a named list.
#'
#' Converts a character vector or named list of excluded
#' locations into a consistent named list format.
#'
#' @param excluded_locations NULL, character vector, or
#' named list of character vector.
#'
#' @return Named list of character vectors.
#' @noRd
normalize_excluded_locations <- function(excluded_locations) {
if (is.null(excluded_locations)) {
return(list())
}
if (is.character(excluded_locations)) {
return(list("all" = excluded_locations))
}
if (is.list(excluded_locations)) {
return(excluded_locations)
}
cli::cli_abort(
"{.arg excluded_locations} must be NULL, a character vector, or a named list."
)
}


#' Build a target-location exclusion data frame.
#'
#' Constructs a tibble of target/location pairs to
#' exclude. Entries keyed by "all" are expanded into
#' one row per supported target. Errors if any named
#' targets in the exclusion list are not in
#' `supported_targets`.
#'
#' @param excluded_locations Named list as returned by
#' `normalize_excluded_locations()`.
#' @param supported_targets character vector of targets
#' the hub accepts, as returned by
#' `get_hub_supported_targets()`.
#'
#' @return A tibble with columns "target" and "location"
#' (hub codes).
#' @noRd
build_exclusion_df <- function(excluded_locations, supported_targets) {
named_targets <- setdiff(names(excluded_locations), "all")
invalid_targets <- setdiff(named_targets, supported_targets)
if (length(invalid_targets) > 0) {
cli::cli_abort(
"{.arg excluded_locations} contains unknown target{?s}: {.val {invalid_targets}}."
)
}

merged <- purrr::map(
purrr::set_names(supported_targets),
\(tgt) unique(c(excluded_locations[["all"]], excluded_locations[[tgt]]))
)

tibble::enframe(merged, name = "target", value = "location") |>
tidyr::unnest(cols = "location") |>
dplyr::mutate(
location = forecasttools::us_location_recode(
.data$location,
"abbr",
"hub"
)
)
}


#' Summarize forecast hub data for a specific reference date.
#'
#' This function generates a tibble of forecast data
Expand All @@ -84,13 +15,13 @@ build_exclusion_df <- function(excluded_locations, supported_targets) {
#' and "population". Adds population-based calculations.
#' @param horizons_to_include integer vector, horizons to
#' include in the output. Default: c(0, 1, 2).
#' @param excluded_locations character vector or named list
#' specifying US state abbreviations to exclude. If a
#' character vector, locations are excluded across all
#' targets. If a named list, names should be target names
#' (or "all" for global exclusions) mapping to character
#' vectors of abbreviations. Converted to hub codes
#' internally. Default: NULL.
#' @param excluded_locations NULL, character vector, or
#' named list of US state/territory abbreviations to
#' exclude. If a character vector, locations are excluded
#' across all targets. If a named list, names should be
#' target names (or "all" for global exclusions) mapping
#' to character vectors of abbreviations. Converted to
#' hub codes internally. Default: NULL (no exclusions).
#' @param targets character vector, target name(s) to filter
#' forecasts. If NULL (default), does not filter by target.
#' @param model_ids character vector of model IDs to include.
Expand All @@ -110,7 +41,6 @@ summarize_ref_date_forecasts <- function(
model_ids = NULL
) {
reference_date <- lubridate::as_date(reference_date)
excluded_locations <- normalize_excluded_locations(excluded_locations)

model_metadata <- hubData::load_model_metadata(
base_hub_path,
Expand All @@ -130,11 +60,10 @@ summarize_ref_date_forecasts <- function(
forecasttools::nullable_comparison(.data$model_id, "%in%", !!model_ids)
)

supported_targets <- get_hub_supported_targets(base_hub_path)
exclusion_df <- build_exclusion_df(excluded_locations, supported_targets)

current_forecasts <- current_forecasts |>
dplyr::anti_join(exclusion_df, by = c("target", "location"))
current_forecasts <- apply_target_location_exclusions(
current_forecasts,
excluded_locations
)

if (nrow(current_forecasts) == 0) {
model_filter_msg <- if (!is.null(model_ids)) {
Expand Down
Loading
Loading