CIFAR-10 model structure
data:image/s3,"s3://crabby-images/96183/961839ee70348401aaa021e2dcce85620f2cbd04" alt="在这里插入图片描述"
通过已知参数(高、宽、dilation=1、kernel_size)推断stride和padding的大小
data:image/s3,"s3://crabby-images/7f949/7f949e0d95319f9cae0e840d7c42f0b4f948856a" alt="在这里插入图片描述"
网络
import torch
from torch import nn
class Tudui(nn.Module):
def __init__(self):
super(Tudui, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2)
self.maxpool1 = nn.MaxPool2d(kernel_size=2)
self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2)
self.maxpool2 = nn.MaxPool2d(2)
self.conv3 = nn.Conv2d(32, 64, 5,padding=2)
self.maxpool3 = nn.MaxPool2d(2)
self.flatten = nn.Flatten()
self.linear1 = nn.Linear(1024, 64)
self.linear2 = nn.Linear(64, 10)
def forward(self, x):
x = self.conv1(x)
x = self.maxpool1(x)
x = self.conv2(x)
x = self.maxpool2(x)
x = self.conv3(x)
x = self.maxpool3(x)
x = self.flatten(x)
x = self.linear1(x)
x = self.linear2(x)
return x
tudui = Tudui()
print(tudui)
data:image/s3,"s3://crabby-images/55af7/55af79458497c5af7cfd68d7d95c6f6452e16b91" alt="在这里插入图片描述"
对网络进行检验
input = torch.ones((64, 3, 32, 32))
output = tudui(input)
print(output.shape)
data:image/s3,"s3://crabby-images/5ae17/5ae1734e233d10c2055ea2c531c466123a282121" alt="在这里插入图片描述"
线性层如果不知道输入特征是多少,注释掉线性层,查看输入特征(这里是1024)
data:image/s3,"s3://crabby-images/3a6ba/3a6ba01f91813253c758a4901098ae5f44317fec" alt="在这里插入图片描述"
使用nn.Sequential
import torch
from torch import nn
class Tudui(nn.Module):
def __init__(self):
super(Tudui, self).__init__()
self.model1 = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 5, padding=2),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(1024, 64),
nn.Linear(64, 10)
)
def forward(self, x):
x = self.model1(x)
return x
tudui = Tudui()
print(tudui)
input = torch.ones((64, 3, 32, 32))
output = tudui(input)
print(output.shape)
data:image/s3,"s3://crabby-images/71f9b/71f9b0ac88a094df5ec8e359d54c1512a1e939d6" alt="在这里插入图片描述"
可视化模型结构
writer = SummaryWriter('logs_seq')
writer.add_graph(tudui, input)
writer.close()
data:image/s3,"s3://crabby-images/4e12e/4e12eccfc5bf3443ac08582aaa6fa41ebe13fb25" alt="在这里插入图片描述"