How to Build a Spam Filter

Published on

We're going to learn how to build a spam filter using some basic probability.

Bayes' Rule

Let's start with a little probability. If you already understand Bayes rule, you can skip to the next section.

P(AB)P(A | B) means "the probability of A given B." Let's look at an example to figure out how to calculate this.

P(Toothache)P(Toothache) is the probability that you have a toothache (the blue area in the picture below) P(Cavity)P(Cavity) is the probability that you have a cavity (the orange area in the picture below)

Now, what is P(ToothacheCavity)P(Toothache | Cavity)? It's the probability that you have a toothache given that you have a cavity. The phrase "given that you have a cavity" means that we want to restrict the domain to CavityCavity (as shown below).

We can then see that "the probability that you have a toothache given that you have a cavity" can be represented as P(ToothacheCavity)=P(ToothacheCavity)P(Cavity)P(Toothache | Cavity) = \frac{P(Toothache \cap Cavity)}{P(Cavity)}

This can be generalized to P(AB)=P(AB)P(B)P(A | B) = \frac{P(A \cap B)}{P(B)}.

After we multiply both sides of the equation by the denominator, we get P(AB)P(B)=P(AB)P(A | B) * P(B) = P(A \cap B). This also means that P(BA)P(A)=P(BA)P(B | A) * P(A) = P(B \cap A) (replacing A with B and B with A).

Since P(AB)=P(BA)P(A \cap B) = P(B \cap A), we can combine the previous two equations to get P(AB)P(B)=P(BA)P(A)P(A | B) * P(B) = P(B | A) * P(A).

After a little simplification, we end up with P(AB)=P(BA)P(A)P(B)P(A | B) = \frac{P(B | A) * P(A)}{P(B)}. This equation is known as Bayes' Rule. This rule is going to be an extremely important part of our spam filter.

Spam Filter

The goal of our spam filter is to figure out whether an email is spam or not. More formally, it marks an email as spam if P(SpamEmail)>P(NotSpamEmail)P(Spam | Email) > P(Not Spam | Email). Let's apply Bayes' rule and simplify.

P(EmailSpam)P(Spam)P(Email)>P(EmailNotSpam)P(NotSpam)P(Email)\frac{P(Email | Spam) * P(Spam)}{P(Email)} > \frac{P(Email | Not Spam) * P(Not Spam)}{P(Email)}

=P(EmailSpam)P(Spam)>P(EmailNotSpam)P(NotSpam) = P(Email | Spam) * P(Spam) > P(Email | Not Spam) * P(Not Spam)

P(Spam)=P(Spam) = The probability that an email is spam if we know nothing else about it.

P(NotSpam)=P(Not Spam) = The probability that an email is not spam if we know nothing else about it.

P(EmailSpam)=P(Email | Spam) = Assuming a classification of spam, the probability of this email being generated.

P(EmailNotSpam)=P(Email | Not Spam) = Assuming a classification of not spam, the probability of this email being generated.

Based on the definition above, P(Spam)=Count(Spam)Count(Spam)+Count(NotSpam)P(Spam) = \frac{Count(Spam)}{Count(Spam) + Count(Not Spam)} (the number of spam emails we've seen divided by the total number of emails we've seen). Similarly with "Not Spam".

Let's formalize the notion of an EmailEmail. Since we're building a simple spam filter, we'll just focus on the words in the body of the email. We are also going to make an assumption that every word in the email is independent of every other word in the email. Therefore, we can say that P(EmailC)=wWP(wC)P(Email | C) = \prod_{w \in W} P(w | C) where WW is the set of words in the email and CC is the class (SpamSpam or NotSpamNot Spam). Basically, this is saying that the probability of an EmailEmail given that it is in class CC is the product of P(wC)P(w | C) for each word in the email.

Okay, so what does P(wSpam)P(w | Spam) mean? Intuitively, it means how likely it is that the word ww is in a SpamSpam email. Therefore P(wSpam)=CountSpam(w)xSpamCountSpam(x)P(w | Spam) = \frac{Count_{Spam}(w)}{\sum_{x \in Spam} Count_{Spam}(x)} (the number of times w shows up in spam emails divided by the number of words in all the spam emails). Similarly for NotSpamNot Spam.

Robin Hood

What if we want to find P(wSpam)P(w | Spam) for a word we haven't seen in our training data? Using the current formua, it would be zero. Since we are multiplying the probabilities for all the words in an email, a word having probability of zero is not ideal (because it would make the whole expression zero). We need some sort of smoothing. Basically, we want to take a little from the "rich" words (words that are used a lot) and give it to the "poor" words (words that are rarely or never used). To do this, we modify the above formula to be the following:

P(wSpam)=CountSpam(w)+1xSpamCountSpam(x)+VocabularySizeP(w | Spam) = \frac{Count_{Spam}(w) + 1}{\sum_{x \in Spam} Count_{Spam}(x) + VocabularySize}

where VocabularySizeVocabularySize is the number of unique words we've seen in our training data.

Great! So now we just need to count the words in our training data and we're done! First, let's combine some of the logic above into one formula:

