public class TradNaiveBayesClassifier extends Object implements JointClassifier<CharSequence>, ObjectHandler<Classified<CharSequence>>, Serializable, Compilable
TradNaiveBayesClassifier implements a traditional
token-based approach to naive Bayes text classification. It wraps
a tokenization factory to convert character sequences into
sequences of tokens. This implementation supports several
enhancements to simple naive Bayes: priors, length normalization,
and semi-supervised training with EM.
It is the token counts (aka "bag of words") sequence that is actually being classified, not the raw character sequence input. So any character sequences that produce the same bags of tokens are considered equal.
Naive Bayes is trainable online, meaning that it can be given
training instances one at a time, and at any point can be used as a
classifier. Training cases consist of a character sequence and
classification, as dictated by the interface ObjectHandler<Classified<CharSequence>>.
Given a character sequence, a naive Bayes classifier returns
joint probability estimates of categories and tokens; this is
reflected in its implementing the Classifier<CharSequence,JointClassification> interface. Note that
this is the joint probability of the token counts, so sums of
probabilities over all input character sequences will exceed 1.0.
Typically, only the conditional probability estimates are used in
practice.
If there is length normalization, the joint probabilities will not sum to 1.0 over all inputs and outputs. The conditional probabilities will always sum to 1.0.
Conditional probabilities are derived by applying Bayes's rule to invert the probability calculation:p(tokens,cat) = p(tokens|cat) * p(cat)
p(cat|tokens) = p(tokens,cat) / p(tokens)
= p(tokens|cat) * p(cat) / p(tokens)
The tokens are assumed to be independent (this is the
"naive" step):
p(tokens|cat) = p(tokens[0]|cat) * ... * p(tokens[tokens.length-1]|cat)
= Πi < tokens.length p(tokens[i]|cat)
Finally, an explicit marginalization allows us to compute the
marginal distribution of tokens:
p(tokens) = Σcat' p(tokens,cat')
= Σcat' p(tokens|cat') * p(cat')
p(cat|tokens) in terms of two distributions, the conditional
probability of a token given a category p(token|cat), and the
marginal probability of a category p(cat) (sometimes called
the category's prior probability, though this shouldn't be confused
with the usual Bayesian prior on model parameters).
Traditional naive Bayes uses a maximum a posterior (MAP)
estimate of the multinomial distributions: p(cat) over the
set of categories, and for each category cat, the
multinomial distribution p(token|cat) over the set of tokens.
Traditional naive Bayes employs the Dirichlet conjugate prior for
multinomials, which is straightforward to compute by adding a fixed
"prior count" to each count in the training data. This lends the
traditional name "additive smoothing".
Two sets of counts are sufficient for estimating a traditional
naive Bayes classifier. The first is tokenCount(w,c), the
number of times token w appeared as a token in a training
case for category c. The second is caseCount(c),
which is the number of training cases for category c.
We assume prior counts α for the case counts
and β for the token counts. These values are supplied
in the constructor for this class.
The estimates for category and token probabilities p'
are most easily understood as proportions:
The probability estimatesp'(w|c) ∝ tokenCount(w,c) + β p'(c) ∝ caseCount(c) + α
p' are obtained through the
usual normalization:
p'(w|c) = ( tokenCount(w,c) + β ) / Σw ( tokenCount(w,c) + β ) p'(c) = ( caseCount(c) + α ) / Σc ( caseCount(c) + α )
Although not traditionally used for naive Bayes, maximum
likelihood estimates arise from setting the prior counts equal to
zero (α = β = 0). The prior counts drop
out of the equations to yield the maximum likelihood estimates
p*:
p*(w|c) = tokenCount(w,c) / Σw tokenCount(w,c) p*(c) = caseCount(c) / Σc caseCount(c)
Unlike traditional naive Bayes implementations, this class allows weighted training, including training directly from a conditional classification. When training using a conditional classification, each category is weighted according to its conditional probability.
Weights may be negative, allowing counts to be decremented (e.g. for Gibbs sampling).
Because the (almost always faulty) independence of tokens
assumptions underlying the naive Bayes classifier, the conditional
probability estimates tend toward either 0.0 or 1.0 as the input
grows longer. In practice, it sometimes help to length normalize
the documents. That is, consider each document to be a given
number of tokens long, lengthNorm.
Length normalization can be computed directly on the linear scale:
but is more easily understood on the log scale, where we multiply the length norm by the log probability normalized per token:pn(tokens|cat) = p(tokens|cat)(lengthNorm/tokens.length)
The length normalization parameter is supplied in the constructor, with alog2 pn(tokens|cat) = lengthNorm * log2 p(tokens|c) / tokens.length
Double.NaN value indicating
that no length normalization should be done.
Length normalization will be applied during training, too. Length normalization may be changed using the set method. For instance, this allows training to skip length normalization and classification to use length normalization.
EM is controlled by epoch. Each epoch consists of an expectation (E) step, followed by a maximization (M) step. The expectation step computes expectations which are then fed in as training weights to the maximization step.
The version of EM implemented in this class allows a mixture of supervised and unsupervised data.
The supervised training data is
in the form of a corpus of classifications, implementing
Corpus Unsupervised data is in the form of a corpus of texts, implementing
The method also requires a factory with which to produce a new
classifier in each epoch, namely an implementation of EM works by iteratively training better and better classifiers
using the previous classifier to label unlabeled data to use
for training.
Note that in each round, the new classifier is trained on
the supervised items.
In general, we have found that EM training works best if the
initial classifier does more smoothing than the classifiers
returned by the factory.
Annealing, of a sort, may be built in by having the factory
return a sequence of classifiers with ever longer length
normalizations and/or lower prior counts, both of which attenuate
the posterior predictions of a naive Bayes classifier. With a
short length normalization, classifications are driven closer to
uniform; with longer length normalizations they are more peaky.
It is possible to train a classifier in a completely
unsupervised fashion by having the initial classifier assign
categories at random. Only the number of categories must be fixed.
The algorithm is exactly the same, and the result after
convergence or the maximum number of epochs is a classifier.
Now take the trained classifier and run it over the texts in the
unsupervised text corpus. This will assign probabilities of the
text belonging to each of the categories. This is known as a soft
clustering, and the algorithm overall is known as EM clustering.
If we assign each item to its most likely category, the result
is then a hard clustering.
A naive Bayes classifier may be serialized. The object read
back in will behave just as the naive Bayes classifier that was
serialized. The tokenizer factory must be serializable in order
to serialize the classifier.
A naive Bayes classifier may be compiled. In order to be
compiled, the tokenizer factory must be either serializable or
compilable. The object read back in will implement A compiled classifier may not be trained.
A compiled classifier is completely thread safe.Corpus<TextHandler>.
Factory<TradNaiveBayesClassifier>. And it also takes an initial
classifier, which may be different than the classifiers generated
by the factory.
set lastClassifier to initialClassifier
for (epoch = 0; epoch < maxEpochs; ++epoch) {
create classifier using factory
train classifier on supervised items
for (x in unsupervised items) {
compute p(c|x) with lastClassifier
for (c in category)
train classifier on c weighted by p(c|x)
}
evaluate corpus and model probability under classifier
set lastClassifier to classifier
break if converged
}
return lastClassifierUnsupervised Learning and EM Soft Clustering
Serialization and Compilation
ConditionalClassifier<CharSequence> if the compiled classifier is
binary (i.e., has exactly two categories) and JointClassifier<CharSequence> if the compiled classifier has more
than two categories. The ability to compute joint probabilities in
the binary case is lost due to an optimization in the compiler, so
the resulting class only implements conditional classifier.
Comparison to
The naive Bayes classifier implemented in NaiveBayesClassifierNaiveBayesClassifier differs from this version in smoothing the
token estimates with character language model estimates.
Thread Safety
A TradNaiveBayesClassifier must be synchronized externally
using read/write synchronization (e.g. with ReadWriteLock. The write methods
include handle(Classified), train(CharSequence,Classification,double), trainConditional(CharSequence,ConditionalClassification,double,double),
and setLengthNorm(double). All other methods are read
methods.
| Constructor and Description |
|---|
TradNaiveBayesClassifier(Set<String> categorySet,
TokenizerFactory tokenizerFactory)
Constructs a naive Bayes classifier over the specified
categories, using the specified tokenizer factory.
|
TradNaiveBayesClassifier(Set<String> categorySet,
TokenizerFactory tokenizerFactory,
double categoryPrior,
double tokenInCategoryPrior,
double lengthNorm)
Constructs a naive Bayes classifier over the specified
categories, using the specified tokenizer factory, priors and
length normalization.
|
| Modifier and Type | Method and Description |
|---|---|
Set<String> |
categorySet()
Returns a set of categories for this classifier.
|
JointClassification |
classify(CharSequence in)
Return the classification of the specified character sequence.
|
void |
compileTo(ObjectOutput out)
Compile this classifier to the specified object output.
|
static Iterator<TradNaiveBayesClassifier> |
emIterator(TradNaiveBayesClassifier initialClassifier,
Factory<TradNaiveBayesClassifier> classifierFactory,
Corpus<ObjectHandler<Classified<CharSequence>>> labeledData,
Corpus<ObjectHandler<CharSequence>> unlabeledData,
double minTokenCount)
Apply the expectation maximization (EM) algorithm to train a traditional
naive Bayes classifier using the specified labeled and unabled data,
initial classifier and factory for creating subsequent factories.
|
static TradNaiveBayesClassifier |
emTrain(TradNaiveBayesClassifier initialClassifier,
Factory<TradNaiveBayesClassifier> classifierFactory,
Corpus<ObjectHandler<Classified<CharSequence>>> labeledData,
Corpus<ObjectHandler<CharSequence>> unlabeledData,
double minTokenCount,
int maxEpochs,
double minImprovement,
Reporter reporter)
Apply the expectation maximization (EM) algorithm to train a traditional
naive Bayes classifier using the specified labeled and unabled data,
initial classifier and factory for creating subsequent factories,
maximum number of epochs, minimum improvement per epoch, and reporter
to which progress reports are sent.
|
void |
handle(Classified<CharSequence> classifiedObject)
Trains the classifier with the specified classified character
sequence.
|
boolean |
isKnownToken(String token)
Returns
true if the token has been seen in
training data. |
Set<String> |
knownTokenSet()
Returns an unmodifiable view of the set of tokens.
|
double |
lengthNorm()
Returns the length normalization factor for this
classifier.
|
double |
log2CaseProb(CharSequence input)
Returns the log (base 2) marginal probability of the specified
input.
|
double |
log2ModelProb()
Returns the log (base 2) of the probability density of this
model in the Dirichlet prior specified by this classifier.
|
double |
probCat(String cat)
Returns the probability estimate for the specified
category.
|
double |
probToken(String token,
String cat)
Returns the probability of the specified token
in the specified category.
|
void |
setLengthNorm(double lengthNorm)
Set the length normalization factor to the specified value.
|
String |
toString()
Return a string representation of this classifier.
|
void |
train(CharSequence cSeq,
Classification classification,
double count)
Trains the classifier with the specified case consisting of
a character sequence and conditional classification with the
specified count.
|
void |
trainConditional(CharSequence cSeq,
ConditionalClassification classification,
double countMultiplier,
double minCount)
Trains this classifier using tokens extracted from the
specified character sequence, using category count multipliers
derived by multiplying the specified count multiplier by the
conditional probablity of a category in the specified
classification.
|
public TradNaiveBayesClassifier(Set<String> categorySet, TokenizerFactory tokenizerFactory)
Double.NaN).
See the class documentation above for more information.
categorySet - Categories for classification.tokenizerFactory - Factory to convert char sequences to
tokens.IllegalArgumentException - If there are fewer than two
categories.public TradNaiveBayesClassifier(Set<String> categorySet, TokenizerFactory tokenizerFactory, double categoryPrior, double tokenInCategoryPrior, double lengthNorm)
categorySet - Categories for classification.tokenizerFactory - Factory to convert char sequences to
tokens.categoryPrior - Prior count for categories.tokenInCategoryPrior - Prior count for tokens per category.lengthNorm - A positive, finite length norm, or Double.NaN if no length normalization is to be done.IllegalArgumentException - If either prior is negative or
not finite, if there are fewer than two categories, or if the
length normalization constant is negative, zero, or infinite.public String toString()
public Set<String> categorySet()
public void setLengthNorm(double lengthNorm)
lengthNorm - Length normalization or Double.NaN to turn
off normalization.IllegalArgumentException - If the length norm is
infinite, zero, or negative.public JointClassification classify(CharSequence in)
classify in interface BaseClassifier<CharSequence>classify in interface ConditionalClassifier<CharSequence>classify in interface JointClassifier<CharSequence>classify in interface RankedClassifier<CharSequence>classify in interface ScoredClassifier<CharSequence>in - Character sequence being classified.public double lengthNorm()
public boolean isKnownToken(String token)
true if the token has been seen in
training data.token - Token to test.true if the token has been seen in
training data.public Set<String> knownTokenSet()
public double probToken(String token, String cat)
IllegalArgumentException - If the category is not known
or the token is not known.public void compileTo(ObjectOutput out) throws IOException
compileTo in interface Compilableout - Object output to which this classifier is compiled.IOException - If there is an underlying I/O error
during the write.public double probCat(String cat)
cat - Category whose probability is returned.IllegalArgumentException - If the category is not known.public void handle(Classified<CharSequence> classifiedObject)
trainConditional(CharSequence,ConditionalClassification,double,double).handle in interface ObjectHandler<Classified<CharSequence>>classifiedObject - Classified character sequence.public void trainConditional(CharSequence cSeq, ConditionalClassification classification, double countMultiplier, double minCount)
cSeq - Character sequence being trained.classification - Conditional classification to train.countMultiplier - Count multiplier of training instance.minCount - Minimum count for which a category is trained for this character
sequence.IllegalArgumentException - If the countMultiplier is not finite and
non-negative, or if the min count is below zero or not a number.public void train(CharSequence cSeq, Classification classification, double count)
If the count value is negative, counts are subtracted rather than added. If any of the counts fall below zero, an illegal argument exception will be thrown and the classifier will be reverted to the counts in place before the method was called. Cleanup after errors requires the tokenizer factory to return the same tokenizer given the same string, but no check is made that it does.
cSeq - Character sequence on which to train.classification - Classification to train with character
sequence.count - How many instances the classification will count
as for training purposes.IllegalArgumentException - If the count is negative and
increments cause accumulated counts to fall below zero.public double log2CaseProb(CharSequence input)
p(x) = Σc in cats p(c,x)
Note that this value is normalized by the number of tokens
in the input, so that
Σlength(x) = n p(x) = 1.0
input - Input character sequence.public double log2ModelProb()
The result is the sum of the log density of the multinomial over categories and the log density of the per-category multinomials over tokens.
For a definition of the probability function for each
category's multinomial and the overall category multinomial,
see Statistics.dirichletLog2Prob(double,double[]).
public static Iterator<TradNaiveBayesClassifier> emIterator(TradNaiveBayesClassifier initialClassifier, Factory<TradNaiveBayesClassifier> classifierFactory, Corpus<ObjectHandler<Classified<CharSequence>>> labeledData, Corpus<ObjectHandler<CharSequence>> unlabeledData, double minTokenCount) throws IOException
This method lets the client take control over assessing convergence, so there are no convergence-related arguments.
initialClassifier - Initial classifier to bootstrap.classifierFactory - Factory for creating subsequent classifiers.labeledData - Labeled data for supervised trianing.unlabeledData - Unlabeled data for unsupervised training.minTokenCount - Min count for a word to not be pruned.IOExceptionpublic static TradNaiveBayesClassifier emTrain(TradNaiveBayesClassifier initialClassifier, Factory<TradNaiveBayesClassifier> classifierFactory, Corpus<ObjectHandler<Classified<CharSequence>>> labeledData, Corpus<ObjectHandler<CharSequence>> unlabeledData, double minTokenCount, int maxEpochs, double minImprovement, Reporter reporter) throws IOException
initialClassifier - Initial classifier to bootstrap.classifierFactory - Factory for creating subsequent classifiers.labeledData - Labeled data for supervised trianing.unlabeledData - Unlabeled data for unsupervised training.minTokenCount - Min count for a word to not be pruned.maxEpochs - Maximum number of epochs to run training.minImprovement - Minimum relative improvement per epoch.reporter - Reporter to which intermediate results are reported,
or null for no reporting.IOExceptionCopyright © 2019 Alias-i, Inc.. All rights reserved.