This project focuses on predicting customer churn using the Telco Customer Churn dataset. Comprehensive exploratory data analysis (EDA) was conducted, leveraging various statistical and visualization techniques. These included histograms to assess feature distributions, heatmaps to understand correlations, boxplots to detect outliers, and pivot tables to summarize churn patterns across categorical variables. Key customer behaviors and subscription patterns were explored to uncover significant trends and drivers of churn.
The project further emphasizes the implementation of federated learning as a privacy-preserving alternative to centralized models. Federated learning allows models to be trained across decentralized datasets without transferring sensitive customer information to a central repository. This not only ensures data privacy but also demonstrates comparable performance to traditional centralized approaches. By combining local computation with global model aggregation, federated learning highlights its potential for use cases where data-sharing restrictions or regulations apply, such as telecom, healthcare, and finance.
Key comparisons between centralized and federated setups provide valuable insights into balancing performance and privacy, establishing federated learning as a cutting-edge solution for churn prediction and beyond.
- Project Overview
- Project Objectives
- Methodology
- Findings and Recommendations
- Federated Learning Setup
- Methodology for Federated Learning
- Results and Key Findings
- Next Steps
- Repository Structure
- How to Run the Project
- Conclusion
This project explores customer churn prediction through various machine learning models, including Logistic Regression, Random Forest, Support Vector Machine (SVM), K-Nearest Neighbors (KNN), and Decision Tree. These models were evaluated and optimized to achieve the best predictive performance for churn prediction. The project also highlights the benefits of federated learning, which preserves data privacy while achieving effective performance.
- Perform exploratory data analysis (EDA) to understand key drivers of churn.
- Build and optimize the models using centralized data.
- Implement a federated learning framework with two clients and a central server.
- Compare federated learning outcomes with centralized model performance.
- Handled missing values and irrelevant features (e.g., dropped
customerID
). - Encoded categorical features and scaled numerical features for model training.
- Used descriptive statistics and visualizations (e.g., histograms, boxplots, heatmaps) to explore relationships between features and churn.
- Examined correlations between numerical variables and churn.
- Established a baseline model.
- Trained and tuned various models, including Logistic Regression, Random Forest, SVM, and others.
- Evaluated models using metrics like accuracy, precision, recall, F1-score, and ROC-AUC.
- Key Drivers of Churn:
- Customers with month-to-month contracts and electronic check payments are more likely to churn.
- Short tenure and higher monthly charges are associated with churn.
- Model Performance:
- Logistic Regression had the highest ROC-AUC (0.84), showcasing strong class discrimination.
- Random Forest achieved the best accuracy (0.79), providing balanced predictions.
- Customer Retention Strategies:
- Target customers on month-to-month contracts with incentives to switch to long-term contracts.
- Identify customers using electronic check payments and encourage them to switch to automated payment methods.
- Focus on High-Churn Risk Groups:
- Implement loyalty programs or discounts for customers with higher monthly charges.
- Offer personalized retention campaigns for customers with shorter tenures.
- Data Partitioning:
- The dataset was divided into two subsets, each representing a unique client.
- Server Configuration:
- A central server was set up to coordinate learning and aggregate model updates.
- Federated Framework:
- Each client trains a local Logistic Regression model and sends its coefficients to the server.
- The server aggregates coefficients through weighted averaging and redistributes the updated model to clients.
- Clients train models locally on their respective datasets and share coefficients with the server.
- The server computes the weighted average of the coefficients and updates the global model.
- This process is repeated for 10 rounds, with metrics monitored at each step.
- The federated model's performance was compared against the centralized baseline.
- The global model generated by the federated learning process is saved as
federated_learning/server/models/global_model.pkl
. - A dedicated notebook,
federated_learning/server/notebook/testing_global_model.ipynb
, was created to validate the global model using test data.
- The federated model improved over 10 rounds, achieving stable metrics by round 7.
- The federated model's performance was slightly below the centralized model:
- Precision: Federated: 0.82 | Centralized: 0.84
- Recall: Federated: 0.79 | Centralized: 0.81
- F1-Score: Federated: 0.80 | Centralized: 0.82
- ROC-AUC: Federated: 0.84 | Centralized: 0.86
- Federated learning preserved privacy while maintaining strong performance.
- Data partitioning introduced variability, slightly impacting metrics.
- Weighted aggregation effectively balanced contributions from clients.
- Hyperparameter Tuning:
- Further optimize the best-performing models to improve predictive performance.
- Feature Engineering:
- Explore interaction terms and non-linear transformations for better model accuracy.
- Decentralized Hyperparameter Optimization:
- Implement decentralized hyperparameter tuning across clients to enhance model performance while maintaining privacy.
- Advanced Aggregation Techniques:
- Explore more sophisticated methods to better handle imbalanced or non-iid data.
- Scaling Federated Learning:
- Expand the setup to include additional clients and study its impact on performance and convergence.
- Incorporating Deep Neural Networks (DNNs):
- Train and evaluate Deep Neural Network models to explore their potential for churn prediction in a federated learning setup.
- Use architectures such as fully connected feedforward networks with multiple hidden layers to capture complex relationships in the data.
- Implement techniques like dropout and batch normalization to improve model generalization and convergence in federated settings.
- Leverage frameworks like TensorFlow Federated (TFF) or PyTorch with federated extensions to support the training of deep learning models.
- Evaluate DNN performance on centralized and federated setups to compare results with traditional machine learning models.
- data/: Contains the Telco Customer Churn dataset.
- models/: Stores trained models and result files (e.g.,
Random Forest.joblib
). - notebook/: Includes the Jupyter Notebook with all EDA and model preparation steps.
README
: Project documentation.requirements.txt
: List of required Python libraries.
- client/:
data_splitter.py
: Splits data into client-specific subsets.federated_client_sklearn_lr.py
: Script to run the federated client with Logistic Regression.- data/: Contains
client_0_data.csv
andclient_1_data.csv
. - metadata/: Contains
encoder.pkl
andscaler.pkl
for preprocessing.
- server/:
federated_server.py
: Runs the federated learning server.preprocess_server.py
: Preprocesses data and generates metadata.- metrics/: Contains
metrics.json
, which logs the loss, accuracy, precision, recall, F1-score, and other relevant metrics for each round of federated learning. This file is essential for tracking the progress and convergence of the global model across training rounds. - models/: Stores
global_model.pkl
, the serialized global model generated after federated learning. This file contains the trained Logistic Regression model coefficients aggregated across all federated learning rounds. The global model can be used for further validation or deployment on unseen data to evaluate its generalization. - notebook/: Contains
testing_global_model.ipynb
, a Jupyter Notebook used to validate the global model's performance. This notebook loadsglobal_model.pkl
and applies it to a validation dataset to assess accuracy, precision, recall, F1-score, and other key metrics. It ensures that the federated training process was successful and the model is ready for practical applications.
eda_and_preparation.ipynb
: Notebook for EDA and model preparation.
.gitignore
: Specifies files to ignore in version control..python-version
: Python version used in the project.README.md
: Project documentation.requirements.txt
: Required Python libraries.
- Clone the repository:
git clone https://github.com/stirelli/federated-churn-ai.git
- Install the necessary packages:
pip install -r requirements.txt
- Run the Jupyter Notebook:
jupyter notebook ./notebook/eda_and_preparation.ipynb
- Divide the dataset into client-specific datasets:
python federated_learning/client/data_splitter.py
- Generate server preprocessing metadata:
python federated_learning/server/preprocess_server.py
- Start the Federated Learning server:
python federated_learning/server/federated_server.py
- Connect the clients to the server:
python federated_learning/client/federated_client_sklearn_lr.py --client_id 0 python federated_learning/client/federated_client_sklearn_lr.py --client_id 1
This project underscores the critical role of data-driven decision-making in improving customer retention. By identifying key drivers of churn and utilizing predictive models, the telecom company can proactively mitigate churn risks, enhance customer satisfaction, and boost revenue retention.
Additionally, the project highlights the potential of federated learning as a privacy-preserving approach to machine learning. Federated models demonstrated performance comparable to centralized methods, providing a practical solution for decentralized data environments.
Future improvements, such as decentralized hyperparameter optimization and advanced aggregation strategies, have the potential to enhance the scalability and performance of the framework further.