您的位置:首頁 > 軟件教程 > 教程 > PyTorch的安裝與使用

PyTorch的安裝與使用

來源:好特整理 | 時(shí)間:2024-05-08 18:58:28 | 閱讀:93 |  標(biāo)簽: T C   | 分享到:

本文介紹了熱門AI框架PyTorch的conda安裝方案,與簡單的自動微分示例。并順帶講解了一下PyTorch開源Github倉庫中的兩個(gè)Issue內(nèi)容,分別是自動微分的關(guān)鍵詞參數(shù)輸入問題與自動微分參數(shù)數(shù)量不匹配時(shí)的參數(shù)返回問題,并包含了這兩個(gè)Issue的解決方案。

技術(shù)背景

PyTorch是一個(gè)非常常用的AI框架,主要?dú)w功于其簡單易用的特點(diǎn),深受廣大科研人員的喜愛。在前面的一篇 文章 中我們介紹過制作PyTorch的Singularity鏡像的方法,這里我們單獨(dú)抽出PyTorch的安裝和使用,再簡單的聊一聊。

安裝Torch

常規(guī)的安裝方案可以使用源碼安裝、pip安裝、conda安裝和容器安裝等,這里我們首選推薦的是conda安裝的方法。關(guān)于conda,其實(shí)沒必要安裝完整版本的anaconda,裝一個(gè)miniconda就可以了。假定我們已經(jīng)安裝好了conda,那么首先要創(chuàng)建一個(gè)專用的pytorch虛擬環(huán)境:

$ conda create -n pytorch python=3.9
Retrieving notices: ...working... done
Collecting package metadata (current_repodata.json): done
Solving environment: done


==> WARNING: A newer version of conda exists. <==
  current version: 23.1.0
  latest version: 24.4.0

Please update conda by running

    $ conda update -n base -c defaults conda

Or to minimize the number of packages updated during conda update use

     conda install conda=24.4.0



## Package Plan ##

  environment location: /home/dechin/anaconda3/envs/pytorch

  added / updated specs:
    - python=3.9


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    ca-certificates-2024.3.11  |       h06a4308_0         127 KB  https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
    libffi-3.4.4               |       h6a678d5_1         141 KB  https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
    openssl-3.0.13             |       h7f8727e_1         5.2 MB  https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
    pip-23.3.1                 |   py39h06a4308_0         2.6 MB  https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
    python-3.9.19              |       h955ad1f_1        25.1 MB  https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
    setuptools-69.5.1          |   py39h06a4308_0        1003 KB  https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
    sqlite-3.45.3              |       h5eee18b_0         1.2 MB  https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
    tk-8.6.14                  |       h39e8969_0         3.4 MB  https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
    tzdata-2024a               |       h04d1e81_0         116 KB  https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
    wheel-0.43.0               |   py39h06a4308_0         109 KB  https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
    xz-5.4.6                   |       h5eee18b_1         643 KB  https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
    zlib-1.2.13                |       h5eee18b_1         111 KB  https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
    ------------------------------------------------------------
                                           Total:        39.8 MB

The following NEW packages will be INSTALLED:

  _libgcc_mutex      anaconda/pkgs/main/linux-64::_libgcc_mutex-0.1-main 
  _openmp_mutex      anaconda/pkgs/main/linux-64::_openmp_mutex-5.1-1_gnu 
  ca-certificates    anaconda/pkgs/main/linux-64::ca-certificates-2024.3.11-h06a4308_0 
  ld_impl_linux-64   anaconda/pkgs/main/linux-64::ld_impl_linux-64-2.38-h1181459_1 
  libffi             anaconda/pkgs/main/linux-64::libffi-3.4.4-h6a678d5_1 
  libgcc-ng          anaconda/pkgs/main/linux-64::libgcc-ng-11.2.0-h1234567_1 
  libgomp            anaconda/pkgs/main/linux-64::libgomp-11.2.0-h1234567_1 
  libstdcxx-ng       anaconda/pkgs/main/linux-64::libstdcxx-ng-11.2.0-h1234567_1 
  ncurses            anaconda/pkgs/main/linux-64::ncurses-6.4-h6a678d5_0 
  openssl            anaconda/pkgs/main/linux-64::openssl-3.0.13-h7f8727e_1 
  pip                anaconda/pkgs/main/linux-64::pip-23.3.1-py39h06a4308_0 
  python             anaconda/pkgs/main/linux-64::python-3.9.19-h955ad1f_1 
  readline           anaconda/pkgs/main/linux-64::readline-8.2-h5eee18b_0 
  setuptools         anaconda/pkgs/main/linux-64::setuptools-69.5.1-py39h06a4308_0 
  sqlite             anaconda/pkgs/main/linux-64::sqlite-3.45.3-h5eee18b_0 
  tk                 anaconda/pkgs/main/linux-64::tk-8.6.14-h39e8969_0 
  tzdata             anaconda/pkgs/main/noarch::tzdata-2024a-h04d1e81_0 
  wheel              anaconda/pkgs/main/linux-64::wheel-0.43.0-py39h06a4308_0 
  xz                 anaconda/pkgs/main/linux-64::xz-5.4.6-h5eee18b_1 
  zlib               anaconda/pkgs/main/linux-64::zlib-1.2.13-h5eee18b_1 