P(Spam)wWP(wSpam)>P(NotSpam)wWP(wNotSpam)P(Spam) * \prod_{w \in W} P(w | Spam) > P(Not Spam) * \prod_{w \in W} P(w | Not Spam)

Underflow

Since we're multiplying lots of small numbers, we could run into floating-point underflow. To fix this, lets use the following rules:

If x>yx > y then log(x)>log(y)log(x) > log(y)

log(ab)=log(a)+log(b)log(a * b) = log(a) + log(b)

We can use them to transform

P(Spam)wWP(wSpam)>P(NotSpam)wWP(wNotSpam)P(Spam) * \prod_{w \in W} P(w | Spam) > P(Not Spam) * \prod_{w \in W} P(w | Not Spam)

into

log(P(Spam))+wWlog(P(wSpam))>log(P(NotSpam))+wWlog(P(wNotSpam))log(P(Spam)) + \sum_{w \in W} log(P(w | Spam)) > log(P(Not Spam)) + \sum_{w \in W} log(P(w | Not Spam))

Bringing it all together

Plug in our results from above to get this formula:

log(Count(Spam)Count(Spam)+Count(NotSpam))+wWlog(CountSpam(w)+1xSpamCountSpam(x)+VocabularySize)log(\frac{Count(Spam)}{Count(Spam) + Count(Not Spam)}) + \sum_{w \in W} log(\frac{Count_{Spam}(w) + 1}{\sum_{x \in Spam} Count_{Spam}(x) + VocabularySize})

>>

log(Count(NotSpam)Count(Spam)+Count(NotSpam))+wWlog(CountNotSpam(w)+1xNotSpamCountNotSpam(x)+VocabularySize)log(\frac{Count(Not Spam)}{Count(Spam) + Count(Not Spam)}) + \sum_{w \in W} log(\frac{Count_{Not Spam}(w) + 1}{\sum_{x \in Not Spam} Count_{Not Spam}(x) + VocabularySize})

We're going to start turning this into code. Let's define the following:

vocabularySize // The number of unique words in our training data

spam.totalEmails // The total number of emails marked spam in our training data
spam.totalWords // The total number of words in all the spam emails
spam.wordToCountMap // A mapping from a word to the number of times it shows up in all spam emails

(Similarly for notSpam)

Our formula now looks like this:

log(spam.totalEmailsspam.totalEmails+notSpam.totalEmails)+wWlog(spam.wordToCountMap.get(w)+1spam.totalWords+vocabularySize)log(\frac{spam.totalEmails}{spam.totalEmails + notSpam.totalEmails}) + \sum_{w \in W} log(\frac{spam.wordToCountMap.get(w) + 1}{spam.totalWords + vocabularySize})

>>

log(notSpam.totalEmailsspam.totalEmails+notSpam.totalEmails)+wWlog(notSpam.wordToCountMap.get(w)+1notSpam.totalWords+vocabularySize)log(\frac{notSpam.totalEmails}{spam.totalEmails + notSpam.totalEmails}) + \sum_{w \in W} log(\frac{notSpam.wordToCountMap.get(w) + 1}{notSpam.totalWords + vocabularySize})

where WW is the set of words in the email we want to classify. At this point, it's pretty easy to turn this theory into code. To train the filter, we just need to populate the variables above (totalEmails, totalWords, and wordToCountMap) for each class (Spam and Not Spam) and keep track of our vocabulary. To classify an email, we just need to check if the inequality above is true or false.

Psuedocode

Psuedocode for training and classifying using the spam filter is included below.

vocabulary = new set();
spam = new FilterClass()
notSpam = new FilterClass()

class FilterClass:
    int totalWords
    int totalEmails
    Map<String, Integer> wordToCountMap

def train(emails):
    for email in emails:
        vocabulary.addAll(email.words)
        currentClass = spam
        if not email.isSpam:
            currentClass = notSpam
        
        currentClass.totalEmails += 1
        currentClass.totalWords += len(email.words)
        for word in email.words:
            if word not in currentClass.wordToCountMap:
                currentClass.wordToCountMap[word] = 0
            currentClass.wordToCountMap += 1
            
def isSpam(email):
    spamProb = log(spam.totalEmails / (spam.totalEmails + notSpam.totalEmails))
    notSpamProb = log(notSpam.totalEmails / (spam.totalEmails + notSpam.totalEmails))
    for word in email.words:
        spamCount = 1
        notSpamCount = 1
        
        if word in spam.wordToCountMap:
            spamCount += spam.wordToCountMap[word]
            
        if word in notSpam.wordToCountMap:
            notSpamCount += notSpam.wordToCountMap[word]
            
        spamCount /= float(spam.totalWords + len(vocabulary))    
        notSpamCount /= float(notSpam.totalWords + len(vocabulary))
        
        spamProb += log(spamCount)
        notSpamProb += log(notSpamCount)
        
    return spamProb > notSpamProb

That's it! A basic spam filter.

Accuracy

This method is around ~97% accurate on the test and training data that I used. In order to increase accuracy, we can take other features into account. For example, we may want to look at the sender's email domain or the words in the subject. I may write a follow-up post to talk about accuracy improvements.

Let me know what you think in the comments below!