WONGCW 網誌
  • 首頁
  • 論壇
  • 微博
  • 雲端筆記
  • 域名申請
  • 網速測試
  • 壁紙下載
  • 免費圖床
  • 匯率兌換
  • More
    • QR Code Generator
    • URL Shortener
    • WordPress Plugins And Themes
  • 免責聲明
  • Search Icon

WONGCW 網誌

記錄生活經驗與點滴

TensorFlow – 线性回归

TensorFlow – 线性回归

2018-09-24 Comments 0 Comment

TensorFlow 实现线性回归模型代码

任务时间:时间未知

前期准备

TensorFlow 相关 API 可以到在实验 TensorFlow – 相关 API 中学习。

模型构建

示例代码:

现在您可以在 /home/ubuntu 目录下创建源文件 linear_regression_model.py,内容可参考:

示例代码:/home/ubuntu/linear_regression_model.py
#!/usr/bin/python
# -*- coding: utf-8 -*
import tensorflow as tf
import numpy as np

class linearRegressionModel:

  def __init__(self,x_dimen):
    self.x_dimen = x_dimen
    self._index_in_epoch = 0
    self.constructModel()
    self.sess = tf.Session()
    self.sess.run(tf.global_variables_initializer())

  #权重初始化
  def weight_variable(self,shape):
    initial = tf.truncated_normal(shape,stddev = 0.1)
    return tf.Variable(initial)

  #偏置项初始化
  def bias_variable(self,shape):
    initial = tf.constant(0.1,shape = shape)
    return tf.Variable(initial)

  #每次选取100个样本,如果选完,重新打乱
  def next_batch(self,batch_size):
    start = self._index_in_epoch
    self._index_in_epoch += batch_size
    if self._index_in_epoch > self._num_datas:
        perm = np.arange(self._num_datas)
        np.random.shuffle(perm)
        self._datas = self._datas[perm]
        self._labels = self._labels[perm]
        start = 0
        self._index_in_epoch = batch_size
        assert batch_size <= self._num_datas
    end = self._index_in_epoch
    return self._datas[start:end],self._labels[start:end]

  def constructModel(self):
    self.x = tf.placeholder(tf.float32, [None,self.x_dimen])
    self.y = tf.placeholder(tf.float32,[None,1])
    self.w = self.weight_variable([self.x_dimen,1])
    self.b = self.bias_variable([1])
    self.y_prec = tf.nn.bias_add(tf.matmul(self.x, self.w), self.b)

    mse = tf.reduce_mean(tf.squared_difference(self.y_prec, self.y))
    l2 = tf.reduce_mean(tf.square(self.w))
    self.loss = mse + 0.15*l2
    self.train_step = tf.train.AdamOptimizer(0.1).minimize(self.loss)

  def train(self,x_train,y_train,x_test,y_test):
    self._datas = x_train
    self._labels = y_train
    self._num_datas = x_train.shape[0]
    for i in range(5000):
        batch = self.next_batch(100)
        self.sess.run(self.train_step,feed_dict={self.x:batch[0],self.y:batch[1]})
        if i%10 == 0:
            train_loss = self.sess.run(self.loss,feed_dict={self.x:batch[0],self.y:batch[1]})
            print('step %d,test_loss %f' % (i,train_loss))

  def predict_batch(self,arr,batch_size):
    for i in range(0,len(arr),batch_size):
        yield arr[i:i + batch_size]

  def predict(self, x_predict):
    pred_list = []
    for x_test_batch in self.predict_batch(x_predict,100):
      pred = self.sess.run(self.y_prec, {self.x:x_test_batch})
      pred_list.append(pred)
    return np.vstack(pred_list)

训练模型并和 sklearn 库线性回归模型对比

示例代码:

现在您可以在 /home/ubuntu 目录下创建源文件 run.py,内容可参考:

示例代码:/home/ubuntu/run.py
#!/usr/bin/python
# -*- coding: utf-8 -*

from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score
from sklearn.datasets import make_regression
from sklearn.linear_model import LinearRegression
from linear_regression_model import linearRegressionModel as lrm

if __name__ == '__main__':
    x, y = make_regression(7000)
    x_train,x_test,y_train, y_test = train_test_split(x, y, test_size=0.5)
    y_lrm_train = y_train.reshape(-1, 1)
    y_lrm_test = y_test.reshape(-1, 1)

    linear = lrm(x.shape[1])
    linear.train(x_train, y_lrm_train,x_test,y_lrm_test)
    y_predict = linear.predict(x_test)
    print("Tensorflow R2: ", r2_score(y_predict.ravel(), y_lrm_test.ravel()))

    lr = LinearRegression()
    y_predict = lr.fit(x_train, y_train).predict(x_test)
    print("Sklearn R2: ", r2_score(y_predict, y_test)) #采用r2_score评分函数

然后执行:

cd /home/ubuntu;
python run.py

执行结果:

step 2410,test_loss 26.531937
step 2420,test_loss 26.542793
step 2430,test_loss 26.533974
step 2440,test_loss 26.530540
step 2450,test_loss 26.551474
step 2460,test_loss 26.541542
step 2470,test_loss 26.560783
step 2480,test_loss 26.538080
step 2490,test_loss 26.535666
('Tensorflow R2: ', 0.99999612588302389)
('Sklearn R2: ', 1.0)

完成

任务时间:时间未知

恭喜,您已完成本实验内容

您可进行更多 TensorFlow 的系列教程:

  • TensorFlow – 相关 API
  • TensorFlow – 基于 CNN 数字识别

关于 TensorFlow 的更多资料可参考 TensorFlow 官网 。

點閱: 16

