public class CategoricalDistribution extends Object
| Constructor and Description |
|---|
CategoricalDistribution(org.nd4j.autodiff.samediff.SDVariable logits,
org.nd4j.autodiff.samediff.SameDiff sd)
Construct the categorical distribution.
|
| Modifier and Type | Method and Description |
|---|---|
org.nd4j.autodiff.samediff.SDVariable |
getProb() |
org.nd4j.autodiff.samediff.SDVariable |
klDivergence(io.siddhi.extension.execution.streamingml.bayesian.model.Distribution distribution)
returns the kl divergence w.r.t the given distribution.
|
org.nd4j.autodiff.samediff.SDVariable |
logProbability(org.nd4j.autodiff.samediff.SDVariable values)
categorical log probability is implemented based on softmax-crossentropy
the implementation is based on following formula
|
org.nd4j.autodiff.samediff.SDVariable |
sample()
returns a random sample from the distribution.
|
org.nd4j.autodiff.samediff.SDVariable |
sample(int n)
returns random samples from the distribution.
|
public CategoricalDistribution(org.nd4j.autodiff.samediff.SDVariable logits,
org.nd4j.autodiff.samediff.SameDiff sd)
logits - should be 2-dimensional.
the dimensions should follow the order (input_size, num_classes)sd - SameDiff contextpublic org.nd4j.autodiff.samediff.SDVariable logProbability(org.nd4j.autodiff.samediff.SDVariable values)
log(p(y)) = sum[1:num_classes]{log(softmax[logits])*y}
however, log(softmax(logits)) can be infinity for some case. hence, unecessary log computations are avoided using transformed formula
log(p(y)) = sum[1:num_classes]{log(softmax[logits]*y)}
output of the both formulas are equivalent if y is one-hot encoded
values - one-hot encoded labelspublic org.nd4j.autodiff.samediff.SDVariable sample()
public org.nd4j.autodiff.samediff.SDVariable sample(int n)
n - number of samplespublic org.nd4j.autodiff.samediff.SDVariable klDivergence(io.siddhi.extension.execution.streamingml.bayesian.model.Distribution distribution)
throws io.siddhi.core.exception.SiddhiAppCreationException
distribution - reference distribution p(x)io.siddhi.core.exception.SiddhiAppCreationExceptionpublic org.nd4j.autodiff.samediff.SDVariable getProb()
Copyright © 2019 WSO2. All rights reserved.