# =============================================================================
# Neighborhood Disadvantage Index: Construction and Analysis
# ACS 2015-2019 5-Year Estimates + CDC PLACES 2021
# =============================================================================
# File structure:
#   Section 1: Package installation and loading
#   Section 2: Data acquisition (CDC PLACES + ACS)
#   Section 3: Data cleaning and variable construction
#   Section 4: Index construction (PCA + FA)
#   Section 5: Regression analysis and cross-validation
#   Section 6: Robustness checks (z-score NDI + LASSO)
#   Section 7: Visualization
# =============================================================================


# =============================================================================
# SECTION 1: Package Installation and Loading
# =============================================================================

# Install packages only if not already installed
required_packages <- c(
  "dplyr", "tidyr", "readr", "stringr", "ggplot2",
  "moments", "psych", "caret", "glmnet", "gridExtra", 
  "tidycensus", "patchwork", "purrr", "broom"
)
new_packages <- required_packages[!(required_packages %in% installed.packages()[, "Package"])]
if (length(new_packages) > 0) install.packages(new_packages)

library(dplyr)
library(tidyr)
library(readr)
library(stringr)
library(ggplot2)
library(moments)
library(psych)
library(caret)
library(glmnet)
library(gridExtra)
library(tidycensus)
library(patchwork)
library(purrr)
library(broom)

# Create output directories
dir.create("data/raw",       recursive = TRUE, showWarnings = FALSE)
dir.create("data/processed", recursive = TRUE, showWarnings = FALSE)
dir.create("outputs",                          showWarnings = FALSE)


# =============================================================================
# SECTION 2: Data Acquisition
# =============================================================================

# ---- 2a. CDC PLACES 2021 (Census Tract Level) --------------------------------
# Source: https://data.cdc.gov/resource/cwsq-ngmh
# Outcomes: MHLTH (poor mental health days prevalence), OBESITY (obesity prevalence)

url <- "https://data.cdc.gov/api/views/cwsq-ngmh/rows.csv?accessType=DOWNLOAD"
places_raw <- read_csv(url, col_types = cols(.default = col_character()))

# Filter to two outcomes and reshape to wide format
places_wide <- places_raw %>%
  filter(MeasureId %in% c("MHLTH", "OBESITY")) %>%
  select(LocationName, MeasureId, Data_Value) %>%
  pivot_wider(names_from = MeasureId, values_from = Data_Value) %>%
  rename(
    geoid         = LocationName,
    mental_health = MHLTH,
    obesity       = OBESITY
  ) %>%
  mutate(
    mental_health = as.numeric(mental_health),
    obesity       = as.numeric(obesity)
  )

# Inspection (commented out; run manually if needed):
# names(places_raw)
# unique(places_raw$MeasureId)
# colSums(is.na(places_wide))  #-> geoid: 0, obesity: 0, mental_health: 0
# dim(places_wide)             #-> 78,815 tracts x 3 columns

saveRDS(places_wide, "data/raw/places_clean.rds")

# ---- 2b. ACS 2015-2019 5-Year Estimates (Census Tract Level) -----------------
# Variables selected based on review of five NDI-related frameworks (Singh 2003, Kind 2014,
# Butler 2013, Sampson & Raudenbush 1999, Gladish et al. 2026) across six dimensions:
# Economic/Income, Employment, Education, Family Structure, Housing, Resources

census_api_key("1d642f78b20eea0cf8d24da2f4e0a023cd31cc62", install = TRUE, overwrite = TRUE)
readRenviron("~/.Renviron")
# Note: This key is provided for reproducibility. 

