Pytorch入门之mnist分类实例


本文摘自php中文网,作者不言,侵删。

这篇文章主要为大家详细介绍了Pytorch入门之mnist分类实例,具有一定的参考价值,感兴趣的小伙伴们可以参考一下

本文实例为大家分享了Pytorch入门之mnist分类的具体代码,供大家参考,具体内容如下

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

#!/usr/bin/env python

# -*- coding: utf-8 -*-

__author__ = 'denny'

__time__ = '2017-9-9 9:03'

 

import torch

import torchvision

from torch.autograd import Variable

import torch.utils.data.dataloader as Data

 

train_data = torchvision.datasets.MNIST(

 './mnist', train=True, transform=torchvision.transforms.ToTensor(), download=True

)

test_data = torchvision.datasets.MNIST(

 './mnist', train=False, transform=torchvision.transforms.ToTensor()

)

print("train_data:", train_data.train_data.size())

print("train_labels:", train_data.train_labels.size())

print("test_data:", test_data.test_data.size())

 

train_loader = Data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)

test_loader = Data.DataLoader(dataset=test_data, batch_size=64)

 

 

class Net(torch.nn.Module):

 def __init__(self):

 super(Net, self).__init__()

 self.conv1 = torch.nn.Sequential(

  torch.nn.Conv2d(1, 32, 3, 1, 1),

  torch.nn.ReLU(),

  torch.nn.MaxPool2d(2))

 self.conv2 = torch.nn.Sequential(

  torch.nn.Conv2d(32, 64, 3, 1, 1),

  torch.nn.ReLU(),

  torch.nn.MaxPool2d(2)

 )

 self.conv3 = torch.nn.Sequential(

  torch.nn.Conv2d(64, 64, 3, 1, 1),

  torch.nn.ReLU(),

  torch.nn.MaxPool2d(2)

 )

 self.dense = torch.nn.Sequential(

  torch.nn.Linear(64 * 3 * 3, 128),

  torch.nn.ReLU(),

  torch.nn.Linear(128, 10)

 )

 

 def forward(self, x):

 conv1_out = self.conv1(x)

 conv2_out = self.conv2(conv1_out)

 conv3_out = self.conv3(conv2_out)

 res = conv3_out.view(conv3_out.size(0), -1)

 out = self.dense(res)

 return out

 

 

model = Net()

print(model)

 

optimizer = torch.optim.Adam(model.parameters())

loss_func = torch.nn.CrossEntropyLoss()

 

for epoch in range(10):

 print('epoch {}'.format(epoch + 1))

 # training-----------------------------

 train_loss = 0.

 train_acc = 0.

 for batch_x, batch_y in train_loader:

 batch_x, batch_y = Variable(batch_x), Variable(batch_y)

 out = model(batch_x)

 loss = loss_func(out, batch_y)

 train_loss += loss.data[0]

 pred = torch.max(out, 1)[1]

 train_correct = (pred == batch_y).sum()

 train_acc += train_correct.data[0]

 optimizer.zero_grad()

 loss.backward()

 optimizer.step()

 print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len(

 train_data)), train_acc / (len(train_data))))

 

 # evaluation--------------------------------

 model.eval()

 eval_loss = 0.

 eval_acc = 0.

 for batch_x, batch_y in test_loader:

 batch_x, batch_y = Variable(batch_x, volatile=True), Variable(batch_y, volatile=True)

 out = model(batch_x)

 loss = loss_func(out, batch_y)

 eval_loss += loss.data[0]

 pred = torch.max(out, 1)[1]

 num_correct = (pred == batch_y).sum()

 eval_acc += num_correct.data[0]

 print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(

 test_data)), eval_acc / (len(test_data))))

相关推荐:

python如何读取二进制mnist实例详解

一篇不错的Python入门教程_python

以上就是Pytorch入门之mnist分类实例的详细内容,更多文章请关注木庄网络博客!!

相关阅读 >>

讲解Python 中删除文件的几种方法

Python语言的保留字

实例介绍Python随机数使用方法,推导以及字符串,双色球

Python中的split是什么

Python发展至今有哪些版本,各版本有什么区别?

Python中split函数如何使用

Python输出2到100之间的素数

Python中的合法变量名有什么规则

如何保存Python代码

词向量嵌入的实例详解

更多相关阅读请进入《Python》频道 >>




打赏

取消

感谢您的支持,我会继续努力的!

扫码支持
扫码打赏,您说多少就多少

打开支付宝扫一扫,即可进行扫码打赏哦

分享从这里开始,精彩与您同在

评论

管理员已关闭评论功能...