自定义采样器

训练二分类数据集时,其正负样本数目相差巨大,同时要求每次批量128个,其中32个正样本,96个负样本。除了使用加权随机采样外,还可以通过自定义采样器实现

自定义数据集

class CustomDataSet(Dataset):
    def __init__(self):
        """
        2类数据集,前950个元素是负样本,后50个元素是正样本
        """
        self.data = list(range(1000))

    def __getitem__(self, index: int):
        return self.data[index], int(index > 950)

    def __len__(self) -> int:
        return len(self.data)

自定义采样器

class CustomSampler(Sampler):

    def __init__(self) -> None:
        self.idx_list = list(range(1000))

    def __iter__(self):
        """
        2分类数据集
        每次批量处理128个数据,其中32个正样本,96个负样本
        共进行7(或者8)次批量处理
        """
        sampler_list = list()
        for i in range(7):
            sampler_list.extend(random.sample(self.idx_list[:950], 96))
            sampler_list.extend(random.sample(self.idx_list[950:], 32))
        return iter(sampler_list)

    def __len__(self) -> int:
        return 128 * 7

共批量采样7次,每次从负样本中随机采集96个样本,从正样本中随机采集32个样本,保证每次批量数据符合训练要求

测试

实现如下:

if __name__ == '__main__':
    data_set = CustomDataSet()

    sampler = CustomSampler()

    data_loader = DataLoader(data_set, batch_size=128, sampler=sampler, shuffle=False, num_workers=8)
    print('epochs: %d' % len(data_loader))
    for item in data_loader:
        inputs, targets = item
        print(inputs)
        print(torch.sum(targets))

输出训练结果:

epochs: 7
tensor([344, 291, 336, 741, 218,   3,  38, 746,  17, 500, 734, 283, 855, 750,
        292, 602, 647, 902, 375, 776, 328, 831, 447, 623, 348, 788, 792, 881,
        264, 509, 493, 210,  60, 476, 520,  59, 326, 157, 643, 143, 607,  20,
        251, 347, 323, 825, 578, 145, 942, 636, 482, 178, 926, 591, 135, 826,
        820, 929, 113, 698,  72, 835, 778, 791, 937, 111, 556, 934, 170, 435,
        471, 576, 311, 134, 872, 613, 313, 675, 771, 247, 949, 309, 421, 864,
        774, 927, 194, 886, 179,  32, 138, 232, 316, 478, 255, 752, 963, 990,
        992, 953, 996, 985, 966, 999, 994, 961, 972, 951, 980, 954, 988, 959,
        981, 967, 976, 986, 977, 956, 969, 991, 960, 957, 997, 984, 989, 983,
        970, 982])
tensor(32)
tensor([395, 611, 145, 649, 634,  38, 679, 683, 665,  42, 579, 777, 847, 372,
        872, 906, 893, 834, 408, 625, 144, 909, 191, 217, 349,  44, 767,  99,
        537,  13, 593, 403, 380, 352, 743, 889, 695, 412, 512, 162, 327, 801,
        867, 258, 818, 629, 770, 234, 488, 785, 141, 303, 675, 632, 942, 708,
        822, 300, 107, 504, 317, 286, 131, 239, 741, 486,  24, 299, 920, 235,
        706, 605, 597, 904, 803, 417, 284, 656, 638, 929, 206, 383, 444, 464,
        535, 154, 150, 795, 947, 291, 631, 732,  75, 771, 487, 887, 988, 974,
        967, 989, 961, 991, 977, 970, 996, 954, 956, 984, 980, 975, 952, 964,
        972, 960, 985, 998, 992, 983, 987, 976, 978, 990, 953, 965, 951, 966,
        999, 963])
tensor(32)
tensor([432, 561, 691, 757, 337, 716, 642, 462, 917, 810, 657, 199, 632, 278,
        393, 159, 274, 818, 389, 712, 411, 805, 779, 131, 611, 769, 109, 521,
        413, 276, 750, 680, 674, 265, 867,  67, 751, 463, 250, 858, 322, 927,
        775, 949, 422, 807, 192, 323, 499, 759,  81, 219, 118, 798, 826, 733,
        386, 864, 875, 478, 298, 945, 613, 895, 835, 564, 237, 232, 676, 569,
        221,  55,  28, 874, 547,  89, 268, 397, 519, 129, 253,  12, 217, 309,
         87, 362, 246, 312, 765, 692,  51, 405, 602, 815, 392, 301, 974, 995,
        966, 990, 959, 960, 954, 962, 965, 981, 952, 982, 996, 969, 977, 964,
        986, 950, 979, 993, 978, 987, 970, 963, 951, 973, 983, 999, 953, 994,
        958, 980])
