Example Notebook : Cox

This notebook demonstrates the bayesian Cox model using SurvivalAnalysis workflow .

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import arviz as az
import pymc as pm
import sys
import os


project_root = os.path.abspath(os.path.join(os.getcwd(), "../../../.."))


if project_root not in sys.path:
    sys.path.append(project_root)


from PyBH.SurvivalAnalysis.SurvivalAnalysis import SurvivalAnalysis
from PyBH.SurvivalAnalysis.pymc_models import Cox

1. Load Data

We use the Mastectomy dataset from the HSAUR R package. It describes survival times of women with breast cancer, comparing those who received a mastectomy only vs. those who received mastectomy + radiotherapy.

url = "https://vincentarelbundock.github.io/Rdatasets/csv/HSAUR/mastectomy.csv"

df_raw = pd.read_csv(url, index_col=0)

# Verify the cleanup
print("Columns:", df_raw.columns.tolist())
print(df_raw.head())
Columns: ['time', 'event', 'metastized']
          time  event metastized
rownames                        
1           23   True         no
2           47   True         no
3           69   True         no
4           70  False         no
5          100  False         no

2. Define Cutpoints

The Bayesian Cox model uses a piecewise constant baseline hazard. We need to define the time intervals (cutpoints). A good strategy is to use percentiles of the observed event times to ensure enough data in each interval.

n_intervals = 5
observed_times = df_raw['time'].values
cutpoints = np.percentile(observed_times[df_raw['event'] == 1], 
                          np.linspace(0, 100, n_intervals + 1)[1:-1])

print(f"Cutpoints: {cutpoints}")
Cutpoints: [23. 35. 50. 71.]

3. Initialize Model and Workflow

We create an instance of our PyMC Cox model and pass it to the SurvivalAnalysis manager.

cox_model = Cox(cutpoints=cutpoints, priors={"beta_sigma": 1.0})

analysis = SurvivalAnalysis(
    model=cox_model,
    data=df_raw,
    time_col="time",
    event_col="event",
    draws=2000,  
    tune=1000,
    chains=2,
    target_accept=0.9
)
Initializing NUTS using jitter+adapt_diag...
   -> Mode: Bayesian (PyMC)
Sequential sampling (2 chains in 1 job)
NUTS: [beta, lambda0]

Sampling 2 chains for 1_000 tune and 2_000 draw iterations (2_000 + 4_000 draws total) took 3 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics

4. Diagnostics

The fit() method has already run (inside the constructor of SurvivalAnalysis). We can now check the Bayesian traces.

analysis.model.plot_traces()
print(analysis.model.print_summary())
../../_images/30c66c7b721a8235b510667b491b545346b5e5d9ef2d2a4ee4ff57e2ef2c5ed2.png
                       mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  \
beta[metastized_yes]  0.692  0.438  -0.092    1.528      0.013    0.008   
lambda0[Int_0]        0.004  0.002   0.001    0.008      0.000    0.000   
lambda0[Int_1]        0.007  0.004   0.001    0.015      0.000    0.000   
lambda0[Int_2]        0.007  0.004   0.001    0.014      0.000    0.000   
lambda0[Int_3]        0.006  0.003   0.001    0.011      0.000    0.000   
lambda0[Int_4]        0.002  0.001   0.000    0.004      0.000    0.000   

                      ess_bulk  ess_tail  r_hat  
beta[metastized_yes]    1166.0    1796.0    1.0  
lambda0[Int_0]          1718.0    2330.0    1.0  
lambda0[Int_1]          1890.0    2660.0    1.0  
lambda0[Int_2]          1728.0    2390.0    1.0  
lambda0[Int_3]          1777.0    2361.0    1.0  
lambda0[Int_4]          1823.0    2427.0    1.0  

5. Survival Curves

Let’s plot the survival curves. Since SurvivalAnalysis.plot_survival_function currently plots a forest plot for coefficients, we can manually generate the survival curves using the predict_survival_function of our Cox class.

eval_times = np.linspace(0, df_raw.time.max(), 50)

# Create fake patients to compare: 
# Patient A: Metastasized = No (0)
# Patient B: Metastasized = Yes (1)
# Note: After one-hot encoding, check the column names in analysis.model._feature_names
print("Features used:", analysis.model._feature_names)

# Construct specific X matrix for prediction
# Assuming 'metastasized_yes' is the second column (index 1) after encoding
X_pred = np.array([
    [0], # Patient A (No metastasis)
    [1]  # Patient B (Metastasis)
])

surv_probs = analysis.model.predict_survival_function(eval_times, X_pred)

plt.figure(figsize=(10, 6))

# Plot Mean and 95% HDI for Patient A
mu_A = surv_probs[:, 0, :].mean(axis=0)
hdi_A = az.hdi(surv_probs[:, 0, :], hdi_prob=0.95)
plt.plot(eval_times, mu_A, label="No Metastasis")
plt.fill_between(eval_times, hdi_A[:,0], hdi_A[:,1], alpha=0.3)

# Plot Mean and 95% HDI for Patient B
mu_B = surv_probs[:, 1, :].mean(axis=0)
hdi_B = az.hdi(surv_probs[:, 1, :], hdi_prob=0.95)
plt.plot(eval_times, mu_B, label="Metastasis")
plt.fill_between(eval_times, hdi_B[:,0], hdi_B[:,1], alpha=0.3)

plt.title("Bayesian Survival Curves: Mastectomy Data")
plt.xlabel("Months")
plt.ylabel("Survival Probability")
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
Features used: ['metastized_yes']
/var/folders/sy/c1_rq0_x6dgdg_c_vymm1lnw0000gn/T/ipykernel_29870/2631143360.py:22: FutureWarning: hdi currently interprets 2d data as (draw, shape) but this will change in a future release to (chain, draw) for coherence with other functions
  hdi_A = az.hdi(surv_probs[:, 0, :], hdi_prob=0.95)
/var/folders/sy/c1_rq0_x6dgdg_c_vymm1lnw0000gn/T/ipykernel_29870/2631143360.py:28: FutureWarning: hdi currently interprets 2d data as (draw, shape) but this will change in a future release to (chain, draw) for coherence with other functions
  hdi_B = az.hdi(surv_probs[:, 1, :], hdi_prob=0.95)
../../_images/bc68f9ef6ee826f43f284e962fc71ec3a3e9be94aa7415dbc69e884405df59c4.png