Why you should consider using survival analysis for your churn prediction along with a practical implementation guide in Python.

Whether you’re in B2B or B2C, or really any business that deals with customers, you’ve likely asked yourself “What customers are most likely to churn?” The exact question you’re asking yourself might be slightly different, but it’s some form of that. 

No matter how much or little churn you experience, this is an important question to ask because you don’t have infinite resources, so it’s important to prioritize your focus, and considering churn risk is certainly one (important) factor to consider. 

In this post, we’re going to talk about why you should consider doing a churn analysis and prediction, and some of the various approaches to it with a focus and deep dive on survival analysis.

This post is also accompanied by a short walkthrough video.

Churn prediction and how to do it

What is churn prediction and why do it?

Whether you’re a B2C streaming subscription platform that costs $20 a month or a B2B SaaS company that signs $100k ARR contracts with customers, churn is an incredibly important factor to analyze. The cost of acquiring a new customers is always going to be much higher than retaining an existing one, and churn impacts growth and net dollar retention. And perhaps more importantly, churn is likely a sign of a product or services issue (or both), and is worth understanding and getting ahead of. It’s much easier to work with an existing customer to understand why they’re not getting value and righting the ship than it is to do post-mortem.

Great, so you’ve now decided that you want to start predicting churn, what now?

Different approaches to churn prediction

You can roughly take three approaches to predict churn:

  1. Simple heuristics (Customer health score = monthly dollar spend + feature engagement x …)
  2. Supervised machine learning models (random forest, logistic regression, gradient boosting machines)
  3. Survival analysis

Of course, using simple heuristics is an easy way to get started. The issue with that is that you’re naturally going to be introducing your own biases, and this type of model isn’t dynamic. And you also need to be able to account for all factors, which can be really hard to identify.

Supervised machine learning models like random forecast or regression analysis are very powerful tools. In this scenario, you would pull a list of customers, mark them as “churned” or “not churned” and ask the machine to predict “churned” or “not churned” based on all factors at your disposal. It will effectively do the same as the heuristics above, but in a much more robust and systematic way.

The last big approach is survival analysis. The idea of a survival analysis is simple. It asks the question: “As time goes on, what is the likelihood of an event happening?” This type of approach is particularly useful when it comes to churn prediction for two key reasons:

  1. It doesn’t simply focus on will a customer churn, but when will a customer churn. This is an important distinction, because given enough time, the reality is most customers will churn, so understanding what time span to focus on can make a big difference.
  2. It handles the situation of events that haven’t happened yet, very well. In the case of churn, if a customer hasn’t churned yet, that doesn’t mean they won’t churn tomorrow. So you don’t want to flag that customer as “not churned” as we did in the case of supervised learning, but you don’t want to throw away all this good data either.

So we’ve established that survival analysis is a good approach for churn prediction, let’s dive into a practical implementation.

Detailed walkthrough of a survival analysis for churn prediction