acs_vars_to_download <- c(
  # Economic/Income
  "B17001_001", "B17001_002",                          # poverty: denominator, below poverty
  "B19013_001",                                         # median household income
  "B22003_001", "B22003_002",                          # SNAP: denominator, received
  "B09010_001", "B09010_002",                          # public assistance: denominator, received
  # Employment
  "B23025_002", "B23025_005",                          # labor force, unemployed
  # Education (B15003: educational attainment, population 25+)
  "B15003_001",                                         # denominator
  "B15003_002", "B15003_003", "B15003_004", "B15003_005",
  "B15003_006", "B15003_007", "B15003_008", "B15003_009",
  "B15003_010", "B15003_011", "B15003_012", "B15003_013",
  "B15003_014", "B15003_015", "B15003_016",             # no HS diploma (codes 002-016)
  "B15003_017", "B15003_018", "B15003_019",
  "B15003_020", "B15003_021",                           # no bachelor's degree (codes 002-021)
  # Family Structure
  "B11003_001", "B11003_010", "B11003_016",             # total, single mother, single father
  # Housing
  "B25003_001", "B25003_003",                          # tenure: total, renter-occupied
  "B25014_001", "B25014_005", "B25014_006",
  "B25014_011", "B25014_012",                          # crowding: owner >1/room, renter >1/room
  "B25052_001", "B25052_003",                          # kitchen facilities: total, lacking
  "B25071_001",                                         # median gross rent as % of household income
  # Resources
  "B08201_001", "B08201_002",                          # vehicles available: total, none
  "B28003_001", "B28003_005", "B28003_006"             # internet: total, no broadband, no computer
)

# Download for all 50 states and DC; combine into single data frame
acs_raw <- map_dfr(
  c(state.abb, "DC"),
  ~ get_acs(
    geography = "tract",
    variables = acs_vars_to_download,
    state     = .x,
    year      = 2019,
    survey    = "acs5",
    output    = "wide"
  )
)

# Inspection (commented out; run manually if needed):
# dim(acs_raw)             -> 73,056 tracts x 98 columns
# colSums(is.na(acs_raw))  -> B19013: 1,024 NAs; B25071: 1,683 NAs
# na_b19013 <- acs_raw %>% filter(is.na(B19013_001E)) %>% pull(GEOID)
# na_b25071 <- acs_raw %>% filter(is.na(B25071_001E)) %>% pull(GEOID)
# acs_raw %>% filter(is.na(B19013_001E)) %>% 
#   count(str_sub(GEOID, 1, 2)) %>% 
#   arrange(desc(n))       -> NA distributed across 48 states, indicating random suppression not systematic bias
# length(intersect(na_b19013, na_b25071)) # -> 974 (95% overlap; same tracts suppressed)

saveRDS(acs_raw, "data/raw/acs_raw.rds")

# =============================================================================
# SECTION 3: Data Cleaning and Variable Construction
# =============================================================================

acs_raw    <- readRDS("data/raw/acs_raw.rds")
places_wide <- readRDS("data/raw/places_clean.rds")

# ---- 3a. Retain estimate columns and exclude suppressed tracts ---------------
# Exclude tracts with suppressed median income (B19013) or rent burden (B25071)
# nrow after exclusion: 71,323 = 73,056 - 1,024 - 1,683 + 974 (overlap)

acs <- acs_raw %>%
  select(GEOID, ends_with("E")) %>%
  filter(!is.na(B19013_001E) & !is.na(B25071_001E))


# ---- 3b. Compute indicators --------------------------------------------------

acs <- acs %>%
  mutate(
    # Economic/Income
    poverty_rate     = B17001_002E / B17001_001E,
    median_income    = B19013_001E,
    snap_rate        = B22003_002E / B22003_001E,
    pub_assist_rate  = B09010_002E / B09010_001E,

    # Employment
    unemployment_rate = B23025_005E / B23025_002E,

    # Education
    less_than_hs_rate = (B15003_002E + B15003_003E + B15003_004E + B15003_005E +
                           B15003_006E + B15003_007E + B15003_008E + B15003_009E +
                           B15003_010E + B15003_011E + B15003_012E + B15003_013E +
                           B15003_014E + B15003_015E + B15003_016E) / B15003_001E,
    less_than_ba_rate = (B15003_002E + B15003_003E + B15003_004E + B15003_005E +
                           B15003_006E + B15003_007E + B15003_008E + B15003_009E +
                           B15003_010E + B15003_011E + B15003_012E + B15003_013E +
                           B15003_014E + B15003_015E + B15003_016E + B15003_017E +
                           B15003_018E + B15003_019E + B15003_020E + B15003_021E) / B15003_001E,

    # Family Structure
    single_parent_rate = (B11003_010E + B11003_016E) / B11003_001E,

    # Housing
    renter_rate     = B25003_003E / B25003_001E,
    crowding_rate   = (B25014_005E + B25014_006E + B25014_011E + B25014_012E) / B25014_001E,
    no_kitchen_rate = B25052_003E / B25052_001E,
    rent_burden     = B25071_001E,

    # Resources
    no_vehicle_rate  = B08201_002E / B08201_001E,
    no_internet_rate = (B28003_005E + B28003_006E) / B28003_001E
  )


