Source code for PyBH.SurvivalAnalysis.pymc_models

from abc import ABC, abstractmethod

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm


class PyMCModel(ABC):
    """
    Abstract Base Class for Bayesian Survival models using PyMC.
    Designed to mimic the Lifelines API.
    """

    def __init__(self):
        self.model = None
        self.idata = None  # Stores the InferenceData after fitting
        self.duration_col = None
        self.event_col = None
        self._feature_names = None

    @abstractmethod
    def build_model(self, data, duration_col, event_col, coords=None, **kwargs):
        """
        Define the PyMC model structure (Priors and Likelihood).
        Must return a pm.Model() object.
        """
        pass

    def fit(
        self,
        data,
        duration_col,
        event_col,
        coords,
        draws=2000,
        tune=1000,
        chains=4,
        **kwargs,
    ):
        """
        Fit the model to the data using MCMC sampling.
        """
        self.duration_col = duration_col
        self.event_col = event_col

        # 1. Initialize the PyMC model
        self.model = self.build_model(data, duration_col, event_col, coords=coords)

        # 2. Run the MCMC sampler
        with self.model:
            self.idata = pm.sample(draws=draws, tune=tune, chains=chains, **kwargs)
        return self

    @abstractmethod
    def predict_survival_function(self, times, X_new):
        """
        Calculate the survival probability S(t) for given time points.
        Returns a posterior distribution of survival curves.
        """
        pass

    def print_summary(self):
        """
        Print statistical summary of the posterior distributions.
        """
        if self.idata is None:
            raise ValueError("Model must be fitted before calling summary().")
        return az.summary(self.idata)

    def plot_traces(self):
        """
        Plot MCMC trace diagnostics.
        """
        if self.idata is None:
            raise ValueError("Model must be fitted before plotting.")
        az.plot_trace(self.idata)
        plt.tight_layout()
        plt.show()

    def score(self, data, duration_col, event_col):
        """
        Calculates the Concordance Index (C-index).
        Higher is better (0.5 is random, 1.0 is perfect).
        """
        print("Method not yet implemented.")


