【深度学习】深度学习中的单元测试

机器学习初学者

共 13484字,需浏览 27分钟

 · 2021-06-16




作者 | Manpreet Singh Minhas

编译 | VK
来源 | Towards Data Science

深度学习/机器学习工作流程通常不同于人们对正常软件开发过程的期望。但这并不意味着人们不应该从这些年来不断发展的软件开发中汲取灵感并进行实践。

在本文中,我将讨论单元测试以及为什么以及如何在代码中包含这些测试。我们将首先简要介绍单元测试,然后是一个深度学习中的单元测试示例,以及如何通过命令行和VS代码测试资源管理器运行这些测试。

介绍

单元测试是软件开发人员熟悉的概念。这是一种非常有用的技术,可以帮助你防止代码中出现明显的错误和bug。它包括测试源代码的各个单元,如函数、方法和类,以确定它们是否满足要求并具有预期的行为。

单元测试通常很小,执行起来不需要太多时间。测试的输入范围很广,通常包括边界和边缘情况。这些输入的输出通常由开发人员手动计算,以测试被测试单元的输出。

例如,对于加法器函数,我们将有如下测试用例。(稍后我们将看到一个深度学习的示例。)

你可以用正输入、零输入、负输入、正输入和负输入测试用例。

如果我们正在测试的函数/方法的输出与单元测试中为所有输入案例定义的输出相等,那么你的单元将通过测试,否则它将失败。你将确切地知道哪个测试用例失败。可以进一步调查,找出问题所在。

如果有多个开发人员正在处理一个大型项目。假设有人基于某些假设和数据大小编写了一段代码,而新的开发人员更改了代码库中不再满足这些假设的内容。那么代码肯定会失败。单元测试允许避免这种情况。

下面是单元测试的一些好处。

  • 强制你编写具有明确定义的输入和输出的模块化和可重用代码。因此,你的代码将更易于集成。

  • 提高了更改/维护代码的信心。它有助于识别代码更改引入的bug。

  • 提高了对单元本身的信心,因为如果它通过了单元测试,我们可以确定逻辑没有明显的错误,并且单元按预期运行。

  • 调试变得更容易,因为你可以知道哪个单元失败了,以及哪些特定的测试用例失败了。

Python中的单元测试

每种语言都有自己的工具和包可用于进行单元测试。Python还提供了一些单元测试框架。unittest包是标准Python库的一部分。

我将讨论如何通过命令行/bash和VS Code UI界面来使用这个框架。它的灵感来自JUnit,与其他语言中的主要单元测试框架有相似的风格。它支持测试自动化、共享测试的设置和关闭代码、将测试聚合到集合中以及独立于测试的报告框架[4]。

在这个框架中,单元测试的基本构建块是测试用例——必须设置并检查其正确性的场景。在unittest中,测试用例是unittest.TestCase。要生成测试用例,必须编写TestCase的子类。

TestCase实例的测试用例应该是自包含的,这样它可以单独运行,也可以与任何数量的其他测试用例任意组合运行。TestCase子类的测试方法应该在名称中有test前缀,并执行特定的测试代码。

为了执行测试,TestCase基类有几个assert方法,允许你对照被测试单元的输出检查测试用例的输出。如果测试失败,将引发异常并给出解释性消息,unittest将测试用例标识为失败。任何其他异常都将被视为错误。

有两种类型的setup方法可用于为测试设置类。

  1. setUp -这将在类中的每个测试方法之前调用。

  2. setUpClass-整个类只运行一次。这是你应该用来做深度学习测试的方法。在此方法中加载模型,以避免在执行每个测试方法之前重新加载模型。这将节省模型重新加载时间。

请注意,各种测试的运行顺序是通过根据字符串的内置顺序对测试方法名称进行排序来确定的。

现在让我们看看我为一个项目的PyTorch数据加载器而创建的单元测试。代码如下所示。

import unittest
from pathlib import Path

import torch
from PIL import Image
from segdataset import SegmentationDataset
from torch.utils.data import DataLoader
from torchvision import transforms


