Sign in

Sparkify churn prediction


Nowadays it’s not enough to be just good to survive on the market. In any area, any service you provide, there always is someone else competing with you for the customers.

Not only you have to fight for new customers, but keeping the ones you already have as well, having an opportunity to predict if customer has intentions to stop using your service may be a great help.

This is where machine learning algorithms come into play.

This blog post describes my way of solving such type of a problem as a part of Udacity Data Science Nanodegree project: predicting churn for music-streaming service.

Project definition

Imagine you have a music streaming service, letting users performs certain types of actions: listening to music, adding some songs to their playlists, giving a song thumbs up and so on. Service can be used either for free or in paid version.

User can stop using the service completely or downgrade from paid version to free. These two events are treated as “churn”, and if we are able to tell which users may potentially want to perform those, we can give them discount or free trial of paid version to make them more interested in using the product.


Although the problem will be addressed using a small subset of data, overall this is a big data problem. Dataset is loaded into Spark, analysed using Spark Dataframes and python API, machine learning models are build using Spark ML Library.


From machine learning point of view churn prediction is a binary classification task. Having users activity historical data , we can mark each user with binary label and build classification model: each user will have the true label and prediction given by model.


Resulting model will then be evaluated on test part of the dataset. It is reasonable to suppose that fewer users belong to “churn“ group than “non-churn”, so evaluating using accuracy is not enough. Recall seems to be extremely important here as it shows the amount of potentially churn users we were able to determine.

However low precision will make us try to “hold” users who are not going to stop using service, which is bad for budget. So summarizing metric — F1 score — will be the one to focus on during evaluation.

Data overview

Service usage data is stored in form of logs: each row corresponds to certain action user performs, providing some user or action related details.

Given is a dataset of 286500 rows and 18 columns. Data schema looks like this:

Such form of data will not be suitable for classification algorithm as there is huge amount of row for every user. After initial data analysis we will need to perform feature engineering and data transformation.

Quick look at the schema makes us suppose most of the features theoretically can be nulls, except userId and sessionId. Checking the data it turns out there are no rows having nulls for those two, however some have empty userId. Assuming we are going to analyse users behaviour, there is no point keeping this piece.

Some of the features don’t seem to be too informative. For instance, “auth” columns only has 2 possible values:

The “Cancelled“ one 100% correlates with “page” column value “Cancellation Confirmation” . On the other side, “page” column pretty much describes the type of action user takes.

This definitely is going to be one of the biggest sources for feature engineering.

“Gender” is almost the only feature characterizing the user him(her)self, for sure something to be used. From the other hand, so are “firstName” and “location”, but this can hardly affect user’s behaviour.

“sessionId” and “itemInSession” can be used to track the length of user sessions, which can be useful.


The crucial part of solving the problem stated is working with data. Dataset holding user-related information should be generated based on raw log-like data.

Starting with analysis of aggregated information related to two user groups, multiple features will be then generated based on analysis performed.

Churn definition

Now that we have some general knowledge about the dataset, it’s time to turn it into something more valuable for model building. Let me start with defining Churn users. Those would be the ones cancelling the service or downgrading to free tier.

First I mark all the events treated as churn, and then — all the events related to users who had churn event.

churn_event = udf(lambda x: 1 if x == ‘Cancellation Confirmation’ or x == ‘Submit Downgrade’ else 0, IntegerType())
sparkify_data = sparkify_data.withColumn(“churn”, churn_event(“page”))
windowval = Window.partitionBy(“userId”)
sparkify_data = sparkify_data.withColumn(“churn_user”, smax(col(“churn”)).over(windowval))

Overall we have 133 “churn” and 92 “non-churn” users. Dataset is quite close to balanced.

Groups analysis

Now I am going to analyse behaviour of users belonging to two different groups and hopefully find something I can further use as features.

Starting with simply counting number of each page attendance for every user in the dataset. In order to make this analysis more descriptive I will create a pandas dataset and build some graphs.

All NaNs have been replaced with 0. I’ve also included a columns identifying churn.

Looks like pretty significant difference can be observed in most of them. I will also build a scatter plot matrix using different colours for user groups using some of produced features.

Closeness to 0 for some actions makes it a bit hard to analyze. Nevertheless, this data is definitely going to be used.

