Predicting Customer Churn with PySpark for a Music Streaming Service

Udacity Data Scientist Nanodegree Capstone Project

Sophie Ferrlein
11 min readNov 12, 2019
Photo by Namroud Gorguis on Unsplash

Background and Objective

Business Need of the Music Streaming Service Sparkify

For businesses like the fictional music streaming service Sparkify, whose business model is (at least partly) based on subscription models, it’s important to make sure your customers stick with your service in the long term. Thus, preventing customer churn, e.g. by providing special incentives to users with churn risk or optimizing the product, plays a crucial role for the company’s success.

Solution Approach

Accordingly, the project presented in this post aims to predict customer churn for Sparkify based on their user tracking data. In more detail, the goal is to identify users with a high churn risk by finding patterns in the behaviour of already churned users.

For this project a small subset of the user tracking data has been provided to locally build a machine learning model using PySpark (the Python version of the Spark framework for large-scale data processing). This technological choice is motivated by the idea that we want to be able to scale the data processing and modeling on the much bigger original dataset on a distributed cluster in the cloud.

To follow the full process of this project, feel free to reference the corresponding Jupyter Notebook provided on GitHub.

Exploratory Data Analysis

Provided Dataset

The provided dataset is a small subset (125MB) of the full dataset (12GB) and contains logging information about the users activities on Sparkify. Our first step is to load the data into a Spark Dataframe.

# Load the dataset into a Spark Dataframe
events = spark.read.json("mini_sparkify_event_data.json")

The schema of the data tells us that entries can provide attributes like userId, timestamp (ts) and page. It allows us to examine the users activities (e.g. if the page is ‘Next Song’, the user just played the referenced song from the referenced artist) as well as some meta information like gender, location, registration (timestamp of the users registration) or level (‘paid’ or ‘free’ usage of the service).

root
|-- artist: string (nullable = true)
|-- auth: string (nullable = true)
|-- firstName: string (nullable = true)
|-- gender: string (nullable = true)
|-- itemInSession: long (nullable = true)
|-- lastName: string (nullable = true)
|-- length: double (nullable = true)
|-- level: string (nullable = true)
|-- location: string (nullable = true)
|-- method: string (nullable = true)
|-- page: string (nullable = true)
|-- registration: long (nullable = true)
|-- sessionId: long (nullable = true)
|-- song: string (nullable = true)
|-- status: long (nullable = true)
|-- ts: long (nullable = true)
|-- userAgent: string (nullable = true)
|-- userId: string (nullable = true)

In total, the dataset contains 286.500 log entries and 225 unique userIds.

Data Cleaning

In our project, we’re concerned with the usage of Sparkify on a user level. In consequence, we want to remove all entries that are not tied to a single user due to a missing userId.

We can identify 8.346 entries without userId which all describe activities from either logged out or guest users. In conclusion we can clean the data from missing userIds by filtering out all entries with the corresponding authentication status.

events = events.filter(“auth not in (‘Logged Out’, ‘Guest’)”)

Churn Definition

As you may have recognized, the dataset does not yet explicitly state whether or not the referenced userId belongs to a churned user or not. For us to be able to predict and explore the differences between churned vs.not churned users, we’re going to add this flag ourselves.

Looking at the page types provided, “Cancellation Confirmation” can be used as an indicator, that the user has just churned.

Page occurrences (except “NextSong”, because it would distort the view with a count > 200.000)

It could be considered, to define the “Submit Downgrade” page calls as churn event as well. Finally, this option has been discarded for the following reasons:

  • The users are not “gone” but still using the service, so they continue to produce events even after the potential churn event, which could be misleading for later models.
  • Some of the users also cancelled their service after they downgraded — ignoring all user events after they downgraded (to solve the above issue) would not be a good option either, as we would lose valuable information about users who actually cancelled.

Finally, we can add a boolean column ‘churn’ to our dataset and set it to 1 where we know that the user churned later on.

# Add column churn with 0 as default value
events = events.withColumn("churn", lit(0))
# Identify all userIds with churn event
churn_userIds = list(events.filter("page in ('Cancellation Confirmation')").select('userId').distinct().toPandas()['userId'])
# Set churn to 1 for all events of a churn user
events = events.withColumn("churn", when(col("userId").isin(churn_userIds), 1).otherwise(col("churn")))

Churned vs. Not Churned Users

The first thing we want to check is how many churned vs. not churned users exist in the data, as this might have a strong effect on our later model performance. We can see that only 20% of the users in the data actually churned and that they are connected to an even smaller share of the events due to a smaller average event count per user.

Looking at the usage behaviour and settings, even a shallow comparison of churned vs. not churned users reveales some interesting (albeit partly unsurprising) differences.

Within all events a user experienced (except playing songs), “Roll Advert” and “Thumbs Down” have a higher share for churned users than for non churned users. At the same time, “Thumbs Up” and “Add to Playlist” events take a smaller proportion in the log history of churned users.

How is the share of a certain event (page) in relation to all events of a user (again except “Next Song” to prevent distortion)?