class Test_TestSegmentationDataset(unittest.TestCase):
    @classmethod
    def setUpClass(cls) -> None:
        seg_dataset = SegmentationDataset("CrackForest",
                                          "Images",
                                          "Masks",
                                          transforms=transforms.Compose(
                                              [transforms.ToTensor()]))
        seg_dataloader = DataLoader(seg_dataset,
                                    batch_size=4,
                                    shuffle=False,
                                    num_workers=8)
        cls.samples = next(iter(seg_dataloader))

    def test_image_tensor_dimensions(self):
        image_tensor_shape = Test_TestSegmentationDataset.samples[
            'image'].shape
        self.assertEqual(image_tensor_shape[0], 4)
        self.assertEqual(image_tensor_shape[1], 3)
        self.assertEqual(image_tensor_shape[2], 320)
        self.assertEqual(image_tensor_shape[3], 480)

    def test_mask_tensor_dimensions(self):
        mask_tensor_shape = Test_TestSegmentationDataset.samples['mask'].shape
        self.assertEqual(mask_tensor_shape[0], 4)
        self.assertEqual(mask_tensor_shape[1], 1)
        self.assertEqual(mask_tensor_shape[2], 320)
        self.assertEqual(mask_tensor_shape[3], 480)

    def test_mask_img_pair(self):
        ref_image_tensor = transforms.ToTensor()(Image.open(
            Path("CrackForest/Images/001.jpg")))
        ref_mask_tensor = transforms.ToTensor()(Image.open(
            Path("CrackForest/Masks/001_label.PNG")))
        datagen_image_tensor = Test_TestSegmentationDataset.samples['image'][0]
        datagen_mask_tensor = Test_TestSegmentationDataset.samples['mask'][0]
        self.assertTrue(torch.equal(ref_image_tensor, datagen_image_tensor))
        self.assertTrue(torch.equal(ref_mask_tensor, datagen_mask_tensor))
© 2021 GitHub, Inc.

被测试的分割数据集需要批量加载相应的图像和mask对。将正确的图像映射到正确的mask是至关重要的。

为此,通常,图像和mask的名称中都有相同的数字。如果你正在通过一些增强来调整图像的大小,那么你的结果大小应该与预期的一样。对于PyTorch,数据加载器返回的张量应该是BxCxHxW形式,其中B是批大小,C是通道数,H是高度,W是宽度。

现在,我来解释代码中发生了什么。我创建了一个从unittest.TestCase测试用例基类。如前所述,我创建了一个setUpClass方法,它是一个类方法,用于确保初始化只执行一次。

class Test_TestSegmentationDataset(unittest.TestCase):
    @classmethod
    def setUpClass(cls) -> None:
        seg_dataset = SegmentationDataset("CrackForest",
                                          "Images",
                                          "Masks",
                                          transforms=transforms.Compose(
                                              [transforms.ToTensor()]))
        seg_dataloader = DataLoader(seg_dataset,
                                    batch_size=4,
                                    shuffle=False,
                                    num_workers=8)
        cls.samples = next(iter(seg_dataloader))

这里需要注意的一点是,为了测试,我在dataloader中禁用了shuffle。因为我希望名称中带有001的映像和mask出现在dataloader创建的第一批的索引0中。

从不同的批次中检查不同的样本索引将是一个更好的测试,因为你将确保不同批次的顺序是一致的。我把第一批储存在cls作为类属性。

现在初始化完成了,我们来看看各个测试。

在第一个测试中,我检查dataloader返回的图像张量维度。因为我没有调整大小的图像,我希望大小为320x480和这些图像正在读取为RGB,所以应该有3个通道。在setUpClass方法中,我将批大小指定为4,因此张量的第一个维度应该是4。如果尺寸有问题,这个测试就会失败。

    def test_image_tensor_dimensions(self):
        image_tensor_shape = Test_TestSegmentationDataset.samples[
            'image'].shape
        self.assertEqual(image_tensor_shape[0], 4)
        self.assertEqual(image_tensor_shape[1], 3)
        self.assertEqual(image_tensor_shape[2], 320)
        self.assertEqual(image_tensor_shape[3], 480)

下一个测试是完全相同的,除了它是为mask张量。在这个特定的数据集中,mask只有一个通道。所以我希望通道数是1。批量大小应为4。mask形状应为320x480。

    def test_mask_tensor_dimensions(self):
        mask_tensor_shape = Test_TestSegmentationDataset.samples['mask'].shape
        self.assertEqual(mask_tensor_shape[0], 4)
        self.assertEqual(mask_tensor_shape[1], 1)
        self.assertEqual(mask_tensor_shape[2], 320)
        self.assertEqual(mask_tensor_shape[3], 480)

