在這篇文章中,我們將利用 TensorFlow.js,D3.js 和 Web 的力量使訓練模型的過程可視化,以預測棒球數據中的壞球(藍色區域)和好球(橙色區域)。 隨著我們的進展,我們將模型在整個訓練過程中理解的打擊區域可視化。您可以通過訪問此 Observable 筆記本在瀏覽器中運行此模型。
注:Observable鏈接
https://beta.observablehq.com/@nkreeger/visualizing-ml-training-using-tensorflow-js-and-baseball-d
如果你不熟悉棒球的擊球區,這里有一篇詳細的文章。
上面的 GIF 可視化神經網絡學習調用壞球(藍色區域)和好球(橙色區域)在每個訓練步驟之后,熱圖會根據模型的預測進行更新
使用 Observable 直接在瀏覽器中運行此模型。
注:文章鏈接
https://beta.observablehq.com/@nkreeger/visualizing-ml-training-using-tensorflow-js-and-baseball-d
體育運動中的高級指標
當今的職業體育環境中充斥著大量的數據。這些數據被團隊,業余愛好者和粉絲應用于各種用例中。感謝像 TensorFlow 這樣的框架 - 這些數據集已準備好應用于機器學習。
美國職業棒球大聯盟先進媒體(MLBAM)的 PITCHf/x
美國職業棒球大聯盟先進媒體(MLBAM)發布了一個可供公眾研究的大型數據集。該數據集包含有關過去幾年在美國職業棒球大聯盟比賽中投擲的投球的傳感器信息。 利用這個數據集,我們已編寫了一個包含 5,000 個樣本的訓練集(2,500 個壞球和 2,500 個好球)。
以下是訓練數據中前幾個字段的示例:
注:示例鏈接
https://gist.github.com/nkreeger/01b5386b522b0cd1f22bc864320f3084#file-baseball-training-data-sample-csv
以下是針對打擊區域繪制的訓練數據的樣子。藍點標記為壞球,橙點標記為好球(此為大聯盟裁判員稱謂):
利用 TensorFlow.js 構建模型
TensorFlow.js 將機器學習引入 JavaScript 和 Web。 我們將利用這個很棒的框架來構建一個深度神經網絡模型。這個模型將能夠按大聯盟裁判的精準度來稱呼好球和壞球。
輸入 Input
該模型在 PITCHf / x 的以下字段中進行了訓練:
協調球越過本壘的位置('px'和'pz')。
擊球手站在壘的哪一側。
擊球區(擊球手的軀干)的高度,以英尺為單位。
擊球區底部的高度(擊球手的膝蓋)以英尺為單位。
裁判所稱的投球(好球或壞球)的實際標簽。
結構 Architecture
該模型將通過使用 TensorFlow.js 圖層 API 定義。Layers API 基于 Keras,對以前使用過該框架的人來說應該很熟悉:
1const model = tf.sequential();
2
3// Two fully connected layers with dropout between each:
4model.add(tf.layers.dense({units: 24, activation: 'relu', inputShape: [5]}));
5model.add(tf.layers.dropout({rate: 0.01}));
6model.add(tf.layers.dense({units: 16, activation: 'relu'}));
7model.add(tf.layers.dropout({rate: 0.01}));
8
9// Only two classes: "strike" and "ball":
10model.add(tf.layers.dense({units: 2, activation: 'softmax'}));
11
12model.compile({
13optimizer: tf.train.adam(0.01),
14loss: 'categoricalCrossentropy',
15metrics: ['accuracy']
16});
加載和準備數據
精選的訓練集可通過GitHub gist 獲得。需要下載此數據集才能開始將 CSV 數據轉換為 TensorFlow.js 用于訓練的格式。
注:GitHub gist 鏈接
https://gist.github.com/nkreeger/43edc6e6daecc2cb02a2dd3293a08f29
1const data = [];
2csvData.forEach((values) => {
3// 'logit' data uses the 5 fields:
4const x = [];
5x.push(parseFloat(values.px));
6x.push(parseFloat(values.pz));
7x.push(parseFloat(values.sz_top));
8x.push(parseFloat(values.sz_bot));
9x.push(parseFloat(values.left_handed_batter));
10// The label is simply 'is strike' or 'is ball':
11const y = parseInt(values.is_strike, 10);
12data.push({x: x, y: y});
13});
14// Shuffle the contents to ensure the model does not always train on the same
15// sequence of pitch data:
16tf.util.shuffle(data);
解析 CSV 數據后,需要將 JS 類型轉換為 Tensor 批次進行培訓和評估。有關此過程的詳細信息,請參閱代碼實驗室。TensorFlow.js 團隊正在開發一種新的 Data API,以便將來更容易獲取。
注:代碼實驗室
https://beta.observablehq.com/@nkreeger/visualizing-ml-training-using-tensorflow-js-and-baseball-d#batches
訓練模型
讓我們把這一切都整合在一起吧。定義了模型,準備好了訓練數據,現在我們已經準備好開始訓練了。以下異步方法訓練一批訓練樣本并更新熱圖:
1// Trains and reports loss+accuracy for one batch of training data:
2async function trainBatch(index) {
3const history = await model.fit(batches[index].x, batches[index].y, {
4epochs: 1,
5shuffle: false,
6validationData: [batches[index].x, batches[index].y],
7batchSize: CONSTANTS.BATCH_SIZE
8});
9
10// Don't block the UI frame by using tf.nextFrame()
11await tf.nextFrame();
12updateHeatmap();
13await tf.nextFrame();
14}
可視化模型的準確性
使用來自均勻放置在本壘板上方的 4 x 4 英尺柵格的預測矩陣來構建熱圖。在每個訓練步驟之后將該矩陣傳遞到模型中以檢查模型的準確度。使用 D3 庫將該預測的結果呈現為熱圖。
構建預測矩陣
熱圖中使用的預測矩陣從本壘板的中間開始,向左和向右各延伸 2 英尺。它的范圍也從本壘板的底部到 4 英尺高。擊打區樣本位于本壘板上方 1.5 至 3.5 英尺之間。下圖有助于讓這些 2d 窗格可視化:
該視覺顯示了打擊區域和預測矩陣與本壘板和游戲區域相關的位置
將預測矩陣與模型一起使用
每個批次在模型中訓練之后,預測矩陣被傳遞到模型中用以請求矩陣中的好球或壞球預測:
1function predictZone() {
2const predictions = model.predictOnBatch(predictionMatrix.data);
3const values = predictions.dataSync();
4
5// Sort each value so the higher prediction is the first element in the array:
6const results = [];
7let index = 0;
8for (let i = 0; i < values.length; i++) { ? ?
9let list = [];
10list.push({value: values[index++], strike: 0});
11list.push({value: values[index++], strike: 1});
12list = list.sort((a, b) => b.value - a.value);
13results.push(list);
14}
15return results;
16}
熱圖與 D3
現在可以使用 D3 顯示預測結果。 來自 50x50 網格中的每一個元素將在 SVG 中呈現為 10px x 10px 的矩形。每個矩形的顏色取決于預測結果(好球或者壞球)以及模型對該結果的確定程度(范圍從 50%-100%)。 以下代碼段顯示了如何從 D3 svg 矩形分組更新數據:
1function updateHeatmap() {
2rects.data(generateHeatmapData());
3rects
4.attr('x', (coord) => { return scaleX(coord.x) * CONSTANTS.HEATMAP_SIZE; })
5.attr('y', (coord) => { return scaleY(coord.y) * CONSTANTS.HEATMAP_SIZE; })
6.attr('width', CONSTANTS.HEATMAP_SIZE)
7.attr('height', CONSTANTS.HEATMAP_SIZE)
8.style('fill', (coord) => {
9if (coord.strike) {
10return strikeColorScale(coord.value);
11} else {
12return ballColorScale(coord.value);
13}
14});
15}
有關使用 D3 繪制熱圖的完整詳細信息,請參閱此部分。
注:此部分鏈接
https://beta.observablehq.com/@nkreeger/visualizing-ml-training-using-tensorflow-js-and-baseball-d#colorDomain
總結
網絡上有許多令人驚嘆的第三方庫和工具,可用于創建視覺效果。將這些與機器學習的強大功能與 TensorFlow.js 相結合,開發人員能夠創建一些非常新奇有趣的演示。
-
神經網絡
+關注
關注
42文章
4776瀏覽量
100935 -
機器學習
+關注
關注
66文章
8428瀏覽量
132827 -
tensorflow
+關注
關注
13文章
329瀏覽量
60557
原文標題:棒球比賽中是好球還是壞球?TensorFlow.js 已經知道
文章出處:【微信號:tensorflowers,微信公眾號:Tensorflowers】歡迎添加關注!文章轉載請注明出處。
發布評論請先 登錄
相關推薦
評論