# ---- 3c. Handle structural zeros (zero-denominator NAs) ----------------------
# Inspection:
# colSums(is.na(acs %>% select(poverty_rate:no_internet_rate)))
# -> pub_assist_rate: 127, less_than_hs_rate: 1, less_than_ba_rate: 1,
#    single_parent_rate: 8, all others: 0
# NAs in pub_assist_rate (127), less_than_hs_rate (1), less_than_ba_rate (1),
# single_parent_rate (8) arise from tracts with zero denominator population.
# Replaced with 0 following ReADI convention for structural zeros.


acs <- acs %>%
  mutate(
    pub_assist_rate    = ifelse(is.na(pub_assist_rate),    0, pub_assist_rate),
    less_than_hs_rate  = ifelse(is.na(less_than_hs_rate),  0, less_than_hs_rate),
    less_than_ba_rate  = ifelse(is.na(less_than_ba_rate),  0, less_than_ba_rate),
    single_parent_rate = ifelse(is.na(single_parent_rate), 0, single_parent_rate)
  )

# Inspection:
# colSums(is.na(acs %>% select(poverty_rate:no_internet_rate))) -> all 0

# ---- 3c.5. Correlation checks before transformation -------------------------

# Economic dimension: poverty_rate, snap_rate, pub_assist_rate, median_income
# acs %>% select(poverty_rate, snap_rate, pub_assist_rate, median_income) %>%
#   cor(use = "complete.obs") %>% round(2)
# -> poverty_rate ~ snap_rate: 0.78
#    poverty_rate ~ pub_assist_rate: 0.74
#    snap_rate ~ pub_assist_rate: 0.85
#    median_income ~ others: -0.64 to -0.67

# Education dimension: less_than_hs_rate, less_than_ba_rate
# acs %>% select(less_than_hs_rate, less_than_ba_rate) %>%
#   cor(use = "complete.obs") %>% round(2)
# -> less_than_hs_rate ~ less_than_ba_rate: 0.64

# Housing dimension: renter_rate, crowding_rate, no_kitchen_rate, rent_burden
# acs %>% select(renter_rate, crowding_rate, no_kitchen_rate, rent_burden) %>%
#   cor(use = "complete.obs") %>% round(2)
# -> max correlation: 0.41 (renter_rate ~ crowding_rate); all retained


# ---- 3d. Log-transformation of right-skewed variables -----------------------
# Skewness > 1.4 threshold applied; log(x+1) for variables with zeros,
# log(x) for median_income (no zeros). Skewness values:
# no_kitchen_rate: 6.78, crowding_rate: 3.03, no_vehicle_rate: 3.03,
# unemployment_rate: 2.23, less_than_hs_rate: 1.56, median_income: 1.53,
# poverty_rate: 1.53, snap_rate: 1.52, single_parent_rate: 1.49

# Distribution and skewness checks:
# acs %>% select(poverty_rate:no_internet_rate) %>%
#   tidyr::pivot_longer(everything()) %>%
#   ggplot(aes(x = value)) + geom_histogram(bins = 50) +
#   facet_wrap(~name, scales = "free") + theme_minimal()
# acs %>% select(poverty_rate:no_internet_rate) %>%
#   summarise(across(everything(), skewness)) %>%
#   tidyr::pivot_longer(everything(), names_to = "variable", values_to = "skewness") %>%
#   arrange(desc(skewness))

acs <- acs %>%
  mutate(
    log_poverty_rate       = log(poverty_rate + 1),
    log_median_income      = log(median_income),
    log_snap_rate          = log(snap_rate + 1),
    log_unemployment_rate  = log(unemployment_rate + 1),
    log_less_than_hs_rate  = log(less_than_hs_rate + 1),
    log_single_parent_rate = log(single_parent_rate + 1),
    log_no_kitchen_rate    = log(no_kitchen_rate + 1),
    log_crowding_rate      = log(crowding_rate + 1),
    log_no_vehicle_rate    = log(no_vehicle_rate + 1)
  )

