Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\cifar-10-python.tar.gz Extracting ./data\cifar-10-python.tar.gz to ./data Files already downloaded and verified
如果你是在 Windows 系统下运行上述代码,并且出现报错信息 BrokenPipeError,可以尝试将 torch.utils.data.DataLoader() 中的 num_workers 设置为 0。
defforward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x
# 打印预测标签的结果 print('预测值:', ' '.join('%5s' % classes[predicted[j]] for j inrange(4)))
输出结果:
1
预测值: cat truck car ship
接下来看一下在全部测试集上的表现:
1 2 3 4 5 6 7 8 9 10 11 12 13
correct = 0 total = 0 net = Net() net.load_state_dict(torch.load(PATH)) with torch.no_grad(): for data in testloader: images, labels = data outputs = net(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print('10000 个测试数据的准确率为:%d%%' % (correct / total * 100))
class_correct = list(0.for i inrange(10)) class_total = list(0.for i inrange(10)) net = Net() net.load_state_dict(torch.load(PATH)) with torch.no_grad(): for data in testloader: images, labels = data outputs = net(images) _, predicted = torch.max(outputs, 1) c = (predicted == labels).squeeze() for i inrange(len(labels)): label = labels[i] class_correct[label] += c[i].item() class_total[label] += 1
for i inrange(10): print('%5s 的准确率为:%2d%%' % ( classes[i], class_correct[i] / class_total[i] * 100if class_total[i] else0))
输出结果:
1 2 3 4 5 6 7 8 9 10
plane 的准确率为:43% car 的准确率为:68% bird 的准确率为:42% cat 的准确率为:30% deer 的准确率为:32% dog 的准确率为:58% frog 的准确率为:63% horse 的准确率为:69% ship 的准确率为:67% truck 的准确率为:68%