| 8 min read
After our encounter with the Python
machine learning
ecosystem, we are now ready for a first
attempt at creating a vulnerability classifier from samples of labeled
code.
Our objective is to produce one such vulnerability classifier for triaging large codebases. When one of our analysts has to audit a new continuous hacking project, usually all they get is access to one or more big repositories. And clients expect vulnerability reports as soon as possible. We would like our predictor to sort the files in these repositories according to the likelihood, or possibility, that they contain a vulnerability so the analyst can prioritize the manual inspection.
For this particular experiment, we will use Juliet Java dataset. This dataset has important advantages in training a vulnerability classifier, namely:
-
It contains a broad spectrum of vulnerabilities: 112
CWEs
out of 808 total. This doesn’t seem much but covers what we usually see in practice. Thus, the dataset is not biased toward any specific kind, but rather, might give us a good idea of what vulnerable code "looks like" to a machine learning algorithm. -
For each of them, it has examples of good and bad code, which we can use as labels for training.
-
it is well-organized, documented, and big enough for machine learning training: 5'143,930
Java
lines of code as counted by cloc, out of which 4'565,713 are in the test cases, as per its documentation. -
It was specifically designed for testing static analysis tools, which is the purpose of a vulnerability classifier.
The bad part is that these are synthetic test cases, which might not
really agree with real-life production code, and also, they only have
datasets for Java
and C(`
). However, for a first test, this is not a deal-breaker, and we will stick to the `Java
version of the dataset only.
The Juliet dataset is organized with a folder for each CWE
code,
inside of which there are one or more source files, one for each
possible kind of vulnerability. In each of those we find one or more
"good" and "bad" methods. For example, in the folder for CWE-327:
Use
of a Broken or Risky Cryptographic
Algorithm there are
several files, some for the usage of DES
and some for 3DES
the most
common broken ciphers still in use. Here is a snipped adaptation of one
of them:
Bad and good test cases from Juliet’s CWE327.
public class CWE327_Use_Broken_Crypto__DES_11 extends AbstractTestCase
{
public void bad() throws Throwable
{
if (IO.staticReturnsTrue())
{
final String CIPHER_INPUT = "ABCDEFG123456";
KeyGenerator keyGenerator = KeyGenerator.getInstance("DES");
/* Perform initialization of KeyGenerator */
keyGenerator.init(56);
SecretKey secretKey = keyGenerator.generateKey();
byte[] byteKey = secretKey.getEncoded();
/* FLAW: Use a weak crypto algorithm, DES */
SecretKeySpec secretKeySpec = new SecretKeySpec(byteKey, "DES");
Cipher desCipher = Cipher.getInstance("DES");
desCipher.init(Cipher.ENCRYPT_MODE, secretKeySpec);
byte[] encrypted = desCipher.doFinal(CIPHER_INPUT.getBytes("UTF-8"));
}
}
private void good1() throws Throwable
{
final String CIPHER_INPUT = "ABCDEFG123456";
KeyGenerator keyGenerator = KeyGenerator.getInstance("AES");
/* Perform initialization of KeyGenerator */
keyGenerator.init(128);
SecretKey secretKey = keyGenerator.generateKey();
byte[] byteKey = secretKey.getEncoded();
/* FIX: Use a stronger crypto algorithm, AES */
SecretKeySpec secretKeySpec = new SecretKeySpec(byteKey, "AES");
Cipher aesCipher = Cipher.getInstance("AES");
aesCipher.init(Cipher.ENCRYPT_MODE, secretKeySpec);
byte[] encrypted = aesCipher.doFinal(CIPHER_INPUT.getBytes("UTF-8"));
}
}
The methods have the labels right in their names. Also, they are syntactically equal, except for the actual encryption algorithm invoked. So we would like our classifier to be able to gloss over the boilerplate code and fixate on those kind of details.
For this particular experiment because each function has its own label, we decided to go with function-level granularity. This means that the classifier should take, as input, a method, and give as output whether it is "good" or "bad", according to what it will learn during training with the above kinds of functions. Other classifiers might work at line or even token granularity; perhaps ours will be less granular and tell us whether a given source code file might contain a vulnerability. But for the Juliet dataset structure, this granularity seemed the most appropriate.
Next step: preprocessing. As this dataset is well-organized and labeled, this step should be minimal. However, there are two obstacles we must overcome:
-
Splitting each file into separate methods and grabbing the labels from their names: This is a job for a parser combinator, as regular expressions are ineffective in dealing with multiline matches. Plus, parsers can be extended to work with other languages. This is necessary as the models expect a table with one column per feature, and a list, series or
1D
array for the labels, which takes us to the next step: -
Encoding. In this experiment, although we are basically treating code as if it were a bag of disordered tokens and ignoring semantic relationships, encoding is required. Such is the case even for natural language (more information on bag of words,
tf-idf
, andword2vec
encodings for natural language later).
To get each method and label, we used a simple approach: a parser to identify the signatures, grab the code in-between signatures, and remove the comments:
from pyparsing import oneOf,NotAny,Regex
SIGNATURE = (oneOf('private public')
+ NotAny('class')
+ Regex('.*').setResultsName('actual_signature'))
remove_comments = lambda x : re.sub('\/\*.*\*\/', '', x)
remove_labelsig = lambda x : re.sub('bad|good', 'method', x)
def get_methods(source_text):
results = list(SIGNATURE.scanString(source_text))
methods = []
labels = []
for i in range(len(results) - 1):
signature = results[i][0][1]
vulnerable = 'bad' in signature
text_mid = source_text[results[i][1] : results[i+1][1]]
method = remove_comments(text_mid)
method = remove_labelsig(method)
labels.append(vulnerable)
return methods, labels
Also, we replace the words "good" and "bad" with "method" in the method signatures. This avoids cheating by the classifier which could make a prediction based on these particular words.
Next, we extend that to read and parse Juliet test cases. We extracted 31,105 methods and labels; 19,976 of which are labelled as not vulnerable while the remaining 11,129 are labeled as vulnerable.
Onward to encoding, which, in this case, is more parsing. This time the
parsing is performed by Keras
, the friendly high-level neural networks
API
for Python
. It makes sense to use Keras
here since the neural
network will be implemented via this library. Now that we have a list,
all_methods
, in which each element is a method, we can break that up
into tokens for analysis:
from keras.preprocessing.text import Tokenizer
NUM_WORDS = 1000
keras_tokenizer = Tokenizer(NUM_WORDS, filters='\t\n', lower=True,
split=' ', char_level=False)
keras_tokenizer.fit_on_texts(all_methods)
Much like the actual machine learning models from APIs
, such as
scikit
, as seen in our previous article,
this tokenizer must be trained, or fit, to the dataset. After that,
the object becomes populated with already interesting facts about our
language corpus:
>>> keras_tokenizer.word_counts
OrderedDict([('public', 15676),
('void', 25995),
('method()', 5125),
('throws', 26778),
('throwable', 26746),
('{', 186876),
('switch', 1279),
('(7)', 405),
('case', 1415),
('7:', 555),
('messagedigest', 658),
('hash', 96),
('=', 127781),
('messagedigest.getinstance("sha-512");', 326),
('byte[]', 1250),
('hashvalue', 240),
('hash.digest("hash', 96),
The most popular tokens are those appearing in the signature. But ignoring those, it is clear that we are dealing with a security-focused dataset: all the following tokens deal with hashing, a common operation when dealing with sensitive data that needs to be masked.
The Keras
tokenizer can perform categorical encoding on these
sequences as well, and is perhaps the most unsophisticated of all
encodings. It simply assigns a number to each of the tokens, and
represents a string of them as the list of those numbers.
sequences = keras_tokenizer.texts_to_sequences(all_methods)
Thus the part of the method:
method = '''public void bad() throws Throwable{
switch (7){
case 7:
MessageDigest hash = MessageDigest.getInstance("SHA-512");
byte[] hashValue = hash.digest("hash me".getBytes("UTF-8"))'''
becomes the sequence:
>>> sequences[0]
[24, 18, 69, 16, 17, 1, 230, 510, 1, 213, 446, 381, 845, 3, 534, 238,
567, 3, 846, 847, 568, 80, 237, 122, 123, 124, 80, 2, 2]
and we can recover its tokens using the index_word
attribute of the
keras_tokenizer
:
>>> [keras_tokenizer.index_word[i] for i in sequences[0]]
['public', 'void', 'bad()', 'throws', 'throwable', '{',
'switch', '(7)', '{', 'case', '7:', 'messagedigest', 'hash',...
Neural networks also expect features to be vectors of the same size, so
we need to pad these sequences by filling them with zeros. Keras
also
provides a convenient function for this:
from keras.preprocessing.sequence import pad_sequences
PAD_SIZE = max(map(len, sequences))
padded_seqs = pad_sequences(sequences, maxlen=PAD_SIZE, padding='post')
Finally, we create our neural network. it will be very simple: the input layer, one hidden layer, and the output layer.
from keras.models import Sequential
from keras.layers import Dense, Flatten
from keras.layers.embeddings import Embedding
MODEL = Sequential()
MODEL.add(Embedding(NUM_WORDS, 100, input_length=PAD_SIZE))
MODEL.add(Flatten())
MODEL.add(Dense(1, activation='sigmoid'))
MODEL.compile(optimizer='adam', loss='binary_crossentropy', metrics=['acc'])
The process is not that different from specifying a scikit model, we just add a few more lines, one per layer, each with their (tunable) hyperparameters. Finally we compile the model, where we define the loss function and the metrics, which, here, are to maximize the accuracy of the classifier.
In order to validate our model, it is good practice to reserve a smaller
part of it (here 20%) for testing purposes and use the remaining for
training. We can do that with scikit
:
X_train, X_test, y_train, y_test = train_test_split(padded_seqs, all_labels,
test_size = 0.2,
random_state=0)
Then we train our model:
MODEL.fit(X_train, y_train, epochs = 20, validation_split = 0.2)
And evaluate it using the reserved part of the dataset:
>>> MODEL.evaluate(X_test, y_test)
6221/6221 [==============================] - 0s 20us/step
[0.22666279486551333, 0.8609548304514416]
The first one is the loss, and the second the accuracy. In our opinion,
an accuracy of 86% is good for a first go at the ML
-aided code
auditing triage problem, but, we hope to raise the bar a bit higher. We
can save our model for sharing with others. The 'h5' file can be loaded
from Keras
as easily as it was saved, just like we did in the
adversarial examples article.
>>> MODEL.save('vuln_classifier.h5')
>>> !ls -lh *.h5
-rw-r--r-- 1 r r 4.7M Sep 23 10:02 dog_tree.h5
-rw-r--r-- 1 r r 1.8M Oct 8 11:31 vuln_classifier.h5
This particular model is relatively lightweight compared to the
MobileNet
-based animal classifier dog_tree.h5
. This model could be
deployed, for example on AWS Lambda
ready to make predictions. Just make a request with the source code file
and it will, at its best, tell you if it thinks it contains a
vulnerability or not.
Download the full notebook here, the Juliet dataset zip here. Running this experiment either as a notebook or script takes around two minutes:
r@x:~$ time jupyter nbconvert --execute parse-juliet-train-simple-nn.ipynb --ExecutePreprocessor.timeout=-1
[NbConvertApp] Converting notebook parse-juliet-train-simple-nn.ipynb to html
[NbConvertApp] Executing notebook with kernel: python3
...
real 2m8.583s
user 2m31.455s
sys 0m3.339s
r@x:~$ time python3 Downloads/parse-juliet-train-simple-nn.py
/* TEMPLATE GENERATED TESTCASE FILE
Filename: CWE760_Predictable_Salt_One_Way_Hash__basic_06.java
...
real 1m59.448s
user 2m24.944s
sys 0m2.970s
Update, July 22, 2022: At Fluid Attacks, we use AI prioritization in our Secure Code Review solution.
Share
Recommended blog posts
You might be interested in the following related posts.
How we enhance our tests by standardizing them
Introduction to cybersecurity in the aviation sector
Why measure cybersecurity risk with our CVSSF metric?
Our new testing architecture for software development
Protecting your PoS systems from cyber threats
Top seven successful cyberattacks against this industry
Challenges, threats, and best practices for retailers