<?xml version="1.0" encoding="utf-8" standalone="yes"?><rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom"><channel><title>Deep Learning | ZC's website</title><link>https://www.zcfun.top/tag/deep-learning/</link><atom:link href="https://www.zcfun.top/tag/deep-learning/index.xml" rel="self" type="application/rss+xml"/><description>Deep Learning</description><generator>Wowchemy (https://wowchemy.com)</generator><language>en-us</language><lastBuildDate>Wed, 27 Apr 2016 00:00:00 +0000</lastBuildDate><image><url>https://www.zcfun.top/media/icon_hu0b7a4cb9992c9ac0e91bd28ffd38dd00_9727_512x512_fill_lanczos_center_3.png</url><title>Deep Learning</title><link>https://www.zcfun.top/tag/deep-learning/</link></image><item><title>GNN图神经网络</title><link>https://www.zcfun.top/project/ai/python02/</link><pubDate>Wed, 27 Apr 2016 00:00:00 +0000</pubDate><guid>https://www.zcfun.top/project/ai/python02/</guid><description>&lt;p>算法使用经典的卷积神经网络，实现对动画片段分为三帧并进行分类。
算法IDE：Pycharm
语言：Python
工具：Pytorch(CPU)&lt;/p>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch.nn&lt;/span> &lt;span class="k">as&lt;/span> &lt;span class="nn">nn&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch.optim&lt;/span> &lt;span class="k">as&lt;/span> &lt;span class="nn">optim&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">from&lt;/span> &lt;span class="nn">torch.autograd&lt;/span> &lt;span class="kn">import&lt;/span> &lt;span class="n">Variable&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">from&lt;/span> &lt;span class="nn">torch.utils.data&lt;/span> &lt;span class="kn">import&lt;/span> &lt;span class="n">DataLoader&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">sampler&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">Dataset&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torchvision.datasets&lt;/span> &lt;span class="k">as&lt;/span> &lt;span class="nn">dset&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torchvision.transforms&lt;/span> &lt;span class="k">as&lt;/span> &lt;span class="nn">T&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">timeit&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">from&lt;/span> &lt;span class="nn">PIL&lt;/span> &lt;span class="kn">import&lt;/span> &lt;span class="n">Image&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">os&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">numpy&lt;/span> &lt;span class="k">as&lt;/span> &lt;span class="nn">np&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">scipy.io&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torchvision.models.inception&lt;/span> &lt;span class="k">as&lt;/span> &lt;span class="nn">inception&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 很多数据集都是mat格式的标注信息，使用模块scipy.io的函数loadmat和savemat可以实现Python对mat数据的读写。&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">label_mat&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">scipy&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">io&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">loadmat&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s1">&amp;#39;./../data//q3_2_data.mat&amp;#39;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">label_train&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">label_mat&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s1">&amp;#39;trLb&amp;#39;&lt;/span>&lt;span class="p">]&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s1">&amp;#39;train_length:&amp;#39;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="nb">len&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">label_train&lt;/span>&lt;span class="p">))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">label_value&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">label_mat&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s1">&amp;#39;valLb&amp;#39;&lt;/span>&lt;span class="p">]&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s1">&amp;#39;val len:&amp;#39;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="nb">len&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">label_value&lt;/span>&lt;span class="p">))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">class&lt;/span> &lt;span class="nc">ActionDataset&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">Dataset&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="s2">&amp;#34;&amp;#34;&amp;#34;Action dataset.&amp;#34;&amp;#34;&amp;#34;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__init__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">root_dir&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">labels&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="p">[],&lt;/span> &lt;span class="n">transform&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">None&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="s2">&amp;#34;&amp;#34;&amp;#34;
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="s2"> Args:
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="s2"> root_dir (String): 整个数据的路径
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="s2"> labels(list): 图片标签
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="s2"> transform(callable, optional): 想要对数据进行的处理函数
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="s2"> &amp;#34;&amp;#34;&amp;#34;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">root_dir&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">root_dir&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">transform&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">transform&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">length&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="nb">len&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">os&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">listdir&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">root_dir&lt;/span>&lt;span class="p">))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">labels&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">labels&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__len__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">):&lt;/span> &lt;span class="c1"># 该方法只需返回数据的数量&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">length&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="mi">3&lt;/span> &lt;span class="c1"># 每个视频片段都包含了3帧&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__getitem__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">idx&lt;/span>&lt;span class="p">):&lt;/span> &lt;span class="c1"># 该方法需要返回一个数据&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">folder&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">idx&lt;/span> &lt;span class="o">//&lt;/span> &lt;span class="mi">3&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="mi">1&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">imidx&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">idx&lt;/span> &lt;span class="o">%&lt;/span> &lt;span class="mi">3&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="mi">1&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">folder&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="nb">format&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">folder&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="s1">&amp;#39;05d&amp;#39;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">imgname&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="nb">str&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">imidx&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="s1">&amp;#39;.jpg&amp;#39;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">img_path&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">os&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">path&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">join&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">root_dir&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">folder&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">imgname&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">image&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Image&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">open&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">img_path&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">if&lt;/span> &lt;span class="nb">len&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">labels&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">!=&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">Label&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">labels&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="n">idx&lt;/span> &lt;span class="o">//&lt;/span> &lt;span class="mi">3&lt;/span>&lt;span class="p">][&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">]&lt;/span> &lt;span class="o">-&lt;/span> &lt;span class="mi">1&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">if&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">transform&lt;/span>&lt;span class="p">:&lt;/span> &lt;span class="c1"># 如果要先对数据进行预处理，则经过transform函数&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">image&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">transform&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">image&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">if&lt;/span> &lt;span class="nb">len&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">labels&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">!=&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">sample&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">{&lt;/span>&lt;span class="s1">&amp;#39;image&amp;#39;&lt;/span>&lt;span class="p">:&lt;/span> &lt;span class="n">image&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="s1">&amp;#39;img_path&amp;#39;&lt;/span>&lt;span class="p">:&lt;/span> &lt;span class="n">img_path&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="s1">&amp;#39;Label&amp;#39;&lt;/span>&lt;span class="p">:&lt;/span> &lt;span class="n">Label&lt;/span>&lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">else&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">sample&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">{&lt;/span>&lt;span class="s1">&amp;#39;image&amp;#39;&lt;/span>&lt;span class="p">:&lt;/span> &lt;span class="n">image&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="s1">&amp;#39;img_path&amp;#39;&lt;/span>&lt;span class="p">:&lt;/span> &lt;span class="n">img_path&lt;/span>&lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">sample&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">image_dataset&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">ActionDataset&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">root_dir&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s1">&amp;#39;./../data//trainClips/&amp;#39;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">labels&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">label_train&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">transform&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">T&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">ToTensor&lt;/span>&lt;span class="p">())&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># torchvision.transforms中定义了非常多对图像的预处理方法，这里使用的ToTensor方法为将0～255的RGB值映射到0～1的Tensor类型。&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">for&lt;/span> &lt;span class="n">i&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="nb">range&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">sample&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">image_dataset&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="n">i&lt;/span>&lt;span class="p">]&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sample&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s1">&amp;#39;image&amp;#39;&lt;/span>&lt;span class="p">]&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">shape&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sample&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s1">&amp;#39;Label&amp;#39;&lt;/span>&lt;span class="p">])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sample&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s1">&amp;#39;img_path&amp;#39;&lt;/span>&lt;span class="p">])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">image_dataloader&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">DataLoader&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">image_dataset&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">batch_size&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">4&lt;/span>&lt;span class="p">,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">shuffle&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_workers&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">for&lt;/span> &lt;span class="n">i&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">sample&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="nb">enumerate&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">image_dataloader&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">sample&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s1">&amp;#39;image&amp;#39;&lt;/span>&lt;span class="p">]&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">sample&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s1">&amp;#39;image&amp;#39;&lt;/span>&lt;span class="p">]&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">i&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">sample&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s1">&amp;#39;image&amp;#39;&lt;/span>&lt;span class="p">]&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">shape&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">sample&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s1">&amp;#39;img_path&amp;#39;&lt;/span>&lt;span class="p">],&lt;/span> &lt;span class="n">sample&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s1">&amp;#39;Label&amp;#39;&lt;/span>&lt;span class="p">])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">if&lt;/span> &lt;span class="n">i&lt;/span> &lt;span class="o">&amp;gt;&lt;/span> &lt;span class="mi">5&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">break&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">image_dataset_train&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">ActionDataset&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">root_dir&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s1">&amp;#39;./../data//trainClips//&amp;#39;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">labels&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">label_train&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">transform&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">T&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">ToTensor&lt;/span>&lt;span class="p">())&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">image_dataloader_train&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">DataLoader&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">image_dataset_train&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">batch_size&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">32&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">shuffle&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_workers&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">image_dataset_val&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">ActionDataset&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">root_dir&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s1">&amp;#39;./..//data//valClips//&amp;#39;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">labels&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">label_value&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">transform&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">T&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">ToTensor&lt;/span>&lt;span class="p">())&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">image_dataloader_val&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">DataLoader&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">image_dataset_val&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">batch_size&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">32&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">shuffle&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">False&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_workers&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">image_dataset_test&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">ActionDataset&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">root_dir&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s1">&amp;#39;./..//data//testClips//&amp;#39;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">labels&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="p">[],&lt;/span> &lt;span class="n">transform&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">T&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">ToTensor&lt;/span>&lt;span class="p">())&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">image_dataloader_test&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">DataLoader&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">image_dataset_test&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">batch_size&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">32&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">shuffle&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">False&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_workers&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">dtype&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">FloatTensor&lt;/span> &lt;span class="c1"># 这是pytorch所支持的cpu数据类型中的浮点数类型。&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">print_every&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">100&lt;/span> &lt;span class="c1"># 这个参数用于控制loss的打印频率，因为我们需要在训练过程中不断的对loss进行检测。&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">def&lt;/span> &lt;span class="nf">reset&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">m&lt;/span>&lt;span class="p">):&lt;/span> &lt;span class="c1"># 这是模型参数的初始化&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">if&lt;/span> &lt;span class="nb">hasattr&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">m&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="s1">&amp;#39;reset_parameters&amp;#39;&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">m&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">reset_parameters&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">class&lt;/span> &lt;span class="nc">Flatten&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Module&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="nf">forward&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">N&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">C&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">H&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">W&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">size&lt;/span>&lt;span class="p">()&lt;/span> &lt;span class="c1"># 读取各个维度&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">view&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">N&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="o">-&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># -1代表除了特殊声明过以外的全部维度&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># torch.nn.Sequential是一个Sequential容器，模块将按照构造函数中传递的顺序添加到模块中。&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">fixed_model_base&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Sequential&lt;/span>&lt;span class="p">(&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Conv2d&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">8&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">kernel_size&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">7&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">stride&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="c1"># 3*64*64 -&amp;gt; 8*58*58&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">ReLU&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">inplace&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">MaxPool2d&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">stride&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="c1"># 8*58*58 -&amp;gt; 8*29*29&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Conv2d&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">8&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">16&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">kernel_size&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">7&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">stride&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="c1"># 8*29*29 -&amp;gt; 16*23*23&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">ReLU&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">inplace&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">MaxPool2d&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">stride&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="c1"># 16*23*23 -&amp;gt; 16*11*11&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">Flatten&lt;/span>&lt;span class="p">(),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">ReLU&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">inplace&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">1936&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">10&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 1936 = 16*11*11&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 这里模型base.type()方法是设定模型使用的数据类型，之前设定的cpu的Float类型。&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 如果想要在GPU上训练则需要设定cuda版本的Float类型。&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">fixed_model&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">fixed_model_base&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">type&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">dtype&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">randn&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">32&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">64&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">64&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">type&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">dtype&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">x_var&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Variable&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">type&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">dtype&lt;/span>&lt;span class="p">))&lt;/span> &lt;span class="c1"># 需要将其封装为Variable类型。&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">ans&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">fixed_model&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x_var&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">np&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">array&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">ans&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">size&lt;/span>&lt;span class="p">()))&lt;/span> &lt;span class="c1"># 检查模型输出&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">np&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">array_equal&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">np&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">array&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">ans&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">size&lt;/span>&lt;span class="p">()),&lt;/span> &lt;span class="n">np&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">array&lt;/span>&lt;span class="p">([&lt;/span>&lt;span class="mi">32&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">10&lt;/span>&lt;span class="p">]))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">def&lt;/span> &lt;span class="nf">train&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">loss_fn&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">optimizer&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">dataloader&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_epochs&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">for&lt;/span> &lt;span class="n">epoch&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="nb">range&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">num_epochs&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s1">&amp;#39;Starting epoch &lt;/span>&lt;span class="si">%d&lt;/span>&lt;span class="s1"> / &lt;/span>&lt;span class="si">%d&lt;/span>&lt;span class="s1">&amp;#39;&lt;/span> &lt;span class="o">%&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">epoch&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_epochs&lt;/span>&lt;span class="p">))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 在验证集上验证模型效果&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">check_accuracy&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">fixed_model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">image_dataloader_val&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">model&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">train&lt;/span>&lt;span class="p">()&lt;/span> &lt;span class="c1"># 模型的.train()方法让模型进入训练状态，参数保留梯度，dropout层等部分正常工作&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">for&lt;/span> &lt;span class="n">t&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">sample&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="nb">enumerate&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">dataloader&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x_var&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Variable&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sample&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s1">&amp;#39;image&amp;#39;&lt;/span>&lt;span class="p">])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">y_var&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Variable&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sample&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s1">&amp;#39;Label&amp;#39;&lt;/span>&lt;span class="p">]&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">long&lt;/span>&lt;span class="p">())&lt;/span> &lt;span class="c1"># 取得对应的标签&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">scores&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">model&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x_var&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 得到输出&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">loss&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">loss_fn&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">scores&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">y_var&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 计算loss&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">if&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">t&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">%&lt;/span> &lt;span class="n">print_every&lt;/span> &lt;span class="o">==&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s1">&amp;#39;t = &lt;/span>&lt;span class="si">%d&lt;/span>&lt;span class="s1">, loss = &lt;/span>&lt;span class="si">%.4f&lt;/span>&lt;span class="s1">&amp;#39;&lt;/span> &lt;span class="o">%&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">t&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">loss&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">item&lt;/span>&lt;span class="p">()))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 三步更新参数&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">optimizer&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">zero_grad&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">loss&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">backward&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">optimizer&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">step&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">def&lt;/span> &lt;span class="nf">check_accuracy&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">loader&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">num_correct&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">0&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">num_samples&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">0&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">model&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">eval&lt;/span>&lt;span class="p">()&lt;/span> &lt;span class="c1"># 模型的.eval()方法切换进入评测模式，对应的dropout等部分将停止工作&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">for&lt;/span> &lt;span class="n">t&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">sample&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="nb">enumerate&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">loader&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x_var&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Variable&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sample&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s1">&amp;#39;image&amp;#39;&lt;/span>&lt;span class="p">])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">y_var&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">sample&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s1">&amp;#39;Label&amp;#39;&lt;/span>&lt;span class="p">]&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">scores&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">model&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x_var&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">preds&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">scores&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">data&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">max&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 找到可能最高的标签作为输出&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">num_correct&lt;/span> &lt;span class="o">+=&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">preds&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">numpy&lt;/span>&lt;span class="p">()&lt;/span> &lt;span class="o">==&lt;/span> &lt;span class="n">y_var&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">numpy&lt;/span>&lt;span class="p">())&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">sum&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">num_samples&lt;/span> &lt;span class="o">+=&lt;/span> &lt;span class="n">preds&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">size&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">acc&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="nb">float&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">num_correct&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">/&lt;/span> &lt;span class="n">num_samples&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s1">&amp;#39;Got &lt;/span>&lt;span class="si">%d&lt;/span>&lt;span class="s1"> / &lt;/span>&lt;span class="si">%d&lt;/span>&lt;span class="s1"> correct (&lt;/span>&lt;span class="si">%.2f&lt;/span>&lt;span class="s1">)&amp;#39;&lt;/span> &lt;span class="o">%&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">num_correct&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_samples&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">100&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="n">acc&lt;/span>&lt;span class="p">))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">optimizer&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">optim&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">RMSprop&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">fixed_model_base&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">parameters&lt;/span>&lt;span class="p">(),&lt;/span> &lt;span class="n">lr&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mf">0.0001&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">loss_fn&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">CrossEntropyLoss&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">random&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">manual_seed&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">54321&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">fixed_model&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">cpu&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">fixed_model&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">apply&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">reset&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">fixed_model&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">train&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">train&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">fixed_model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">loss_fn&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">optimizer&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">image_dataloader_train&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_epochs&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">5&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">check_accuracy&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">fixed_model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">image_dataloader_val&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">def&lt;/span> &lt;span class="nf">predict_on_test&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">loader&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">num_correct&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">0&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">num_samples&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">0&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">model&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">eval&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">results&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="nb">open&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s1">&amp;#39;results.csv&amp;#39;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="s1">&amp;#39;w&amp;#39;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">count&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">0&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">results&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">write&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s1">&amp;#39;Id&amp;#39;&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="s1">&amp;#39;,&amp;#39;&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="s1">&amp;#39;Class&amp;#39;&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="s1">&amp;#39;&lt;/span>&lt;span class="se">\n&lt;/span>&lt;span class="s1">&amp;#39;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">for&lt;/span> &lt;span class="n">t&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">sample&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="nb">enumerate&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">loader&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x_var&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Variable&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sample&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s1">&amp;#39;image&amp;#39;&lt;/span>&lt;span class="p">])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">scores&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">model&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x_var&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">preds&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">scores&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">data&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">max&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">for&lt;/span> &lt;span class="n">i&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="nb">range&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="nb">len&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">preds&lt;/span>&lt;span class="p">)):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">results&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">write&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="nb">str&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">count&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="s1">&amp;#39;,&amp;#39;&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="nb">str&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">preds&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="n">i&lt;/span>&lt;span class="p">])&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="s1">&amp;#39;&lt;/span>&lt;span class="se">\n&lt;/span>&lt;span class="s1">&amp;#39;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">count&lt;/span> &lt;span class="o">+=&lt;/span> &lt;span class="mi">1&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">results&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">close&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">count&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">count&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">predict_on_test&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">fixed_model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">image_dataloader_test&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 放入你想要测试的训练集，然后打开文件去看一看结果吧。&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">count&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div></description></item><item><title>pytorch-经典CNN-视频分帧动画分类</title><link>https://www.zcfun.top/project/ai/python01/</link><pubDate>Wed, 27 Apr 2016 00:00:00 +0000</pubDate><guid>https://www.zcfun.top/project/ai/python01/</guid><description>&lt;p>算法使用经典的卷积神经网络，实现对动画片段分为三帧并进行分类。
算法IDE：Pycharm
语言：Python
工具：Pytorch(CPU)&lt;/p>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch.nn&lt;/span> &lt;span class="k">as&lt;/span> &lt;span class="nn">nn&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch.optim&lt;/span> &lt;span class="k">as&lt;/span> &lt;span class="nn">optim&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">from&lt;/span> &lt;span class="nn">torch.autograd&lt;/span> &lt;span class="kn">import&lt;/span> &lt;span class="n">Variable&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">from&lt;/span> &lt;span class="nn">torch.utils.data&lt;/span> &lt;span class="kn">import&lt;/span> &lt;span class="n">DataLoader&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">sampler&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">Dataset&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torchvision.datasets&lt;/span> &lt;span class="k">as&lt;/span> &lt;span class="nn">dset&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torchvision.transforms&lt;/span> &lt;span class="k">as&lt;/span> &lt;span class="nn">T&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">timeit&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">from&lt;/span> &lt;span class="nn">PIL&lt;/span> &lt;span class="kn">import&lt;/span> &lt;span class="n">Image&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">os&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">numpy&lt;/span> &lt;span class="k">as&lt;/span> &lt;span class="nn">np&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">scipy.io&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torchvision.models.inception&lt;/span> &lt;span class="k">as&lt;/span> &lt;span class="nn">inception&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 很多数据集都是mat格式的标注信息，使用模块scipy.io的函数loadmat和savemat可以实现Python对mat数据的读写。&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">label_mat&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">scipy&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">io&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">loadmat&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s1">&amp;#39;./../data//q3_2_data.mat&amp;#39;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">label_train&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">label_mat&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s1">&amp;#39;trLb&amp;#39;&lt;/span>&lt;span class="p">]&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s1">&amp;#39;train_length:&amp;#39;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="nb">len&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">label_train&lt;/span>&lt;span class="p">))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">label_value&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">label_mat&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s1">&amp;#39;valLb&amp;#39;&lt;/span>&lt;span class="p">]&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s1">&amp;#39;val len:&amp;#39;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="nb">len&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">label_value&lt;/span>&lt;span class="p">))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">class&lt;/span> &lt;span class="nc">ActionDataset&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">Dataset&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="s2">&amp;#34;&amp;#34;&amp;#34;Action dataset.&amp;#34;&amp;#34;&amp;#34;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__init__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">root_dir&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">labels&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="p">[],&lt;/span> &lt;span class="n">transform&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">None&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="s2">&amp;#34;&amp;#34;&amp;#34;
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="s2"> Args:
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="s2"> root_dir (String): 整个数据的路径
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="s2"> labels(list): 图片标签
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="s2"> transform(callable, optional): 想要对数据进行的处理函数
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="s2"> &amp;#34;&amp;#34;&amp;#34;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">root_dir&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">root_dir&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">transform&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">transform&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">length&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="nb">len&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">os&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">listdir&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">root_dir&lt;/span>&lt;span class="p">))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">labels&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">labels&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__len__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">):&lt;/span> &lt;span class="c1"># 该方法只需返回数据的数量&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">length&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="mi">3&lt;/span> &lt;span class="c1"># 每个视频片段都包含了3帧&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__getitem__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">idx&lt;/span>&lt;span class="p">):&lt;/span> &lt;span class="c1"># 该方法需要返回一个数据&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">folder&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">idx&lt;/span> &lt;span class="o">//&lt;/span> &lt;span class="mi">3&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="mi">1&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">imidx&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">idx&lt;/span> &lt;span class="o">%&lt;/span> &lt;span class="mi">3&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="mi">1&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">folder&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="nb">format&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">folder&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="s1">&amp;#39;05d&amp;#39;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">imgname&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="nb">str&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">imidx&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="s1">&amp;#39;.jpg&amp;#39;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">img_path&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">os&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">path&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">join&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">root_dir&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">folder&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">imgname&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">image&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Image&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">open&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">img_path&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">if&lt;/span> &lt;span class="nb">len&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">labels&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">!=&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">Label&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">labels&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="n">idx&lt;/span> &lt;span class="o">//&lt;/span> &lt;span class="mi">3&lt;/span>&lt;span class="p">][&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">]&lt;/span> &lt;span class="o">-&lt;/span> &lt;span class="mi">1&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">if&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">transform&lt;/span>&lt;span class="p">:&lt;/span> &lt;span class="c1"># 如果要先对数据进行预处理，则经过transform函数&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">image&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">transform&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">image&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">if&lt;/span> &lt;span class="nb">len&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">labels&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">!=&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">sample&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">{&lt;/span>&lt;span class="s1">&amp;#39;image&amp;#39;&lt;/span>&lt;span class="p">:&lt;/span> &lt;span class="n">image&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="s1">&amp;#39;img_path&amp;#39;&lt;/span>&lt;span class="p">:&lt;/span> &lt;span class="n">img_path&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="s1">&amp;#39;Label&amp;#39;&lt;/span>&lt;span class="p">:&lt;/span> &lt;span class="n">Label&lt;/span>&lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">else&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">sample&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">{&lt;/span>&lt;span class="s1">&amp;#39;image&amp;#39;&lt;/span>&lt;span class="p">:&lt;/span> &lt;span class="n">image&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="s1">&amp;#39;img_path&amp;#39;&lt;/span>&lt;span class="p">:&lt;/span> &lt;span class="n">img_path&lt;/span>&lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">sample&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">image_dataset&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">ActionDataset&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">root_dir&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s1">&amp;#39;./../data//trainClips/&amp;#39;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">labels&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">label_train&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">transform&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">T&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">ToTensor&lt;/span>&lt;span class="p">())&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># torchvision.transforms中定义了非常多对图像的预处理方法，这里使用的ToTensor方法为将0～255的RGB值映射到0～1的Tensor类型。&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">for&lt;/span> &lt;span class="n">i&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="nb">range&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">sample&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">image_dataset&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="n">i&lt;/span>&lt;span class="p">]&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sample&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s1">&amp;#39;image&amp;#39;&lt;/span>&lt;span class="p">]&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">shape&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sample&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s1">&amp;#39;Label&amp;#39;&lt;/span>&lt;span class="p">])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sample&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s1">&amp;#39;img_path&amp;#39;&lt;/span>&lt;span class="p">])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">image_dataloader&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">DataLoader&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">image_dataset&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">batch_size&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">4&lt;/span>&lt;span class="p">,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">shuffle&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_workers&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">for&lt;/span> &lt;span class="n">i&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">sample&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="nb">enumerate&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">image_dataloader&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">sample&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s1">&amp;#39;image&amp;#39;&lt;/span>&lt;span class="p">]&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">sample&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s1">&amp;#39;image&amp;#39;&lt;/span>&lt;span class="p">]&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">i&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">sample&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s1">&amp;#39;image&amp;#39;&lt;/span>&lt;span class="p">]&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">shape&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">sample&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s1">&amp;#39;img_path&amp;#39;&lt;/span>&lt;span class="p">],&lt;/span> &lt;span class="n">sample&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s1">&amp;#39;Label&amp;#39;&lt;/span>&lt;span class="p">])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">if&lt;/span> &lt;span class="n">i&lt;/span> &lt;span class="o">&amp;gt;&lt;/span> &lt;span class="mi">5&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">break&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">image_dataset_train&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">ActionDataset&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">root_dir&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s1">&amp;#39;./../data//trainClips//&amp;#39;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">labels&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">label_train&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">transform&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">T&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">ToTensor&lt;/span>&lt;span class="p">())&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">image_dataloader_train&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">DataLoader&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">image_dataset_train&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">batch_size&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">32&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">shuffle&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_workers&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">image_dataset_val&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">ActionDataset&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">root_dir&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s1">&amp;#39;./..//data//valClips//&amp;#39;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">labels&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">label_value&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">transform&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">T&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">ToTensor&lt;/span>&lt;span class="p">())&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">image_dataloader_val&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">DataLoader&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">image_dataset_val&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">batch_size&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">32&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">shuffle&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">False&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_workers&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">image_dataset_test&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">ActionDataset&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">root_dir&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s1">&amp;#39;./..//data//testClips//&amp;#39;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">labels&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="p">[],&lt;/span> &lt;span class="n">transform&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">T&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">ToTensor&lt;/span>&lt;span class="p">())&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">image_dataloader_test&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">DataLoader&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">image_dataset_test&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">batch_size&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">32&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">shuffle&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">False&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_workers&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">dtype&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">FloatTensor&lt;/span> &lt;span class="c1"># 这是pytorch所支持的cpu数据类型中的浮点数类型。&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">print_every&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">100&lt;/span> &lt;span class="c1"># 这个参数用于控制loss的打印频率，因为我们需要在训练过程中不断的对loss进行检测。&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">def&lt;/span> &lt;span class="nf">reset&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">m&lt;/span>&lt;span class="p">):&lt;/span> &lt;span class="c1"># 这是模型参数的初始化&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">if&lt;/span> &lt;span class="nb">hasattr&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">m&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="s1">&amp;#39;reset_parameters&amp;#39;&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">m&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">reset_parameters&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">class&lt;/span> &lt;span class="nc">Flatten&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Module&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="nf">forward&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">N&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">C&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">H&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">W&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">size&lt;/span>&lt;span class="p">()&lt;/span> &lt;span class="c1"># 读取各个维度&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">view&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">N&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="o">-&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># -1代表除了特殊声明过以外的全部维度&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># torch.nn.Sequential是一个Sequential容器，模块将按照构造函数中传递的顺序添加到模块中。&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">fixed_model_base&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Sequential&lt;/span>&lt;span class="p">(&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Conv2d&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">8&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">kernel_size&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">7&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">stride&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="c1"># 3*64*64 -&amp;gt; 8*58*58&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">ReLU&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">inplace&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">MaxPool2d&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">stride&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="c1"># 8*58*58 -&amp;gt; 8*29*29&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Conv2d&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">8&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">16&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">kernel_size&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">7&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">stride&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="c1"># 8*29*29 -&amp;gt; 16*23*23&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">ReLU&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">inplace&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">MaxPool2d&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">stride&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="c1"># 16*23*23 -&amp;gt; 16*11*11&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">Flatten&lt;/span>&lt;span class="p">(),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">ReLU&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">inplace&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">1936&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">10&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 1936 = 16*11*11&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 这里模型base.type()方法是设定模型使用的数据类型，之前设定的cpu的Float类型。&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># 如果想要在GPU上训练则需要设定cuda版本的Float类型。&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">fixed_model&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">fixed_model_base&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">type&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">dtype&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">randn&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">32&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">64&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">64&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">type&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">dtype&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">x_var&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Variable&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">type&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">dtype&lt;/span>&lt;span class="p">))&lt;/span> &lt;span class="c1"># 需要将其封装为Variable类型。&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">ans&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">fixed_model&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x_var&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">np&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">array&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">ans&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">size&lt;/span>&lt;span class="p">()))&lt;/span> &lt;span class="c1"># 检查模型输出&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">np&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">array_equal&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">np&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">array&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">ans&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">size&lt;/span>&lt;span class="p">()),&lt;/span> &lt;span class="n">np&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">array&lt;/span>&lt;span class="p">([&lt;/span>&lt;span class="mi">32&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">10&lt;/span>&lt;span class="p">]))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">def&lt;/span> &lt;span class="nf">train&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">loss_fn&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">optimizer&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">dataloader&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_epochs&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">for&lt;/span> &lt;span class="n">epoch&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="nb">range&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">num_epochs&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s1">&amp;#39;Starting epoch &lt;/span>&lt;span class="si">%d&lt;/span>&lt;span class="s1"> / &lt;/span>&lt;span class="si">%d&lt;/span>&lt;span class="s1">&amp;#39;&lt;/span> &lt;span class="o">%&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">epoch&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_epochs&lt;/span>&lt;span class="p">))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 在验证集上验证模型效果&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">check_accuracy&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">fixed_model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">image_dataloader_val&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">model&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">train&lt;/span>&lt;span class="p">()&lt;/span> &lt;span class="c1"># 模型的.train()方法让模型进入训练状态，参数保留梯度，dropout层等部分正常工作&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">for&lt;/span> &lt;span class="n">t&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">sample&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="nb">enumerate&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">dataloader&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x_var&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Variable&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sample&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s1">&amp;#39;image&amp;#39;&lt;/span>&lt;span class="p">])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">y_var&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Variable&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sample&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s1">&amp;#39;Label&amp;#39;&lt;/span>&lt;span class="p">]&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">long&lt;/span>&lt;span class="p">())&lt;/span> &lt;span class="c1"># 取得对应的标签&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">scores&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">model&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x_var&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 得到输出&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">loss&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">loss_fn&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">scores&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">y_var&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 计算loss&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">if&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">t&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">%&lt;/span> &lt;span class="n">print_every&lt;/span> &lt;span class="o">==&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s1">&amp;#39;t = &lt;/span>&lt;span class="si">%d&lt;/span>&lt;span class="s1">, loss = &lt;/span>&lt;span class="si">%.4f&lt;/span>&lt;span class="s1">&amp;#39;&lt;/span> &lt;span class="o">%&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">t&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">loss&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">item&lt;/span>&lt;span class="p">()))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 三步更新参数&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">optimizer&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">zero_grad&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">loss&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">backward&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">optimizer&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">step&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">def&lt;/span> &lt;span class="nf">check_accuracy&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">loader&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">num_correct&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">0&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">num_samples&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">0&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">model&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">eval&lt;/span>&lt;span class="p">()&lt;/span> &lt;span class="c1"># 模型的.eval()方法切换进入评测模式，对应的dropout等部分将停止工作&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">for&lt;/span> &lt;span class="n">t&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">sample&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="nb">enumerate&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">loader&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x_var&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Variable&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sample&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s1">&amp;#39;image&amp;#39;&lt;/span>&lt;span class="p">])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">y_var&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">sample&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s1">&amp;#39;Label&amp;#39;&lt;/span>&lt;span class="p">]&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">scores&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">model&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x_var&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">preds&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">scores&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">data&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">max&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 找到可能最高的标签作为输出&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">num_correct&lt;/span> &lt;span class="o">+=&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">preds&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">numpy&lt;/span>&lt;span class="p">()&lt;/span> &lt;span class="o">==&lt;/span> &lt;span class="n">y_var&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">numpy&lt;/span>&lt;span class="p">())&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">sum&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">num_samples&lt;/span> &lt;span class="o">+=&lt;/span> &lt;span class="n">preds&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">size&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">acc&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="nb">float&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">num_correct&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">/&lt;/span> &lt;span class="n">num_samples&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s1">&amp;#39;Got &lt;/span>&lt;span class="si">%d&lt;/span>&lt;span class="s1"> / &lt;/span>&lt;span class="si">%d&lt;/span>&lt;span class="s1"> correct (&lt;/span>&lt;span class="si">%.2f&lt;/span>&lt;span class="s1">)&amp;#39;&lt;/span> &lt;span class="o">%&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">num_correct&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_samples&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">100&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="n">acc&lt;/span>&lt;span class="p">))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">optimizer&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">optim&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">RMSprop&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">fixed_model_base&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">parameters&lt;/span>&lt;span class="p">(),&lt;/span> &lt;span class="n">lr&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mf">0.0001&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">loss_fn&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">CrossEntropyLoss&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">random&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">manual_seed&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">54321&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">fixed_model&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">cpu&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">fixed_model&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">apply&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">reset&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">fixed_model&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">train&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">train&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">fixed_model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">loss_fn&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">optimizer&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">image_dataloader_train&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_epochs&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">5&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">check_accuracy&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">fixed_model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">image_dataloader_val&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">def&lt;/span> &lt;span class="nf">predict_on_test&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">loader&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">num_correct&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">0&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">num_samples&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">0&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">model&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">eval&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">results&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="nb">open&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s1">&amp;#39;results.csv&amp;#39;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="s1">&amp;#39;w&amp;#39;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">count&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">0&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">results&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">write&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s1">&amp;#39;Id&amp;#39;&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="s1">&amp;#39;,&amp;#39;&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="s1">&amp;#39;Class&amp;#39;&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="s1">&amp;#39;&lt;/span>&lt;span class="se">\n&lt;/span>&lt;span class="s1">&amp;#39;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">for&lt;/span> &lt;span class="n">t&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">sample&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="nb">enumerate&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">loader&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x_var&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Variable&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">sample&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s1">&amp;#39;image&amp;#39;&lt;/span>&lt;span class="p">])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">scores&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">model&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x_var&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">preds&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">scores&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">data&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">max&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">for&lt;/span> &lt;span class="n">i&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="nb">range&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="nb">len&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">preds&lt;/span>&lt;span class="p">)):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">results&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">write&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="nb">str&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">count&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="s1">&amp;#39;,&amp;#39;&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="nb">str&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">preds&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="n">i&lt;/span>&lt;span class="p">])&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="s1">&amp;#39;&lt;/span>&lt;span class="se">\n&lt;/span>&lt;span class="s1">&amp;#39;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">count&lt;/span> &lt;span class="o">+=&lt;/span> &lt;span class="mi">1&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">results&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">close&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">count&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">count&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">predict_on_test&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">fixed_model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">image_dataloader_test&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># 放入你想要测试的训练集，然后打开文件去看一看结果吧。&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">count&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div></description></item></channel></rss>