public class SectorTagger extends Tagger
| Modifier and Type | Field and Description |
|---|---|
protected Encoder |
bagEncoder |
protected Encoder |
embEncoder |
protected ModelEvaluation |
eval |
protected org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor |
ff2rnn |
protected Encoder |
flagEncoder |
protected static org.slf4j.Logger |
log |
protected boolean |
requireSubsampling |
protected Encoder |
targetEncoder |
protected int |
workers |
batchSize, embeddingLayerSize, embeddingVectorSize, inputVectorSize, maxTimeSeriesLength, net, numEpochs, numExamples, outputVectorSize, randomizeid, model, modelAvailable, name, timer| Constructor and Description |
|---|
SectorTagger()
used by XML deserializer
|
SectorTagger(Resource modelPath) |
SectorTagger(String id) |
| Modifier and Type | Method and Description |
|---|---|
void |
attachVectors(Collection<Document> docs,
AbstractMultiDataSetIterator.Stage stage,
Class<? extends Encoder> targetClass) |
protected void |
attachVectors(DocumentSentenceIterator.DocumentBatch batch,
Class<? extends Encoder> targetClass) |
SectorTagger |
buildSECTORModel(int ffwLayerSize,
int lstmLayerSize,
int embeddingLayerSize,
int iterations,
double learningRate,
double dropout,
org.nd4j.linalg.lossfunctions.ILossFunction lossFunc,
org.nd4j.linalg.activations.Activation activation) |
protected static void |
clearLayerStates(org.deeplearning4j.nn.graph.ComputationGraph net)
clear layer states to avoid leaks
|
void |
enableTrainingUI() |
Map<String,org.nd4j.linalg.api.ndarray.INDArray> |
encodeMatrix(DocumentSentenceIterator.DocumentBatch batch) |
static Map<String,org.nd4j.linalg.api.ndarray.INDArray> |
feedForward(org.deeplearning4j.nn.graph.ComputationGraph net,
org.nd4j.linalg.dataset.api.MultiDataSet next) |
List<Encoder> |
getEncoders() |
org.deeplearning4j.nn.conf.ComputationGraphConfiguration |
getGraphConfiguration() |
org.deeplearning4j.nn.graph.ComputationGraph |
getNN() |
Encoder |
getTargetEncoder() |
boolean |
isRequireSubsampling() |
void |
loadModel(Resource modelFile) |
void |
saveModel(Resource modelPath,
String name)
Saves the model to
|
void |
setEncoders(List<Encoder> encoders) |
void |
setGraphConfiguration(org.nd4j.shade.jackson.databind.JsonNode conf) |
void |
setInputEncoders(Encoder bagEncoder,
Encoder embEncoder,
Encoder flagEncoder) |
void |
setRequireSubsampling(boolean requireSubsampling) |
void |
setTargetEncoder(Encoder targetEncoder) |
SectorTagger |
setWorkspaceParams(int workers) |
void |
tag(Collection<Document> docs) |
void |
testModel(Dataset dataset) |
void |
trainModel(Dataset dataset) |
org.deeplearning4j.earlystopping.EarlyStoppingResult<org.deeplearning4j.nn.graph.ComputationGraph> |
trainModel(Dataset train,
Dataset validation,
org.deeplearning4j.earlystopping.EarlyStoppingConfiguration conf) |
void |
trainModel(Dataset dataset,
int numEpochs) |
protected void |
triggerEpochListeners(boolean epochStart,
int epochNum) |
getBatchSize, getEmbeddingLayerSize, getLayerConfiguration, getNumEpochs, isModelAvailableInChildren, isRandomize, loadConf, saveUpdater, setBatchSize, setEmbeddingLayerSize, setLayerConfiguration, setListeners, setNumEpochs, setRandomize, setTrainingParams, tagappendTestLog, appendTestLog, appendTrainLog, appendTrainLog, clearTestLog, clearTrainLog, getConf, getId, getModel, getName, getTestLog, getTrainLog, isModelAvailable, setConf, setId, setModel, setModelAvailable, setModelFilename, setNameprotected static final org.slf4j.Logger log
protected Encoder bagEncoder
protected Encoder embEncoder
protected Encoder flagEncoder
protected Encoder targetEncoder
protected int workers
protected boolean requireSubsampling
protected ModelEvaluation eval
protected final org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor ff2rnn
public SectorTagger()
public SectorTagger(String id)
public SectorTagger(Resource modelPath)
public boolean isRequireSubsampling()
public void setRequireSubsampling(boolean requireSubsampling)
public void setInputEncoders(Encoder bagEncoder, Encoder embEncoder, Encoder flagEncoder)
public void setTargetEncoder(Encoder targetEncoder)
public SectorTagger setWorkspaceParams(int workers)
public List<Encoder> getEncoders()
getEncoders in interface IComponentgetEncoders in class Taggerpublic Encoder getTargetEncoder()
public void setEncoders(List<Encoder> encoders)
setEncoders in interface IComponentsetEncoders in class Taggerpublic SectorTagger buildSECTORModel(int ffwLayerSize, int lstmLayerSize, int embeddingLayerSize, int iterations, double learningRate, double dropout, org.nd4j.linalg.lossfunctions.ILossFunction lossFunc, org.nd4j.linalg.activations.Activation activation)
public void trainModel(Dataset dataset)
trainModel in class Taggerpublic void trainModel(Dataset dataset, int numEpochs)
public org.deeplearning4j.earlystopping.EarlyStoppingResult<org.deeplearning4j.nn.graph.ComputationGraph> trainModel(Dataset train, Dataset validation, org.deeplearning4j.earlystopping.EarlyStoppingConfiguration conf)
public void tag(Collection<Document> docs)
public Map<String,org.nd4j.linalg.api.ndarray.INDArray> encodeMatrix(DocumentSentenceIterator.DocumentBatch batch)
public static Map<String,org.nd4j.linalg.api.ndarray.INDArray> feedForward(org.deeplearning4j.nn.graph.ComputationGraph net, org.nd4j.linalg.dataset.api.MultiDataSet next)
protected void triggerEpochListeners(boolean epochStart,
int epochNum)
public void attachVectors(Collection<Document> docs, AbstractMultiDataSetIterator.Stage stage, Class<? extends Encoder> targetClass)
protected void attachVectors(DocumentSentenceIterator.DocumentBatch batch, Class<? extends Encoder> targetClass)
protected static void clearLayerStates(org.deeplearning4j.nn.graph.ComputationGraph net)
public void enableTrainingUI()
public void saveModel(Resource modelPath, String name)
saveModel in interface IComponentsaveModel in class TaggermodelPath - name - public void loadModel(Resource modelFile)
loadModel in interface IComponentloadModel in class Taggerpublic org.deeplearning4j.nn.conf.ComputationGraphConfiguration getGraphConfiguration()
getGraphConfiguration in class Taggerpublic void setGraphConfiguration(org.nd4j.shade.jackson.databind.JsonNode conf)
setGraphConfiguration in class TaggerCopyright © 2019. All rights reserved.