Начните с создания простой сети, чтобы проиллюстрировать последующие примеры.
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.conv1 = nn.Conv2d(3, 6, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(6)
self.conv2 = nn.Conv2d(6,8,kernel_size=3,padding=1)
self.relu = nn.ReLU(inplace=True)
def forward(self,x):
x = self.conv1(1)
x = self.bn1(x)
x = self.conv2(x)
x = self.relu(x)
return x
net = Net()
net.parameters(), вы можете получить параметры в конкретной модели сети:
for para in net.parameters():
print(para)
print(para.shape)
print()
'''
输出为:
Parameter containing:
tensor([[[[ 0.0794, 0.1070, 0.0415],
[ 0.0037, -0.0850, 0.0919],
[ 0.0039, 0.0899, -0.1446]],
[[-0.0642, 0.0251, -0.1055],
[ 0.1085, 0.0627, 0.0388],
[-0.0878, -0.1305, 0.1335]],
[[-0.0907, 0.0113, 0.1400],
[ 0.0051, -0.0605, -0.1085],
[ 0.0544, -0.0649, 0.0847]]],
[[[-0.1319, 0.0152, -0.0736],
[ 0.1796, 0.0857, 0.1668],
[ 0.0586, -0.1508, -0.1571]],
[[-0.1053, 0.0372, 0.1596],
[ 0.1509, 0.1125, -0.1773],
[ 0.0960, -0.0507, 0.0569]],
[[-0.0640, 0.0070, -0.1253],
[-0.1739, 0.0552, 0.1892],
[ 0.1232, -0.0811, 0.1263]]],
[[[ 0.0483, -0.1212, -0.0870],
[-0.0915, 0.0072, 0.1581],
[ 0.1184, -0.0907, 0.1109]],
[[-0.0024, 0.0980, -0.1080],
[-0.0311, -0.1013, -0.0581],
[ 0.1855, -0.0202, 0.0950]],
[[-0.1640, -0.0848, 0.0254],
[ 0.0318, 0.0538, 0.0277],
[ 0.0641, 0.0298, 0.0352]]],
[[[-0.0955, 0.0569, -0.0565],
[-0.1186, -0.0177, 0.0604],
[ 0.0305, -0.0398, -0.1165]],
[[-0.1532, 0.0179, 0.0317],
[ 0.0910, 0.1470, -0.1013],
[-0.0165, 0.0095, -0.0887]],
[[-0.0314, 0.1790, -0.1142],
[ 0.1710, -0.1628, 0.1342],
[-0.0781, 0.0194, -0.0568]]],
[[[-0.1903, -0.1659, -0.1797],
[ 0.1109, 0.0686, 0.1767],
[-0.0777, -0.0341, -0.1549]],
[[ 0.0615, -0.1309, -0.1492],
[ 0.1291, -0.1705, 0.1749],
[ 0.0173, -0.1587, 0.0072]],
[[-0.1669, -0.0803, 0.0378],
[ 0.1880, -0.0338, 0.1056],
[-0.0171, 0.0892, -0.0090]]],
[[[-0.1615, 0.1901, -0.1313],
[-0.0775, -0.0043, -0.0902],
[-0.0786, 0.0501, 0.0921]],
[[ 0.1332, 0.1698, 0.1657],
[ 0.0244, 0.0792, -0.1830],
[ 0.0519, -0.1610, 0.0821]],
[[-0.1437, 0.0229, -0.0810],
[-0.1200, 0.1311, 0.0776],
[ 0.0772, -0.0238, -0.0981]]]], requires_grad=True)
torch.Size([6, 3, 3, 3])
Parameter containing:
tensor([ 0.0599, -0.1511, -0.0591, 0.1000, 0.1050, 0.0743],
requires_grad=True)
torch.Size([6])
Parameter containing:
tensor([1., 1., 1., 1., 1., 1.], requires_grad=True)
torch.Size([6])
Parameter containing:
tensor([0., 0., 0., 0., 0., 0.], requires_grad=True)
torch.Size([6])
Parameter containing:
tensor([[[[-1.0620e-01, 5.6997e-02, -7.9542e-03],
[-6.6638e-02, -1.0529e-02, 1.3376e-01],
[ 7.1680e-02, 1.3388e-01, 1.2293e-01]],
[[ 9.2092e-02, 2.4215e-02, -1.2708e-01],
[ 1.9943e-03, -8.7654e-02, 1.0564e-01],
[-1.2967e-01, -1.2077e-01, -4.4365e-02]],
[[ 9.9798e-04, -7.9709e-02, 2.7571e-02],
[-1.4309e-02, 1.1243e-01, -1.1661e-01],
[ 7.5213e-02, 7.6132e-02, 1.4844e-02]],
[[ 1.2713e-01, -7.3697e-02, 9.4301e-02],
[ 7.7325e-02, 9.6845e-02, -1.0990e-01],
[ 6.2486e-02, 1.0107e-01, 3.0378e-02]],
[[-1.0599e-01, 2.7444e-02, -8.8193e-02],
[-1.0384e-01, 1.2580e-01, 4.1619e-02],
[ 1.3596e-01, -1.2098e-01, 8.2317e-02]],
[[-1.0979e-01, 9.2484e-02, -5.2828e-03],
[ 7.7915e-02, 6.0981e-02, 9.0634e-02],
[ 8.3001e-02, 7.1535e-02, -1.6206e-02]]],
[[[ 1.1561e-01, -2.1935e-02, -8.5694e-03],
[-4.9740e-03, -2.1594e-02, 9.7255e-02],
[ 1.2904e-01, 7.2028e-02, 9.6564e-02]],
[[-7.6498e-02, -1.2666e-01, -3.2563e-02],
[ 9.0076e-02, -8.3288e-02, 1.1785e-01],
[-4.3596e-02, 3.6950e-03, -5.0087e-02]],
[[-2.9787e-02, -5.2824e-02, -9.9231e-02],
[ 9.1963e-02, 7.7965e-02, 1.1397e-01],
[ 1.3667e-02, 1.1007e-01, -4.1288e-02]],
[[ 9.4790e-02, -6.8296e-02, -4.3310e-02],
[-6.3128e-02, 2.3350e-02, -6.3908e-02],
[-1.2005e-01, -6.2899e-02, -7.2392e-02]],
[[-1.1934e-01, -4.5716e-02, -5.7582e-02],
[ 8.1211e-06, 9.6752e-02, -4.1839e-02],
[ 9.9383e-02, -4.9952e-02, -4.1875e-02]],
[[ 1.0271e-01, -9.7970e-02, -2.5481e-02],
[ 1.2039e-01, 1.7195e-02, -2.2504e-02],
[ 6.3394e-02, -1.0446e-02, 9.7013e-02]]],
[[[-6.2230e-02, -8.0188e-02, -4.3593e-02],
[ 9.6622e-02, 7.5777e-02, 1.9751e-02],
[ 4.6756e-02, 8.1505e-02, 2.1734e-02]],
[[-4.0420e-02, -4.7027e-02, 2.7860e-02],
[-4.5530e-04, 1.0848e-01, -9.7263e-02],
[ 4.0441e-02, -2.3740e-03, -1.1751e-01]],
[[-1.0342e-01, 1.4509e-02, 3.5800e-02],
[-7.3109e-02, -4.4676e-02, 1.1477e-01],
[ 1.0436e-01, -1.1468e-01, 1.1279e-01]],
[[ 1.2757e-01, -5.4175e-02, 3.9229e-02],
[ 1.2238e-01, -4.1751e-02, 1.0329e-02],
[ 1.1175e-01, -1.3469e-01, 9.0738e-02]],
[[-1.2890e-01, 1.0985e-01, -3.5065e-02],
[-1.0353e-02, -1.1117e-01, -1.0932e-01],
[ 2.3825e-02, -5.1328e-02, 1.0952e-01]],
[[-1.2119e-01, -1.1721e-01, 3.9911e-02],
[-9.3294e-02, 3.6181e-02, -9.2453e-02],
[-1.0519e-01, 5.3727e-02, 4.4648e-03]]],
...,
[[[ 6.6163e-02, -1.0531e-01, -1.0589e-01],
[ 7.9671e-02, -3.3005e-02, -1.0760e-01],
[ 1.4868e-02, 1.4420e-02, -9.6573e-02]],
[[ 2.2414e-02, -1.5715e-02, 2.4232e-02],
[ 2.3479e-02, -8.7212e-02, -1.8911e-02],
[ 9.3712e-02, 1.0342e-01, 5.4269e-02]],
[[-9.8044e-02, 7.1834e-02, -1.0760e-01],
[-9.7597e-02, 9.9367e-02, -9.9010e-02],
[ 2.6155e-02, -1.3208e-01, 1.0316e-02]],
[[ 7.7097e-02, 1.0838e-01, 2.7527e-02],
[-4.3391e-02, 1.3416e-01, -1.1440e-01],
[-3.8224e-02, -2.7650e-03, -5.9436e-03]],
[[ 6.5886e-02, 1.1016e-02, -1.0989e-01],
[ 4.2206e-02, -9.2878e-02, 7.4586e-02],
[ 1.1299e-01, -1.1260e-01, -7.2581e-02]],
[[ 8.6093e-03, 3.0288e-02, 7.8243e-02],
[-6.7512e-03, -8.5671e-02, 8.3012e-02],
[-2.4528e-02, 1.7389e-02, 2.0112e-02]]],
[[[ 3.9985e-02, 6.4231e-03, 1.3579e-01],
[ 8.8007e-02, -1.8449e-02, 2.9483e-02],
[-5.8890e-02, 3.1275e-02, 1.1129e-01]],
[[ 9.9826e-02, -1.0343e-01, 1.7781e-02],
[-1.5528e-02, -1.2074e-01, -5.4819e-02],
[-8.1487e-02, 3.7535e-02, -6.7128e-02]],
[[-2.2612e-02, -4.7612e-02, -1.3335e-01],
[ 3.7972e-02, -1.2762e-01, 5.4009e-02],
[ 9.0579e-02, 5.4727e-02, -9.1461e-02]],
[[ 8.0858e-02, 1.4411e-03, -1.2739e-01],
[ 1.0097e-01, 8.3857e-02, -8.0914e-02],
[-1.9743e-02, 1.1509e-01, 8.2933e-02]],
[[-3.0184e-02, 1.0409e-01, 2.2486e-02],
[-7.8506e-02, -7.7744e-02, -2.8042e-02],
[-3.3265e-02, 9.1861e-02, 4.7874e-02]],
[[ 3.1688e-02, 1.2607e-01, 8.8575e-02],
[ 1.0217e-01, 2.8618e-02, 8.4546e-02],
[ 2.8103e-02, 1.2679e-01, 2.4444e-02]]],
[[[ 7.9484e-02, -1.1017e-02, -2.9063e-02],
[ 5.4235e-02, 1.1226e-01, -1.0663e-01],
[ 9.8365e-02, -2.1643e-02, 6.3686e-02]],
[[ 3.0368e-03, 1.2335e-03, 1.3460e-02],
[-5.6941e-02, -9.9266e-02, 3.3269e-02],
[ 8.6997e-02, 1.1879e-01, -1.2027e-02]],
[[ 3.4441e-02, 1.3346e-01, 1.4495e-03],
[ 6.1219e-02, 8.4678e-02, -4.3233e-02],
[ 1.3061e-01, -1.1880e-01, -1.2782e-01]],
[[ 3.4226e-02, 7.5535e-02, -7.4717e-02],
[ 8.2468e-03, -9.3862e-02, -5.3166e-02],
[ 1.3202e-01, 7.6724e-02, 6.3903e-02]],
[[-5.8022e-02, -7.8344e-02, -4.7197e-02],
[ 3.7977e-02, 8.6118e-02, 1.1670e-02],
[-1.3180e-01, -3.9207e-02, 1.3028e-01]],
[[-5.4157e-03, -7.3742e-02, 4.5027e-02],
[-2.8969e-02, -2.3086e-02, -3.3792e-02],
[ 7.5957e-02, 3.4847e-02, 1.3248e-01]]]], requires_grad=True)
torch.Size([32, 6, 3, 3])
Parameter containing:
tensor([-0.0801, 0.0075, 0.0469, -0.0886, 0.0583, -0.0399, -0.0551, 0.0094,
-0.0457, 0.1121, 0.0496, -0.0684, 0.1093, 0.0834, -0.0910, -0.1112,
-0.0711, -0.0641, -0.0981, 0.0356, 0.1234, -0.0284, 0.0813, 0.0188,
-0.0063, -0.0851, 0.1308, -0.0041, -0.0926, -0.0906, 0.1180, 0.0142],
requires_grad=True)
torch.Size([32])
'''
net.named_parameters() вернет две части, а именно имя атрибута в модели и соответствующее значение параметра:
for name, para in net.named_parameters():
print(name)
print(para)
print(para.shape)
print()
'''
输出例子:
conv1.weight
torch.Size([6, 3, 3, 3])
Parameter containing:
tensor([[[[ 0.0215, 0.1517, 0.1218],
[ 0.1887, -0.0702, 0.1366],
[-0.0947, 0.0794, -0.1096]],
[[-0.0045, 0.0683, -0.0814],
[ 0.0367, -0.0305, -0.1630],
[ 0.0413, 0.0197, 0.1726]],
[[ 0.0212, 0.1100, 0.0536],
[ 0.1513, 0.0163, 0.1070],
[-0.1378, -0.1698, 0.1431]]],
'''
Выход conv1.weight просто соответствует параметру веса в conv1 в self.conv1 = nn.Conv2d(3,6,kernal_size=1,padding=1), определенному выше.
Укажите, как параметр обновляется:
net = Net()
ignored_params = list( map(id, net.conv1.parameters()) )
base_params = filter ( lambda p: id(p) not in ignored_params, net.parameters() )
optimizer = torch.optim.SGD( [
{'params':base_params},
{'params':net.conv1.parameters(), 'lr':1e-3}
], lr = 1e-2, momentum=0.9 )
В {} вы можете указать метод обновления для некоторых параметров. Детали обновления, которые не установлены, могут быть указаны единообразно после ()