博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Pytorch实战(3)----分类
阅读量:7122 次
发布时间:2019-06-28

本文共 2784 字,大约阅读时间需要 9 分钟。

一、分类任务:

将以下两类分开。

创建数据代码:

# make fake datan_data = torch.ones(100, 2)x0 = torch.normal(2*n_data, 1)      # class0 x data (tensor), shape=(100, 2)y0 = torch.zeros(100)               # class0 y data (tensor), shape=(100, 1)x1 = torch.normal(-2*n_data, 1)     # class1 x data (tensor), shape=(100, 2)y1 = torch.ones(100)                # class1 y data (tensor), shape=(100, 1)x = torch.cat((x0, x1), 0).type(torch.FloatTensor)  # shape (200, 2) FloatTensor = 32-bit floatingy = torch.cat((y0, y1), ).type(torch.LongTensor)    # shape (200,) LongTensor = 64-bit integer# torch can only train on Variable, so convert them to Variablex, y = Variable(x), Variable(y)plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=y.data.numpy(), s=100, lw=0, cmap='RdYlGn')plt.show()

 

二、步骤

  1. 导入包

  2. 创建模型

  3. 设置优化器和损失函数

  4. 训练模型

三、代码:

导入包:

import torchfrom torch.autograd import Variableimport torch.nn.functional as Fimport matplotlib.pyplot as plt%matplotlib inlinetorch.manual_seed(1)    # reproducible

创建模型:

class Net(torch.nn.Module):    def __init__(self, n_feature, n_hidden, n_output):        super(Net, self).__init__()        self.hidden = torch.nn.Linear(n_feature, n_hidden)   # hidden layer        self.out = torch.nn.Linear(n_hidden, n_output)   # output layer    def forward(self, x):        x = F.relu(self.hidden(x))      # activation function for hidden layer        x = self.out(x)        return x

设置优化器和损失函数

#输入的x为2维张量,输出有两类net = Net(n_feature=2, n_hidden=10, n_output=2)     # define the networkprint(net)  # net architecture# Loss and Optimizer# Softmax is internally computed.# Set parameters to be updated.optimizer = torch.optim.SGD(net.parameters(), lr=0.02)loss_func = torch.nn.CrossEntropyLoss()  # the target label is NOT an one-hotted

 

训练模型并画图展示

plt.ion()   # something about plottingplt.show()for t in range(100):    out = net(x)                 # input x and predict based on x    loss = loss_func(out, y)     # must be (1. nn output, 2. target), the target label is NOT one-hotted    optimizer.zero_grad()   # clear gradients for next train    loss.backward()         # backpropagation, compute gradients    optimizer.step()        # apply gradients        if t % 10 == 0 or t in [3, 6]:        # plot and show learning process        plt.cla()        _, prediction = torch.max(F.softmax(out), 1)  #这里是得到softmax之后最大概率的y预测值。        pred_y = prediction.data.numpy().squeeze()        print(pred_y)        target_y = y.data.numpy()        plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, s=100, lw=0, cmap='RdYlGn')        accuracy = sum(pred_y == target_y)/200.        plt.text(1.5, -4, 'Accuracy=%.2f' % accuracy, fontdict={
'size': 20, 'color': 'red'}) plt.show()# plt.pause(0.1)plt.ioff()

结果展示:

 

 

 

转载于:https://www.cnblogs.com/Lee-yl/p/10139165.html

你可能感兴趣的文章
服务部署如何做到高可用?这份“三级跳”秘籍送给你\n
查看>>
独家解读 | 滴滴机器学习平台架构演进之路
查看>>
KubeEdge向左,K3S向右
查看>>
微软正式发布 Azure IoT Central
查看>>
Build 2018大会:.NET概述和路线图
查看>>
七牛李倩:⼯程效率如何为研发赋能
查看>>
从“被动挖光缆”到“主动剪网线”,蚂蚁金服异地多活的微服务体系
查看>>
PhpStorm2016.3激活
查看>>
Docker4Dev #7 新瓶装老酒 – 使用 Windows Container运行ASP.NET MVC 2 + SQLExpress 应用
查看>>
使用vue.js构建一个知乎日报
查看>>
Microsoft Flow发布GA版本
查看>>
Python 赋值的一般含义是引用
查看>>
magento2 在香港用paypal
查看>>
Yii系列(1)打造虚拟开发环境及Yii的安装配置
查看>>
img/background/iconfont---谁最适合你?
查看>>
我的iOS程序生涯的起点
查看>>
程序员的工匠精神
查看>>
【underscore.js 源码解读】for ... in 存在的浏览器兼容问题你造吗
查看>>
Sass 与 Compass 实战经验总结
查看>>
微信公众号开发小记——3.接入三方登录
查看>>