Andrzej Prałat
Andrzej Prałat

How to create a web application firewall using machine learning - part I

This series of articles is for anyone who would like to explore the idea of using a web application firewall that is not dependent on a set of manually defined rules. We will show you how to train a machine learning model to automatically detect malicious web traffic. In the first part of this series we will describe how to explore our dataset, create the simplest predictive model and properly measure its effectiveness. This model will be created in Python with the use of popular libraries like pandas, scikit-learn or XGBoost. Basic Python knowledge is required.

What is a Web Application Firewall?

A web application firewall (or WAF) is a network security system that examines web traffic between a client and web application in order to find and block suspicious activity. By applying a set of rules created and tuned by experts it can block specific HTTP requests and prevent attacks such as SQL Injection, cross-site scripting (XSS) or local file inclusion (LFI).

Why Should we use Machine Learning?

Machine Learning often works great for problems that are usually solved by creating a long lists of decision rules. Such complex lists require a lot of maintenance. In case of WAF the rules must be updated and tuned on a regular basis to avoid false detections and to detect all new types of attacks.

Machine learning algorithm can be easily automated to constantly learn on the newest data without human intervention. Moreover, predictive models can be more robust to attack obfuscation techniques since they are great in detecting similarities between patterns. Finally, thanks to ability to discover unknown patterns in the data machine learning algorithms can help humans to learn about new types of attacks.

Creating a Predictive Model

The full code presented below is available online in Jupyter notebook at


We are going to create a binary classification model that will mark each request as normal or malicious. To train such model we need to have a dataset of labeled HTTP requests. We are going to use our own dataset of 5 million records. Some of these records were labeled by our security experts as malicious. To build your own dataset you can use HTTP logs of your web application. Malicious requests can be obtained for example with the use of popular security scanners that will test your web app for many different vulnerabilities and generate appropriate requests.

Explore the Data

Before creating a model it’s important to explore the data to better grasp the problem we are trying to solve. Let’s read the data from CSV file using pandas library. The data will be loaded into pandas DataFrame structure. The method info() of DataFrame can give us basic information about the data:

As previously mentioned we have 5 mln records. Each of these records has 6 columns: 5 attributes (uri, is_static, http_version, has_referer, method) and a label - the value that represents whether record is malicious or not. Let’s check what is the distribution of labels in the whole dataset:

It looks like our dataset is highly imbalanced - only around 0.1% of examples have value of label equal to 1. This of course was expected, since only a very small minority of requests are malicious. Now let’s look at some of the requests. The method head() will print the first rows of the DataFrame.

Our dataset contains a simple set of attributes:

uri - in our case the path component of the URLis_static  - precomputed boolean value; equals to True when path component points to static filehttp_version - HTTP version of requesthas_referer - precomputed boolean value; equals to False when referer header was not setmethod - HTTP method of requestSince in the picture above we can see only the examples of correct requests (each of these requests has label value equal to 0) let’s take a look at examples from malicious class:

We can see attempts of different types of attacks (e.g. SQL Injection or path traversal). It looks like malicious requests from this picture can be recognized by analysis of the uri. Let’s check if other attributes can also help our classifier to make the correct decision. We are going to write a helper function that will plot distribution of selected attribute per each class using matplotlib and seaborn libraries:

Now for each attribute we can simply call this function. Let’s analyze plot for http_version (the rest of the plots can be found in our jupyter notebook):

Based on this plot, we can see that majority of requests in both classes are using HTTP version 1.1. However, it seems that some of the malicious requests were made using the old 1.0 version of HTTP that almost does not occur in correct requests. Similarly, version 2.0 has some share in class 0, and almost none examples in class 1. It looks like HTTP version might actually be a useful feature for our classifier.

Prepare dev and test set

There is one more thing we need to do before creating our model. It’s necessary to have a way to verify how accurate the model is. To do this we will split our dataset into training and testing part. This way we can train our model on the training set and later check it’s accuracy on the test set. Actually, we will also create one more set - development set. It will be used to measure the accuracy of different algorithms and will help us to make the final decision which algorithm to use.

To sum up, we will use training set to train the models, test different models on the development set and when we find one that performs best we will test it against test set to measure its real performance (this process may seem unnecessary, but is important to avoid overfitting of the model). To split the data we will use scikit-learn’s train_test_split() function:

This function will randomly split our data into subsets. There are few things in this code that require explanation:

Dataset is here divided into attributes and labels. We must be careful not to include our label into the feature set that will be used by the algorithm to make a decision. The test_size parameter is equal to 0.2, which means 80% of our data will be in the training set and 20% in test set. 

We use stratify parameter, to make sure each subset has the same proportions of labels (we want training and test set to come from the same distribution). random_state parameter is set to 0 - the split is random, but we want to make sure that each time we call this function the split will stay the same (to ensure that our results are reproducible).Now we can extract development set from the training set and check how many examples are in each set:

Now we can extract development set from the training set and check how many examples are in each set:

Train the model

Finally! We have training and development set and now we are ready to create our first model. To train the model we will need to create features that will describe each request and will be used by the algorithm to make a decision. This is called feature engineering and it’s usually the most important step during the whole process of making a model. The effectiveness of the model is highly dependent on the quality of the features we create.

Our first model will use a very simple set of features. For each request we will count how many times each character occurs in it’s uri. This will be the only information that the classifier will receive about each request. To transform the data to such form we will use scikit-learn’s CountVectorizer:

CountVectorizer can be used to count both words and characters, so in our case we will set the param analyzer to 'char'. We will also set min_df to 10, which means that we want to get rid of all characters that did not occur at least 10 times in the whole dataset (we assume that such rare characters are not important for us). Next, we will call fit_transform() method on the list of uris from the training set.

This method first trains CountVectorizer on our training set (to get the list of all characters that will be used to calculate features) and then transforms the uris to numeric vectors that can be used by the model (each position of such vector represents the number of occurrences of specific character). We also need to transform the development set so our model can make predictions for this set. As we can see the number of unique characters that were found in the training set is equal to 74 and this will be the length of our feature vector.

At last, we can train our model. We are going to use SGDClassifier (which implements linear models with stochastic gradient descent training) from scikit-learn library. Scikit-learn provides a lot of already implemented machine learning algorithms. SGDClassifier is one of the simplest and fastest models, so we will be able to quickly get some results.

Amazing! Our simple classifier made the correct decision for over 99.8% of requests in the dataset using only information about the number of each character occurrences in the uri. How is that possible? Actually, accuracy (calculated as fraction of correct predictions for our development set) is not a great metric to use with very imbalanced datasets. To illustrate this we will compute accuracy for scikit-learn's DummyClassifier that can be used to make prediction using very simple rules. In this case we will use 'most_frequent' strategy that will simply predict for each request the most frequent label from the training dataset.

This very simple strategy gives us even better accuracy than SGDClassifier. Such strategy, however, does not make sense for our kind of problem. Predicting the most frequent label means that we are not detecting any malicious requests at all. It turns out we don't really know how good SGDClassifier is. Possibly its behaviour is very similar to 'most_frequent' strategy of DummyClassifier. To solve our problem we need to use different evaluation metric for the model.

Some of the popular metrics used in such scenarios are precision and recall. The precision is calculated as the ratio tp / (tp + fp), where tp is the number of true positives and fp the number of false positives in our predictions. Precision intuitively tells us how often model made a correct decision while marking request as malicious. The precision can be very high for models that are very strict and rarely mark requests as malicious. This is why precision is usually used together with complementary metric that is called recall. Recall is calculated as tp / (tp + fn) and intuitively tells us how many of all malicious samples were found by the classifier. Let’s calculate precision and recall for our predictions on the dev set:

It seems our classifier did not perform as good as we initially thought. Precision of 41% means that in every 100 requests that our model marked as malicious we had 59 mistakes. Moreover recall is equal to 37%, so for every 100 requests that were truly malicious predictive model detected only 37. If one of these metrics is more important to us than the other, we could try to tweak the decision threshold.

SGDClassifier not only makes binary decisions, but also assigns confidence score for each request. We can get this scores by calling method of SGDClassifier named decision_function(). By increasing the threshold value of classifier we could make it more strict. That would increase precision but decrease the value of recall. We can get all possible values of precision, recall and threshold by calling scikit-learn's precision_recall_curve() function. Let's plot the result below, this time using ggplot library:

Plotting precision-recall is a great way to visualise performance of the model. The closer is the line to upper-right corner of the plot, the better our classifier is. As you can see, there is plenty of room for improvement.

In subsequent parts of this article we will show you how to improve the effectiveness of our predictive model.

Andrzej Prałat
Andrzej Prałat

For media

Provide us with contact details.

Thank you