# Post-transformation skewness check:
# acs %>% select(starts_with("log_")) %>%
#   summarise(across(everything(), skewness)) %>%
#   tidyr::pivot_longer(everything(), names_to = "variable", values_to = "skewness") %>%
#   arrange(desc(skewness))
# -> log_no_kitchen_rate: 5.59, log_crowding_rate: 2.71, log_no_vehicle_rate: 2.51, log_unemployment_rate: 1.95
#  residual skewness in no_kitchen_rate, crowding_rate, no_vehicle_rate
#  due to true zeros, acceptable per Butler et al. (2013);
#  log_unemployment_rate residual skewness reflects genuine right tail
#  of high-unemployment tracts


# ---- 3e. Standardization: z-score and centile ranking -----------------------
# Z-score: applied to log-transformed and untransformed variables
# Centile ranking (following Butler et al. 2013): rank(x) / n * 100

acs <- acs %>%
  mutate(
    across(
      c(log_poverty_rate, log_median_income, log_snap_rate, pub_assist_rate,
        log_unemployment_rate, less_than_hs_rate, less_than_ba_rate,
        log_single_parent_rate, log_no_kitchen_rate, renter_rate,
        log_crowding_rate, rent_burden, log_no_vehicle_rate, no_internet_rate),
      scale, .names = "z_{.col}"
    ),
    across(
      c(poverty_rate, median_income, snap_rate, pub_assist_rate,
        unemployment_rate, less_than_hs_rate, less_than_ba_rate,
        single_parent_rate, no_kitchen_rate, renter_rate,
        crowding_rate, rent_burden, no_vehicle_rate, no_internet_rate),
      ~ rank(.) / length(.) * 100, .names = "centile_{.col}"
    )
  )


# ---- 3f. Merge ACS with PLACES ----------------------------------------------
# Inner join on 11-digit FIPS (GEOID); retains only tracts present in both datasets
# nrow(places_wide): 78,815; nrow(acs_final): 71,323; nrow(df): 55,534
# Unmatched tracts distributed across 49 states -> no systematic geographic bias

# GEOID format check (commented out):
# head(acs$GEOID); head(places_wide$geoid) -> both 11-digit FIPS, formats identical

acs_final <- acs %>%
  select(GEOID, starts_with("z_"), starts_with("centile_"))

df <- acs_final %>%
  inner_join(places_wide, by = c("GEOID" = "geoid"))

# Geographic distribution of unmatched tracts:
# places_wide %>% filter(!geoid %in% acs_final$GEOID) %>%
#   count(str_sub(geoid, 1, 2)) %>% arrange(desc(n))

saveRDS(df, "data/processed/df_merged.rds")


# =============================================================================
# SECTION 4: Index Construction
# =============================================================================

df <- readRDS("data/processed/df_merged.rds")

# ---- 4a. PCA to assess single-factor structure ------------------------------
# Compare z-score vs centile-ranked transformations
# z-score:  PC1 eigenvalue = 6.95 (49.6% variance); 4 components > 1
# centile:  PC1 eigenvalue = 7.26 (51.9% variance); 3 components > 1
# Centile ranking shows stronger single-factor structure -> selected for final FA

z_vars      <- df %>% select(starts_with("z_"))
centile_vars <- df %>% select(starts_with("centile_"))

pca_z <- principal(z_vars,       nfactors = 14, rotate = "none")
pca_c <- principal(centile_vars, nfactors = 14, rotate = "none")

# pca_z$values -> 6.95, 1.36, 1.13, 1.00, ...
# pca_c$values -> 7.26, 1.28, 1.03, ...
# pca_z$values[1] / sum(pca_z$values) -> 0.496 (49.6% variance explained by PC1)
# pca_c$values[1] / sum(pca_c$values) -> 0.519 (51.9% variance explained by PC1)


# ---- 4b. Single-factor FA (maximum likelihood) on centile variables ----------
# Initial FA with all 14 variables to check loadings
fa_result <- fa(centile_vars, nfactors = 1, rotate = "none", fm = "ml")
# fa_result$loadings -> centile_no_kitchen_rate = 0.158 (below 0.3 threshold)
# 0.3 loading threshold follows standard FA practice (Hair et al., 2010)

# Remove centile_no_kitchen_rate and rerun FA
centile_vars_final <- df %>%
  select(starts_with("centile_"), -centile_no_kitchen_rate)