We’ve got length for songs played, why not use those? Maybe users who listen more tend to churn less?

Not quite what I was expecting to be honest, on the average “churn” users tend to listen more. Still, difference seem to be pretty impressive.

Next we can extract is sessionIds. For each user I want to get average and maximum length of the session.

Here is the comparison of maximum session length distributions for two groups:

Definitely seem to differ. Not much, but it’s quite observable.

Dataset has timestamps, another important source of information. Data only covers October and November, and couple days in December. For example, here are distributions of average actions taken per day:

Again, weirdly, “churn” users tend to be more active. Same can be computed for concrete types of actions, like listening to song or adding a friend.

Another type of feature that can be calculated using timestamp is amount of active days per month. I will get those separately for October and November.

Difference is definitely meaningful.

Last thing I want to investigate is device user runs the application from. “userAgent” feature contains quite a lot of information including browser and OS. Closer look at all possible values of OS related information makes it possible to stick to short list of devices.

Distributions comparison shows the following:

Does not looks like distributions are significantly different for two groups of users.

Feature engineering

To sum up the preceding analysis, here is the list of features I am going to stick to:

  • Gender
  • Total counts for 12 types of pages (stored in pages list)
  • Total length of listened songs
  • Average session length
  • Maximum session length
  • Average actions performed daily
  • Average songs listened daily
  • Average adds to playlist daily
  • Active days in October
  • Active days in November

I will start with creating new dataframe by extracting most straightforward features: userId itself, gender and target variable — churn_user, will further calculate aggregated features as reviewed in the previous sections and join those one by one.

As some of the values may get not defined here (for instance, a certain user has never added a song to his playlist), all those need to be relaced with 0.

Row in resulting dataset will look like this:

Row(userId=’100010', gender=’F’, churn_user=0, add_friend=4, add_to_playlist=7, error=0, help=2, home=11, next_song=275, roll_advert=52, save_settings=0, settings=0, submit_upgrade=0, thumbs_down=5, thumbs_up=17, total_songs_length=66940.89735000003, avg_session_length=54.42857142857143, max_session_length=112, avg_actions_daily=54.42857142857143, avg_songs_daily=39.285714285714285, avg_adds_daily=1.4, active_days_october=4, active_days_november=3)

Next all the features which will be used for modelling need to be packed into a vector.

Some of the classification algorithms, for example, logistic regression, tend to perform much better if all the features are on the same scale. I will use SnandardScaler here: the whole bunch of numbers will be modified into a distribution having mean 0 and variance 1.

scaler = StandardScaler().setInputCol(“features”).setOutputCol(“scaled_features”)scaler_model = = scaler_model.transform(sparkify_df)

Finally, I am renaming “churn_user” columns to “label” and casting it to float.

Resulting dataset, completely ready for modelling, looks like this:

Row(userId='100010', gender='F', label=0.0, add_friend=4, add_to_playlist=7, error=0, help=2, home=11, next_song=275, roll_advert=52, save_settings=0, settings=0, submit_upgrade=0, thumbs_down=5, thumbs_up=17, total_songs_length=66940.89735000003, avg_session_length=54.42857142857143, max_session_length=112, avg_actions_daily=54.42857142857143, avg_songs_daily=39.285714285714285, avg_adds_daily=1.4, active_days_october=4, active_days_november=3, gender_int=1.0, features=DenseVector([1.0, 4.0, 7.0, 0.0, 2.0, 11.0, 275.0, 52.0, 0.0, 0.0, 0.0, 5.0, 17.0, 66940.8974, 54.4286, 112.0, 54.4286, 39.2857, 1.4, 4.0, 3.0]), scaled_features=DenseVector([2.0013, 0.1943, 0.214, 0.0, 0.2761, 0.2326, 0.2489, 2.413, 0.0, 0.0, 0.0, 0.3823, 0.2596, 0.2431, 1.1092, 0.5558, 1.4756, 1.2249, 1.09, 0.6639, 0.4996]))


Among several classification algorithms supported by Spark I am going to try out three, ones of the most popular overall: Logistic Regression, Random Forest and Gradient Boosting.

Starting with splitting data into 3 parts: train, test and validation. While test part is going to be basis for determining best model, validation will be the one showing final performance of the best selected one.

Assuming somewhat class disbalance, best metric to base on seems to be F1 score. Still, I will also calculate accuracy as disbalance is not so serious.

For Logistic Regression parameters tuning I will use the following ranges:

paramGrid = ParamGridBuilder() \
.addGrid(lr.maxIter, [100, 50, 10]) \
.addGrid(lr.regParam,[0.0, 0.1, 0.2]) \
.addGrid(lr.elasticNetParam,[0.8, 0.9]) \

Best model found shows performance:

Precision: 0.7142857142857143, Recall: 0.625, F1-score: 0.6666666666666666, Accuracy: 0.7674418604651163

Best model parameters are 10 iterations only with no regularization. That’s explainable: dataset is fairly small.

For Random Forest model I will tune the following parameters:

rf = RandomForestClassifier(featuresCol="features", labelCol="label", numTrees=10, maxDepth=7)
paramGrid = ParamGridBuilder() \
.addGrid(rf.numTrees, [50, 100]) \
.addGrid(rf.maxDepth,[5, 7]) \

Best model shows performance:

Precision: 0.9090909090909091, Recall: 0.625, F1-score: 0.7407407407407406, Accuracy: 0.8372093023255814

Much better here, F1 scrore is 0.74. Final model has 100 trees with depth 5. Recall is also higher, minding the nature of our problem, we should note recall is especially important here: it shows how good we are able to find potentially churn users.

Next one is Gradient Boosting. It takes much londer time for training, so here I will tune parameters manually rather than using grid search.

Tuned model:

classifier_gb = GBTClassifier(featuresCol=”features”, labelCol=”label”, maxIter=50, maxDepth=5, subsamplingRate=1.0)Precision: 0.5882352941176471, Recall: 0.625, F1-score: 0.6060606060606061, Accuracy: 0.6976744186046512

Random Forest definitely beats it.

Going back to slight class disbalance, let me try to introduce class weights: rows belonging to class 1 get higher weigh as the class is smaller. I will train logistic regression with the weighted dataset.

pos_w = sparkify_df.filter(sparkify_df.label == 1.0).count()/sparkify_df.count()get_weights = udf(lambda x: pos_w if x == 0 else (1-pos_w), DoubleType())sparkify_df_weighted = sparkify_df.withColumn(“class_weight”, get_weights(“label”))classifier_lr = LogisticRegression(featuresCol=”scaled_features”, labelCol=”label”, weightCol=”class_weight”, maxIter=100, regParam=0.0, elasticNetParam=0.8)Precision: 0.5714285714285714, Recall: 0.75, F1-score: 0.6486486486486486, Accuracy: 0.6976744186046512

Summary of metrics for all 3 models:

Here are confustion matrices for all the models. Logistic Regression:

                label pos	label neg
pred pos 10 4
pred neg 6 23

Random Forest:

                label pos	label neg
pred pos 10 1
pred neg 6 26

Gradient Boosting:

                label pos	label neg
pred pos 10 7
pred neg 6 20


Based on performed test, the winning model is Random Forest with 100 trees of maximum depth 5.

Final performance is evaluated on validation dataset:

Precision: 0.9230769230769231, Recall: 0.6666666666666666, F1-score: 0.7741935483870968, Accuracy: 0.825

Confusion matrix:

                label pos	label neg
pred pos 12 1
pred neg 6 21

Further steps

Further investigation should mainly be focused on building more useful features and using statistical analysis for determining their significance.

Transferring the performed feature engineering and model building onto larger dataset should by itself increase the performance due to greater amount of data to be trained on. Most likely regularization will need to be introduced, while with small data subset it does not make any use.

Also, with dataset growth, class disbalance will grow as well, this needs to be addressed with class weights.


During work on this project the following main stages were performed:

  • Initial data analysis
  • Details analysis of users behaviour in two groups
  • Feature engineering based on aggregated users data, inclusing total and average count of certain actions perfomed and date-based information.
  • Three types of machine learning algorythms were trained on data and evaluated.
  • Best model — Random Forest has been chosen as final model.

It’s been a great experience working with quite interesting type of datadet involving big work on feature engineering, also using Spark and its ML library.

Can I say the results are super impressive? No. There definitely is more room for investigation and optimization, but isn't that great? Dataset only consists of a bit more than 200 rows, in real life this would be thousands, even millions of rows. Even with a small dataset however it is possible to build an adequately performing model.