Proceed ([y]/n)? y


Downloading and Extracting Packages
                                                                                                                                                 
Preparing transaction: done                                                                                                                      
Verifying transaction: done                                                                                                                      
Executing transaction: done                                                                                                                      
#                                                                                                                                                
# To activate this environment, use                                                                                                              
#                                                                                                                                                
#     $ conda activate pytorch                                                                                                                   
#                                                                                                                                                
# To deactivate an active environment, use                                                                                                       
#                                                                                                                                                
#     $ conda deactivate                                                                                                                         

這里我們是基于Python3.9版本創(chuàng)建了一個(gè)Python虛擬環(huán)境。相比于容器和虛擬機(jī)來說,虛擬環(huán)境結(jié)構(gòu)更加簡單,非常適用于本地的Python軟件管理。當(dāng)然,如果是在服務(wù)器上面運(yùn)行,那還是推薦容器的方案多一些。有了基礎(chǔ)的Python環(huán)境之后,可以去 PyTorch官網(wǎng) 找找適用于自己本地環(huán)境的conda安裝命令:

然后把這條命令復(fù)制到自己本地進(jìn)行安裝。建議在安裝的時(shí)候加上 -y 的配置,就省的加載一半還需要你自己手動去配置一個(gè)輸入一個(gè)y來決定是否繼續(xù)下一步安裝。因?yàn)檫@個(gè)安裝的過程可能也會比較耗時(shí),尤其網(wǎng)絡(luò)對于一部分國內(nèi)的IP可能并不是那么的友好。

$ conda install -y pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia                                                                            
Solving environment: done                                                                                           
                                                                                                                    
## Package Plan ##

  environment location: /home/dechin/anaconda3/envs/pytorch

  added / updated specs:
    - pytorch
    - pytorch-cuda=11.8
    - torchaudio
    - torchvision


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    charset-normalizer-2.0.4   |     pyhd3eb1b0_0          35 KB  https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
    jinja2-3.1.3               |   py39h06a4308_0         269 KB  https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
    libdeflate-1.17            |       h5eee18b_1          64 KB  https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
    libnpp-11.8.0.86           |                0       147.8 MB  nvidia
    libunistring-0.9.10        |       h27cfd23_0         536 KB  https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
    typing_extensions-4.9.0    |   py39h06a4308_1          54 KB  https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
    ------------------------------------------------------------
                                           Total:       148.8 MB

The following NEW packages will be INSTALLED:

  blas               anaconda/pkgs/main/linux-64::blas-1.0-mkl 
  bzip2              anaconda/pkgs/main/linux-64::bzip2-1.0.8-h5eee18b_6 
  certifi            anaconda/pkgs/main/linux-64::certifi-2024.2.2-py39h06a4308_0 
  charset-normalizer anaconda/pkgs/main/noarch::charset-normalizer-2.0.4-pyhd3eb1b0_0 
  cuda-cudart        nvidia/linux-64::cuda-cudart-11.8.89-0 
  cuda-cupti         nvidia/linux-64::cuda-cupti-11.8.87-0 
  ...
  pytorch            pytorch/linux-64::pytorch-2.3.0-py3.9_cuda11.8_cudnn8.7.0_0 
  pytorch-cuda       pytorch/linux-64::pytorch-cuda-11.8-h7e8668a_5 
  pytorch-mutex      pytorch/noarch::pytorch-mutex-1.0-cuda 
  zstd               anaconda/pkgs/main/linux-64::zstd-1.5.5-hc292b87_2 

Downloading and Extracting Packages
                                                                                                                    
Preparing transaction: done                                                                                         
Verifying transaction: done                                                                                         
Executing transaction: done                         

安裝完成后可以通過如下指令,在bash命令行里面檢查一下是否安裝成功了PyTorch的CUDA版本:

$ python3 -c "import torch;print(torch.cuda.is_available())"
True

如果輸出為 True 則表明安裝成功。另外順便一提,如果在 conda 安裝的過程中出現(xiàn)如下的報(bào)錯(cuò):

CondaHTTPError: HTTP 000 CONNECTION FAILED for url <https://conda.anaconda.org/nvidia/linux-64/libnpp-11.8.0.86-0.tar.bz2>                                                                                                              
Elapsed: -                                                                                                          
                                                                                                                    
An HTTP error occurred when trying to retrieve this URL.                                                            
HTTP errors are often intermittent, and a simple retry will get you on your way.                                    
                                                                                                                    
CancelledError()                                                                                                    
CancelledError()                                                                                                    
CancelledError()                                                                                                    
CancelledError()

一般情況下就是由網(wǎng)絡(luò)問題導(dǎo)致的,但也并不是完全無法鏈接,我們同樣的命令行多輸入幾次就可以了,直到安裝完成為止。

PyTorch自動微分

