Loading...
墨滴

希仔

2021/04/11  阅读:21  主题:默认主题

tf1-movie_review

Sentiment Analysis

This notebook trains a sentiment analysis model to classify movie reviews as positive or negative, based on the text of the review. This is an example of binary—or two-class—classification, an important and widely applicable kind of machine learning problem.

We'll use the Large Movie Review Dataset that contains the text of 50,000 movie reviews from the Internet Movie Database. These are split into 25,000 reviews for training and 25,000 reviews for testing. The training and testing sets are balanced, meaning they contain an equal number of positive and negative reviews.

Specify the version of tensorflow: 1.x

%tensorflow_version 1.x
import tensorflow as tf
print(tf.__version__)
TensorFlow 1.x selected.
1.15.2
from tensorflow.contrib import learn
import numpy as np
from tensorflow.python.ops.rnn import static_rnn
from tensorflow.python.ops.rnn_cell_impl import BasicLSTMCell
from tensorflow import keras

Download the dataset

imdb = keras.datasets.imdb

(train_data, train_labels), (test_data, test_labels) = imdb.load_data(num_words=10000)
print(train_data.shape, test_data.shape)
print(train_labels.shape,test_labels.shape)
print(train_data[0])
print(train_labels[0])

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb.npz
17465344/17464789 [==============================] - 0s 0us/step


/tensorflow-1.15.2/python3.7/tensorflow_core/python/keras/datasets/imdb.py:129: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray
  x_train, y_train = np.array(xs[:idx]), np.array(labels[:idx])


(25000,) (25000,)
(25000,) (25000,)
[1, 14, 22, 16, 43, 530, 973, 1622, 1385, 65, 458, 4468, 66, 3941, 4, 173, 36, 256, 5, 25, 100, 43, 838, 112, 50, 670, 2, 9, 35, 480, 284, 5, 150, 4, 172, 112, 167, 2, 336, 385, 39, 4, 172, 4536, 1111, 17, 546, 38, 13, 447, 4, 192, 50, 16, 6, 147, 2025, 19, 14, 22, 4, 1920, 4613, 469, 4, 22, 71, 87, 12, 16, 43, 530, 38, 76, 15, 13, 1247, 4, 22, 17, 515, 17, 12, 16, 626, 18, 2, 5, 62, 386, 12, 8, 316, 8, 106, 5, 4, 2223, 5244, 16, 480, 66, 3785, 33, 4, 130, 12, 16, 38, 619, 5, 25, 124, 51, 36, 135, 48, 25, 1415, 33, 6, 22, 12, 215, 28, 77, 52, 5, 14, 407, 16, 82, 2, 8, 4, 107, 117, 5952, 15, 256, 4, 2, 7, 3766, 5, 723, 36, 71, 43, 530, 476, 26, 400, 317, 46, 7, 4, 2, 1029, 13, 104, 88, 4, 381, 15, 297, 98, 32, 2071, 56, 26, 141, 6, 194, 7486, 18, 4, 226, 22, 21, 134, 476, 26, 480, 5, 144, 30, 5535, 18, 51, 36, 28, 224, 92, 25, 104, 4, 226, 65, 16, 38, 1334, 88, 12, 16, 283, 5, 16, 4472, 113, 103, 32, 15, 16, 5345, 19, 178, 32]
1


/tensorflow-1.15.2/python3.7/tensorflow_core/python/keras/datasets/imdb.py:130: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray
  x_test, y_test = np.array(xs[idx:]), np.array(labels[idx:])

Data preprocessing

1. make all setences of same length

train_data = keras.preprocessing.sequence.pad_sequences(train_data,
                                                        value=0,
                                                        padding='post',
                                                        maxlen=256)

test_data = keras.preprocessing.sequence.pad_sequences(test_data,
                                                       value=0,
                                                       padding='post',
                                                       maxlen=256)