最后一个测试检查两件事。首先是通过手动应用dataloader中指定的变换获得的张量是否产生与dataloader相同的结果。其次是图像和mask对是正确的。

要直接应用torchvision变换,需要实例化transform并将图像作为输入传递给该实例。如果transform需要一个PIL图像或numpy数组(对于ToTensor就是这种情况),任何其他格式都会导致错误。

    def test_mask_img_pair(self):
        ref_image_tensor = transforms.ToTensor()(Image.open(
            Path("CrackForest/Images/001.jpg")))
        ref_mask_tensor = transforms.ToTensor()(Image.open(
            Path("CrackForest/Masks/001_label.PNG")))
        datagen_image_tensor = Test_TestSegmentationDataset.samples['image'][0]
        datagen_mask_tensor = Test_TestSegmentationDataset.samples['mask'][0]
        self.assertTrue(torch.equal(ref_image_tensor, datagen_image_tensor))
        self.assertTrue(torch.equal(ref_mask_tensor, datagen_mask_tensor))

现在我们已经准备好了unittest,让我们先看看如何通过命令行运行这个测试。

可以使用以下命令:

python -m unittest discover -s Tests -p "test_*"

一旦指定了搜索目录和搜索模式,Unittest就可以发现测试。

-s或--start directory directory:它指定开始发现目录。在我们的例子中,由于测试位于tests文件夹中,所以我们将该文件夹指定为该标志的值。

-p或--pattern:它指定匹配模式。我指定了一个自定义模式,只是为了向你展示这个功能是可用的。因为默认模式是test*.py,所以它在默认情况下适用于我们的测试脚本。

-v或--verbose:如果你指定这个值,你将获得测试类中每个测试方法的输出。

非详细输出和详细输出如下所示。如果所有的测试方法都通过了,那么最后会收到一条OK消息。

但是,如果任何一个测试方法失败,你将得到一条失败消息,其中指定了失败的测试。你会知道哪个断言失败了。如前所述,这对调试和查找破坏代码的原因非常有帮助。在本例中,我更改了正在读取的图像,但没有更改正在比较的张量,这导致了错误。

你可以将此测试执行行包含在任何自动批处理或bash文件中,这些文件可用于自动部署。例如,我们在GitHub操作中使用类似的测试,在更新版本自动推送到包存储库之前自动验证代码是否工作。

接下来,我将向你展示如何使用VS代码测试资源管理器通过UI运行这些测试。

在VS Code[3]中运行Python单元测试

在VS代码中,Python中的测试在默认情况下是禁用的。

要启用测试,请在命令Pallete上使用Python:configuretests命令。此命令提示你选择测试框架、包含测试的文件夹以及用于标识测试文件的模式。

最后两个输入与我们用于通过命令行运行单元测试的输入完全相同。Unittest框架不需要进一步安装。但是,如果你选择的框架包没有安装在你的环境中,VS代码会提示你安装它。

一旦发现被正确设置,我们将在VS代码活动栏中看到带有图标的测试资源管理器。测试资源管理器帮助你可视化、导航和运行测试。

你还可以在测试脚本中看到直接可用的运行测试和调试测试选项。你可以从该视图运行所有或单个测试,还可以导航到不同类中的单个测试方法。

如果测试失败,我会出现一个红色的十字而不是绿色的勾号。如果你想节省时间,你可以选择只运行失败的测试,而不是再次运行所有测试。

结论

本文结束了关于深度学习单元测试的文章。我们简要地了解了什么是单元测试以及它们的好处。

接下来,我们介绍了一个使用unittest包框架用PyTorch编写的数据加载器单元的实际示例。我们学习了如何通过命令行和Python测试资源管理器从VS代码运行这些测试。

我希望你开始为代码编写单元测试并从中获益!谢谢你阅读这篇文章。代码位于:https://github.com/msminhas93/deeplabv3finetunning

参考引用

[1]https://softwaretestingfundamentals.com/unit-testing/

[2]https://www.tutorialspoint.com/unittest_framework/unittest_framework_overview.htm

[3]https://code.visualstudio.com/docs/python/testing

[4]https://docs.python.org/3/library/unittest.html

[5]https://stackoverflow.com/questions/23667610/what-is-the-difference-between-setup-and-setupclass-in-python-unittest/23670844

往期精彩回顾





本站qq群851320808,加入微信群请扫码:

浏览 31
点赞
评论
收藏
分享

手机扫一扫分享

举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

举报