fa_final <- fa(centile_vars_final, nfactors = 1, rotate = "none", fm = "ml")
# fa_final$loadings -> poverty: +0.886, median_income: -0.896
# Negative loading for median_income confirms factor captures disadvantage
# range(abs(fa_final$loadings)) -> 0.365 to 0.923

# fa_result$Vaccounted -> proportion of variance = 0.490
# fa_final$Vaccounted  -> proportion of variance = 0.525
# Removing centile_no_kitchen_rate improved variance explained from 49.0% to 52.5%

# Extract factor scores directly from FA output (following Gladish et al. 2026)
# and rescale to 0-100 for interpretability
df$ndi        <- fa_final$scores[, 1]
df$ndi_scaled <- (df$ndi - min(df$ndi)) / (max(df$ndi) - min(df$ndi)) * 100

# Correlation check:
# cor(df$ndi_scaled, df$obesity)       -> 0.661 (positive, as expected)
# cor(df$ndi_scaled, df$mental_health) -> 0.747 (positive, as expected)

saveRDS(df, "data/processed/df_final.rds")


# =============================================================================
# SECTION 5: Regression Analysis and Cross-Validation
# =============================================================================

df <- readRDS("data/processed/df_final.rds")

# ---- 5a. Primary OLS regressions --------------------------------------------
model_obesity <- lm(obesity       ~ ndi_scaled, data = df)
model_mental  <- lm(mental_health ~ ndi_scaled, data = df)

summary(model_obesity)
summary(model_mental)


# ---- 5b. 10-fold cross-validation -------------------------------------------
# Assesses model generalizability; set.seed(42) ensures reproducibility
set.seed(42)
ctrl <- trainControl(method = "cv", number = 10)

cv_obesity <- train(obesity       ~ ndi_scaled, data = df, method = "lm", trControl = ctrl)
cv_mental  <- train(mental_health ~ ndi_scaled, data = df, method = "lm", trControl = ctrl)

cv_obesity$results
cv_mental$results

# Export regression results as reproducible tables
results_table <- list(
  obesity = tidy(model_obesity),
  mental  = tidy(model_mental)
)

write.csv(results_table$obesity, "outputs/obesity_results.csv", row.names = FALSE)
write.csv(results_table$mental,  "outputs/mental_results.csv", row.names = FALSE)


# =============================================================================
# SECTION 6: Robustness Checks
# =============================================================================

# ---- 6a. Robustness Check 1: Z-score NDI ------------------------------------
# Alternative standardization method to assess sensitivity of primary results

z_vars_final <- df %>%
  select(starts_with("z_"), -z_log_no_kitchen_rate)

fa_z     <- fa(z_vars_final, nfactors = 1, rotate = "none", fm = "ml")
df$ndi_z <- fa_z$scores[, 1]

model_obesity_z <- lm(obesity       ~ ndi_z, data = df)
model_mental_z  <- lm(mental_health ~ ndi_z, data = df)

summary(model_obesity_z)
summary(model_mental_z)


# ---- 6b. Robustness Check 2: LASSO ------------------------------------------
# Direct variable selection using all 13 centile variables as predictors
# Tests whether individual ACS variables improve prediction beyond composite NDI

X <- as.matrix(df %>% select(starts_with("centile_"), -centile_no_kitchen_rate))

set.seed(42)
lasso_obesity <- cv.glmnet(X, df$obesity,       alpha = 1, nfolds = 10)
lasso_mental  <- cv.glmnet(X, df$mental_health, alpha = 1, nfolds = 10)

coef(lasso_obesity, s = "lambda.min")
coef(lasso_mental,  s = "lambda.min")

# R-squared for LASSO models
lasso_obesity_r2 <- max(1 - lasso_obesity$cvm / var(df$obesity))
lasso_mental_r2  <- max(1 - lasso_mental$cvm  / var(df$mental_health))

lasso_obesity_r2
lasso_mental_r2


# =============================================================================
# SECTION 7: Visualization
# =============================================================================

# ---- Figure 1. Scree Plot ---------------------------------------------------

eigenvalues <- data.frame(
  component  = 1:14,
  eigenvalue = pca_c$values
)

