Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
f641698
Updated partition_tracker to track auxiliary data for CLogLog Ordinal…
Entejar Sep 17, 2025
e99791b
Added leaf model for CLogLog Ordinal BART
Entejar Sep 17, 2025
8f77e15
Added ordinal_sampler
Entejar Sep 17, 2025
8547425
Updated tree_sampler.h
Entejar Sep 17, 2025
6c1d3ce
Updated sampler.cpp
Entejar Sep 17, 2025
955a211
Merge branch 'StochasticTree:main' into main
Entejar Sep 28, 2025
084be88
Added cloglog_ordinal_bart.R function
Entejar Sep 28, 2025
c8492fb
Tested CLogLog Ordinal BART — running successfully!
Entejar Sep 28, 2025
444c067
Added vignette for CLogLog Ordinal Bart
Entejar Sep 29, 2025
132071e
Update leaf_model.h
Entejar Sep 30, 2025
74cff51
Merge branch 'StochasticTree:main' into main
Entejar Oct 8, 2025
7bc3eb3
Merge branch 'main' into pr/196
andrewherren Oct 24, 2025
18c9e15
Migrated auxiliary data to ForestDataset from ForestTracker
andrewherren Oct 27, 2025
36b6a98
Removed call to deprectated cpp function
andrewherren Oct 27, 2025
2de707c
Fixed indexing bug
andrewherren Oct 27, 2025
2d19399
Refactored and fixed bugs
andrewherren Oct 27, 2025
f9a0b5a
Updated multinomial cloglog vignette
andrewherren Oct 27, 2025
66164f5
Added binary outcome cloglog model demo
andrewherren Oct 27, 2025
d5de763
Reworking sampler implementation to match current stochtree::main API
andrewherren Oct 27, 2025
0302459
Reflecting num_threads further through the interface
andrewherren Oct 27, 2025
a7c79d4
Refactoring out unused slice sampler for leaf scale parameter
andrewherren Oct 27, 2025
6ffdef7
Adding num_threads (back) to GFR interface
andrewherren Oct 27, 2025
853b129
Continue building in multithreading support to cloglog branch
andrewherren Oct 27, 2025
9edad36
Update tree_sampler.h
andrewherren Oct 27, 2025
04de102
Updating GFR to reflect multithreading capabilities in the main branch
andrewherren Oct 27, 2025
cdca915
Reflecting num_threads through the MCMC and GFR interface
andrewherren Oct 27, 2025
bf21447
Set up cloglog to work with GFR and updated examples
andrewherren Oct 27, 2025
a5cee2b
Updating vignettes
andrewherren Oct 28, 2025
815c538
WIP fix for data augmentation in the binary case
andrewherren Oct 28, 2025
2106f32
Updating sampler
andrewherren Oct 28, 2025
02dc2ac
Remove unused slice sampler code
andrewherren Oct 28, 2025
8f92425
Cleaned up PR
andrewherren Oct 28, 2025
95a0ce9
Including variant in leaf model header file
andrewherren Oct 28, 2025
42f9ac4
Updated vignettes and function defaults
andrewherren Oct 28, 2025
4f576a6
Added a release candidate readme
andrewherren Oct 28, 2025
dc25a1d
Updated demo scripts
andrewherren Oct 28, 2025
e0ccb02
WIP python frontend for cloglog ordinal BART
andrewherren Oct 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ file(
src/io.cpp
src/json11.cpp
src/leaf_model.cpp
src/ordinal_sampler.cpp
src/partition_tracker.cpp
src/random_effects.cpp
src/tree.cpp
Expand Down
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Description: Flexible stochastic tree ensemble software.
License: MIT + file LICENSE
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.2
RoxygenNote: 7.3.3
LinkingTo:
cpp11, BH
Suggests:
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ S3method(predict,bcfmodel)
export(bart)
export(bcf)
export(calibrateInverseGammaErrorVariance)
export(cloglog_ordinal_bart)
export(computeForestLeafIndices)
export(computeForestLeafVariances)
export(computeForestMaxLeafIndex)
Expand Down
218 changes: 218 additions & 0 deletions R/cloglog_ordinal_bart.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
#' Run the BART algorithm for ordinal outcomes using a complementary log-log link
#'
#' @param X A numeric matrix of predictors (training data).
#' @param y A numeric vector of ordinal outcomes (positive integers starting from 1).
#' @param X_test An optional numeric matrix of predictors (test data).
#' @param n_trees Number of trees in the BART ensemble. Default: `50`.
#' @param num_gfr Number of GFR samples to draw at the beginning of the sampler. Default: `0`.
#' @param num_burnin Number of burn-in MCMC samples to discard. Default: `1000`.
#' @param num_mcmc Total number of MCMC samples to draw. Default: `500`.
#' @param n_thin Thinning interval for MCMC samples. Default: `1`.
#' @param alpha_gamma Shape parameter for the log-gamma prior on cutpoints. Default: `2.0`.
#' @param beta_gamma Rate parameter for the log-gamma prior on cutpoints. Default: `2.0`.
#' @param variable_weights (Optional) vector of variable weights for splitting (default: equal weights).
#' @param feature_types (Optional) vector indicating feature types (0 for continuous, 1 for categorical; default: all continuous).
#' @param seed (Optional) random seed for reproducibility.
#' @param num_threads (Optional) Number of threads to use in split evaluations and other compute-intensive operations. Default: 1.
#' @export
cloglog_ordinal_bart <- function(X, y, X_test = NULL,
n_trees = 50,
num_gfr = 0,
num_burnin = 1000,
num_mcmc = 500,
n_thin = 1,
alpha_gamma = 2.0,
beta_gamma = 2.0,
variable_weights = NULL,
feature_types = NULL,
seed = NULL,
num_threads = 1) {
# BART parameters
alpha_bart <- 0.95
beta_bart <- 2
min_samples_in_leaf <- 5
max_depth <- 10
scale_leaf <- 2 / sqrt(n_trees)
cutpoint_grid_size <- 100 # Needed for stochtree::sample_gfr_one_iteration_cpp, not used in MCMC BART

# Fixed for identifiability (can be pass as argument later if desired)
gamma_0 = 0.0 # First gamma cutpoint fixed at gamma_0 = 0

# Determine whether a test dataset is provided
has_test <- !is.null(X_test)

# Data checks
if (!is.matrix(X)) X <- as.matrix(X)
if (!is.numeric(y)) y <- as.numeric(y)
if (has_test && !is.matrix(X_test)) X_test <- as.matrix(X_test)

n_samples <- nrow(X)
n_features <- ncol(X)

if (any(y < 1) || any(y != round(y))) {
stop("Ordinal outcome y must contain positive integers starting from 1")
}

# Convert from 1-based (R) to 0-based (C++) indexing
ordinal_outcome <- as.integer(y - 1)
n_levels <- max(y) # Number of ordinal categories

if (n_levels < 2) {
stop("Ordinal outcome must have at least 2 categories")
}

if (is.null(variable_weights)) {
variable_weights <- rep(1.0, n_features)
}

if (is.null(feature_types)) {
feature_types <- rep(0L, n_features)
}

if (!is.null(seed)) {
set.seed(seed)
}

# Indices of MCMC samples to keep after GFR, burn-in, and thinning
keep_idx <- seq(num_gfr + num_burnin + 1, num_gfr + num_burnin + num_mcmc, by = n_thin)
n_keep <- length(keep_idx)

# Storage for MCMC samples
forest_pred_train <- matrix(0, n_samples, n_keep)
if (has_test) {
n_samples_test <- nrow(X_test)
forest_pred_test <- matrix(0, n_samples_test, n_keep)
}
gamma_samples <- matrix(0, n_levels - 1, n_keep)
latent_samples <- matrix(0, n_samples, n_keep)

# Initialize samplers
ordinal_sampler <- stochtree:::ordinal_sampler_cpp()
rng <- stochtree::createCppRNG(if (is.null(seed)) sample.int(.Machine$integer.max, 1) else seed)

# Initialize other model structures as before
dataX <- stochtree::createForestDataset(X)
if (has_test) {
dataXtest <- stochtree::createForestDataset(X_test)
}
outcome_data <- stochtree::createOutcome(as.numeric(ordinal_outcome))
active_forest <- stochtree::createForest(as.integer(n_trees), 1L, TRUE, FALSE) # Use constant leaves
active_forest$set_root_leaves(0.0)
split_prior <- stochtree:::tree_prior_cpp(alpha_bart, beta_bart, min_samples_in_leaf, max_depth)
forest_samples <- stochtree::createForestSamples(as.integer(n_trees), 1L, TRUE, FALSE) # Use constant leaves
forest_tracker <- stochtree:::forest_tracker_cpp(
dataX$data_ptr,
as.integer(feature_types),
as.integer(n_trees),
as.integer(n_samples)
)

# Latent variable (Z in Alam et al (2025) notation)
dataX$add_auxiliary_dimension(nrow(X))
# Forest predictions (eta in Alam et al (2025) notation)
dataX$add_auxiliary_dimension(nrow(X))
# Log-scale non-cumulative cutpoint (gamma in Alam et al (2025) notation)
dataX$add_auxiliary_dimension(n_levels - 1)
# Exponentiated cumulative cutpoints (exp(c_k) in Alam et al (2025) notation)
# This auxiliary series is designed so that the element stored at position `i`
# corresponds to the sum of all exponentiated gamma_j values for j < i.
# It has n_levels elements instead of n_levels - 1 because even the largest
# categorical index has a valid value of sum_{j < i} exp(gamma_j)
dataX$add_auxiliary_dimension(n_levels)

# Initialize gamma parameters to zero (3rd auxiliary data series, mapped to `dim_idx = 2` with 0-indexing)
initial_gamma <- rep(0.0, n_levels - 1)
for (i in seq_along(initial_gamma)) {
dataX$set_auxiliary_data_value(2, i - 1, initial_gamma[i])
}

# Convert the log-scale parameters into cumulative exponentiated parameters.
# This is done under the hood in a C++ function for efficiency.
stochtree:::ordinal_sampler_update_cumsum_exp_cpp(ordinal_sampler, dataX$data_ptr)

# Initialize forest predictions to zero (slot 1)
for (i in 1:n_samples) {
dataX$set_auxiliary_data_value(1, i - 1, 0.0)
}

# Initialize latent variables to zero (slot 0)
for (i in 1:n_samples) {
dataX$set_auxiliary_data_value(0, i - 1, 0.0)
}

# Set up sweep indices for tree updates (sample all trees each iteration)
sweep_indices <- as.integer(seq(0, n_trees - 1))

sample_counter <- 0
for (i in 1:(num_mcmc + num_burnin + num_gfr)) {
keep_sample <- i %in% keep_idx
if (keep_sample) {
sample_counter <- sample_counter + 1
}

# 1. Sample forest using MCMC
if (i > num_gfr) {
stochtree:::sample_mcmc_one_iteration_cpp(
dataX$data_ptr, outcome_data$data_ptr, forest_samples$forest_container_ptr,
active_forest$forest_ptr, forest_tracker, split_prior, rng$rng_ptr,
sweep_indices, as.integer(feature_types), as.integer(cutpoint_grid_size),
scale_leaf, variable_weights, alpha_gamma, beta_gamma, 1.0, 4L, keep_sample,
num_threads
)
} else {
stochtree:::sample_gfr_one_iteration_cpp(
dataX$data_ptr, outcome_data$data_ptr, forest_samples$forest_container_ptr,
active_forest$forest_ptr, forest_tracker, split_prior, rng$rng_ptr,
sweep_indices, as.integer(feature_types), as.integer(cutpoint_grid_size),
scale_leaf, variable_weights, alpha_gamma, beta_gamma, 1.0, 4L, keep_sample,
ncol(X), num_threads
)
}

# Set auxiliary data slot 1 to current forest predictions = lambda_hat = sum of all the tree predictions
# This is needed for updating gamma parameters, latent z_i's
forest_pred_current <- active_forest$predict(dataX)
for (i in 1:n_samples) {
dataX$set_auxiliary_data_value(1, i - 1, forest_pred_current[i]);
}

# 2. Sample latent z_i's using truncated exponential
stochtree:::ordinal_sampler_update_latent_variables_cpp(
ordinal_sampler, dataX$data_ptr, outcome_data$data_ptr, rng$rng_ptr
)

# 3. Sample gamma parameters
stochtree:::ordinal_sampler_update_gamma_params_cpp(
ordinal_sampler, dataX$data_ptr, outcome_data$data_ptr,
alpha_gamma, beta_gamma, gamma_0, rng$rng_ptr
)

# 4. Update cumulative sum of exp(gamma) values
stochtree:::ordinal_sampler_update_cumsum_exp_cpp(ordinal_sampler, dataX$data_ptr)

if (keep_sample) {
forest_pred_train[, sample_counter] <- active_forest$predict(dataX)
if (has_test) {
forest_pred_test[, sample_counter] <- active_forest$predict(dataXtest)
}
gamma_current <- dataX$get_auxiliary_data_vector(2)
gamma_samples[, sample_counter] <- gamma_current
latent_current <- dataX$get_auxiliary_data_vector(0)
latent_samples[, sample_counter] <- latent_current
}
}

result <- list(
forest_predictions_train = forest_pred_train,
forest_predictions_test = if (has_test) forest_pred_test else NULL,
gamma_samples = gamma_samples,
latent_samples = latent_samples,
scale_leaf = scale_leaf,
ordinal_outcome = ordinal_outcome,
n_trees = n_trees,
n_keep = n_keep
)

class(result) <- "cloglog_ordinal_bart"
return(result)
}
40 changes: 40 additions & 0 deletions R/cpp11.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,30 @@ forest_dataset_get_variance_weights_cpp <- function(dataset_ptr) {
.Call(`_stochtree_forest_dataset_get_variance_weights_cpp`, dataset_ptr)
}

forest_dataset_has_auxiliary_dimension_cpp <- function(dataset_ptr, dim_idx) {
.Call(`_stochtree_forest_dataset_has_auxiliary_dimension_cpp`, dataset_ptr, dim_idx)
}

forest_dataset_add_auxiliary_dimension_cpp <- function(dataset_ptr, dim_size) {
invisible(.Call(`_stochtree_forest_dataset_add_auxiliary_dimension_cpp`, dataset_ptr, dim_size))
}

forest_dataset_get_auxiliary_data_value_cpp <- function(dataset_ptr, dim_idx, element_idx) {
.Call(`_stochtree_forest_dataset_get_auxiliary_data_value_cpp`, dataset_ptr, dim_idx, element_idx)
}

forest_dataset_set_auxiliary_data_value_cpp <- function(dataset_ptr, dim_idx, element_idx, value) {
invisible(.Call(`_stochtree_forest_dataset_set_auxiliary_data_value_cpp`, dataset_ptr, dim_idx, element_idx, value))
}

forest_dataset_get_auxiliary_data_vector_cpp <- function(dataset_ptr, dim_idx) {
.Call(`_stochtree_forest_dataset_get_auxiliary_data_vector_cpp`, dataset_ptr, dim_idx)
}

forest_dataset_store_auxiliary_data_vector_as_column_cpp <- function(dataset_ptr, output_matrix, dim_idx, matrix_col_idx) {
.Call(`_stochtree_forest_dataset_store_auxiliary_data_vector_as_column_cpp`, dataset_ptr, output_matrix, dim_idx, matrix_col_idx)
}

create_column_vector_cpp <- function(outcome) {
.Call(`_stochtree_create_column_vector_cpp`, outcome)
}
Expand Down Expand Up @@ -692,6 +716,22 @@ sample_without_replacement_integer_cpp <- function(population_vector, sampling_p
.Call(`_stochtree_sample_without_replacement_integer_cpp`, population_vector, sampling_probs, sample_size)
}

ordinal_sampler_cpp <- function() {
.Call(`_stochtree_ordinal_sampler_cpp`)
}

ordinal_sampler_update_latent_variables_cpp <- function(sampler_ptr, data_ptr, outcome_ptr, rng_ptr) {
invisible(.Call(`_stochtree_ordinal_sampler_update_latent_variables_cpp`, sampler_ptr, data_ptr, outcome_ptr, rng_ptr))
}

ordinal_sampler_update_gamma_params_cpp <- function(sampler_ptr, data_ptr, outcome_ptr, alpha_gamma, beta_gamma, gamma_0, rng_ptr) {
invisible(.Call(`_stochtree_ordinal_sampler_update_gamma_params_cpp`, sampler_ptr, data_ptr, outcome_ptr, alpha_gamma, beta_gamma, gamma_0, rng_ptr))
}

ordinal_sampler_update_cumsum_exp_cpp <- function(sampler_ptr, data_ptr) {
invisible(.Call(`_stochtree_ordinal_sampler_update_cumsum_exp_cpp`, sampler_ptr, data_ptr))
}

init_json_cpp <- function() {
.Call(`_stochtree_init_json_cpp`)
}
Expand Down
53 changes: 53 additions & 0 deletions R/data.R
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,59 @@ ForestDataset <- R6::R6Class(
#' @return True if variance weights are loaded, false otherwise
has_variance_weights = function() {
return(dataset_has_variance_weights_cpp(self$data_ptr))
},

#' @description
#' Whether or not a dataset has auxiliary data stored at the dimension indicated
#' @param dim_idx Dimension of auxiliary data
#' @return True if auxiliary data has been allocated for `dim_idx` False otherwise
has_auxiliary_dimension = function(dim_idx) {
return(forest_dataset_has_auxiliary_dimension_cpp(self$data_ptr, dim_idx))
},

#' @description
#' Initialize a new dimension / lane of auxiliary data and allocate data in its place
#' @param dim_size Size of the new vector of data to allocate
#' @return None
add_auxiliary_dimension = function(dim_size) {
return(forest_dataset_add_auxiliary_dimension_cpp(self$data_ptr, dim_size))
},

#' @description
#' Retrieve auxiliary data value
#' @param dim_idx Dimension from which data value to be retrieved
#' @param element_idx Element to retrieve from dimension `dim_idx`
#' @return Floating point value stored in the requested auxiliary data space
get_auxiliary_data_value = function(dim_idx, element_idx) {
return(forest_dataset_get_auxiliary_data_value_cpp(self$data_ptr, dim_idx, element_idx))
},

#' @description
#' Set auxiliary data value
#' @param dim_idx Dimension in which data value to be set
#' @param element_idx Element to set within dimension `dim_idx`
#' @param value Data value to set at auxiliary data dimension `dim_idx` and element `element_idx`
#' @return None
set_auxiliary_data_value = function(dim_idx, element_idx, value) {
return(forest_dataset_set_auxiliary_data_value_cpp(self$data_ptr, dim_idx, element_idx, value))
},

#' @description
#' Retrieve entire auxiliary data vector
#' @param dim_idx Dimension to retrieve
#' @return Vector of all of the auxiliary data stored at dimension `dim_idx`
get_auxiliary_data_vector = function(dim_idx) {
return(forest_dataset_get_auxiliary_data_vector_cpp(self$data_ptr, dim_idx))
},

#' @description
#' Retrieve auxiliary data vector and place it into a column of the supplied matrix
#' @param output_matrix Matrix to be overwritten
#' @param dim_idx Auxiliary data dimension to retrieve
#' @param matrix_col_idx Matrix column in which to copy auxiliary data
#' @return Vector of all of the auxiliary data stored at dimension `dim_idx`
store_auxiliary_data_vector_matrix = function(output_matrix, dim_idx, matrix_col_idx) {
return(forest_dataset_store_auxiliary_data_vector_as_column_cpp(self$data_ptr, output_matrix, dim_idx, matrix_col_idx))
}
)
)
Expand Down
1 change: 0 additions & 1 deletion R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,6 @@ createForestModel <- function(
))
}


#' Draw `sample_size` samples from `population_vector` without replacement, weighted by `sampling_probabilities`
#'
#' @param population_vector Vector from which to draw samples.
Expand Down
16 changes: 16 additions & 0 deletions RC_README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Release Candidate for StochTree Cloglog BART

This branch serves as a staging / testing zone for the planned incorporation of BART / BCF with a complementary log-log link function into `stochtree`.

## Installation

The cloglog release candidate version of `stochtree` can be installed from github via

```
remotes::install_github("StochasticTree/stochtree", ref="cloglog-bart-rc")
```

## Vignettes and Demos

Before incorporating this functionality into `stochtree`, we intend to develop a rich set of vignettes.
We have included demo scripts for the cloglog model on synthetic ordinal data with 2, 3 and 4 categories in the `tools` subfolder of this branch.
Loading