分享此文:

  • 分享到 Twitter(在新視窗中開啟)
  • 按一下以分享至 Facebook(在新視窗中開啟)
  • 分享到 WhatsApp(在新視窗中開啟)
  • 按一下以分享到 Telegram(在新視窗中開啟)
  • 按一下即可分享至 Skype(在新視窗中開啟)
  • 點這裡列印(在新視窗中開啟)
  • 點這裡寄給朋友(在新視窗中開啟)

相關


TensorFlow

Post navigation

NEXT
TensorFlow – 逻辑回归 (Logistic Regression)
PREVIOUS
TensorFlow – 相关 API

發表迴響 取消回覆

這個網站採用 Akismet 服務減少垃圾留言。進一步瞭解 Akismet 如何處理網站訪客的留言資料。

九月 2018
一二三四五六日
« 八月 十月 »
 12
3456789
10111213141516
17181920212223
24252627282930

分類

  • 網站公告 (2)
  • ESET NOD32 (4)
  • WhatsApp Sticker (32)
  • WINDOWS 10 INSIDER PREVIEW (14)
  • Windows 軟件下載 (888)
  • 系統軟件 (202)
  • 辦公軟件 (44)
  • 圖像處理 (103)
  • 影音軟件 (167)
  • 網絡軟件 (218)
  • 應用軟件 (99)
  • Mac 軟件下載 (110)
  • Mac 教學資訊 (110)
  • 安卓軟件 (29)
  • Nulled Scripts (371)
  • WordPress 相關 (181)
  • 在線工具 (18)
  • VPS主機 (1,359)
  • Linux相關 (187)
  • 網絡資訊 (14,474)
  • 教學資訊 (771)
  • NASA (255)
  • SEO工具 (16)
  • WeChat相關 (72)
  • 開源程序 (30)
  • PHP語言 (55)
  • Plesk面板 (3)
  • TensorFlow (11)
  • 日常生活 (46)
  • 動手賺錢 (206)
  • 其他資料 (244)

彙整

  • 2019 年 二月
  • 2019 年 一月
  • 2018 年 十二月
  • 2018 年 十一月
  • 2018 年 十月
  • 2018 年 九月
  • 2018 年 八月
  • 2018 年 七月
  • 2018 年 六月
  • 2018 年 五月
  • 2018 年 四月
  • 2018 年 三月
  • 2018 年 一月
  • 2017 年 十二月

近期文章

  • 20W功率的小米無線車充現身 2019-02-23
  • 三星折疊屏手機之後人們展開了對折疊iPhone的幻想 2019-02-23
  • Galaxy S10+ 跑分出爐Android 陣營第一落後於iPhone XS 2019-02-23
  • 科學家揭示細胞“門神”抗病毒感染重要調控機制 2019-02-23
  • 中國ARJ21飛機已交付12架運送旅客逾27萬人次 2019-02-23
  • 中國北疆首列“智軌”電車在哈爾濱市試跑 2019-02-23
  • 解讀日本將太空岩石帶回地球壯舉的背後隱藏了哪些挑戰 2019-02-23
  • 頭骨上的“微雕” 中國醫生完成了“不可能完成的任務” 2019-02-23
  • 主播拿錢後給《聖歌》打低分EA:可以,把水印刪了 2019-02-23
  • 傳聞稱微軟折疊屏智能機仍在路上但不會運行Windows 2019-02-23

熱門文章與頁面︰

  • ESET NOD32 LICENSE KEY (UPDATED 2019-02-23)
  • ESET Internet Security / ESET NOD32 AntiVirus v12.0.31.0 x86 / x64多語言中文​​正式版
  • PDF-XChange Editor Plus 7.0.327.0(含:註冊機序列號)
  • WINDOWS SERVER 2019 MSDN 正式版ISO鏡像-簡體中文/繁體中文/英文
  • ESET Smart Security Premium/ESET Internet Security/ESET NOD32 AntiVirus v11.2.63.0 正式版-简体中文/繁体中文/英文
  • ESET Internet Security/ESET NOD32 AntiVirus v12.0.27.0 x86/x64 多语言中文正式版
  • Acronis True Image 2019 v23.4.1.14690+Bootable ISO Win/Mac多語言中文​​註冊版
  • WhatsApp Stickers 貼圖下載技
  • YouTube By Click 2.2.86(含:註冊機序列號)
  • Windows 10 version 1809 (Updated October 2018) RS5 正式版MSDN ISO鏡像-簡體中文/繁體中文/英文(最後更新:2018-11-16)

投遞稿件

歡迎各界人士投遞稿件到admin@wongcw.com

請提供以下資料:

1.你的名字

2.你的電郵

3.分類目錄

4.文章標題

5.文章摘要

6.文章內容

7.文章來源

 

標籤

Adobe Alibaba Alipay Amazon AMD Apple Azure Baidu Chrome Facebook Firefox GitHub Google HTC HUAWAI Intel iOS iPad iPhone Kindle LG Linux Mac MacBook Microsoft Nginx Nod32 Nokia Nvidia Samsung Sony VMware WeChat WhatsApp Sticker Windows YouTube 小米

聯繫我們

E-mail:admin@wongcw.com

QQ群:833641851

隱私權與 Cookie:此網站可使用 Cookie。繼續使用此網站即表示你同意使用 Cookie。
若要瞭解更多資訊,包括如何控制 Cookie,請參閱此處: Cookie 政策
© 2019   All Rights Reserved.
loading 取消
文章未送出─請檢查你的電子郵件地址!
電子郵件地址檢查失敗,請再試一次
抱歉,你的網誌無法透過電子郵件分享