簡介
主要內(nèi)容包括
如何將文本處理為Tensorflow LSTM的輸入
如何定義LSTM
用訓(xùn)練好的LSTM進(jìn)行文本分類
代碼
導(dǎo)入相關(guān)庫
#coding=utf-8
importtensorflowastf
fromtensorflow.contribimportlearn
importnumpyasnp
fromtensorflow.python.ops.rnnimportstatic_rnn
fromtensorflow.python.ops.rnn_cell_implimportBasicLSTMCell
數(shù)據(jù)
# 數(shù)據(jù)
positive_texts=[
"我 今天 很 高興",
"我 很 開心",
"他 很 高興",
"他 很 開心"
]
negative_texts=[
"我 不 高興",
"我 不 開心",
"他 今天 不 高興",
"他 不 開心"
]
label_name_dict={
0:"正面情感",
1:"負(fù)面情感"
}
配置信息
配置信息
embedding_size=50
num_classes=2
將文本和label數(shù)值化
# 將文本和label數(shù)值化
all_texts=positive_texts+negative_textslabels=[0]*len(positive_texts)+[1]*len(negative_texts)
max_document_length=4
vocab_processor=learn.preprocessing.VocabularyProcessor(max_document_length)
datas=np.array(list(vocab_processor.fit_transform(all_texts)))
vocab_size=len(vocab_processor.vocabulary_)
定義placeholder(容器),存放輸入輸出
# 容器,存放輸入輸出
datas_placeholder=tf.placeholder(tf.int32, [None, max_document_length])
labels_placeholder=tf.placeholder(tf.int32, [None])
詞向量處理
# 詞向量表
embeddings=tf.get_variable("embeddings", [vocab_size, embedding_size],initializer=tf.truncated_normal_initializer)
# 將詞索引號轉(zhuǎn)換為詞向量[None, max_document_length] => [None, max_document_length, embedding_size]
embedded=tf.nn.embedding_lookup(embeddings, datas_placeholder)
將數(shù)據(jù)處理為LSTM的輸入格式
# 轉(zhuǎn)換為LSTM的輸入格式,要求是數(shù)組,數(shù)組的每個元素代表某個時間戳一個Batch的數(shù)據(jù)
rnn_input=tf.unstack(embedded, max_document_length,axis=1)
定義LSTM
# 定義LSTM
lstm_cell=BasicLSTMCell(20,forget_bias=1.0)
rnn_outputs, rnn_states=static_rnn(lstm_cell, rnn_input,dtype=tf.float32)
#利用LSTM最后的輸出進(jìn)行預(yù)測
logits=tf.layers.dense(rnn_outputs[-1], num_classes)
predicted_labels=tf.argmax(logits,axis=1)
定義損失和優(yōu)化器
# 定義損失和優(yōu)化器
losses=tf.nn.softmax_cross_entropy_with_logits(
labels=tf.one_hot(labels_placeholder, num_classes),
logits=logits
)
mean_loss=tf.reduce_mean(losses)
optimizer=tf.train.AdamOptimizer(learning_rate=1e-2).minimize(mean_loss)
執(zhí)行
withtf.Session()assess:
# 初始化變量
sess.run(tf.global_variables_initializer())
訓(xùn)練# 定義要填充的數(shù)據(jù)
feed_dict={
datas_placeholder: datas,
labels_placeholder: labels
}
print("開始訓(xùn)練")
forstepinrange(100):
_, mean_loss_val=sess.run([optimizer, mean_loss],feed_dict=feed_dict)
ifstep%10==0:
print("step ={}tmean loss ={}".format(step, mean_loss_val))
預(yù)測
print("訓(xùn)練結(jié)束,進(jìn)行預(yù)測")
predicted_labels_val=sess.run(predicted_labels,feed_dict=feed_dict)
fori, textinenumerate(all_texts):
label=predicted_labels_val[i]
label_name=label_name_dict[label]
print("{}=>{}".format(text, label_name))
審核編輯 黃昊宇
-
LSTM
+關(guān)注
關(guān)注
0文章
59瀏覽量
3767
發(fā)布評論請先 登錄
相關(guān)推薦
評論