庖丁解牛式的学习,才能真正的一劳永逸!
这是CVHub公众号的第五篇原创文章,也是《学术小白也能看懂的学术进阶专栏》(计算机视觉方向)的第五篇文章!
前言
欢迎大家来到CVHub学习。本文主要介绍的是使用Pytorch搭建分类网络,利用最简单易懂的代码带领大家体会搭神经网络如同搭乐高积木一般的乐趣。
PyTorch介绍
PyTorch是torch的python版本,是由Facebook开源的神经网络框架。与Tensorflow的静态计算图不同,pytorch的计算图是动态的,可以根据计算需要实时改变计算图。PyTorch的开发重点是支持快速的实验。能简单快速地把idea转换为代码形式进行实验,是做好研究的关键。
数据集
我们使用的是最常见且具有代表性的数据集,CIFAR-10,这个数据集的计算量大概是MNIST的10倍左右,CPU也勉强能跑。CIFAR-10是由60000张32×32彩色图片组合的10分类任务,其中训练图片共有50000张,测试图片共有10000张,每个类有6000张图片。CIFAR-10数据集下载地址:https://www.cs.toronto.edu/~kriz/cifar.html
ResNet分类网络介绍ResNet:由Kaiming He于2015年提出,在ImageNet比赛classification任务上获得第一名。ResNet提出了残差块(Residual Block),高效地解决了在深层网络中可能会出现的梯度弥散/爆炸的问题。
Plain block vs Residual block
代码解读
本次实验的详细代码在我的GitHub仓库中可以下载。Github地址(点击文末左下角阅读原文可直接挑转):https://github.com/CVHuber/Classification-getting-startedtrain.py
样本训练流程:deftrain(model, device, train_loader, optimizer):# 设置为训练模式,会更新BN和Dropout参数model.train()epoch_loss = 0# 训练forbatch_idx, (data, target) inenumerate(train_loader):# 将训练样本和掩模加载到GPUdata, target = data.to(device), target.to(device)# 将数据喂入网络,获得一个预测结果output = model(data)# 通过交叉熵函数计算预测结果和标签之间的lossloss = F.cross_entropy(output, target)# loss累加epoch_loss += loss.item()# 梯度清零optimizer.zero_grad()# 通过loss求出梯度loss.backward()# 使用Adam进行梯度回传optimizer.step()print(‘loss=%.3f’% (100.* epoch_loss / len(train_loader.dataset)))
resnet_model.py样本编码流程:defforward(self, x):# 常规卷积序列:Conv->BN->ReLU->Maxpooling# 提取感受野x = self.conv1(x)# 参数归一化x = self.bn1(x)# 激活函数,非线性转换x = self.relu(x)# 图片下采样x = self.maxpool(x)# 四个残差卷积块x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)# 平均池化x = self.avgpool(x)# 修改Tensor维度x = x.view(x.size(0), -1)# 全连接映射x = self.fc(x)returnx
eval.pyaccuracy计算流程:# 计算topk准确率defaccuracy(output, target, topk=(1,)):maxk = max(topk)batch_size = target.size(0)# 选出k个概率最高值_, pred = output.topk(maxk, 1, True, True)# tensor转置pred = pred.t()# 计算正确的个数correct = pred.eq(target.view(1, -1).expand_as(pred))res = []# 计算topk值fork intopk:correct_k = correct[:k].view(-1).float().sum(0)res.append(correct_k.mul_(100.0/ batch_size))returnres
结语本文仅展示了核心代码的注释说明,其余部分介绍以及完整代码工程文件可参看我的Github。
Github地址(点击文末左下角阅读原文可直接挑转):https://github.com/CVHuber/Classification-getting-started
所有代码均写有详细注释。交流渠道微信群(长按加入):