pytorch + visdom 处理简单分类问题


本文摘自php中文网,作者不言,侵删。

这篇文章主要介绍了关于pytorch + visdom 处理简单分类问题,有着一定的参考价值,现在分享给大家,有需要的朋友可以参考一下

环境

系统 : win 10
显卡:gtx965m
cpu :i7-6700HQ
python 3.61
pytorch 0.3

包引用

1

2

3

4

5

6

7

import torch

from torch.autograd import Variable

import torch.nn.functional as F

import numpy as np

import visdom

import time

from torch import nn,optim

数据准备

1

2

3

4

5

6

7

8

9

10

11

12

13

use_gpu = True

ones = np.ones((500,2))

x1 = torch.normal(6*torch.from_numpy(ones),2)

y1 = torch.zeros(500)

x2 = torch.normal(6*torch.from_numpy(ones*[-1,1]),2)

y2 = y1 +1

x3 = torch.normal(-6*torch.from_numpy(ones),2)

y3 = y1 +2

x4 = torch.normal(6*torch.from_numpy(ones*[1,-1]),2)

y4 = y1 +3

 

x = torch.cat((x1, x2, x3 ,x4), 0).float()

y = torch.cat((y1, y2, y3, y4), ).long()

可视化如下看一下:

visdom可视化准备

先建立需要观察的windows

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

viz = visdom.Visdom()

colors = np.random.randint(0,255,(4,3)) #颜色随机

#线图用来观察loss 和 accuracy

line = viz.line(X=np.arange(1,10,1), Y=np.arange(1,10,1))

#散点图用来观察分类变化

scatter = viz.scatter(

  X=x,

  Y=y+1,

  opts=dict(

    markercolor = colors,

    marksize = 5,

    legend=["0","1","2","3"]),)

#text 窗口用来显示loss 、accuracy 、时间

text = viz.text("FOR TEST")

#散点图做对比

viz.scatter(

  X=x,

  Y=y+1,

  opts=dict(

    markercolor = colors,

    marksize = 5,

    legend=["0","1","2","3"]

  ),

)

效果如下:

逻辑回归处理

输入2,输出4

1

2

3

logstic = nn.Sequential(

  nn.Linear(2,4)

)

gpu还是cpu选择:

1

2

3

4

5

6

7

8

9

10

if use_gpu:

  gpu_status = torch.cuda.is_available()

  if gpu_status:

    logstic = logstic.cuda()

    # net = net.cuda()

    print("###############使用gpu##############")

  else : print("###############使用cpu##############")

else:

  gpu_status = False

  print("###############使用cpu##############")

优化器和loss函数:

1

2

loss_f = nn.CrossEntropyLoss()

optimizer_l = optim.SGD(logstic.parameters(), lr=0.001)

训练2000次:

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

start_time = time.time()

time_point, loss_point, accuracy_point = [], [], []

for t in range(2000):

  if gpu_status:

    train_x = Variable(x).cuda()

    train_y = Variable(y).cuda()

  else:

    train_x = Variable(x)

    train_y = Variable(y)

  # out = net(train_x)

  out_l = logstic(train_x)

  loss = loss_f(out_l,train_y)

  optimizer_l.zero_grad()

  loss.backward()

  optimizer_l.step()

训练过成观察及可视化:

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

if t % 10 == 0:

  prediction = torch.max(F.softmax(out_l, 1), 1)[1]

  pred_y = prediction.data

  accuracy = sum(pred_y ==train_y.data)/float(2000.0)

  loss_point.append(loss.data[0])

  accuracy_point.append(accuracy)

  time_point.append(time.time()-start_time)

  print("[{}/{}] | accuracy : {:.3f} | loss : {:.3f} | time : {:.2f} ".format(t + 1, 2000, accuracy, loss.data[0],

                                  time.time() - start_time))

  viz.line(X=np.column_stack((np.array(time_point),np.array(time_point))),

       Y=np.column_stack((np.array(loss_point),np.array(accuracy_point))),

       win=line,

       opts=dict(legend=["loss", "accuracy"]))

   #这里的数据如果用gpu跑会出错,要把数据换成cpu的数据 .cpu()即可

  viz.scatter(X=train_x.cpu().data, Y=pred_y.cpu()+1, win=scatter,name="add",

        opts=dict(markercolor=colors,legend=["0", "1", "2", "3"]))

  viz.text("<h3 align='center' style='color:blue'>accuracy : {}</h3><br><h3 align='center' style='color:pink'>"

       "loss : {:.4f}</h3><br><h3 align ='center' style='color:green'>time : {:.1f}</h3>"

       .format(accuracy,loss.data[0],time.time()-start_time),win =text)

我们先用cpu运行一次,结果如下:

然后用gpu运行一下,结果如下:

发现cpu的速度比gpu快很多,但是我听说机器学习应该是gpu更快啊,百度了一下,知乎上的答案是:

我的理解就是gpu在处理图片识别大量矩阵运算等方面运算能力远高于cpu,在处理一些输入和输出都很少的,还是cpu更具优势。

添加神经层:

1

2

3

4

5

net = nn.Sequential(

  nn.Linear(2, 10),

  nn.ReLU(),  #激活函数

  nn.Linear(10, 4)

)

添加一层10单元神经层,看看效果是否会有所提升:

使用cpu:


使用gpu:

比较观察,似乎并没有什么区别,看来处理简单分类问题(输入,输出少)的问题,神经层和gpu不会对机器学习加持。

相关推荐:

PyTorch上搭建简单神经网络实现回归和分类的示例

详解PyTorch批训练及优化器比较

以上就是pytorch + visdom 处理简单分类问题的详细内容,更多文章请关注木庄网络博客!!

相关阅读 >>

Python中any()和all()使用方法的简单介绍

Python中的elif是什么意思

Python3断言是什么

Python3将Python代码打包成exe文件的方法

Python怎么把值输入字典

Python基础学习之算数运算符、比较运算符

Python中的模块string.py

Python如何输入列表

Python内建函数是什么

Python基本运算符号有哪些

更多相关阅读请进入《Python》频道 >>




打赏

取消

感谢您的支持,我会继续努力的!

扫码支持
扫码打赏,您说多少就多少

打开支付宝扫一扫,即可进行扫码打赏哦

分享从这里开始,精彩与您同在

评论

管理员已关闭评论功能...