tensor(32)
tensor([322, 922, 517, 447, 602, 563, 376, 253, 825, 934, 772, 168,  87, 684,
        241, 312, 362, 449, 829, 639, 218, 739, 872,  48, 755,  46, 297, 686,
        711, 628, 697,  38, 305, 518, 664, 583, 883, 381, 671, 339, 421, 696,
        783,  47, 933, 191,  35, 345, 460, 899, 614, 747, 179, 879, 644,  18,
        523, 430, 295, 145, 843, 636, 681, 214, 238,  80, 756, 753, 175, 142,
        109, 271, 814, 412,  62, 427, 231,  67, 431,  44, 201, 651, 348, 718,
        199, 156, 875,  39, 597, 942, 163, 346, 733, 654, 780, 207, 955, 956,
        995, 982, 984, 993, 974, 985, 964, 980, 991, 972, 999, 975, 954, 957,
        992, 970, 951, 950, 981, 969, 976, 986, 979, 967, 971, 961, 973, 966,
        990, 958])
tensor(32)
tensor([589, 293, 425, 554, 323, 162, 301, 600, 805, 499, 172, 665, 568, 566,
        797, 785, 245, 851, 332, 762, 570, 621, 871, 662, 404, 667, 616, 140,
        590,  53, 946, 507, 423, 396, 101, 522, 571, 197, 781, 904, 424, 574,
        795, 392, 731, 500, 278, 260, 312, 942, 151, 815, 888, 302, 923, 253,
        615, 793, 833, 437, 109, 607, 399, 218, 913, 115, 732, 702, 131, 738,
        884, 311, 835,   1, 111, 859, 749, 408, 420, 409, 484,  46,  13, 541,
        877, 764, 920, 857, 750, 684, 909, 291, 588, 686, 864, 180, 966, 958,
        973, 970, 952, 959, 962, 950, 981, 992, 989, 955, 960, 961, 975, 978,
        994, 991, 953, 977, 995, 979, 998, 984, 967, 954, 983, 971, 976, 964,
        951, 999])
tensor(32)
tensor([643, 433, 614, 244, 263, 479, 353,   7, 478, 507, 936, 848,  89, 880,
        850, 192, 435, 762, 678, 726, 405,  55,  86, 591, 809, 426, 680, 237,
        296, 688, 175, 529, 100, 797, 255, 946, 248, 724, 165,  61, 721, 108,
        368, 424, 116, 775, 634, 466, 418, 874, 580, 562, 879, 459, 298, 687,
        740, 603, 355, 271,  92, 899, 456, 905, 620, 528, 916,   3, 273, 441,
        585, 416,  41, 903, 577, 682, 385, 284, 431, 297, 922,  67, 702, 758,
        636,  14,  16,  73, 464, 605,  46, 885, 815, 746, 454, 587, 964, 963,
        961, 981, 962, 970, 983, 976, 980, 959, 998, 991, 974, 972, 953, 979,
        990, 973, 989, 960, 950, 988, 975, 977, 956, 957, 999, 958, 966, 986,
        951, 996])
tensor(32)
tensor([ 11, 788,  27, 144, 516, 618, 386, 337, 174, 831, 482, 312, 916, 946,
        371, 454, 852, 572, 763, 221, 240, 343, 395, 522, 199, 134, 643, 707,
        559, 135, 737,  67, 658, 418, 150, 715, 351, 217,  40, 213, 623,  38,
        205, 787, 360, 890, 676, 744, 160,  37, 264, 151, 530, 443, 871, 635,
        689, 577, 861, 766, 729, 381, 156, 379, 704, 415, 432, 465, 246, 644,
        497,  72, 885, 291, 357, 451, 214, 316, 655, 755, 362,  64, 856, 256,
        382, 734, 446, 854, 815, 897, 869, 939, 889, 741, 912, 813, 980, 961,
        995, 973, 991, 990, 954, 977, 988, 958, 952, 985, 997, 999, 976, 969,
        974, 992, 965, 994, 956, 975, 978, 987, 951, 957, 962, 998, 970, 968,
        972, 993])
tensor(32)