ggplot(eigenvalues, aes(x = component, y = eigenvalue)) +
  geom_line(color = "steelblue") +
  geom_point(color = "steelblue", size = 2) +
  geom_hline(yintercept = 1, linetype = "dashed", color = "gray50") +
  scale_x_continuous(breaks = 1:14) +
  labs(
    title = "Figure 1. Scree Plot of Principal Components",
    x     = "Component",
    y     = "Eigenvalue"
  ) +
  theme_minimal()

ggsave("outputs/scree_plot.png", width = 7, height = 4, dpi = 300)


# ---- Figures 2a & 2b. FA Loadings ------------------------------------------

loadings_df <- data.frame(
  variable  = rownames(fa_final$loadings),
  loading   = as.numeric(fa_final$loadings),
  dimension = c("Economic", "Economic", "Economic", "Economic",
                "Employment", "Education", "Education",
                "Family", "Housing", "Housing",
                "Housing", "Resources", "Resources")
)

# Figure 2a: Bar chart (ordered by loading magnitude)
p2a <- ggplot(loadings_df, aes(x = reorder(variable, loading), y = loading, fill = loading > 0)) +
  geom_col() +
  coord_flip() +
  scale_fill_manual(values = c("TRUE" = "#2166ac", "FALSE" = "#d73027"),
                    labels = c("Negative", "Positive"),
                    name   = "Direction") +
  labs(
    title = "Figure 2a. Factor Loadings:\nNeighborhood Disadvantage Index",
    x     = "",
    y     = "Loading"
  ) +
  theme_minimal()


# Figure 2b: Dot plot grouped by dimension
p2b <- ggplot(loadings_df, aes(x = loading, y = reorder(variable, loading), color = dimension)) +
  geom_point(size = 4) +
  geom_vline(xintercept = 0, linetype = "dashed", color = "gray50") +
  labs(
    title = "Figure 2b. Factor Loadings by Dimension",
    x     = "Loading",
    y     = "",
    color = "Dimension"
  ) +
  theme_minimal()

p2a + p2b
ggsave("outputs/fa_loadings_combined.png", width = 14, height = 5, dpi = 300)


# ---- Figure 3. NDI Distribution ---------------------------------------------

ggplot(df, aes(x = ndi_scaled)) +
  geom_histogram(bins = 100, fill = "steelblue", color = "white", alpha = 0.8) +
  geom_vline(xintercept = mean(df$ndi_scaled),
             linetype = "dashed", color = "red", linewidth = 0.8) +
  annotate("text",
           x = mean(df$ndi_scaled) + 3, y = Inf, vjust = 2,
           label = paste0("Mean = ", round(mean(df$ndi_scaled), 1)),
           color = "red", size = 3.5) +
  labs(
    title    = "Figure 3. Distribution of Neighborhood Disadvantage Index (NDI)",
    subtitle = paste0("N = ", format(nrow(df), big.mark = ","),
                      " census tracts | Mean = ", round(mean(df$ndi_scaled), 1),
                      ", SD = ", round(sd(df$ndi_scaled), 1)),
    x = "NDI Score (0-100)",
    y = "Number of Census Tracts"
  ) +
  theme_minimal() +
  theme(plot.title = element_text(size = 11))

ggsave("outputs/ndi_distribution.png", width = 7, height = 4, dpi = 300)


# ---- Figures 4a & 4b. LASSO Coefficients ( Obesity & Mental Health) ---------------------------------

# Figure 4a: LASSO Coefficients (Obesity)
lasso_coef_df <- coef(lasso_obesity, s = "lambda.min") %>%
  as.matrix() %>%
  as.data.frame() %>%
  tibble::rownames_to_column("variable") %>%
  rename(coefficient = lambda.min) %>%
  filter(variable != "(Intercept)", coefficient != 0) %>%
  mutate(
    variable = recode(variable,
      "centile_poverty_rate"       = "Poverty rate",
      "centile_median_income"      = "Median household income",
      "centile_snap_rate"          = "SNAP receipt rate",
      "centile_pub_assist_rate"    = "Public assistance rate",
      "centile_unemployment_rate"  = "Unemployment rate",
      "centile_less_than_hs_rate"  = "Less than HS diploma rate",
      "centile_less_than_ba_rate"  = "Less than bachelor's degree rate",
      "centile_single_parent_rate" = "Single-parent household rate",
      "centile_renter_rate"        = "Renter-occupied rate",
      "centile_crowding_rate"      = "Crowding rate",
      "centile_rent_burden"        = "Rent burden",
      "centile_no_vehicle_rate"    = "No-vehicle rate",
      "centile_no_internet_rate"   = "No internet access rate"
    ),
    direction = ifelse(coefficient > 0, "Positive", "Negative")
  ) %>%
  arrange(desc(abs(coefficient)))