Also the users metadata reveal some insights that might drive our model. Looking at the ‘userAgent’ information for example, we can see that the relation of churned and not churned users differs quite a lot for Chrome vs. Firefox users.

The operating system or device on the other hand do not seem to be as useful, because either the churn/ no churn ratio seems alike between the different manifestations or all observations fall mainly in just one category (like ‘device’: ‘Other’).

User Churn by User Agent Information

To follow the full exploration of the differences between churned and not churned users, please reference the original Notebook.

Feature Engineering

Based on our findings in the exploratory analysis, we’re now ready to create the features to train our model with.

In the end of our project, we want to be able to identify whether a certain user is about to churn. Thus, we need to aggregate the log data by userId to train the algorithm on a user basis.

Categorical Features

First we’re extracting the categorical features from the original dataset (after we have created the ‘browser’ column based on the ‘userAgent’ during the data exploration).

To group the categorical values, we use the max function. While this approach causes no issues with the ‘gender’ and ‘churn’ information (they only hold one distinct value per user anyway), we potentially throw away information about the users browser and level history (they could have used different browsers and changed level).

For this first attempt of modeling the data, we’re still going to continue with this for the sake of simplicity. Anyhow, we should keep this fuzziness in mind and consider other solutions in case the model performance or other factors suggest it.

cat_features = events.groupby('userId') \
.agg(
Fmax('gender').alias('gender'),
Fmax('level').alias('level'),
Fmax('browser').alias('browser'),
Fmax('churn').alias('churn')
)

Calculated Features

Next, we calculate several features based on the usage history. For example, we determine the average count of distinct artists a user has listened to per day:

# Calculate the average daily count of distinct artists listened to by user
artist_per_day = events.filter('page == "NextSong"') \
.withColumn("date", datetime_func(events.ts)) \
.groupby(['userId', 'date']) \
.agg(countDistinct('artist').alias('artists')) \
.groupby(['userId']) \
.agg(Fround(avg('artists')).alias('daily_artists'))

As we created a separate DataFrame for every calculated feature, we need to join them by ‘userId’ into a single Dataframe which we then join with the categorical feature to get our final features Dataframe.

# Join categorical and calculated features
feature_df = feature_df.join(
create_calculated_features(events),
on='userId',
how='inner'
)

Feature Transformation

Now we’re almost good to dive into the modeling process. We just have to apply some transformations to make our data digestible for PySpark classifiers. One reason is that other than sklearn classifiers, PySpark expects all features in just one vercorized column and the label in another vectorized column. Further more like in sklearn, we want to prepare our categorical and numerical features properly by using common methods like one hot encoding and scaling.

String Indexing

All categorical values as well as our label column (‘churn’) need to be indexed (using Sparks StringIndexer), so that the columns do no longer contain a string representing the category value but an index instead.

One Hot Encoding

As our categorical features are not ordinal, we also want to perform one hot encoding (using Sparks OneHotEncoder) which will create a new binary column (Dummy Variable) for every categorical value. Note: This time we don’t include our label column!

Vectorization

As the Spark Classifiers expect the features in the form of one single (vector) column, we now apply the VectorAssembler.

Scaling

Last but not least, we scale the feature values using Sparks StandardScaler. This makes sense when the range of numeric variables differs strongly. As we have already mentioned above, the “Next Song” event occurs way more often than other page events. Thus, the count variables will differ a lot in their range and scaling seems like a good fit.

Modeling

Splitting Data for Training and Testing

For developing our prediction model, we split the data into two subsets for training (70%) and testing (30%).

Best practice in machine learning suggests to also create a validation dataset, which we won’t for the following reasons:

The validation dataset is usually used during the model optimization process (when tuning hyperparameters). Only after all optimization iterations, the test dataset is used to finally evaluate the model.

As we’re going to use PySparks CrossValidator for the model optimization, this is not necessary for now. The CrossValidator will only get the training dataset as an input and split it up into several subsets for optimization itself. Thus, we will only need one further subset of the data (the test data) to evaluate the models performance.

The advantage of not creating a validation set is that the training set can be bigger (we can make a 70/30 split instead of a 60/20/20 split), which might be crucial for the final model performance, as we only work on a quite small dataset for now and are dealing with class imbalance.

Defining Performance Metrics

Given the class imbalance in our dataset, we don’t want to just look at the models accuracy (share of correct predictions). To figure out whether or not our model does a good job at identifying the small number of churned users, we can look at the recall (share of correctly identified churned users).

On the other hand, in case we identify too many users as potential churn users, the recall might be high but the precision (share of predicted churn users that are actually correct) would be low. Therefore we should finally look at the f1 score, which provides the harmonic mean of precision and recall.

As Sparks BinaryClassificationEvaluator does not provide ‘f1 score’ as a metric, we use the MulticlassClassificationEvaluator instead. We should keep in mind, that this evaluator calculates the weighted average precision and recall of both classes (‘churn’/’no churn’) rather than just returning the precision and recall for the target class (‘churn’).