len(train_data[0]), len(train_data[1])
(256, 256)
print(train_data[:2])
[[   1   14   22   16   43  530  973 1622 1385   65  458 4468   66 3941
     4  173   36  256    5   25  100   43  838  112   50  670    2    9
    35  480  284    5  150    4  172  112  167    2  336  385   39    4
   172 4536 1111   17  546   38   13  447    4  192   50   16    6  147
  2025   19   14   22    4 1920 4613  469    4   22   71   87   12   16
    43  530   38   76   15   13 1247    4   22   17  515   17   12   16
   626   18    2    5   62  386   12    8  316    8  106    5    4 2223
  5244   16  480   66 3785   33    4  130   12   16   38  619    5   25
   124   51   36  135   48   25 1415   33    6   22   12  215   28   77
    52    5   14  407   16   82    2    8    4  107  117 5952   15  256
     4    2    7 3766    5  723   36   71   43  530  476   26  400  317
    46    7    4    2 1029   13  104   88    4  381   15  297   98   32
  2071   56   26  141    6  194 7486   18    4  226   22   21  134  476
    26  480    5  144   30 5535   18   51   36   28  224   92   25  104
     4  226   65   16   38 1334   88   12   16  283    5   16 4472  113
   103   32   15   16 5345   19  178   32    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0]
 [   1  194 1153  194 8255   78  228    5    6 1463 4369 5012  134   26
     4  715    8  118 1634   14  394   20   13  119  954  189  102    5
   207  110 3103   21   14   69  188    8   30   23    7    4  249  126
    93    4  114    9 2300 1523    5  647    4  116    9   35 8163    4
   229    9  340 1322    4  118    9    4  130 4901   19    4 1002    5
    89   29  952   46   37    4  455    9   45   43   38 1543 1905  398
     4 1649   26 6853    5  163   11 3215    2    4 1153    9  194  775
     7 8255    2  349 2637  148  605    2 8003   15  123  125   68    2
  6853   15  349  165 4362   98    5    4  228    9   43    2 1157   15
   299  120    5  120  174   11  220  175  136   50    9 4373  228 8255
     5    2  656  245 2350    5    4 9837  131  152  491   18    2   32
  7464 1212   14    9    6  371   78   22  625   64 1382    9    8  168
   145   23    4 1690   15   16    4 1355    5   28    6   52  154  462
    33   89   78  285   16  145   95    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0]]

Model

max_document_length = 256
vocab_size = 10000
embedding_size = 50
num_classes = 2
# place holders for input data and labels
datas_placeholder = tf.placeholder(tf.int32, [None, max_document_length])
labels_placeholder = tf.placeholder(tf.int32, [None])

# word embedding
embeddings_ = tf.get_variable("embeddings_", [vocab_size, embedding_size], initializer=tf.truncated_normal_initializer)

# mapping the input: [None, max_document_length] => [None, max_document_length, embedding_size]
embedded = tf.nn.embedding_lookup(embeddings_, datas_placeholder)

# Converted to the input format of LSTM:the requirement is an array, each element of the array represents a batch of data of a certain timestamp
rnn_input = tf.unstack(embedded, max_document_length, axis=1)
# LSTM
lstm_cell = BasicLSTMCell(20, forget_bias=1.0)
rnn_outputs, rnn_states = static_rnn(lstm_cell, rnn_input, dtype=tf.float32)

# use LSTM to predict on the final output
logits = tf.layers.dense(rnn_outputs[-1], num_classes)
predicted_labels = tf.argmax(logits, axis=1)

