首页 > 其他 > 详细

tensorflow2.0——鸢尾花数据集的一元分类

时间:2020-08-05 00:07:08      阅读:154      评论:0      收藏:0      [点我收藏+]
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pylab as plt
import matplotlib as mpl

#   读取数据集
TRIN_URL = http://download.tensorflow.org/data/iris_training.csv            # 数据集下载网址
df_iris = pd.read_csv(./鸢尾花数据集/iris.csv, header=0)                   # 读取本地csv文件数据
# print(df_iris)                                                              #   最后一列是分类编号
#   处理数据集
iris = np.array(df_iris)                                                        # 将数据转换成numpy数组
iris2 = iris[iris[:, -1] < 2]                                                   # 只取分类编号小于2的两类数据
train_x = iris2[:, 0:2]                                                         # 只取特征的前两列
train_x = train_x - np.mean(train_x, axis=0)                                     # 需要将样本的均值变为0
train_1 = np.ones(train_x.shape[0]).reshape(-1, 1)                               # 生成一个与train_x一样行数的全1矩阵
# print(‘train_1:‘,train_1)
train_x = tf.concat((train_x, train_1), axis=1)                                  # 将train_x扩充一列全1
train_x = tf.cast(train_x,tf.float32)
train_y = iris2[:, -1]  # 标签
train_y = train_y.reshape(-1,1)
# print(train_x, train_x.shape)
# print(‘axis = 0:‘,np.mean(train_x,axis=0))                                    #   axis = 0 求一列的平均值
print(train_y, train_y.shape)

#   设置超参
iter = 2000
learn_rate = 0.1
loss_list = []
acc_list = []
#   初始化训练参数

w = tf.Variable(np.random.randn(3,1),dtype=tf.float32)
# w = tf.Variable(np.array([1.,1.,1.]).reshape(-1,1),dtype=tf.float32)

for i in range(iter):
    with tf.GradientTape() as tape:
        y_p = 1 / (1 + tf.exp(-(tf.matmul(train_x, w))))
        loss = tf.reduce_mean(-(train_y * tf.math.log(y_p) + (1 - train_y) * tf.math.log(1 - y_p)))
    dloss_dw = tape.gradient(loss, w)
    w.assign_sub(learn_rate * dloss_dw)
    loss_list.append(loss)
    acc = tf.reduce_mean(tf.cast(tf.equal(tf.round(y_p),train_y),dtype = tf.float32))
    acc_list.append(acc)
    if i % 100 == 0:
        print(第{}次, loss:{},acc:{}.format(i,loss,acc))
        # print(‘y_p:{}\ntrain_y:{}‘.format(y_p, train_y))
        print()


#   预测直线的横纵坐标处理
x1 = train_x[:,0]
print(w[0] * 55)
x2 = -(w[0] * x1 + w[2])/ w[1]
# print(‘x2:‘,x2)

#   画图
plt.rcParams["font.family"] = SimHei                          # 将字体改为中文
plt.rcParams[axes.unicode_minus] = False                      # 设置了中文字体默认后,坐标的"-"号无法显示,设置这个参数就可以避免
plt.subplot(221)
cm_pt = mpl.colors.ListedColormap([red, green])
plt.scatter(x=train_x[:, 0], y=train_x[:, 1], c=train_y, cmap=cm_pt)
plt.plot(x1,x2,label = 预测线)
plt.legend()
plt.subplot(223)
plt.title(损失值)
plt.plot(loss_list)
plt.subplot(224)
plt.title(准确率)
plt.plot(acc_list)
plt.show()

技术分享图片

 

tensorflow2.0——鸢尾花数据集的一元分类

原文:https://www.cnblogs.com/cxhzy/p/13435689.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!