p4a <- ggplot(lasso_coef_df, aes(x = reorder(variable, abs(coefficient)),
                          y = coefficient, fill = direction)) +
  geom_col() +
  coord_flip() +
  scale_fill_manual(values = c("Positive" = "#2166ac", "Negative" = "#d73027")) +
  labs(
    title = "Figure 4a. LASSO Coefficients:\nObesity Outcome",
    subtitle = "Variables selected at lambda.min (10-fold CV)",
    x        = "",
    y        = "LASSO Coefficient",
    fill     = "Direction"
  ) +
  theme_minimal()

# Figure 4b: LASSO Coefficients (Mental Health)
lasso_coef_mental_df <- coef(lasso_mental, s = "lambda.min") %>%
  as.matrix() %>%
  as.data.frame() %>%
  tibble::rownames_to_column("variable") %>%
  rename(coefficient = lambda.min) %>%
  filter(variable != "(Intercept)", coefficient != 0) %>%
  mutate(
    variable = recode(variable,
                      "centile_poverty_rate"       = "Poverty rate",
                      "centile_median_income"      = "Median household income",
                      "centile_snap_rate"          = "SNAP receipt rate",
                      "centile_pub_assist_rate"    = "Public assistance rate",
                      "centile_unemployment_rate"  = "Unemployment rate",
                      "centile_less_than_hs_rate"  = "Less than HS diploma rate",
                      "centile_less_than_ba_rate"  = "Less than bachelor's degree rate",
                      "centile_single_parent_rate" = "Single-parent household rate",
                      "centile_renter_rate"        = "Renter-occupied rate",
                      "centile_crowding_rate"      = "Crowding rate",
                      "centile_rent_burden"        = "Rent burden",
                      "centile_no_vehicle_rate"    = "No-vehicle rate",
                      "centile_no_internet_rate"   = "No internet access rate"
    ),
    direction = ifelse(coefficient > 0, "Positive", "Negative")
  ) %>%
  arrange(desc(abs(coefficient)))

p4b <- ggplot(lasso_coef_mental_df, aes(x = reorder(variable, abs(coefficient)),
                                 y = coefficient, fill = direction)) +
  geom_col() +
  coord_flip() +
  scale_fill_manual(values = c("Positive" = "#2166ac", "Negative" = "#d73027")) +
  labs(
    title    = "Figure 4b. LASSO Coefficients:\nMental Health Outcome",
    subtitle = "Variables selected at lambda.min (10-fold CV)",
    x        = "",
    y        = "LASSO Coefficient",
    fill     = "Direction"
  ) +
  theme_minimal()

p4a + p4b
ggsave("outputs/lasso_combined.png", width = 14, height = 5, dpi = 300)


# ---- Figure 5. Predicted vs Actual ------------------------------------------

df$pred_obesity <- predict(model_obesity)
df$pred_mental  <- predict(model_mental)

p1 <- ggplot(df, aes(x = pred_obesity, y = obesity)) +
  geom_point(alpha = 0.1, size = 0.5) +
  geom_abline(slope = 1, intercept = 0, color = "red") +
  labs(title = "Obesity: Predicted vs Actual",
       x = "Predicted", y = "Actual") +
  theme_minimal()

p2 <- ggplot(df, aes(x = pred_mental, y = mental_health)) +
  geom_point(alpha = 0.1, size = 0.5) +
  geom_abline(slope = 1, intercept = 0, color = "red") +
  labs(title = "Mental Health: Predicted vs Actual",
       x = "Predicted", y = "Actual") +
  theme_minimal()

grid.arrange(p1, p2, ncol = 2,
             top = "Figure 5. Predicted vs Actual Values: Primary OLS Models")

ggsave("outputs/predicted_vs_actual.png", 
       plot = gridExtra::arrangeGrob(p1, p2, ncol = 2,
                                     top = "Figure 5. Predicted vs Actual Values: Primary OLS Models"),
       width = 10, height = 5, dpi = 300)