# loss and optimizer
losses= tf.nn.softmax_cross_entropy_with_logits(
    labels=tf.one_hot(labels_placeholder, num_classes),
    logits=logits
)
mean_loss = tf.reduce_mean(losses)
optimizer = tf.train.AdamOptimizer(learning_rate=1e-2).minimize(mean_loss)
WARNING:tensorflow:From <ipython-input-9-07958f873640>:2: BasicLSTMCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This class is equivalent as tf.keras.layers.LSTMCell, and will be replaced by that in Tensorflow 2.0.
WARNING:tensorflow:From <ipython-input-9-07958f873640>:3: static_rnn (from tensorflow.python.ops.rnn) is deprecated and will be removed in a future version.
Instructions for updating:
Please use `keras.layers.RNN(cell, unroll=True)`, which is equivalent to this API
WARNING:tensorflow:From /tensorflow-1.15.2/python3.7/tensorflow_core/python/ops/rnn_cell_impl.py:735: Layer.add_variable (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
Please use `layer.add_weight` method instead.
WARNING:tensorflow:From /tensorflow-1.15.2/python3.7/tensorflow_core/python/ops/rnn_cell_impl.py:739: calling Zeros.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
WARNING:tensorflow:From <ipython-input-9-07958f873640>:6: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.
Instructions for updating:
Use keras.layers.Dense instead.
WARNING:tensorflow:From /tensorflow-1.15.2/python3.7/tensorflow_core/python/layers/core.py:187: Layer.apply (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
Please use `layer.__call__` method instead.
WARNING:tensorflow:From <ipython-input-9-07958f873640>:12: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.
Instructions for updating:

Future major versions of TensorFlow will allow gradients to flow
into the labels input on backprop by default.

See `tf.nn.softmax_cross_entropy_with_logits_v2`.

compute accuracy

correct_prediction =  tf.equal(predicted_labels,tf.cast(labels_placeholder,dtype=tf.int64))
accuarcy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

Model Training

with tf.Session() as sess:
    # init variables
    sess.run(tf.global_variables_initializer())

    # feed data
    feed_dict_training = {
        datas_placeholder: train_data,
        labels_placeholder: train_labels
    }

    feed_dict_test = {
        datas_placeholder: test_data,
        labels_placeholder: test_labels
    }

    print("Start Training")
    for step in range(100):
        _, mean_loss = sess.run([optimizer, mean_loss], feed_dict=feed_dict_training)


        acc = sess.run(accuarcy, feed_dict = feed_dict_test)
        print("Epoch = {}\tmean loss = {}\tacc_val = {}".format(step+1, mean_loss_val,acc))    

Start Training
Epoch = 1	mean loss = 0.7018250226974487	acc_val = 0.5061600208282471
Epoch = 2	mean loss = 0.8433809876441956	acc_val = 0.5053200125694275
Epoch = 3	mean loss = 0.6917682886123657	acc_val = 0.5091599822044373
Epoch = 4	mean loss = 0.707111120223999	acc_val = 0.5129200220108032
Epoch = 5	mean loss = 0.6967856287956238	acc_val = 0.5148400068283081
Epoch = 6	mean loss = 0.684948205947876	acc_val = 0.5206000208854675
Epoch = 7	mean loss = 0.6836138963699341	acc_val = 0.522599995136261
Epoch = 8	mean loss = 0.6858319044113159	acc_val = 0.5242400169372559
Epoch = 9	mean loss = 0.684241771697998	acc_val = 0.5269200205802917
Epoch = 10	mean loss = 0.6788839101791382	acc_val = 0.5261600017547607
Epoch = 11	mean loss = 0.6737165451049805	acc_val = 0.5271599888801575
Epoch = 12	mean loss = 0.6746671795845032	acc_val = 0.5289599895477295
Epoch = 13	mean loss = 0.6716814041137695	acc_val = 0.5308799743652344
Epoch = 14	mean loss = 0.6670898199081421	acc_val = 0.5317999720573425
Epoch = 15	mean loss = 0.6637671589851379	acc_val = 0.5389999747276306
Epoch = 16	mean loss = 0.6608981490135193	acc_val = 0.5406399965286255
Epoch = 17	mean loss = 0.6578797698020935	acc_val = 0.5421199798583984
Epoch = 18	mean loss = 0.6545631885528564	acc_val = 0.5432000160217285
Epoch = 19	mean loss = 0.6508879065513611	acc_val = 0.5453600287437439
Epoch = 20	mean loss = 0.6468613147735596	acc_val = 0.5423600077629089
Epoch = 21	mean loss = 0.6425195336341858	acc_val = 0.543720006942749
Epoch = 22	mean loss = 0.6378709673881531	acc_val = 0.5448399782180786
Epoch = 23	mean loss = 0.632870078086853	acc_val = 0.5478799939155579
Epoch = 24	mean loss = 0.6274447441101074	acc_val = 0.5503600239753723
Epoch = 25	mean loss = 0.621570885181427	acc_val = 0.5526000261306763
Epoch = 26	mean loss = 0.6153502464294434	acc_val = 0.5549600124359131
Epoch = 27	mean loss = 0.6090404391288757	acc_val = 0.5598000288009644
Epoch = 28	mean loss = 0.6030464172363281	acc_val = 0.5671600103378296
Epoch = 29	mean loss = 0.5980799198150635	acc_val = 0.5714799761772156
Epoch = 30	mean loss = 0.5941561460494995	acc_val = 0.5737599730491638
Epoch = 31	mean loss = 0.5900091528892517	acc_val = 0.5759999752044678
Epoch = 32	mean loss = 0.5856810212135315	acc_val = 0.5705599784851074
Epoch = 33	mean loss = 0.5771172642707825	acc_val = 0.5710399746894836
Epoch = 34	mean loss = 0.5724096894264221	acc_val = 0.5867199897766113
Epoch = 35	mean loss = 0.5676168203353882	acc_val = 0.5832800269126892
Epoch = 36	mean loss = 0.5637449026107788	acc_val = 0.5854399800300598
Epoch = 37	mean loss = 0.557736873626709	acc_val = 0.5997599959373474
Epoch = 38	mean loss = 0.5516165494918823	acc_val = 0.5705199837684631
Epoch = 39	mean loss = 0.543196976184845	acc_val = 0.5724800229072571
Epoch = 40	mean loss = 0.531811535358429	acc_val = 0.6820399761199951
Epoch = 41	mean loss = 0.5182607173919678	acc_val = 0.7416800260543823
Epoch = 42	mean loss = 0.5001638531684875	acc_val = 0.7403600215911865
Epoch = 43	mean loss = 0.4822198152542114	acc_val = 0.7567600011825562
Epoch = 44	mean loss = 0.47088053822517395	acc_val = 0.7705600261688232
Epoch = 45	mean loss = 0.46142587065696716	acc_val = 0.7749199867248535
Epoch = 46	mean loss = 0.4502145051956177	acc_val = 0.7807999849319458
Epoch = 47	mean loss = 0.4386557340621948	acc_val = 0.7857599854469299
Epoch = 48	mean loss = 0.42622530460357666	acc_val = 0.769320011138916
Epoch = 49	mean loss = 0.4302798807621002	acc_val = 0.7838000059127808
Epoch = 50	mean loss = 0.41005149483680725	acc_val = 0.7864000201225281
Epoch = 51	mean loss = 0.39099031686782837	acc_val = 0.7799599766731262
Epoch = 52	mean loss = 0.39037975668907166	acc_val = 0.7923600077629089
Epoch = 53	mean loss = 0.3648592233657837	acc_val = 0.7934399843215942
Epoch = 54	mean loss = 0.35888051986694336	acc_val = 0.7990400195121765
Epoch = 55	mean loss = 0.3380703926086426	acc_val = 0.8023999929428101
Epoch = 56	mean loss = 0.3224000930786133	acc_val = 0.8049600124359131
Epoch = 57	mean loss = 0.3076469898223877	acc_val = 0.8029199838638306
Epoch = 58	mean loss = 0.3019159734249115	acc_val = 0.8050400018692017
Epoch = 59	mean loss = 0.28536394238471985	acc_val = 0.807200014591217
Epoch = 60	mean loss = 0.27080798149108887	acc_val = 0.8105999827384949
Epoch = 61	mean loss = 0.26150450110435486	acc_val = 0.8113999962806702
Epoch = 62	mean loss = 0.25392431020736694	acc_val = 0.8105599880218506
Epoch = 63	mean loss = 0.24787749350070953	acc_val = 0.8091199994087219
Epoch = 64	mean loss = 0.24503426253795624	acc_val = 0.809719979763031
Epoch = 65	mean loss = 0.2484210729598999	acc_val = 0.8029599785804749
Epoch = 66	mean loss = 0.2573583424091339	acc_val = 0.8084800243377686
Epoch = 67	mean loss = 0.25407230854034424	acc_val = 0.8083199858665466
Epoch = 68	mean loss = 0.22917018830776215	acc_val = 0.8041599988937378
Epoch = 69	mean loss = 0.2343709021806717	acc_val = 0.8134400248527527
Epoch = 70	mean loss = 0.2283281683921814	acc_val = 0.8120800256729126
Epoch = 71	mean loss = 0.22515082359313965	acc_val = 0.7999200224876404
Epoch = 72	mean loss = 0.23339244723320007	acc_val = 0.8056399822235107
Epoch = 73	mean loss = 0.21400490403175354	acc_val = 0.817520022392273
Epoch = 74	mean loss = 0.20288775861263275	acc_val = 0.8111600279808044
Epoch = 75	mean loss = 0.2456928938627243	acc_val = 0.8085200190544128
Epoch = 76	mean loss = 0.20023442804813385	acc_val = 0.7849199771881104
Epoch = 77	mean loss = 0.28102487325668335	acc_val = 0.8102800250053406
Epoch = 78	mean loss = 0.2016381025314331	acc_val = 0.788640022277832
Epoch = 79	mean loss = 0.28656795620918274	acc_val = 0.8080800175666809
Epoch = 80	mean loss = 0.21067172288894653	acc_val = 0.8108000159263611
Epoch = 81	mean loss = 0.18842840194702148	acc_val = 0.7950800061225891
Epoch = 82	mean loss = 0.24920979142189026	acc_val = 0.8184800148010254
Epoch = 83	mean loss = 0.17711809277534485	acc_val = 0.8131600022315979
Epoch = 84	mean loss = 0.20713195204734802	acc_val = 0.8133999705314636
Epoch = 85	mean loss = 0.1933300793170929	acc_val = 0.8186799883842468
Epoch = 86	mean loss = 0.16914387047290802	acc_val = 0.8082000017166138
Epoch = 87	mean loss = 0.17878127098083496	acc_val = 0.8019999861717224
Epoch = 88	mean loss = 0.19565638899803162	acc_val = 0.8085200190544128
Epoch = 89	mean loss = 0.1741904467344284	acc_val = 0.817799985408783
Epoch = 90	mean loss = 0.16046839952468872	acc_val = 0.8190000057220459
Epoch = 91	mean loss = 0.15992751717567444	acc_val = 0.8195199966430664
Epoch = 92	mean loss = 0.16215763986110687	acc_val = 0.8206400275230408
Epoch = 93	mean loss = 0.15712639689445496	acc_val = 0.8241599798202515
Epoch = 94	mean loss = 0.14683911204338074	acc_val = 0.8224400281906128
Epoch = 95	mean loss = 0.14932098984718323	acc_val = 0.8205999732017517
Epoch = 96	mean loss = 0.14783647656440735	acc_val = 0.8217999935150146
Epoch = 97	mean loss = 0.1401595175266266	acc_val = 0.8216400146484375
Epoch = 98	mean loss = 0.13524788618087769	acc_val = 0.8198400139808655
Epoch = 99	mean loss = 0.1349489688873291	acc_val = 0.8183599710464478
Epoch = 100	mean loss = 0.13359633088111877	acc_val = 0.8202800154685974

希仔

2021/04/11  阅读:21  主题:默认主题

作者介绍

希仔