Quick, Draw!是一款游戲;在這個游戲中,玩家要接受一項挑戰:繪制幾個圖形,看看計算機能否識別玩家繪制的是什么。
Quick, Draw!的識別操作 由一個分類器執行,它接收用戶輸入(用 (x, y) 中的點筆畫序列表示),然后識別用戶嘗試涂鴉的圖形所屬的類別。
在本教程中,我們將展示如何針對此問題構建基于 RNN 的識別器。該模型將結合使用卷積層、LSTM 層和 softmax 輸出層對涂鴉進行分類:
上圖顯示了我們將在本教程中構建的模型的結構。輸入為一個涂鴉,用 (x, y, n) 中的點筆畫序列表示,其中 n 表示點是否為新筆畫的第一個點。
然后,模型將應用一系列一維卷積,接下來,會應用 LSTM 層,并將所有 LSTM 步的輸出之和饋送到 softmax 層,以便根據我們已知的涂鴉類別來決定涂鴉的分類。
本教程使用的數據來自真實的Quick, Draw!游戲,這些數據是公開提供的。此數據集包含 5000 萬幅涂鴉,涵蓋 345 個類別。
運行教程代碼
要嘗試本教程的代碼,請執行以下操作:
安裝 TensorFlow(如果尚未安裝的話)
下載教程代碼
下載數據(TFRecord格式),然后解壓縮。如需詳細了解如何獲取原始 Quick, Draw!數據以及如何將數據轉換為TFRecord文件,請參閱下文
使用以下命令執行教程代碼,以訓練本教程中所述的基于 RNN 的模型。請務必調整路徑,使其指向第 3 步中下載的解壓縮數據
python train_model.py \ --training_data=rnn_tutorial_data/training.tfrecord-?????-of-????? \ --eval_data=rnn_tutorial_data/eval.tfrecord-?????-of-????? \ --classes_file=rnn_tutorial_data/training.tfrecord.classes
教程詳情
下載數據
我們將本教程中要使用的數據放在了包含TFExamples的TFRecord文件中。您可以從以下位置下載這些數據:http://download.tensorflow.org/data/quickdraw_tutorial_dataset_v1.tar.gz(大約 1GB)。
或者,您也可以從 Google Cloud 下載ndjson格式的原始數據,并將這些數據轉換為包含TFExamples的TFRecord文件,如下一部分中所述。
可選:下載完整的 QuickDraw 數據
完整的Quick, Draw!數據集可在 Google Cloud Storage 上找到,此數據集是按類別劃分的ndjson文件。您可以在 Cloud Console 中瀏覽文件列表。
要下載數據,我們建議使用gsutil下載整個數據集。請注意,原始 .ndjson 文件需要下載約 22GB 的數據。
然后,使用以下命令檢查 gsutil 安裝是否成功以及您是否可以訪問數據存儲分區:
gsutil ls -r "gs://quickdraw_dataset/full/simplified/*"
系統會輸出一長串文件,如下所示:
gs://quickdraw_dataset/full/simplified/The Eiffel Tower.ndjsongs://quickdraw_dataset/full/simplified/The Great Wall of China.ndjsongs://quickdraw_dataset/full/simplified/The Mona Lisa.ndjsongs://quickdraw_dataset/full/simplified/aircraft carrier.ndjson...
之后,創建一個文件夾并在其中下載數據集。
mkdir rnn_tutorial_datacd rnn_tutorial_datagsutil -m cp "gs://quickdraw_dataset/full/simplified/*" .
下載過程需要花費一段時間,且下載的數據量略超 23GB。
可選:轉換數據
要將ndjson文件轉換為TFRecord文件(包含tf.train.Example樣本),請運行以下命令。
python create_dataset.py --ndjson_path rnn_tutorial_data \ --output_path rnn_tutorial_data
此命令會將數據存儲在TFRecord文件的 10 個分片中,每個類別有 10000 項用于訓練數據,有 1000 項用于評估數據。
下文詳細說明了該轉換過程。
原始 QuickDraw 數據的格式為ndjson文件,其中每行包含一個如下所示的 JSON 對象:
{"word":"cat","countrycode":"VE","timestamp":"2017-03-02 23:25:10.07453 UTC","recognized":true,"key_id":"5201136883597312","drawing":[ [ [130,113,99,109,76,64,55,48,48,51,59,86,133,154,170,203,214,217,215,208,186,176,162,157,132], [72,40,27,79,82,88,100,120,134,152,165,184,189,186,179,152,131,114,100,89,76,0,31,65,70] ],[ [76,28,7], [136,128,128] ],[ [76,23,0], [160,164,175] ],[ [87,52,37], [175,191,204] ],[ [174,220,246,251], [134,132,136,139] ],[ [175,255], [147,168] ],[ [171,208,215], [164,198,210] ],[ [130,110,108,111,130,139,139,119], [129,134,137,144,148,144,136,130] ],[ [107,106], [96,113] ]]}
在構建我們的分類器時,我們只關注 “word” 和 “drawing” 字段。在解析 ndjson 文件時,我們使用一個函數逐行處理它們,該函數可將drawing字段中的筆畫轉換為大小為[number of points, 3](包含連續點的差異)的張量。此函數還會以字符串形式返回類別名稱。
def parse_line(ndjson_line): """Parse an ndjson line and return ink (as np array) and classname.""" sample = json.loads(ndjson_line) class_name = sample["word"] inkarray = sample["drawing"] stroke_lengths = [len(stroke[0]) for stroke in inkarray] total_points = sum(stroke_lengths) np_ink = np.zeros((total_points, 3), dtype=np.float32) current_t = 0 for stroke in inkarray: for i in [0, 1]: np_ink[current_t:(current_t + len(stroke[0])), i] = stroke[i] current_t += len(stroke[0]) np_ink[current_t - 1, 2] = 1 # stroke_end # Preprocessing. # 1. Size normalization. lower = np.min(np_ink[:, 0:2], axis=0) upper = np.max(np_ink[:, 0:2], axis=0) scale = upper - lower scale[scale == 0] = 1 np_ink[:, 0:2] = (np_ink[:, 0:2] - lower) / scale # 2. Compute deltas. np_ink = np_ink[1:, 0:2] - np_ink[0:-1, 0:2] return np_ink, class_name
由于我們希望數據在寫入時進行隨機處理,因此我們以隨機順序從每個類別文件中讀取數據并寫入隨機分片。
對于訓練數據,我們讀取每個類別的前 10000 項;對于評估數據,我們讀取每個類別接下來的 1000 項。
然后,將這些數據變形為[num_training_samples, max_length, 3]形狀的張量。接下來,我們用屏幕坐標確定原始涂鴉的邊界框并標準化涂鴉的尺寸,使涂鴉具有單位高度。
最后,我們計算連續點之間的差異,并將它們存儲為VarLenFeature(位于tensorflow.Example中的ink鍵下)。另外,我們將class_index存儲為單一條目FixedLengthFeature,將ink的shape存儲為長度為 2 的FixedLengthFeature。
定義模型
要定義模型,我們需要創建一個新的Estimator。如需詳細了解 Estimator,建議您閱讀此教程。
要構建模型,我們需要執行以下操作:
將輸入調整回原始形狀,其中小批次通過填充達到其內容的最大長度。除了 ink 數據之外,我們還擁有每個樣本的長度和目標類別。這可通過函數_get_input_tensors實現
將輸入傳遞給_add_conv_layers中的一系列卷積層
將卷積的輸出傳遞到_add_rnn_layers中的一系列雙向 LSTM 層。最后,將每個時間步的輸出相加,針對輸入生成一個固定長度的緊湊嵌入
在_add_fc_layers中使用 softmax 層對此嵌入進行分類
代碼如下所示:
inks, lengths, targets = _get_input_tensors(features, targets)convolved = _add_conv_layers(inks)final_state = _add_rnn_layers(convolved, lengths)logits =_add_fc_layers(final_state)
_get_input_tensors
要獲得輸入特征,我們先從特征字典獲得形狀,然后創建大小為[batch_size](包含輸入序列的長度)的一維張量。ink 作為稀疏張量存儲在特征字典中,我們將其轉換為密集張量,然后變形為[batch_size, ?, 3]。最后,如果傳入目標,我們需要確保它們存儲為大小為[batch_size]的一維張量。
代碼如下所示:
shapes = features["shape"]lengths = tf.squeeze( tf.slice(shapes, begin=[0, 0], size=[params["batch_size"], 1]))inks = tf.reshape( tf.sparse_tensor_to_dense(features["ink"]), [params["batch_size"], -1, 3])if targets is not None: targets = tf.squeeze(targets)
_add_conv_layers
您可以通過params字典中的參數num_conv和conv_len配置所需的卷積層數量和過濾器長度。
輸入是一個每個點維數都是 3 的序列。我們將使用一維卷積,將 3 個輸入特征視為通道。這意味著輸入為[batch_size, length, 3]張量,而輸出為[batch_size, length, number_of_filters]張量。
convolved = inksfor i in range(len(params.num_conv)): convolved_input = convolved if params.batch_norm: convolved_input = tf.layers.batch_normalization( convolved_input, training=(mode == tf.estimator.ModeKeys.TRAIN)) # Add dropout layer if enabled and not first convolution layer. if i > 0 and params.dropout: convolved_input = tf.layers.dropout( convolved_input, rate=params.dropout, training=(mode == tf.estimator.ModeKeys.TRAIN)) convolved = tf.layers.conv1d( convolved_input, filters=params.num_conv[i], kernel_size=params.conv_len[i], activation=None, strides=1, padding="same", name="conv1d_%d" % i)return convolved, lengths
_add_rnn_layers
我們將卷積的輸出傳遞給雙向 LSTM 層,對此我們使用 contrib 的輔助函數。
outputs, _, _ = contrib_rnn.stack_bidirectional_dynamic_rnn( cells_fw=[cell(params.num_nodes) for _ in range(params.num_layers)], cells_bw=[cell(params.num_nodes) for _ in range(params.num_layers)], inputs=convolved, sequence_length=lengths, dtype=tf.float32, scope="rnn_classification")
請參閱代碼以了解詳情以及如何使用CUDA加速實現。
要創建一個固定長度的緊湊嵌入,我們需要將 LSTM 的輸出相加。我們首先將其中的序列不含數據的批次區域設為 0。
mask = tf.tile( tf.expand_dims(tf.sequence_mask(lengths, tf.shape(outputs)[1]), 2), [1, 1, tf.shape(outputs)[2]])zero_outside = tf.where(mask, outputs, tf.zeros_like(outputs))outputs = tf.reduce_sum(zero_outside, axis=1)
_add_fc_layers
將輸入的嵌入傳遞至全連接層,之后將此層用作 softmax 層。
tf.layers.dense(final_state, params.num_classes)
損失、預測和優化器
最后,我們需要添加一個損失函數、一個訓練操作和預測來創建ModelFn:
cross_entropy = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=targets, logits=logits))# Add the optimizer.train_op = tf.contrib.layers.optimize_loss( loss=cross_entropy, global_step=tf.train.get_global_step(), learning_rate=params.learning_rate, optimizer="Adam", # some gradient clipping stabilizes training in the beginning. clip_gradients=params.gradient_clipping_norm, summaries=["learning_rate", "loss", "gradients", "gradient_norm"])predictions = tf.argmax(logits, axis=1)return model_fn_lib.ModelFnOps( mode=mode, predictions={"logits": logits, "predictions": predictions}, loss=cross_entropy, train_op=train_op, eval_metric_ops={"accuracy": tf.metrics.accuracy(targets, predictions)})
訓練和評估模型
要訓練和評估模型,我們可以借助EstimatorAPI 的功能,并使用ExperimentAPI 輕松運行訓練和評估操作:
estimator = tf.estimator.Estimator( model_fn=model_fn, model_dir=output_dir, config=config, params=model_params) # Train the model. tf.contrib.learn.Experiment( estimator=estimator, train_input_fn=get_input_fn( mode=tf.contrib.learn.ModeKeys.TRAIN, tfrecord_pattern=FLAGS.training_data, batch_size=FLAGS.batch_size), train_steps=FLAGS.steps, eval_input_fn=get_input_fn( mode=tf.contrib.learn.ModeKeys.EVAL, tfrecord_pattern=FLAGS.eval_data, batch_size=FLAGS.batch_size), min_eval_frequency=1000)
請注意,本教程只是用一個相對較小的數據集進行簡單演示,目的是讓您熟悉遞歸神經網絡和 Estimator 的 API。如果在大型數據集上嘗試,這些模型可能會更強大。
當模型完成 100 萬個訓練步后,分數最高的候選項的準確率預計會達到 70% 左右。請注意,這種程度的準確率足以構建 Quick, Draw! 游戲,由于該游戲的動態特性,用戶可以在系統準備好識別之前調整涂鴉。此外,如果目標類別顯示的分數高于固定閾值,該游戲不會僅使用分數最高的候選項,而且會將某個涂鴉視為正確的涂鴉。
-
神經網絡
+關注
關注
42文章
4778瀏覽量
101009 -
識別器
+關注
關注
0文章
20瀏覽量
7619
原文標題:Quick, Draw! 涂鴉分類遞歸神經網絡
文章出處:【微信號:tensorflowers,微信公眾號:Tensorflowers】歡迎添加關注!文章轉載請注明出處。
發布評論請先 登錄
相關推薦
評論