自定义采样器¶
训练二分类数据集时,其正负样本数目相差巨大,同时要求每次批量128
个,其中32
个正样本,96
个负样本。除了使用加权随机采样外,还可以通过自定义采样器实现
自定义数据集¶
class CustomDataSet(Dataset):
def __init__(self):
"""
2类数据集,前950个元素是负样本,后50个元素是正样本
"""
self.data = list(range(1000))
def __getitem__(self, index: int):
return self.data[index], int(index > 950)
def __len__(self) -> int:
return len(self.data)
自定义采样器¶
class CustomSampler(Sampler):
def __init__(self) -> None:
self.idx_list = list(range(1000))
def __iter__(self):
"""
2分类数据集
每次批量处理128个数据,其中32个正样本,96个负样本
共进行7(或者8)次批量处理
"""
sampler_list = list()
for i in range(7):
sampler_list.extend(random.sample(self.idx_list[:950], 96))
sampler_list.extend(random.sample(self.idx_list[950:], 32))
return iter(sampler_list)
def __len__(self) -> int:
return 128 * 7
共批量采样7
次,每次从负样本中随机采集96
个样本,从正样本中随机采集32
个样本,保证每次批量数据符合训练要求
测试¶
实现如下:
if __name__ == '__main__':
data_set = CustomDataSet()
sampler = CustomSampler()
data_loader = DataLoader(data_set, batch_size=128, sampler=sampler, shuffle=False, num_workers=8)
print('epochs: %d' % len(data_loader))
for item in data_loader:
inputs, targets = item
print(inputs)
print(torch.sum(targets))
输出训练结果:
epochs: 7
tensor([344, 291, 336, 741, 218, 3, 38, 746, 17, 500, 734, 283, 855, 750,
292, 602, 647, 902, 375, 776, 328, 831, 447, 623, 348, 788, 792, 881,
264, 509, 493, 210, 60, 476, 520, 59, 326, 157, 643, 143, 607, 20,
251, 347, 323, 825, 578, 145, 942, 636, 482, 178, 926, 591, 135, 826,
820, 929, 113, 698, 72, 835, 778, 791, 937, 111, 556, 934, 170, 435,
471, 576, 311, 134, 872, 613, 313, 675, 771, 247, 949, 309, 421, 864,
774, 927, 194, 886, 179, 32, 138, 232, 316, 478, 255, 752, 963, 990,
992, 953, 996, 985, 966, 999, 994, 961, 972, 951, 980, 954, 988, 959,
981, 967, 976, 986, 977, 956, 969, 991, 960, 957, 997, 984, 989, 983,
970, 982])
tensor(32)
tensor([395, 611, 145, 649, 634, 38, 679, 683, 665, 42, 579, 777, 847, 372,
872, 906, 893, 834, 408, 625, 144, 909, 191, 217, 349, 44, 767, 99,
537, 13, 593, 403, 380, 352, 743, 889, 695, 412, 512, 162, 327, 801,
867, 258, 818, 629, 770, 234, 488, 785, 141, 303, 675, 632, 942, 708,
822, 300, 107, 504, 317, 286, 131, 239, 741, 486, 24, 299, 920, 235,
706, 605, 597, 904, 803, 417, 284, 656, 638, 929, 206, 383, 444, 464,
535, 154, 150, 795, 947, 291, 631, 732, 75, 771, 487, 887, 988, 974,
967, 989, 961, 991, 977, 970, 996, 954, 956, 984, 980, 975, 952, 964,
972, 960, 985, 998, 992, 983, 987, 976, 978, 990, 953, 965, 951, 966,
999, 963])
tensor(32)
tensor([432, 561, 691, 757, 337, 716, 642, 462, 917, 810, 657, 199, 632, 278,
393, 159, 274, 818, 389, 712, 411, 805, 779, 131, 611, 769, 109, 521,
413, 276, 750, 680, 674, 265, 867, 67, 751, 463, 250, 858, 322, 927,
775, 949, 422, 807, 192, 323, 499, 759, 81, 219, 118, 798, 826, 733,
386, 864, 875, 478, 298, 945, 613, 895, 835, 564, 237, 232, 676, 569,
221, 55, 28, 874, 547, 89, 268, 397, 519, 129, 253, 12, 217, 309,
87, 362, 246, 312, 765, 692, 51, 405, 602, 815, 392, 301, 974, 995,
966, 990, 959, 960, 954, 962, 965, 981, 952, 982, 996, 969, 977, 964,
986, 950, 979, 993, 978, 987, 970, 963, 951, 973, 983, 999, 953, 994,
958, 980])
tensor(32)
tensor([322, 922, 517, 447, 602, 563, 376, 253, 825, 934, 772, 168, 87, 684,
241, 312, 362, 449, 829, 639, 218, 739, 872, 48, 755, 46, 297, 686,
711, 628, 697, 38, 305, 518, 664, 583, 883, 381, 671, 339, 421, 696,
783, 47, 933, 191, 35, 345, 460, 899, 614, 747, 179, 879, 644, 18,
523, 430, 295, 145, 843, 636, 681, 214, 238, 80, 756, 753, 175, 142,
109, 271, 814, 412, 62, 427, 231, 67, 431, 44, 201, 651, 348, 718,
199, 156, 875, 39, 597, 942, 163, 346, 733, 654, 780, 207, 955, 956,
995, 982, 984, 993, 974, 985, 964, 980, 991, 972, 999, 975, 954, 957,
992, 970, 951, 950, 981, 969, 976, 986, 979, 967, 971, 961, 973, 966,
990, 958])
tensor(32)
tensor([589, 293, 425, 554, 323, 162, 301, 600, 805, 499, 172, 665, 568, 566,
797, 785, 245, 851, 332, 762, 570, 621, 871, 662, 404, 667, 616, 140,
590, 53, 946, 507, 423, 396, 101, 522, 571, 197, 781, 904, 424, 574,
795, 392, 731, 500, 278, 260, 312, 942, 151, 815, 888, 302, 923, 253,
615, 793, 833, 437, 109, 607, 399, 218, 913, 115, 732, 702, 131, 738,
884, 311, 835, 1, 111, 859, 749, 408, 420, 409, 484, 46, 13, 541,
877, 764, 920, 857, 750, 684, 909, 291, 588, 686, 864, 180, 966, 958,
973, 970, 952, 959, 962, 950, 981, 992, 989, 955, 960, 961, 975, 978,
994, 991, 953, 977, 995, 979, 998, 984, 967, 954, 983, 971, 976, 964,
951, 999])
tensor(32)
tensor([643, 433, 614, 244, 263, 479, 353, 7, 478, 507, 936, 848, 89, 880,
850, 192, 435, 762, 678, 726, 405, 55, 86, 591, 809, 426, 680, 237,
296, 688, 175, 529, 100, 797, 255, 946, 248, 724, 165, 61, 721, 108,
368, 424, 116, 775, 634, 466, 418, 874, 580, 562, 879, 459, 298, 687,
740, 603, 355, 271, 92, 899, 456, 905, 620, 528, 916, 3, 273, 441,
585, 416, 41, 903, 577, 682, 385, 284, 431, 297, 922, 67, 702, 758,
636, 14, 16, 73, 464, 605, 46, 885, 815, 746, 454, 587, 964, 963,
961, 981, 962, 970, 983, 976, 980, 959, 998, 991, 974, 972, 953, 979,
990, 973, 989, 960, 950, 988, 975, 977, 956, 957, 999, 958, 966, 986,
951, 996])
tensor(32)
tensor([ 11, 788, 27, 144, 516, 618, 386, 337, 174, 831, 482, 312, 916, 946,
371, 454, 852, 572, 763, 221, 240, 343, 395, 522, 199, 134, 643, 707,
559, 135, 737, 67, 658, 418, 150, 715, 351, 217, 40, 213, 623, 38,
205, 787, 360, 890, 676, 744, 160, 37, 264, 151, 530, 443, 871, 635,
689, 577, 861, 766, 729, 381, 156, 379, 704, 415, 432, 465, 246, 644,
497, 72, 885, 291, 357, 451, 214, 316, 655, 755, 362, 64, 856, 256,
382, 734, 446, 854, 815, 897, 869, 939, 889, 741, 912, 813, 980, 961,
995, 973, 991, 990, 954, 977, 988, 958, 952, 985, 997, 999, 976, 969,
974, 992, 965, 994, 956, 975, 978, 987, 951, 957, 962, 998, 970, 968,
972, 993])
tensor(32)