Comparing Classifiers with Default Hyperparameter

We’re finally ready to train our first machine learning model. But which to choose? Well, we don’t have to decide right away. Instead, we can compare several classifiers that are widely used and are known as solid choices for the problem at hand:

  • LogisticRegression: A simple algorithm that can be used for binary as well as multi-class classification (more on datacamp).
  • RandomForestClassifier: An ensemble algorithm which combines several “weak learning” Decision Trees into one good performing forest (more on datacamp).
  • GBTClassifier: Also an ensemble algorithm which combines several Decision Trees in a way where each weak learner already knows and considers the errors from its predecessor (more on Medium).

After fitting those classifiers using default hyperparameters, we can see that the LogisticRegression model performs best with a quite good f1 score of 0.84. But also the f1 score of the RandomForestClassifier (0.74) looks quite ok for a first try without any tuning.

========================================
Fitting LogisticRegression
========================================
F1 Score: 0.8377792898504437
Recall: 0.8548387096774194
----------------------------------------
========================================
Fitting RandomForestClassifier
========================================
F1 Score: 0.7406095884873475
Recall: 0.7903225806451613
----------------------------------------
========================================
Fitting GBTClassifier
========================================
F1 Score: 0.6824405196321705
Recall: 0.6612903225806451
----------------------------------------

For the final decision, we can check whether we can improve the results even more, leaving the weakest performer (GBTClassifier) out of the equation.

Optimization by Hyperparameter Tuning

To decide between the LogisticRegression and the RandomForestClassifier, we’re going to unlock the potential of hyperparameter tuning (hopefully).

To be concrete: We will build a CrossValidation model to find the best hyperparameter settings by letting it train several model versions based on the provided hyperparameter grid.

Hyperparameter Tuning of LogisticRegressions

Let’s start with our top performing classifier and try to optimize its performance with a set of different hyperparameters.

# Use grid search to find the best hyperparameters
lr = LogisticRegression()
# Set different hyperparameter values
pg = ParamGridBuilder() \
.addGrid(lr.regParam, [1.0, 2.0]) \
.addGrid(lr.maxIter, [1, 5]) \
.build()
# Build cross validation classifier
lr_cv = CrossValidator(
estimator=lr,
estimatorParamMaps=pg,
evaluator=MulticlassClassificationEvaluator(metricName='f1')
)

After fitting the CrossValidation model we use its best model for predicting the test_data again. Annoyingly, it’s performance has gotten worse. Much worse.

F1 Score: 0.6537437111571471 
Recall: 0.7580645161290323

We’re coming back to this in a second. Let’s first look at the RandomForestClassifier results.

Hyperparameter Tuning of RandomForestClassifier

Using the same approach as before, we know try to improve the RandomForestClassifier performance.

# Use grid search to find the best hyperparameters
rf = RandomForestClassifier()
# Set different hyperparameter values
rf_pg = ParamGridBuilder() \
.addGrid(rf.minInfoGain, [0, 1]) \
.addGrid(rf.numTrees, [10, 20]) \
.addGrid(rf.maxDepth, [5, 10]) \
.build()
# Build cross validation classifier
rf_cv = CrossValidator(
estimator=rf,
estimatorParamMaps=rf_pg,
evaluator=MulticlassClassificationEvaluator(metricName='f1')
)

Unfortunately, the results are not really what we have hoped for either. Although the f1 score has not decreased, it has not improved as well but instead just stayed the same.

F1 Score: 0.7406095884873475
Recall: 0.7903225806451613

Hyperparameter Tuning Assessment

After applying cross validation to the two best performing classifiers from round one, the result in terms of which model to choose is inconclusive.

While the LogisticRegressions worked better with default settings, it actually produced worse results when testing different hyperparameter combinations, leaving the RandomForestClassifier with its unchange performance as the better performing model.

It can be assumed that this development is mainly caused by the small size of the data. When using CrossValidation, the training data will be split into several subsets used as input for training several models with different hyperparameter combinations. As the engineered dataset is already pretty small with only 225 entries, the subsets are likely too small to result in a solid performing classification model.

Thus, it is recommended to repeat the hyperparameter tuning on a larger dataset in the cloud before finally selecting a model for predicting user churn.

Conclusion

To predict user churn based on user tracking data, we can easily start working on a smaller subset of the data using Spark in our local environment.

Especially after aggregating the data into an even smaller featureset, the data might not be big enough to train and evaluate a final solution for the business though.

We should now ship our code to a spark cluster in the cloud and run it with the original dataset. There we can also consider (down)sampling the data, to address the class imbalance already during the modeling step (not only within the evaluation).

Once the model is finalized, we can verify its business impact by running an A/B tests (e.g. giving incentives to 50% of the potential churn users to see if the churn rate within this group stays lower). Last but not least, such a verification is also relevant to justify the effort and costs of constantly maintaining and improving the model.

--

--

Sophie Ferrlein

I am a data visualization strategist and developer who helps build effective data products, tell compelling stories, and engage audience.