Crash course in machine learningA survey of machine learning algorithms
In this article we mean to clarify some of the undefined terms in our previous article and, by way of it, explore a selection of machine learning algorithms and their applications to information security.
We do not pretend to make an exhaustive list of all machine learning (ML) algorithms and techniques. Furthermore we would like to demistify some obscure concepts and dethrone a few buzzwords. Hence this article is far from neutral.
First of all we can classify by the type of input available. If we are trying to develop a system that can identify if the animal in a picture is a cat or a dog, first we need to train it with pictures of cats or dogs. Now, do we tell the system what each picture contains? If we do, it’s supervised learning and we say we are training the system with labeled data. The opposite is completely unsupervised learning and there are a couple of stations in between, such as
semi-supervised: learning: with partially labeled data,
active learning: the computer has to "pay" for the labels,
reinforcement learning: labels are given as output starts to be produced.
However each algorithm typically fits more naturally either in the supervised or unsupervised learning category, so we will stick to those two.
Next, what do want to obtain from our learned sytem? The cat-dog situation above is a typical classification problem: given a new picture, to what category does it belong? A related but different problem is that of clustering, which tries to divide the inputs into groups according to the features identified in them. The difference is that the groups are not known beforehand, so this is an unsupervised problem.
Both of these problems are discrete in the sense that categories are separate and there are no in-between elements (a picture that is 80% cat and 20% dog). What if my data is continuous? Good old regression to the rescue! Even the humble least squares linear regression we learned in school can be thought of as a machine learning technique, since it learns a couple of parameters from the data and can make predictions based on those.
Two other interesting problems are estimating the probability distribution of a sample of points (density estimation) and finding a lower-dimensional representation of our inputs (dimensionality reduction):
With these classifications out of the way, let’s go deeper into each particular technique that is interesting for our purposes.
Support vector machines
Much like linear regression tries to draw a line that best joins a set of closely correlated points, support vector machines (SVM) try to draw a line that separates a set of naturally separated points. Since a line divides the plane in two, any new point must be on one of the two sides, and is thus classified as belonging to one class or the other.
More generally, if the inputs are n-dimensional vectors, an SVM tries to find a geometric object of dimension n-1 (a hyperplane) that divides the given inputs into two (or more) groups. To name an application, support vector machines are used to detect spam in images (which is supposed to evade text spam filters) and face detection. In both cases, the translation of image to vector is relatively easy since computer representations of images are merely matrices of three-number vectors (pixels).
We need to group unlabeled data in a meaningful way. Of course, the number of possible clusterings is very large. In the k-means technique, we need to specify the desired number of clusters k beforehand. How do we choose? We need a way to measure cluster compactness. For every cluster we can define its centroid, something like its center of mass. Thus a measure of the compactness of a cluster could be the sum of the member-to-centroid distances, called the distortion:
With that defined, we can state the problem clearly as an optimization problem: minimize the sum of all distortions. However, this problem is NP-complete (computationally very difficult) but good estimations can be achieved via k-means. It can be shown and, more importantly, makes intuitive sense, that:
Each point must be clustered with the nearest centroid.
Each centroid is at the center of its cluster.
Condition 1 already suggests an algorithm: if you find a point which is closer to a centroid different than the centroid where it is currently assigned, switch them. Where do we begin? The initial clustering choice could be random, we could space centroids evenly, or a perhaps an ad hoc strategy. In fact, since this is a hill-climbing algorithm, i.e. one that makes small improvements in each iteration thus ensuring finding a local maximum but perhaps not the global one, different starting points might lead to different optima. Thus several tries with various starting points are recommended.
Artificial neural networks and deep learning
Loosely inspired by the massive parallelism animal brains are capable of, these models are highly interconnected graphs in which the nodes are (mathematical) functions and the edges have weights which are to be adjusted by the training. A set of weights is scored by the accuracy of labeled output, and optimized in the next step or epoch of training in a process called back-propagation (of error). The weights are adjusted in such a way that the measured error decreases. The nodes are arranged in layers and their functions are typically smooth versions of step functions (i.e. yes/no functions, but with no big jumps) and there are two special layers for input and output. After training, since the whole network is fixed, it’s only a matter of giving it input and getting the output.
The networks described above are feed-forward, in the sense that data flows only in the direction from input to output. Without this restriction, we get recurrent neural networks. Convolutional networks use the mathematical process cross-correlation which is similar to a convolution instead of regular smooth step functions. Deep neural networks owe their name to the great number of layers they use and to the fact that they are unsupervised learning models.
While these networks have been quite succesful in applications, particularly in video games, they are not perfect:
in contrast to simpler machine learning models, they don’t produce a usable or understandable model; it’s just a black box that computes output given input.
biology is perhaps not the best model for engineering. In Mark Stamp’s words ,
Attempting to construct intelligent systems by modeling neural interactions within the brain might seem one day be seen as akin to trying to build an airplane that flaps its wings.
Decision trees and forests
In stark contrast to the unintelligible models extracted from neural networks, decision trees are simple enough to understand at a glance:
However decision trees have a tendency to overfit the training data, i.e., are sensitive to noise and extreme values in it, and, worse, a particular testing point could be predicted differently by two trees made with the same training data, but with, say, the order of features reversed.
These difficulties can be overcome by constructing many trees with different (even possibly overlapping) subsets of the training data and make the final conclusion by taking a vote from all the trees' decisions. This solves overfitting, but the intution obtained from simple trees is lost.
Anomaly detection via k-nearest neighbors
Detecting anomalies is naturally an unsupervised problem and really makes up a whole class of algorithms and techniques, some of which actually make more sense in the context of data mining than machine learning.
The simplest way to detect anomalies could be to compute the average and standard deviation of your data, and declare everything that is more than two standard deviations away from the mean an anomaly (outliers in classical statistics). A slightly more involved approach is to use the k-nearest neighbors algorithm (kNN), which essentially classifies an element according to the k training elements closest to it.
Variations on the same theme are:
assigning weights to neighbors based on their distance or their relative frequency in the training frequency;
classify items based on a fixed radius
The kNN algorithm can also be adapted to be used in the context of regression, classification and anomaly detection; in particular by scoring elements in terms of the distance to its closest neighbor (1NN).
Notice that in kNN there is no training phase: the labeled input is the training data and the model in itself. The most natural application for anomaly detection in computer security is in intrusion detection systems.
I hope this article has served to establish the following general ideas on machine learning:
Even though ML has gained a lot of momentum in the past few years, their basic ideas are quite old.
Fancy names can sometimes be used to masquerade simple ideas. Especially the word learning can actually be misleading, making us think of autonomous machines, when in reality they are just algorithms that extract parameters from training data and later use them in a deterministic way.
ML is not a field of its own, rather a field in between statistics, optimization, data analysis and data mining.
Mark Stamp (2018). Introduction to Machine Learning with Applications in Information Security. CRC Press.
with an itch for CS