If your data lives in a data warehouse, SQL and Python are particularly well suited to do this kind of analysis. SQL allows you to extract your customer data while Python has all the necessary packages all ready to go - we’ve written about how to use the two in tandem [https://www.fabi.ai/blog/why-and-how-to-use-sql-and-python-in-tandem]. In the following example, we’ll build the prediction in Python.

Step 1: Prep the data

First, make sure you have all your data and key factors that you want to consider for your prediction. Some examples might be:

  • Contract value or monthly spend
  • Feature or user engagement
  • Industry or geography

You don’t need your data to be impeccable, but take the opportunity to make sure you’ve dealt with missing values as best you can and that formatting is correct.

Let’s go back to our fictional company Superdope, and see how they would pull their customer data:

select 
    customer_id,
    sign_up_date,
    -- Calculate time_to_event
    case 
        when churn_date is not null then datediff(churn_date, sign_up_date)
        else datediff(current_date, sign_up_date)
    end as time_to_event,
    -- Indicate if the customer has churned
    case 
        when churn_date is not null then 1
        else 0
    end as churned,
    customer_tier,
    average_monthly_spend,
    customer_segment,
    industry,
    feature_engagement
from 
    customers
where 
    -- Filter for customers who churned in the last 90 days or have not churned
    (churn_date is not null and churn_date >= date_sub(current_date, interval 90 day))
    or 
    churn_date is null

This query will pull a list of customers churned in the last 90 days or not yet churned and calculates the “time_to_event” as either the time from sign up to today for non-churned customers or the time between sign up and churn for churned customers.

Now that we have our data, let’s build a model.

Step 2: Build the model

Building a survival analysis model is as simple as importing libraries. If we use the Cox probability hazard as an example, it simply looks like:

from lifelines import CoxPHFitter
import matplotlib.pyplot as plt

# Fit the Cox model
cph = CoxPHFitter()
cph.fit(sample_data, duration_col='time_to_event', event_col='churned',
        formula="average_monthly_spend + customer_tier + customer_segment + industry + feature_engagement")

# Display the summary of the model
cph.print_summary()

You can get more sophisticated by tuning the formula to your own needs if you’d like.

From there you can also plot the effect of various factors on the likelihood of churn over time:

# Function to plot partial effects of covariates
def plot_partial_effects(covariate):
    unique_values = sample_data[covariate].unique()
    cph.plot_partial_effects_on_outcome(covariates=customer_tier, values=unique_values)
    plt.title(f'Partial Effects of customer_tier on Survival')
    plt.show()

# Example usage: Visualize the impact of 'customer_segment'
plot_partial_effects(covariate=segmentation)

That’s it! You’ve created your first survival analysis. Next up: using this model to identify which customers are at highest risk.

Step 3: Apply predictions to existing customers

Now that you’ve built a model to predict whether or not a customer might churn within a given time period, now you need to apply that model to existing customers so that you know where to focus your attention. Let’s go through it step by step:

from datetime import datetime

# Step 1: Filter customers with "time_to_event" in the future
future_customers = sample_data[~sample_data['churned']]

# Convert observation_end_date to a float representing the number of days from a reference date
observation_end_date = datetime(2024, 12, 31)
reference_date = datetime(2024, 1, 1)
observation_end_days = (observation_end_date - reference_date).days

# Step 2: Predict the survival function for these customers
survival_functions = cph.predict_survival_function(future_customers, times=[observation_end_days])

# Convert the survival functions to a DataFrame
survival_df = survival_functions.T
survival_df['customer_id'] = future_customers['customer_id'].values

Now that we have the survival function for each customer, you can either simply compute the probability of churning by a given date for each customer, or you can use some risk threshold to flag customers as “At risk of churning”. In this example we’ve opted for the latter with a churn probability of greater than 50% between now and the “observation_end_date”:

# Step 3: Identify customers with a high probability of churn
churn_threshold = 0.5
likely_to_churn = survival_df[survival_df[observation_end_days] < churn_threshold]

# Display the customer IDs likely to churn
likely_to_churn_ids = likely_to_churn['customer_id'].values
print("Customers likely to churn:", likely_to_churn_ids)

Now you have a shortlist of customers that are at highest risk of churning.

Another great aspect of this approach, in particular in contrast to a heuristics-based approach, is that you can easily update the model on the fly as you consider new factors and the model will update as various factors have more or less impact on customer churn over time. This will avoid getting stuck in an old paradigm as your business evolves.

Leverage survival analysis to predict when a customer might churn

Churn prediction is a critical component to managing your customer base and making sure that you’re focusing your resources in the right areas. At a high level you have three possible approaches:

  1. Heuristics-based scoring
  2. Supervised learning
  3. Survival analysis

Heuristics-based scoring can be a great approach to start off, but as we noted, the model can quickly fall out of sync with your business as various factors change and play more or less of an impact on customer churn over time. For example if you were missing a key feature for retail customers but addressed it, retail customers may suddenly become a lot less likely to churn, but the hand-built heuristics model wouldn’t pick up on that unless you explicitly make a change.

Supervised learning models such as logistic regressions, random forest, gradient boosting machines or support vector machines (SVMs) are powerful tools, but they require you to be able to clearly label customers as “Churned” or “Not Churned”. This can be difficult to do, because a customer that hasn’t churn today may churn tomorrow, so labeling them as “Not Churned” could negatively skew the model.

This leaves us with survival analysis, a way of measuring the probability of an event (churn) happening over a period of time. This approach allows you to keep data from customers that haven’t churned yet without skewing the model, and it also allows you to answer the question “when will a customer churn” instead of “will a customer churn”.

If you’re working on your own churn predictor and need any guidance on implementation, please feel free to reach out!

"I was able to get insights in 1/10th of the time it normally would have"

Don't take our word for it, give it a try!