초간단 biLSTM 케라스 레이어 파이썬 치트코드

이번에는, 글로벌벡터를 불러오지 않고 그냥 Embedding()을 불러와서 임베딩을 시키는 레이어를 추가해서 훈련합니다. 글로벌임베딩을 쓰는 편이 정확도에는 보통 낫지만 특이한 단어가 많은 특정 도메인의 경우 직접 훈련시키거나 글로벌 임베딩의 값을 바꾸는 식으로 훈련합니다.

biLSTM example1
In [1]:
from keras.preprocessing import sequence
from keras.models import Sequential
from keras.layers import Dense, Dropout, Embedding, LSTM, Bidirectional
from keras.datasets import imdb
Using TensorFlow backend.
In [2]:
from keras.layers import CuDNNLSTM
In [3]:
import numpy as np
import pandas as pd
In [4]:
max_features = 20000
maxlen = 100
batch_size = 128

(train_x, train_y), (test_x, test_y) = imdb.load_data(num_words=max_features)
In [5]:
train_x = sequence.pad_sequences(train_x, maxlen=maxlen)
test_x = sequence.pad_sequences(test_x, maxlen=maxlen)
In [6]:
print('train_x shape:', train_x.shape)
print('test_x shape:', test_x.shape)
train_x shape: (25000, 100)
test_x shape: (25000, 100)
In [7]:
model = Sequential()
model.add(Embedding(max_features, 128, input_length = maxlen))
model.add(Bidirectional(CuDNNLSTM(64)))
model.add(Dropout(0.5))
model.add(Dense(1, activation = 'sigmoid'))
WARNING:tensorflow:From C:\Users\kohry\anaconda3\lib\site-packages\tensorflow\python\framework\op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
WARNING:tensorflow:From C:\Users\kohry\anaconda3\lib\site-packages\keras\backend\tensorflow_backend.py:3445: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
In [8]:
model.compile('adam', 'binary_crossentropy', metrics=['accuracy'])
In [ ]:
model.fit(train_x, train_y, batch_size = 128, epochs = 4, validation_data=[test_x, test_y])
WARNING:tensorflow:From C:\Users\kohry\anaconda3\lib\site-packages\tensorflow\python\ops\math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.cast instead.
Train on 25000 samples, validate on 25000 samples
Epoch 1/4
In [ ]:
 

댓글 남기기