目標檢測領域沒有MNIST和Fashion-MNIST這樣的小型數據集。為了快速演示對象檢測模型,我們收集并標記了一個小型數據集。首先,我們從辦公室拍攝了免費香蕉的照片,并生成了 1000 張不同旋轉和大小的香蕉圖像。然后我們將每個香蕉圖像放置在一些背景圖像上的隨機位置。最后,我們為圖像上的那些香蕉標記了邊界框。
14.6.2。讀取數據集
我們將在 read_data_bananas
下面的函數中讀取香蕉檢測數據集。數據集包括一個 csv 文件,用于對象類標簽和左上角和右下角的地面實況邊界框坐標。
#@save
def read_data_bananas(is_train=True):
"""Read the banana detection dataset images and labels."""
data_dir = d2l.download_extract('banana-detection')
csv_fname = os.path.join(data_dir, 'bananas_train' if is_train
else 'bananas_val', 'label.csv')
csv_data = pd.read_csv(csv_fname)
csv_data = csv_data.set_index('img_name')
images, targets = [], []
for img_name, target in csv_data.iterrows():
images.append(torchvision.io.read_image(
os.path.join(data_dir, 'bananas_train' if is_train else
'bananas_val', 'images', f'{img_name}')))
# Here `target` contains (class, upper-left x, upper-left y,
# lower-right x, lower-right y), where all the images have the same
# banana class (index 0)
targets.append(list(target))
return images, torch.tensor(targets).unsqueeze(1) / 256
#@save
def read_data_bananas(is_train=True):
"""Read the banana detection dataset images and labels."""
data_dir = d2l.download_extract('banana-detection')
csv_fname = os.path.join(data_dir, 'bananas_train' if is_train
else 'bananas_val', 'label.csv')
csv_data = pd.read_csv(csv_fname)
csv_data = csv_data.set_index('img_name')
images, targets = [], []
for img_name, target in csv_data.iterrows():
images.append(image.imread(
os.path.join(data_dir, 'bananas_train' if is_train else
'bananas_val', 'images', f'{img_name}')))
# Here `target` contains (class, upper-left x, upper-left y,
# lower-right x, lower-right y), where all the images have the same
# banana class (index 0)
targets.append(list(target))
return images, np.expand_dims(np.array(targets), 1) / 256
通過使用read_data_bananas
函數讀取圖像和標簽,下面的BananasDataset
類將允許我們創建一個自定義Dataset
實例來加載香蕉檢測數據集。
#@save
class BananasDataset(torch.utils.data.Dataset):
"""A customized dataset to load the banana detection dataset."""
def __init__(self, is_train):
self.features, self.labels = read_data_bananas(is_train)
print('read ' + str(len(self.features)) + (f' training examples' if
is_train else f' validation examples'))
def __getitem__(self, idx):
return (self.features[idx].float(), self.labels[idx])
def __len__(self):
return len(self.features)
#@save
class BananasDataset(gluon.data.Dataset):
"""A customized dataset to load the banana detection dataset."""
def __init__(self, is_train):
self.features, self.labels = read_data_bananas(is_train)
print('read ' + str(len(self.features)) + (f' training examples' if
is_train else f' validation examples'))
def __getitem__(self, idx):
return (self.features[idx].astype('float32').transpose(2, 0, 1),
self.labels[idx])
def __len__(self):
return len(self.features)
最后,我們定義load_data_bananas
函數為訓練集和測試集返回兩個數據迭代器實例。對于測試數據集,不需要隨機讀取。
讓我們讀取一個 minibatch 并打印這個 minibatch 中圖像和標簽的形狀。圖像小批量的形狀(批量大小、通道數、高度、寬度)看起來很熟悉:它與我們之前的圖像分類任務相同。label minibatch的shape是(batch size,m, 5), 其中m是任何圖像在數據集中具有的最大可能數量的邊界框。
雖然 minibatch 的計算效率更高,但它要求所有圖像示例都包含相同數量的邊界框,以通過連接形成一個 minibatch。通常,圖像可能具有不同數量的邊界框;因此,圖像少于m 邊界框將被非法邊界框填充,直到 m到達了。然后每個邊界框的標簽用一個長度為5的數組表示,數組的第一個元素是邊界框中物體的類,其中-1表示填充的非法邊界框。數組的其余四個元素是 (x,y)-邊界框左上角和右下角的坐標值(范圍在0到1之間)。對于香蕉數據集,由于每張圖像上只有一個邊界框,我們有m=1.
read 1000 training examples
read 100 validation examples
(torch.Size([32, 3, 256, 256]), torch.Size([32, 1, 5]))
14.6.3。示范
讓我們演示十張帶有標記的真實邊界框的圖像。我們可以看到香蕉的旋轉、大小和位置在所有這些圖像中都不同。當然,這只是一個簡單的人工數據集。實際上,真實世界的數據集通常要復雜得多。
imgs = (batch[0][:10].permute(0, 2, 3, 1)) / 255
axes = d2l.show_images(imgs, 2, 5, scale=2)
for ax, label in zip(axes, batch[1][:10]):
d2l.show_bboxes(ax, [label[0
評論
查看更多