MXNet——symbol
參考資料:有基礎(Pytorch/TensorFlow基礎)mxnet+gluon快速入門
symbol
symbol 是一個重要的概念,可以理解為符號,就像我們平時使用的代數符號 x,y,z 一樣。一個簡單的類比,一個函數 \(f(x) = x^{2}\),符號 x 就是 symbol,而具體 x 的值就是 ndarray,關于 symbol 的是 mxnet.sym,具體可參照官方API文檔
基本操作
- 使用 mxnet.sym.Variable() 傳入名稱可建立一個 symbol
- 使用 mxnet.viz.plot_network(symbol=) 傳入 symbol 可以繪制運算圖
帶入 ndarray
使用 mxnet.sym.bind() 方法可以獲得一個帶入操作數的對象,再使用 forward() 方法可運算出數值
x = c.bind(ctx=mx.cpu(),args={"a": mx.nd.ones(5),"b":mx.nd.ones(5)}) result = x.forward() print(result) [ [2. 2. 2. 2. 2.] <NDArray 5 @cpu(0)>]mxnet 的數據載入
深度學習中數據的載入方式非常重要,mxnet 提供了 mxnet.io 的一系列 dataiter 用于處理數據載入,詳細可參照官方API文檔。同時,動態圖接口gluon 也提供了 mxnet.gluon.data 系列的 dataiter 用于數據載入,詳細可參照官方API文檔
mxnet.io 數據載入
mxnet.io的數據載入核心是 mxnet.io.DataIter 類及其派生類,例如 ndarray 的 iter:NDArrayIter
- 參數 data:傳入一個(名稱-數據)的數據 dict
- 參數 label:傳入一個(名稱-標簽)的標簽 dict
- 參數 batch_size:傳入 batch 大小
gluon.data 數據載入
gluon 的數據 API 幾乎與 pytorch 相同,均是 Dataset+DataLoader 的方式:
- Dataset:存儲數據,使用時需要繼承該基類并重載 __len__(self) 和 __getitem__(self,idx) 方法
- DataLoader:將 Dataset 變成能產生 batch 的可迭代對象
網絡搭建
mxnet 網絡搭建
mxnet 網絡搭建類似于 TensorFlow,使用 symbol 搭建出網絡,再用一個 module 封裝
data = mx.sym.Variable('data') # layer1 conv1 = mx.sym.Convolution(data=data, kernel=(5,5), num_filter=32,name="conv1") relu1 = mx.sym.Activation(data=conv1,act_type="relu",name="relu1") pool1 = mx.sym.Pooling(data=relu1,pool_type="max",kernel=(2,2),stride=(2,2),name="pool1") # layer2 conv2 = mx.sym.Convolution(data=pool1, kernel=(3,3), num_filter=64,name="conv2") relu2 = mx.sym.Activation(data=conv2,act_type="relu",name="relu2") pool2 = mx.sym.Pooling(data=relu2,pool_type="max",kernel=(2,2),stride=(2,2),name="pool2") # layer3 fc1 = mx.symbol.FullyConnected(data=mx.sym.flatten(pool2), num_hidden=256,name="fc1") relu3 = mx.sym.Activation(data=fc1, act_type="relu",name="relu3") # layer4 fc2 = mx.symbol.FullyConnected(data=relu3, num_hidden=10,name="fc2") out = mx.sym.SoftmaxOutput(data=fc2, label=mx.sym.Variable("label"),name='softmax') mxnet_model = mx.mod.Module(symbol=out,label_names=["label"],context=mx.gpu()) mx.viz.plot_network(symbol=out)福利:剛剛發現一個解決路徑錯誤的方法:只需要將 *\Anaconda3\Library\bin\graphviz 添加到 Path 環境變量之下即可 (安裝后記得重啟,環境變量修改才可以生效,調用庫,即可成功)!
轉載于:https://www.cnblogs.com/q735613050/p/9315504.html
總結
以上是生活随笔為你收集整理的MXNet——symbol的全部內容,希望文章能夠幫你解決所遇到的問題。