{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Example Notebook : Weibull\n", "In this notebook, we'll use the **mastectomy dataset**, which contains information about breast cancer patients and their survival times. We will use the **Bayesian Weibull Model** (implemented in `PyBH`) to analyze survival data and compare patients with and without metastasis." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "import sys\n", "import os\n", "import pandas as pd\n", "\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import pymc as pm\n", "import arviz as az\n", "\n", "# Add parent directory to path to import pymc_models\n", "\n", "sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '../../../..')))\n", "\n", "from PyBH.SurvivalAnalysis.pymc_models import Weibull\n", "from PyBH.SurvivalAnalysis.SurvivalAnalysis import SurvivalAnalysis" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Load Data\n", "\n", "We load the `mastectomy` dataset using PyMC's data utility." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dataset loaded! Shape: (44, 3)\n", " time event metastasized\n", "0 23 True no\n", "1 47 True no\n", "2 69 True no\n", "3 70 False no\n", "4 100 False no\n" ] } ], "source": [ "try:\n", " # Try loading via pymc\n", " data = pd.read_csv(pm.get_data(\"mastectomy.csv\"))\n", "except:\n", " # Fallback to hardcoded url or local file if pm.get_data fails (for offline safety)\n", " print(\"Could not load via pm.get_data, checking local or alternative source...\")\n", " # Assuming it works as per user request, otherwise we might need a backup source\n", " # data = pd.read_csv(\"https://raw.githubusercontent.com/pymc-devs/pymc-examples/main/examples/data/mastectomy.csv\")\n", " raise\n", "\n", "print(f\"Dataset loaded! Shape: {data.shape}\")\n", "print(data.head())" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Initializing NUTS using jitter+adapt_diag...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " -> Mode: Bayesian (PyMC)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Multiprocess sampling (2 chains in 2 jobs)\n", "NUTS: [alpha, beta]\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c7e9b6b8566949bba2da1259574fd3fc", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n" ], "text/plain": [] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 1 seconds.\n", "We recommend running at least 4 chains for robust computation of convergence diagnostics\n" ] } ], "source": [ "wbll = Weibull()\n", "\n", "# Use PyBH's SurvivalAnalysis class to fit the model\n", "# This automatically handles validation and preprocessing\n", "survival_analysis = SurvivalAnalysis(\n", " model=wbll,\n", " data=data,\n", " time_col='time',\n", " event_col='event',\n", " draws=1000, tune=1000, chains=2\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Preprocessing\n", "\n", "The dataset has:\n", "- `time`: survival time\n", "- `event`: 1 if event observed (death), 0 if censored\n", "- `metastasized`: 'yes' or 'no'" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "metastasized\n", "yes 32\n", "no 12\n", "Name: count, dtype: int64\n" ] } ], "source": [ "# Ensure event column is numeric\n", "data['event'] = data['event'].astype(int)\n", "print(data['metastasized'].value_counts())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Fit Weibull Model (All Patients)\n", "\n", "First, we fit the model on the entire population to get a global survival curve." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "| \n", " | mean | \n", "sd | \n", "hdi_3% | \n", "hdi_97% | \n", "mcse_mean | \n", "mcse_sd | \n", "ess_bulk | \n", "ess_tail | \n", "r_hat | \n", "
|---|---|---|---|---|---|---|---|---|---|
| alpha | \n", "0.884 | \n", "0.153 | \n", "0.573 | \n", "1.159 | \n", "0.004 | \n", "0.004 | \n", "1156.0 | \n", "959.0 | \n", "1.0 | \n", "
| beta | \n", "192.051 | \n", "54.335 | \n", "115.229 | \n", "298.528 | \n", "1.708 | \n", "2.451 | \n", "1343.0 | \n", "974.0 | \n", "1.0 | \n", "