關(guān)于自動微分的原理,讀者可以參考一下之前的這篇 手搓自動微分 的文章,PyTorch大概就是使用的這個(gè)自動微分的原理。在PyTorch框架下,我們可以通過backward函數(shù)來自定義反向傳播函數(shù),這一點(diǎn)跟MindSpore框架有所不同,MindSpore框架下自定義反向傳播函數(shù)使用的是bprop函數(shù),MindSpore自定義反向傳播相關(guān)內(nèi)容可以參考下這篇 文章 。如下所示是一個(gè)Torch的用例:

# 忽略告警信息
import warnings
warnings.filterwarnings("ignore")

import torch

# 自定義可微分的類型
class Gradient(torch.autograd.Function):
    # 前向傳播
    @staticmethod
    def forward(ctx, x, w=None):
        # 保存一個(gè)參數(shù)到計(jì)算圖中
        ctx.save_for_backward(w)
        return x
    # 反向傳播
    @staticmethod
    def backward(ctx, g):
        w,  = ctx.saved_tensors
        if w is None:
            return g
        else:
            return g * w, None

# 非加權(quán)自動微分測試
x = torch.autograd.Variable(torch.tensor(3.14), requires_grad=True)
g = torch.autograd.Variable(torch.tensor(3.15))
gradient = Gradient()
# 前向傳播
y = gradient.apply(x)
print (y)
# 反向傳播
y.backward(g)
# 打印梯度
print (x.grad)
# 加權(quán)自動微分測試
x = torch.autograd.Variable(torch.tensor(3.14), requires_grad=True)
g = torch.autograd.Variable(torch.tensor(3.15))
w = torch.autograd.Variable(torch.tensor(2.0))
z = gradient.apply(x, w)
print (z)
z.backward(g)
print (x.grad)

輸出結(jié)果為:

tensor(3.1400, grad_fn=)
tensor(3.1500)
tensor(3.1400, grad_fn=)
tensor(6.3000)

這樣一來,就把需要輸入到反向傳播函數(shù)中的加權(quán)值傳了進(jìn)去。因?yàn)樵谡5腷ackward函數(shù)中,相關(guān)的輸入類型都是規(guī)定好的,不能隨便加輸入,所以要從前向傳播中傳遞給計(jì)算圖。在這個(gè)案例中,順便介紹下PyTorch開源倉庫中的兩個(gè)Issue。第一個(gè)問題是, PyTorch的前向傳播函數(shù)中,如果從外部傳入一個(gè)關(guān)鍵字參數(shù),會報(bào)錯(cuò)

關(guān)于這個(gè)問題,官方做了如下解釋:

大體意思就是,如果使用關(guān)鍵字類型的參數(shù)輸入,會給參數(shù)校驗(yàn)和結(jié)果返回帶來一些困難。同時(shí)給出了一個(gè)臨時(shí)的解決方案:

其實(shí)也就是我們這個(gè)案例中所采用的方案,套一個(gè)條件語句就可以了。另外一條Issue是, 如果涉及到多個(gè)輸入,那么在反向傳播函數(shù)中也要給到多個(gè)輸出

不過在這個(gè)Issue中,提Issue的人本身也給出了一個(gè)方案,就是直接在返回結(jié)果中給一個(gè)None值。

總結(jié)概要

本文介紹了熱門AI框架PyTorch的conda安裝方案,與簡單的自動微分示例。并順帶講解了一下PyTorch開源Github倉庫中的兩個(gè)Issue內(nèi)容,分別是自動微分的關(guān)鍵詞參數(shù)輸入問題與自動微分參數(shù)數(shù)量不匹配時(shí)的參數(shù)返回問題,并包含了這兩個(gè)Issue的解決方案。

版權(quán)聲明

本文首發(fā)鏈接為: https://www.cnblogs.com/dechinphy/p/torch.html

作者ID:DechinPhy

更多原著文章: https://www.cnblogs.com/dechinphy/

請博主喝咖啡: https://www.cnblogs.com/dechinphy/gallery/image/379634.html

參考鏈接

  1. https://pytorch.org/get-started/locally/
  2. https://www.cnblogs.com/dechinphy/p/pytorch.html
  3. https://github.com/pytorch/pytorch/issues/16940
  4. https://github.com/Lightning-AI/pytorch-lightning/issues/6624
  5. https://blog.csdn.net/winycg/article/details/104410525
小編推薦閱讀

好特網(wǎng)發(fā)布此文僅為傳遞信息,不代表好特網(wǎng)認(rèn)同期限觀點(diǎn)或證實(shí)其描述。

相關(guān)視頻攻略

更多

掃二維碼進(jìn)入好特網(wǎng)手機(jī)版本!

掃二維碼進(jìn)入好特網(wǎng)微信公眾號!

本站所有軟件,都由網(wǎng)友上傳,如有侵犯你的版權(quán),請發(fā)郵件[email protected]

湘ICP備2022002427號-10 湘公網(wǎng)安備:43070202000427號© 2013~2025 haote.com 好特網(wǎng)