[docs] class Cox(PyMCModel): r""" This class defines the Bayesian Cox Proportional Hazard model using the Poisson equivalence (Piecewise Exponential Model). It models the survival process as a set of Poisson distributions where the expected number of events :math:`\mu_{ij}` for patient i in interval j is: .. math:: \mu_{ij} = \Delta t_{ij} \cdot \lambda_j \cdot \exp(X_i \beta) where: - :math:`\Delta t_{ij}` is the time (exposure) patient i spent in interval j. - :math:`\lambda_j` is the baseline hazard for the interval j. - :math:`X_i` is the vector of covariates for patient i. - :math:`\beta` is the vector of coefficients (log-hazard ratios) associated \ with the covariates. Parameters ---------- cutpoints : list or np.array Ordered timepoints defining the intervals for the piecewise constant baseline hazard. Examples -------- >>> import pymc >>> import pandas >>> from PyBH.SurvivalAnalysis.SurvivalAnalysis import SurvivalAnalysis >>> from PyBH.SurvivalAnalysis.pymc_models import Cox >>> # Typical dataset for survival analysis >>> data = pandas.read_csv(pymc.get_data("mastectomy.csv")) # Define intervals: 0-10, 10-20, 20+ model = Cox(cutpoints=[10, 20]) # Launch analysis analysis = SurvivalAnalysis(model=model, data=data, time_col="time", event_col="event",) # Plot obtained survival function analysis.plot_survival_function() """ def __init__(self, cutpoints, priors=None): super().__init__() self.cutpoints = np.sort(np.unique(np.concatenate(([0], cutpoints)))) self.interval_bounds_ = np.concatenate((self.cutpoints, [np.inf])) self.priors = {"beta_sigma": 1.0, "lambda_alpha": 0.01, "lambda_beta": 0.01} if priors: self.priors.update(priors) self._feature_names = None def _transform_to_long_format(self, X, times, events): """ Converts survival data to long format for piecewise constant hazard modeling. Each subject is expanded into multiple rows, one for each time interval they entered, tracking their exposure time and whether the event occurred. """ n_samples = len(X) n_intervals = len(self.interval_bounds_) - 1 long_idx, long_exp, long_evt, long_X = [], [], [], [] for i in range(n_samples): # t_obs : Time of the event for i # e_obs : 0 if censored, 1 if event t_obs, e_obs = times[i], events[i] for j in range(n_intervals): # Extract j-th interval's delimitation t_start, t_end = self.interval_bounds_[j], self.interval_bounds_[j + 1] # If event occurred before beginning of time interval, Break if t_obs <= t_start: break # Time spent at risk within the interval exposure = min(t_obs, t_end) - t_start is_event = 1.0 if (t_obs <= t_end and e_obs == 1) else 0.0 long_idx.append(j) long_exp.append(exposure) long_evt.append(is_event) long_X.append(X[i]) return ( np.array(long_idx, dtype=int), np.array(long_exp, dtype=float), np.array(long_evt, dtype=float), np.array(long_X, dtype=float), ) def build_model(self, interval_indices, exposures, events, X_long, coords): """ Constructs the Bayesian Piecewise Exponential Model using PyMC. """ with pm.Model(coords=coords) as model: # Priors for the regression coefficients (log-hazard ratios) beta = pm.Normal( "beta", mu=0, sigma=self.priors["beta_sigma"], dims="coeffs" ) # Baseline hazard for each discrete time interval lambda0 = pm.Gamma( "lambda0", alpha=self.priors["lambda_alpha"], beta=self.priors["lambda_beta"], dims="intervals", ) # Compute log-risk for each observation log_risk = (X_long * beta[None, :]).sum(axis=-1) # Expected value for the Poisson likelihood: mu = exposures * lambda0[interval_indices] * pm.math.exp(log_risk) pm.Poisson("obs", mu=mu, observed=events) return model def fit( self, X, time, event, coords=None, draws=2000, tune=1000, chains=2, **kwargs ): """ Fits the Bayesian Piecewise Exponential Model to the provided survival data. """ # Define feature names for the model coordinates self._feature_names = coords.get( "coeffs", [f"v{i}" for i in range(X.shape[1] if hasattr(X, "shape") else len(X[0]))], ) # Convert from Wide (1 row/subject) to Long (N rows/subject) # This is required to model the survival process as a Poisson counting process idx, exp, evt, X_long = self._transform_to_long_format(X, time, event) # Define Model Dimensions model_coords = { "coeffs": self._feature_names, "intervals": [f"Int_{i}" for i in range(len(self.interval_bounds_) - 1)], } self.model = self.build_model(idx, exp, evt, X_long, model_coords) with self.model: self.idata = pm.sample( draws=draws, tune=tune, chains=chains, cores=1, **kwargs ) return self
[docs] def predict_survival_function(self, times, X_new): """ Predicts the survival function for new samples at given time points. Calculates S(t) = exp(-H(t)), where H(t) is the cumulative hazard. """ if self.idata is None: raise ValueError("Model not fitted.") # Extract posterior samples for baseline hazards and coefficients post = self.idata.posterior lambdas = post["lambda0"].stack(sample=("chain", "draw")).values.T betas = post["beta"].stack(sample=("chain", "draw")).values.T X_arr = X_new.values if hasattr(X_new, "values") else X_new # Calculate the relative risk scores for each posterior sample risk_scores = np.exp(np.dot(betas, X_arr.T)) # Compute cumulative baseline hazard by integrating the piecewise # constant hazard cum_h0 = np.zeros((betas.shape[0], len(times))) for t_idx, t in enumerate(times): for j in range(len(self.interval_bounds_) - 1): t_start, t_end = self.interval_bounds_[j], self.interval_bounds_[j + 1] # If the target time 't' is beyond the start of this interval if t > t_start: # Add hazard contribution: (rate * time_spent_in_interval) cum_h0[:, t_idx] += lambdas[:, j] * (min(t, t_end) - t_start) # Final survival probability return np.exp(-risk_scores[:, :, np.newaxis] * cum_h0[:, np.newaxis, :])
[docs] class Weibull(PyMCModel): """ Bayesian Weibull Survival Model implementation. Parameters: alpha (shape k), beta (scale eta). """ def build_model(self, data, duration_col, event_col, coords=None, **kwargs): # Data split: Censored (0) vs Observed (1) observed = duration_col[event_col == 1] censored = duration_col[event_col == 0] # Prior for beta (scale) based on average survival time mean_time = duration_col.mean() with pm.Model(coords=coords) as model: # --- Priors --- # alpha (k): shape parameter. Controls if risk is increasing (>1) # or decreasing (<1) alpha = pm.HalfNormal("alpha", sigma=2.0) # beta (eta): scale parameter. Characteristic time of failure. beta = pm.HalfNormal("beta", sigma=mean_time * 5) # --- Likelihood --- # 1. Observed events: PDF f(t) if len(observed) > 0: pm.Weibull("obs_likelihood", alpha=alpha, beta=beta, observed=observed) # 2. Censored events: Survival function S(t) # Log(S(t)) = -(t/beta)^alpha if len(censored) > 0: log_surv_censored = -((censored / beta) ** alpha) pm.Potential("cens_likelihood", log_surv_censored) return model
[docs] def predict_survival_function(self, times, X_new, credible_interval=0.95): """ Predict S(t) = exp(-(t/beta)^alpha). Returns a DataFrame with mean survival and uncertainty bounds. """ if self.idata is None: raise ValueError("Fit the model first.") # Extract posterior draws stacked = self.idata.posterior.stack(sample=("chain", "draw")) alpha_samples = stacked["alpha"].values beta_samples = stacked["beta"].values times = np.atleast_1d(times) # Compute survival curves: S(t) = exp(-(t/beta)^alpha) # Result shape: (num_samples, num_time_points) surv_curves = np.exp( -( (times[np.newaxis, :] / beta_samples[:, np.newaxis]) ** alpha_samples[:, np.newaxis] ) ) # Statistics mean_surv = np.mean(surv_curves, axis=0) lower_bound = (1 - credible_interval) / 2 upper_bound = 1 - lower_bound hdi = np.quantile(surv_curves, [lower_bound, upper_bound], axis=0) return pd.DataFrame( { "time": times, "mean_survival": mean_surv, f"lower_{credible_interval}": hdi[0], f"upper_{credible_interval}": hdi[1], } ).set_index("time")