Skip to content

Latest commit

 

History

History
172 lines (127 loc) · 8.61 KB

README.md

File metadata and controls

172 lines (127 loc) · 8.61 KB

Text classification using DIGITS and Torch7

Table of Contents

Introduction

This example follows the implementation in Crepe of the following paper:

Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances in Neural Information Processing Systems 28 (NIPS 2015)

This shows how to create a feed-forward convolutional neural network that is able to classify text with high accuracy. The network operates at the character level and does not require any feature engineering, beside converting characters to arbitrary numbers.

Dataset Creation

We will use the DBPedia ontology dataset. This dataset is available in .csv format on @zhangxiangxiao's Google Drive storage. Download the file dbpedia_csv.tar.gz and extract its contents into a folder which we will later refer to as $DBPEDIA.

The following sample is an example from the "company" class:

"E. D. Abbott Ltd"," Abbott of Farnham E D Abbott Limited was a British coachbuilding business based in Farnham Surrey trading under that name from 1929. A major part of their output was under sub-contract to motor vehicle manufacturers. Their business closed in 1972."

The first step to creating the dataset is to convert the .csv files to a format that DIGITS can use:

$ cd $DIGITS_ROOT/examples/text-classification
$ ./create_dataset.py $DBPEDIA/dbpedia_csv/train.csv dbpedia/train --labels $DBPEDIA/dbpedia_csv/classes.txt --create-images

This script parses train.csv in order to generate one sample per entry in the file. Every entry is converted into a 1024-long vector of bytes. Characters are converted using a very simple mapping: strings are first converted to lower case and each character is replaced with its index (note: indices start from number 2) in the extended alphabet (abc...xyz0...9 + a number of signs). Other characters (including those that are used for padding) are replaced with number 1. Note we have not implemented the backward quantization order mentioned in paragraph 2.2 of the paper. The script then proceeds to reshaping data into a 32x32 matrix and saves the matrix into an unencoded LMDB file. The above command additionally enables saving each sample into an actual image file. Image files are saved into sub-folders that are named after the sample's class. This makes it possible to have DIGITS proceed as if we were creating an image classification network. We will see later how that step may be skipped.

On the DIGITS homepage, click New Dataset > Images > Classification then:

  • change the image type to grayscale and the image height and width to 32,
  • point to the location of your dataset,
  • use 10% of samples for validation and 1% for testing,
  • make sure the image encoding is set to PNG (lossless),
  • give your dataset a name then click the "Create" button.

image classification dataset

Model Creation

If you haven't done so already, install the dpnn Lua package:

luarocks install dpnn

On the DIGITS homepage, click New Model > Images > Classification then:

  • select the dataset you just created,
  • set the Mean Subtraction method to "None",
  • select the "Custom Network" pane then click "Torch",
  • in the Custom Network field paste this network definition
  • give your model a name

Optionally, for better results:

  • set the number of training epochs to 15,
  • set the validation interval to 0.25,
  • click "Show advanced learning rate options",
  • set the learning rate policy to "Exponential Decay",
  • set Gamma to 0.98.

The model resembles a typical image classification convolutional neural network, with convolutional layers, max pooling, dropouts and a linear classifier. The main difference is that the each character is one-hot encoded into a vector and 1D (temporal) convolutions are used instead of 2D (spatial) convolutions.

When you are ready, click the "Create" button.

After a few hours of training, your network loss and accuracy may look like:

loss

Verification

At the bottom of the model page, select the model snapshot that achieved the best validation accuracy (this is not necessarily the last one). Then in the "Test a list of images" section, upload the test.txt file from your dataset job folder. This text file was created by DIGITS during dataset creation and is formatted in a way that allows DIGITS to extract the ground truth and compute accuracy and a confusion matrix. There you can also see Top-1 and Top-5 average accuracy, and per-class accuracy:

loss

Alternative Method

If you think creating image files to represent text is overkill, you might be interested in this: you can create LMDB files manually and use them in DIGITS directly. When you created the dataset with create_dataset.py, the script also created an LMDB database out of train.csv. You can use the same script to create another database out of test.csv (from DBPedia ontology dataset), for validation purpose:

./create_dataset.py $DBPEDIA/dbpedia_csv/test.csv dbpedia/val

On the DIGITS homepage, click New Dataset > Images > Other then:

  • in the "Images LMDB" column, select the paths to your train and validation databases, respectively (note that the labels are encoded in these databases so you don't need to specify an alternative database for labels)
  • give your dataset a name then click 'Create'.

Generic Dataset Creation

What difference does it make to use this alternative method? The main difference is that you do not need to create image files to create the dataset in DIGITS as you can pass LMDB files directly. This can save a significant amount of time. This also implies that you get more freedom in the data formats that you wish to use, as long as you stick to 3D blobs of data. You may for example choose to work with 16-bit or 32-bit data, or you may choose to work with blobs that have a non-standard or unsupported number of channels. There is a downside though: since DIGITS is not told that you are creating a classification model, DIGITS does not process the network outputs in any way. For classification models, DIGITS is able to extract the predicted class by identifying which class had the highest probability in the SoftMax layer. For generic ("other") models, DIGITS only shows the raw network output. Besides, quality metrics like accuracy of confusion matrices are not computed automatically for those models.

In order to create the model, on the DIGITS homepage, click New Model > Images > Other then proceed exactly as you did when creating the image classification model.

After training you can test samples using the "Test a Database" section. You just need to point to the location of an LMDB database, for example the validation database.

The following snapshot shows the first 5 inference outputs from the validation dataset:

Generic Inference

Each line shows the contents of the logSoftMax layer, for each sample. The LMDB key format used by create_dataset.py is (%d_%d) % (index,class): for the first item in the database, the key is 0_1, which means that the item is from class 1. You can see that the output takes its maximum at index 1 (indices starting from 1), therefore it was correctly classified. The predicted probability for class 0 is math.exp(-5.72204590e-06)=0.99999427797047 (a high degree of confidence).

You may also choose to use the REST API to download predictions in JSON format. This could be useful if you wish to implement any kind of post-processing of the data. In the below command, replace the job_id and db_path with your job ID and LMDB path respectively:

curl localhost:5000/models/images/generic/infer_db.json -XPOST -F job_id=20160414-040451-9cc5 -F db_path="/path/to/dbpedia/test" > predictions.txt

Running this command will dump inference data in a format similar to:

{
  "outputs": {
    "0_1": {
      "output": [
        -5.7220458984375e-06,
        -12.106060028076,
        -25.820121765137,
        -29.935920715332,
        -27.315780639648,
        -17.158786773682,
        -22.92654800415,
        -15.421851158142,
        -23.10737991333,
        -29.26469039917,
        -16.862657546997,
        -28.460214614868,
        -21.428464889526,
        -17.860265731812
      ]
    },
    ...